quack-kernels 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/topk.py ADDED
@@ -0,0 +1,221 @@
1
+ # Copyright (c) 2025, Wentao Guo, Mayank Mishra, Tri Dao.
2
+
3
+ import math
4
+ import torch
5
+ from typing import Type
6
+
7
+ import cuda.bindings.driver as cuda
8
+
9
+ import cutlass
10
+ import cutlass.cute as cute
11
+ from cutlass.cute.runtime import from_dlpack
12
+ from cutlass import const_expr
13
+
14
+ import quack.utils as utils
15
+ from quack.reduction_base import torch2cute_dtype_map
16
+ from quack.sort.bitonic_sort import bitonic_topk
17
+
18
+
19
+ class TopK:
20
+ def __init__(self, dtype: Type[cutlass.Numeric], N: int, k: int):
21
+ self.dtype = dtype
22
+ self.N = N
23
+ self.vecsize = 128 // dtype.width
24
+ self.k = k
25
+ assert N == 2 ** int(math.log2(N)), "N must be a power of 2"
26
+ assert k == 2 ** int(math.log2(k)), "N must be a power of 2"
27
+ assert k <= 128
28
+ assert N <= 4096
29
+
30
+ def _calculate_threads_per_row(self):
31
+ # we want num_elems_per_thread >= self.k
32
+ # and each thread can handle at most 64 elements
33
+ N = self.N
34
+ num_threads_per_row = max(min(N // self.k, 32, N // 64), 1)
35
+ return num_threads_per_row
36
+
37
+ def _get_tv_layout(self):
38
+ N = self.N
39
+ vecsize = self.vecsize
40
+ num_threads = 128 if N <= 16384 else 256
41
+ threads_per_row = self._calculate_threads_per_row()
42
+ cols_per_block = num_threads // threads_per_row
43
+ num_blocks_N = cute.ceil_div(min(N, 16384) // vecsize, threads_per_row)
44
+ tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
45
+ tv_layout = cute.make_layout(
46
+ ((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
47
+ stride=(
48
+ (vecsize * cols_per_block, 1),
49
+ (cols_per_block, cols_per_block * vecsize * threads_per_row),
50
+ ),
51
+ )
52
+ return tiler_mn, tv_layout
53
+
54
+ @cute.jit
55
+ def __call__(
56
+ self,
57
+ mX: cute.Tensor,
58
+ mValues: cute.Tensor,
59
+ mIndices: cute.Tensor,
60
+ stream: cuda.CUstream,
61
+ ):
62
+ assert mX.element_type == self.dtype
63
+ assert mValues.element_type == self.dtype
64
+ assert mIndices.element_type == cutlass.Int32
65
+ tiler_mn, tv_layout = self._get_tv_layout()
66
+ num_threads = cute.size(tv_layout, mode=[0])
67
+ self.kernel(mX, mValues, mIndices, tv_layout, tiler_mn).launch(
68
+ grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), 1, 1],
69
+ block=[num_threads, 1, 1],
70
+ stream=stream,
71
+ )
72
+
73
+ @cute.kernel
74
+ def kernel(
75
+ self,
76
+ mX: cute.Tensor,
77
+ mValues: cute.Tensor,
78
+ mIndices: cute.Tensor,
79
+ tv_layout: cute.Layout,
80
+ tiler_mn: cute.Shape,
81
+ ):
82
+ tidx, _, _ = cute.arch.thread_idx()
83
+ bidx, _, _ = cute.arch.block_idx()
84
+
85
+ shape = mX.shape
86
+ idX = cute.make_identity_tensor(shape)
87
+ # slice for CTAs
88
+ # We use domain_offset_i64 to deal with tensors larger than 2^31 elements
89
+ mX = utils.domain_offset_i64((bidx * tiler_mn[0], 0), mX)
90
+ gX = cute.local_tile(mX, tiler_mn, (0, 0))
91
+ cX = cute.local_tile(idX, tiler_mn, (bidx, 0))
92
+
93
+ # declare the atoms which will be used later for memory copy
94
+ copy_atom_load_X = cute.make_copy_atom(
95
+ cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128
96
+ )
97
+ thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
98
+ tXgX = thr_copy_X.partition_S(gX)
99
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
100
+
101
+ # allocate fragments for gmem->rmem
102
+ tXrX = cute.make_fragment_like(tXgX)
103
+
104
+ is_even_N = const_expr(shape[1] == tiler_mn[1])
105
+ tXpX = (
106
+ utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
107
+ )
108
+ if tXcX[0][0] < shape[0]:
109
+ cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
110
+ tXrX_f32 = cute.make_fragment(tXrX.shape, cutlass.Float32)
111
+ tXrX_f32.store(tXrX.load().to(cutlass.Float32))
112
+
113
+ # Encode the indices into the bottom bits of values.
114
+ log_N = int(math.log2(self.N))
115
+ idx_mask = (1 << log_N) - 1
116
+ vecsize = cutlass.const_expr(tv_layout.shape[1][0])
117
+ tXrX_u32 = cute.recast_tensor(tXrX_f32, cutlass.Uint32)
118
+ # Encode indices into the last log_N bits of tXrX_u32
119
+ for i in cutlass.range(cute.size(tXrX_u32), unroll_full=True):
120
+ # tXcX only keeps track of the indices for every @vecsize elements
121
+ col_idx = cutlass.Uint32(tXcX[i // vecsize][1] + i % vecsize)
122
+ # If positive, invert the bits of the index, so that if there's a tie,
123
+ # indices coming from a earlier column will win.
124
+ encoded_idx = ~col_idx if tXrX_f32[i] >= 0 else col_idx
125
+ # Mask to keep only the last log_N bits of the encoded index
126
+ encoded_idx = encoded_idx & idx_mask
127
+ # Clear the last log_N bits and set them to our encoded index
128
+ tXrX_u32[i] = (tXrX_u32[i] & ~idx_mask) | encoded_idx
129
+
130
+ # Fill OOB values with -inf for top-k
131
+ if const_expr(not is_even_N):
132
+ utils.fill_oob(tXrX_f32, tXpX, -tXrX_f32.element_type.inf)
133
+
134
+ threads_per_row = tv_layout.shape[0][0]
135
+ topk_vals = bitonic_topk(tXrX_f32, self.k, warp_width=threads_per_row)
136
+ # Extract indices and clean values
137
+ topk_vals_u32 = cute.recast_tensor(topk_vals, cutlass.Uint32)
138
+ topk_indices = cute.make_fragment(self.k, cutlass.Int32)
139
+ for i in cutlass.range(self.k):
140
+ # Extract the encoded index from the last log_N bits
141
+ encoded_idx = topk_vals_u32[i] & idx_mask
142
+ # Check if original value was positive by looking at the cleaned value
143
+ topk_vals_u32[i] = topk_vals_u32[i] & ~idx_mask # Clear last log_N bits
144
+ # If positive, we need to invert the bits back to get original index
145
+ col_idx = ~encoded_idx if topk_vals[i] >= 0 else encoded_idx
146
+ topk_indices[i] = cutlass.Int32(col_idx & idx_mask)
147
+
148
+ # Convert cleaned values to output type
149
+ topk_vals_out = cute.make_fragment_like(topk_vals, mValues.element_type)
150
+ topk_vals_out.store(topk_vals.load().to(mValues.element_type))
151
+
152
+ row = tXcX[0][0]
153
+ # Only the 1st thread in this row writes the top-k values and indices
154
+ if row < shape[0] and tXcX[0][1] == 0:
155
+ # for i in cutlass.range(self.k):
156
+ # mValues[row, i] = topk_vals_out[i]
157
+ # mIndices[row, i] = topk_indices[i]
158
+ # Vectorized write
159
+ elems_per_store = const_expr(math.gcd(vecsize, self.k))
160
+ mValues_store = cute.tiled_divide(mValues[row, None], (elems_per_store,))
161
+ mIndices_store = cute.tiled_divide(mIndices[row, None], (elems_per_store,))
162
+ topk_vals_out_store = cute.tiled_divide(topk_vals_out, (elems_per_store,))
163
+ topk_indices_store = cute.tiled_divide(topk_indices, (elems_per_store,))
164
+ for i in cutlass.range(cute.size(topk_vals_out_store.shape, [1]), unroll_full=True):
165
+ cute.autovec_copy(topk_vals_out_store[None, i], mValues_store[None, i])
166
+ cute.autovec_copy(topk_indices_store[None, i], mIndices_store[None, i])
167
+
168
+
169
+ def _topk_fwd(x: torch.Tensor, k: int):
170
+ """Top-k forward pass.
171
+ Args:
172
+ x: Input tensor of shape (M, N)
173
+ k: Number of top elements to return
174
+ Returns:
175
+ Tuple of (values tensor of shape (M, k), indices tensor of shape (M, k))
176
+ """
177
+ assert x.dim() == 2, "Input must be 2D"
178
+ assert x.is_cuda, "Tensor must be on CUDA device"
179
+ assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
180
+ assert k > 0 and k <= x.shape[1], "k must be positive and <= N"
181
+
182
+ M, N = x.shape
183
+ values = torch.empty((M, k), dtype=x.dtype, device=x.device)
184
+ indices = torch.empty((M, k), dtype=torch.int32, device=x.device)
185
+
186
+ dtype = torch2cute_dtype_map[x.dtype]
187
+ convert_from_dlpack = lambda tensor: (
188
+ from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
189
+ mode=0, stride_order=(0, 1)
190
+ )
191
+ )
192
+
193
+ x_tensor, values_tensor, indices_tensor = [
194
+ convert_from_dlpack(tensor) for tensor in (x, values, indices)
195
+ ]
196
+ current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
197
+ compile_key = (dtype, N, k)
198
+ if compile_key not in _topk_fwd.compile_cache:
199
+ topk_op = TopK(dtype, N, k)
200
+ _topk_fwd.compile_cache[compile_key] = cute.compile(
201
+ topk_op, x_tensor, values_tensor, indices_tensor, current_stream
202
+ )
203
+ _topk_fwd.compile_cache[compile_key](x_tensor, values_tensor, indices_tensor, current_stream)
204
+
205
+ return values, indices
206
+
207
+
208
+ _topk_fwd.compile_cache = {}
209
+
210
+
211
+ def topk(x: torch.Tensor, k: int):
212
+ """Top-k operation.
213
+
214
+ Args:
215
+ x: Input tensor of shape (M, N)
216
+ k: Number of top elements to return
217
+
218
+ Returns:
219
+ Tuple of (values tensor of shape (M, k), indices tensor of shape (M, k))
220
+ """
221
+ return _topk_fwd(x, k)
quack/utils.py CHANGED
@@ -2,14 +2,14 @@
2
2
 
3
3
  import operator
4
4
  import math
5
- from typing import Callable, Optional, Tuple
5
+ from typing import Callable, Optional, Tuple, Type, Union
6
6
 
7
7
  import cutlass
8
8
  import cutlass.cute as cute
9
9
 
10
- from cutlass import Float32
10
+ from cutlass import Float32, Int32
11
11
  from cutlass.cutlass_dsl import T, dsl_user_op
12
- from cutlass._mlir.dialects import llvm, vector
12
+ from cutlass._mlir.dialects import llvm, nvvm, vector
13
13
  from cutlass.cute.runtime import from_dlpack
14
14
 
15
15
 
@@ -100,13 +100,14 @@ def store_shared_remote(
100
100
  ).ir_value()
101
101
  if cutlass.const_expr(isinstance(val, float)):
102
102
  val = Float32(val)
103
- assert isinstance(val, (Float32, cutlass.Int64)), "val must be Float32 or Int64"
104
- suffix = "f32" if cutlass.const_expr(isinstance(val, Float32)) else "s64"
103
+ assert isinstance(val, (Float32, Int32, cutlass.Int64)), "val must be Float32, Int32, or Int64"
104
+ suffix = {Float32: "f32", Int32: "s32", cutlass.Int64: "s64"}[type(val)]
105
+ constraint = {Float32: "f", Int32: "r", cutlass.Int64: "l"}[type(val)]
105
106
  llvm.inline_asm(
106
107
  None,
107
108
  [remote_smem_ptr_i32, val.ir_value(loc=loc, ip=ip), remote_mbar_ptr_i32],
108
109
  f"st.async.shared::cluster.mbarrier::complete_tx::bytes.{suffix} [$0], $1, [$2];",
109
- f"r,{'f' if cutlass.const_expr(isinstance(val, Float32)) else 'l'},r",
110
+ f"r,{constraint},r",
110
111
  has_side_effects=True,
111
112
  is_align_stack=False,
112
113
  asm_dialect=llvm.AsmDialect.AD_ATT,
@@ -198,9 +199,9 @@ def row_reduce(
198
199
  hook_fn()
199
200
  if cutlass.const_expr(reduction_buffer is not None):
200
201
  warps_per_row, cluster_n = reduction_buffer.shape[1]
201
- assert (
202
- cluster_n == 1 or mbar_ptr is not None
203
- ), "mbar_ptr must be provided for cluster reduction"
202
+ assert cluster_n == 1 or mbar_ptr is not None, (
203
+ "mbar_ptr must be provided for cluster reduction"
204
+ )
204
205
  if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
205
206
  val = block_or_cluster_reduce(
206
207
  val, warp_op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val
@@ -237,13 +238,13 @@ def online_softmax_reduce(
237
238
  hook_fn()
238
239
  if cutlass.const_expr(reduction_buffer is not None):
239
240
  rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
240
- assert (
241
- cluster_n == 1 or mbar_ptr is not None
242
- ), "mbar_ptr must be provided for cluster reduction"
241
+ assert cluster_n == 1 or mbar_ptr is not None, (
242
+ "mbar_ptr must be provided for cluster reduction"
243
+ )
243
244
  if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
244
- assert (
245
- reduction_buffer.element_type == cutlass.Int64
246
- ), "reduction_buffer must be of type cute.Int64"
245
+ assert reduction_buffer.element_type == cutlass.Int64, (
246
+ "reduction_buffer must be of type cute.Int64"
247
+ )
247
248
  lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
248
249
  row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
249
250
  if cutlass.const_expr(mbar_ptr is None):
@@ -304,6 +305,19 @@ def online_softmax_reduce(
304
305
  return max_x, sum_exp_x, (exp_x if cutlass.const_expr(return_exp_x) else None)
305
306
 
306
307
 
308
+ @dsl_user_op
309
+ def fmin(a: Union[float, Float32], b: Union[float, Float32], *, loc=None, ip=None) -> Float32:
310
+ return Float32(
311
+ nvvm.fmin(
312
+ T.f32(),
313
+ Float32(a).ir_value(loc=loc, ip=ip),
314
+ Float32(b).ir_value(loc=loc, ip=ip),
315
+ loc=loc,
316
+ ip=ip,
317
+ )
318
+ )
319
+
320
+
307
321
  @cute.jit
308
322
  def exp2f(x: cute.TensorSSA | Float32) -> cute.TensorSSA | Float32:
309
323
  """exp2f calculation for both vector and scalar.
@@ -315,7 +329,7 @@ def exp2f(x: cute.TensorSSA | Float32) -> cute.TensorSSA | Float32:
315
329
  if cutlass.const_expr(isinstance(x, cute.TensorSSA)):
316
330
  res = cute.make_fragment(x.shape, Float32)
317
331
  res.store(x)
318
- for i in cutlass.range_constexpr(cute.size(x.shape)):
332
+ for i in cutlass.range(cute.size(x.shape), unroll_full=True):
319
333
  res[i] = cute.arch.exp2(res[i])
320
334
  return res.load()
321
335
  else:
@@ -337,6 +351,21 @@ def log2f(a: float | Float32, *, loc=None, ip=None) -> Float32:
337
351
  )
338
352
 
339
353
 
354
+ @dsl_user_op
355
+ def sqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
356
+ return Float32(
357
+ llvm.inline_asm(
358
+ T.f32(),
359
+ [Float32(a).ir_value(loc=loc, ip=ip)],
360
+ "sqrt.approx.ftz.f32 $0, $1;",
361
+ "=f,f",
362
+ has_side_effects=False,
363
+ is_align_stack=False,
364
+ asm_dialect=llvm.AsmDialect.AD_ATT,
365
+ )
366
+ )
367
+
368
+
340
369
  @dsl_user_op
341
370
  def rsqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
342
371
  return Float32(
@@ -352,6 +381,98 @@ def rsqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
352
381
  )
353
382
 
354
383
 
384
+ @dsl_user_op
385
+ def tanh(a: float | Float32, *, loc=None, ip=None) -> Float32:
386
+ return Float32(
387
+ llvm.inline_asm(
388
+ T.f32(),
389
+ [Float32(a).ir_value(loc=loc, ip=ip)],
390
+ "tanh.approx.f32 $0, $1;",
391
+ "=f,f",
392
+ has_side_effects=False,
393
+ is_align_stack=False,
394
+ asm_dialect=llvm.AsmDialect.AD_ATT,
395
+ )
396
+ )
397
+
398
+
399
+ @dsl_user_op
400
+ def ceil(a: float | Float32, *, loc=None, ip=None) -> Int32:
401
+ return Int32(
402
+ llvm.inline_asm(
403
+ T.i32(),
404
+ [Float32(a).ir_value(loc=loc, ip=ip)],
405
+ "cvt.rpi.ftz.s32.f32 $0, $1;",
406
+ "=r,f",
407
+ has_side_effects=False,
408
+ is_align_stack=False,
409
+ asm_dialect=llvm.AsmDialect.AD_ATT,
410
+ )
411
+ )
412
+
413
+
414
+ @dsl_user_op
415
+ def silu(a: float | Float32, *, loc=None, ip=None) -> Float32:
416
+ """
417
+ silu(a) = a * sigmoid(a) = a * (1 + tanh(a / 2)) / 2 = (0.5 * a) * tanh(0.5 * a) + (0.5 * a)
418
+ This compiles down to 3 SASS instructions: FMUL to get 0.5 * a, MUFU.TANH, and FFMA.
419
+ """
420
+ a_half = 0.5 * a
421
+ return a_half * tanh(a_half) + a_half
422
+
423
+
424
+ @dsl_user_op
425
+ def prmt(a: int | Int32, b: int | Int32, c: int | Int32, *, loc=None, ip=None) -> Int32:
426
+ return Int32(
427
+ llvm.inline_asm(
428
+ T.i32(),
429
+ [
430
+ Int32(a).ir_value(loc=loc, ip=ip),
431
+ Int32(b).ir_value(loc=loc, ip=ip),
432
+ Int32(c).ir_value(loc=loc, ip=ip),
433
+ ],
434
+ "prmt.b32 $0, $1, $2, $3;",
435
+ "=r,r,r,r",
436
+ has_side_effects=False,
437
+ is_align_stack=False,
438
+ asm_dialect=llvm.AsmDialect.AD_ATT,
439
+ )
440
+ )
441
+
442
+
443
+ @cute.jit
444
+ def permute_gated_Cregs_b16(t: cute.Tensor) -> None:
445
+ assert t.element_type.width == 16
446
+ assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b16 permutation"
447
+ t_u32 = cute.recast_tensor(t, Int32)
448
+
449
+ quad_idx = cute.arch.lane_idx() % 4
450
+ lane_03 = quad_idx == 0 or quad_idx == 3
451
+ selector_upper = Int32(0x5410) if lane_03 else Int32(0x1054)
452
+ selector_lower = Int32(0x7632) if lane_03 else Int32(0x3276)
453
+ # upper_map = [0, 3, 1, 2]
454
+ # lower_map = [1, 2, 0, 3]
455
+ # upper_idx = upper_map[quad_idx]
456
+ # indexing isn't supported so we have to do arithmetic
457
+ upper_idx = quad_idx // 2 if quad_idx % 2 == 0 else 3 - quad_idx // 2
458
+ lower_idx = upper_idx ^ 1
459
+
460
+ # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
461
+ width = 4
462
+ mask = cute.arch.WARP_SIZE - width
463
+ clamp = cute.arch.WARP_SIZE - 1
464
+ mask_and_clamp = mask << 8 | clamp
465
+
466
+ for i in cutlass.range(cute.size(t_u32.shape) // 2, unroll_full=True):
467
+ upper, lower = t_u32[i * 2 + 0], t_u32[i * 2 + 1]
468
+ upper0 = upper if lane_03 else lower
469
+ lower0 = lower if lane_03 else upper
470
+ upper0 = cute.arch.shuffle_sync(upper0, offset=upper_idx, mask_and_clamp=mask_and_clamp)
471
+ lower0 = cute.arch.shuffle_sync(lower0, offset=lower_idx, mask_and_clamp=mask_and_clamp)
472
+ t_u32[i * 2 + 0] = prmt(upper0, lower0, selector_upper)
473
+ t_u32[i * 2 + 1] = prmt(upper0, lower0, selector_lower)
474
+
475
+
355
476
  @cute.jit
356
477
  def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
357
478
  # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
@@ -417,9 +538,9 @@ def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float
417
538
  def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
418
539
  flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord))
419
540
  flat_stride = cute.flatten_to_tuple(tensor.stride)
420
- assert len(flat_coord_i64) == len(
421
- flat_stride
422
- ), "Coordinate and stride must have the same length"
541
+ assert len(flat_coord_i64) == len(flat_stride), (
542
+ "Coordinate and stride must have the same length"
543
+ )
423
544
  offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride))
424
545
  assert isinstance(tensor.iterator, cute.Pointer)
425
546
  # HACK: we assume that applying the offset does not change the pointer alignment
@@ -446,3 +567,100 @@ def coord_offset_i64(
446
567
  assumed_align=tensor.iterator.max_alignment,
447
568
  )
448
569
  return cute.make_tensor(new_ptr, tensor.layout)
570
+
571
+
572
+ @cute.jit
573
+ def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) -> cutlass.Int32:
574
+ if cutlass.const_expr(lane is None):
575
+ lane = cute.arch.lane_idx()
576
+ for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))):
577
+ offset = 1 << i
578
+ # Very important that we set mask_and_clamp to 0
579
+ partial_sum = cute.arch.shuffle_sync_up(val, offset=offset, mask_and_clamp=0)
580
+ if lane >= offset:
581
+ val += partial_sum
582
+ return val
583
+
584
+
585
+ def convert_layout_acc_mn(acc_layout: cute.Layout) -> cute.Layout:
586
+ """
587
+ For Sm80, convert ((2, 2), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, MMA_N), ...).
588
+ For Sm90, convert ((2, 2, V), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, V, MMA_N), ...).
589
+ """
590
+ acc_layout_col_major = cute.make_layout(acc_layout.shape)
591
+ acc_layout_mn = cute.make_layout(
592
+ (
593
+ (acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M
594
+ (
595
+ acc_layout_col_major.shape[0][0],
596
+ *acc_layout_col_major.shape[0][2:],
597
+ acc_layout_col_major.shape[2],
598
+ ), # MMA_N
599
+ *acc_layout_col_major.shape[3:],
600
+ ),
601
+ stride=(
602
+ (acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M
603
+ (
604
+ acc_layout_col_major.stride[0][0],
605
+ *acc_layout_col_major.stride[0][2:],
606
+ acc_layout_col_major.stride[2],
607
+ ), # MMA_N
608
+ *acc_layout_col_major.stride[3:],
609
+ ),
610
+ )
611
+ return cute.composition(acc_layout, acc_layout_mn)
612
+
613
+
614
+ def make_acc_tensor_mn_view(acc: cute.Tensor) -> cute.Tensor:
615
+ return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout))
616
+
617
+
618
+ @dsl_user_op
619
+ def sm90_get_smem_load_op(
620
+ layout_c: cutlass.utils.LayoutEnum,
621
+ elem_ty_c: Type[cutlass.Numeric],
622
+ *,
623
+ loc=None,
624
+ ip=None,
625
+ ) -> cute.CopyAtom:
626
+ """
627
+ Selects the largest vectorized smem load atom available subject to constraint of gmem layout.
628
+
629
+ Parameters:
630
+ -----------
631
+ layout_c : LayoutEnum
632
+ The layout enum of the output tensor D.
633
+
634
+ elem_ty_c : Type[Numeric]
635
+ The element type for output tensor D.
636
+
637
+ Returns:
638
+ --------
639
+ Either SmemLoadMatrix or SimtSyncCopy, based on the input parameters.
640
+ """
641
+
642
+ if not isinstance(elem_ty_c, cutlass.cutlass_dsl.NumericMeta):
643
+ raise TypeError(f"elem_ty_c must be a Numeric, but got {elem_ty_c}")
644
+ is_m_major = layout_c.is_m_major_c()
645
+ if elem_ty_c.width == 16:
646
+ return cute.make_copy_atom(
647
+ cute.nvgpu.warp.LdMatrix8x8x16bOp(is_m_major, 4), elem_ty_c, loc=loc, ip=ip
648
+ )
649
+ else:
650
+ return cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), elem_ty_c, loc=loc, ip=ip)
651
+
652
+
653
+ @dsl_user_op
654
+ def atomic_add_i32(a: int | Int32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> Int32:
655
+ return nvvm.atomicrmw(
656
+ res=T.i32(), op=nvvm.AtomicOpKind.ADD, ptr=gmem_ptr.llvm_ptr, a=Int32(a).ir_value()
657
+ )
658
+
659
+
660
+ @dsl_user_op
661
+ def atomic_inc_i32(a: int | Int32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> Int32:
662
+ return nvvm.atomicrmw(
663
+ res=T.i32(), op=nvvm.AtomicOpKind.INC, ptr=gmem_ptr.llvm_ptr, a=Int32(a).ir_value()
664
+ )
665
+
666
+
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.1.9
4
- Requires-Python: >=3.9
3
+ Version: 0.1.11
4
+ Requires-Python: >=3.12
5
5
  License-File: LICENSE
6
- Requires-Dist: nvidia-cutlass-dsl==4.1.0.dev0
6
+ Requires-Dist: nvidia-cutlass-dsl==4.1.0
7
7
  Requires-Dist: torch
8
8
  Provides-Extra: dev
9
9
  Requires-Dist: pre-commit; extra == "dev"
@@ -0,0 +1,31 @@
1
+ quack/__init__.py,sha256=AD0T-rBhSfKXpwZ6E4JIPiugvlFaAePjl-3pUhWOlPE,292
2
+ quack/autotuner.py,sha256=aF9-Cw47gaX7_LZvyVbLsj6Z2AWi4UZ-0Qwjy06Xd5I,10733
3
+ quack/cross_entropy.py,sha256=xsg2bXZ4wNvusBARhN4PwAzm5PbejEcfwj71nR7bzuE,20852
4
+ quack/cute_dsl_utils.py,sha256=LkNyFEKwYrgp-tLt_775EZWuBR3v7G80El3UAObHY2U,1292
5
+ quack/dense_gemm_sm100.py,sha256=W_j8BO-ilb1YUYFuclo7_itfPIRTkjPV_ittWgQy8t4,109937
6
+ quack/dense_gemm_sm90.py,sha256=Dff0GbIv92uTjrtsUE1GjVKCtwSf6_5KZbrqYZm-ZMY,110418
7
+ quack/fast_math.py,sha256=XqXVvKLSxXC3c9tIGLvKVRWdPsmjAa_O4C0plmsfZ0w,3106
8
+ quack/gemm_config.py,sha256=Gz4dkHH1Uwg9IdW-x5W_5tjdaFHBfxq4bn7hJx_xu5s,1789
9
+ quack/gemm_interface.py,sha256=XHgxo08d8LIu6dTlQKBOBJtjCegUB5uLh4k9hC-5mvY,9525
10
+ quack/layernorm.py,sha256=1WUspbr6ktPZ25O00kKs-FK_lm_Fejat72BMV8tBSfw,13504
11
+ quack/linear.py,sha256=Wd0KeXWvWjbkKrgW4Av1ud2v_mbhzf1RvubF7BYhcw4,6425
12
+ quack/lse.py,sha256=aANOleIYREyrkUQM9cfJ9Gt63eawMb2KVd7YAGWNoZU,2092
13
+ quack/mlp.py,sha256=D9V7aIfvoBMzhKwN8ZE6GlSOmwFJe_JGqgOvQprU0OQ,8224
14
+ quack/pipeline.py,sha256=SwvRZAR4RqYH60wAFC3OTu5DisN1XDMv5umQF4czJW4,5867
15
+ quack/reduction_base.py,sha256=4nAzkZR1yoQVA4Lc-GpU0XMjS5ARAmvYdeE0Doy7UCU,3789
16
+ quack/rmsnorm.py,sha256=bJEHqc8ila-LTGco-tNNCUyFBjJ2UdXeoMplYNJPXFI,32740
17
+ quack/softmax.py,sha256=3-5P_ORBrfQ6JYTIzgDs9jwmV7Za73SogaX7q9M7GCM,16698
18
+ quack/symmetric_dense_gemm_sm90.py,sha256=t-6eLasZwyu1NW4HpnvVBBPOvfqUzOg8VHe9sJQYdmg,88637
19
+ quack/tensormap_manager.py,sha256=pzBNwLCB8kV_yp8X8_BoDdtbwWeht2jrgRhyyfVIcMI,5261
20
+ quack/tile_scheduler.py,sha256=mImjD2LuIVchM6USJoJY4-CSG54jGuwyLIvFG6LTP9Y,42205
21
+ quack/topk.py,sha256=1pObblNJnxKLaE_T3qGvaMnUua0dqG2en9OU5PSp71s,9020
22
+ quack/utils.py,sha256=4ViEFgHecaX5wcYpO6XzTCzdnuZv2rniUJAJH5Ta0bA,24981
23
+ quack/sort/bitonic_sort.py,sha256=8t0SG1a6iEpYIlY8YM_AWvm4aN-4AA4vEzdBuJMJm9g,4768
24
+ quack/sort/generate_sorting_networks.py,sha256=vkJBOjTVEinQkWT4OtFqOWxFVdTIPoNAQocneKc9-rM,14477
25
+ quack/sort/sorting_networks.py,sha256=l_26zi3gXD_z-tnm2eAczRrmE-mbaz00KmqH6ONivL8,9686
26
+ quack/sort/utils.py,sha256=Mkr-l97RMAV-ZoNrwuzA1U3KO0Wjr38CV9Jm7ScyZoI,1090
27
+ quack_kernels-0.1.11.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
28
+ quack_kernels-0.1.11.dist-info/METADATA,sha256=WTYlk9lmhr4Jkin71stp3h-NrBdme-8OrBc7lAf4vSw,286
29
+ quack_kernels-0.1.11.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
30
+ quack_kernels-0.1.11.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
31
+ quack_kernels-0.1.11.dist-info/RECORD,,
@@ -1,12 +0,0 @@
1
- quack/__init__.py,sha256=CT76CeRNh5bzQ9f13yVuRz9Sj7V3MvwzHH4fB1iQIf0,203
2
- quack/cross_entropy.py,sha256=VYSAd28GmtnMoKQwLrorvySDtJfRhoqVd-aeM52FmsI,20866
3
- quack/layernorm.py,sha256=1WUspbr6ktPZ25O00kKs-FK_lm_Fejat72BMV8tBSfw,13504
4
- quack/reduction_base.py,sha256=4nAzkZR1yoQVA4Lc-GpU0XMjS5ARAmvYdeE0Doy7UCU,3789
5
- quack/rmsnorm.py,sha256=bJEHqc8ila-LTGco-tNNCUyFBjJ2UdXeoMplYNJPXFI,32740
6
- quack/softmax.py,sha256=3-5P_ORBrfQ6JYTIzgDs9jwmV7Za73SogaX7q9M7GCM,16698
7
- quack/utils.py,sha256=aiyzBc9BEwq8s965elfiR331hAaLLBKL9kDHjuls86Q,17791
8
- quack_kernels-0.1.9.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
9
- quack_kernels-0.1.9.dist-info/METADATA,sha256=vOnpbShNHRiUXKAnOUxzfRM7zkpW3RmjW4hIgvYda08,289
10
- quack_kernels-0.1.9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
11
- quack_kernels-0.1.9.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
12
- quack_kernels-0.1.9.dist-info/RECORD,,