quack-kernels 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. quack/__init__.py +1 -8
  2. quack/activation.py +366 -121
  3. quack/broadcast_utils.py +29 -0
  4. quack/compile_utils.py +19 -0
  5. quack/copy_utils.py +487 -0
  6. quack/cross_entropy.py +157 -233
  7. quack/cute_dsl_utils.py +20 -34
  8. quack/gemm.py +194 -0
  9. quack/{gemm_act_sm90.py → gemm_act.py} +218 -117
  10. quack/gemm_config.py +72 -46
  11. quack/{gemm_dact_sm90.py → gemm_dact.py} +53 -21
  12. quack/gemm_default_epi.py +259 -0
  13. quack/gemm_interface.py +177 -31
  14. quack/gemm_sm100.py +729 -506
  15. quack/{dense_gemm_sm90.py → gemm_sm90.py} +344 -814
  16. quack/gemm_symmetric.py +330 -0
  17. quack/gemm_wrapper_utils.py +3 -1
  18. quack/layout_utils.py +287 -0
  19. quack/linear.py +24 -16
  20. quack/pipeline.py +158 -3
  21. quack/reduce.py +88 -49
  22. quack/reduction_base.py +25 -36
  23. quack/rmsnorm.py +476 -526
  24. quack/sm100_utils.py +62 -0
  25. quack/sm90_utils.py +127 -0
  26. quack/softmax.py +135 -203
  27. quack/sort/bitonic_sort.py +13 -10
  28. quack/sort/utils.py +6 -6
  29. quack/tile_scheduler.py +23 -16
  30. quack/topk.py +409 -85
  31. quack/utils.py +32 -220
  32. quack/varlen_utils.py +370 -1
  33. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/METADATA +4 -2
  34. quack_kernels-0.2.4.dist-info/RECORD +44 -0
  35. quack/layernorm.py +0 -353
  36. quack/symmetric_dense_gemm_sm90.py +0 -2091
  37. quack_kernels-0.2.2.dist-info/RECORD +0 -37
  38. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/WHEEL +0 -0
  39. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/licenses/LICENSE +0 -0
  40. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/top_level.txt +0 -0
quack/topk.py CHANGED
@@ -1,55 +1,57 @@
1
1
  # Copyright (c) 2025, Wentao Guo, Mayank Mishra, Tri Dao.
2
2
 
3
3
  import math
4
+ from functools import partial
5
+ from typing import Type, Optional
6
+
4
7
  import torch
5
- from typing import Type
6
8
 
7
9
  import cuda.bindings.driver as cuda
8
10
 
9
11
  import cutlass
10
12
  import cutlass.cute as cute
11
- from cutlass.cute.runtime import from_dlpack
12
- from cutlass import const_expr
13
+ from cutlass import Int32, Float32, const_expr
13
14
 
14
15
  import quack.utils as utils
16
+ import quack.copy_utils as copy_utils
17
+ from quack.compile_utils import make_fake_tensor as fake_tensor
18
+ from quack.reduction_base import ReductionBase
19
+ from quack.reduce import row_reduce
15
20
  from quack.cute_dsl_utils import torch2cute_dtype_map
16
21
  from quack.sort.bitonic_sort import bitonic_topk
17
22
 
18
23
 
19
24
  class TopK:
20
- def __init__(self, dtype: Type[cutlass.Numeric], N: int, k: int):
25
+ def __init__(self, dtype: Type[cutlass.Numeric], N: int, k: int, softmax: bool = False):
21
26
  self.dtype = dtype
22
27
  self.N = N
23
28
  self.vecsize = 128 // dtype.width
24
29
  self.k = k
30
+ self.softmax = softmax
25
31
  assert N == 2 ** int(math.log2(N)), "N must be a power of 2"
26
32
  assert k == 2 ** int(math.log2(k)), "N must be a power of 2"
27
33
  assert k <= 128
28
34
  assert N <= 4096
29
35
 
30
- def _calculate_threads_per_row(self):
36
+ def _threads_per_row(self):
31
37
  # we want num_elems_per_thread >= self.k
32
38
  # and each thread can handle at most 64 elements
33
39
  N = self.N
34
40
  num_threads_per_row = max(min(N // self.k, 32, N // 64), 1)
35
41
  return num_threads_per_row
36
42
 
37
- def _get_tv_layout(self):
43
+ def _get_tiled_copy(self):
38
44
  N = self.N
39
45
  vecsize = self.vecsize
40
46
  num_threads = 128 if N <= 16384 else 256
41
- threads_per_row = self._calculate_threads_per_row()
47
+ threads_per_row = self._threads_per_row()
42
48
  cols_per_block = num_threads // threads_per_row
43
49
  num_blocks_N = cute.ceil_div(min(N, 16384) // vecsize, threads_per_row)
44
50
  tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
45
- tv_layout = cute.make_layout(
46
- ((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
47
- stride=(
48
- (vecsize * cols_per_block, 1),
49
- (cols_per_block, cols_per_block * vecsize * threads_per_row),
50
- ),
51
+ tiled_copy = copy_utils.tiled_copy_2d(
52
+ self.dtype, threads_per_row, num_threads, num_copy_elems=vecsize
51
53
  )
52
- return tiler_mn, tv_layout
54
+ return tiled_copy, tiler_mn, threads_per_row
53
55
 
54
56
  @cute.jit
55
57
  def __call__(
@@ -61,10 +63,10 @@ class TopK:
61
63
  ):
62
64
  assert mX.element_type == self.dtype
63
65
  assert mValues.element_type == self.dtype
64
- assert mIndices.element_type == cutlass.Int32
65
- tiler_mn, tv_layout = self._get_tv_layout()
66
- num_threads = cute.size(tv_layout, mode=[0])
67
- self.kernel(mX, mValues, mIndices, tv_layout, tiler_mn).launch(
66
+ assert mIndices.element_type == Int32
67
+ tiled_copy, tiler_mn, threads_per_row = self._get_tiled_copy()
68
+ num_threads = tiled_copy.size
69
+ self.kernel(mX, mValues, mIndices, tiler_mn, tiled_copy, threads_per_row).launch(
68
70
  grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), 1, 1],
69
71
  block=[num_threads, 1, 1],
70
72
  stream=stream,
@@ -76,103 +78,151 @@ class TopK:
76
78
  mX: cute.Tensor,
77
79
  mValues: cute.Tensor,
78
80
  mIndices: cute.Tensor,
79
- tv_layout: cute.Layout,
80
81
  tiler_mn: cute.Shape,
82
+ tiled_copy: cute.TiledCopy,
83
+ threads_per_row: cutlass.Constexpr[int],
81
84
  ):
82
85
  tidx, _, _ = cute.arch.thread_idx()
83
86
  bidx, _, _ = cute.arch.block_idx()
87
+ tv_layout = tiled_copy.layout_tv_tiled
84
88
 
85
89
  shape = mX.shape
86
90
  idX = cute.make_identity_tensor(shape)
87
91
  # slice for CTAs
88
- # We use domain_offset_i64 to deal with tensors larger than 2^31 elements
89
- mX = utils.domain_offset_i64((bidx * tiler_mn[0], 0), mX)
90
- gX = cute.local_tile(mX, tiler_mn, (0, 0))
91
- cX = cute.local_tile(idX, tiler_mn, (bidx, 0))
92
-
93
- # declare the atoms which will be used later for memory copy
94
- copy_atom_load_X = cute.make_copy_atom(
95
- cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128
96
- )
97
- thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
98
- tXgX = thr_copy_X.partition_S(gX)
99
- tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
92
+ gX, cX = [cute.local_tile(mT, tiler_mn, (bidx, 0)) for mT in (mX, idX)]
93
+
94
+ thr_copy = tiled_copy.get_slice(tidx)
100
95
 
101
- # allocate fragments for gmem->rmem
96
+ tXgX = thr_copy.partition_S(gX)
97
+ tXcX = thr_copy.partition_S(cX)[(0, None), None, None]
102
98
  tXrX = cute.make_fragment_like(tXgX)
103
99
 
104
100
  is_even_N = const_expr(shape[1] == tiler_mn[1])
105
101
  tXpX = (
106
- utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
102
+ None if is_even_N else copy_utils.predicate_k(thr_copy.partition_S(cX), limit=shape[1])
107
103
  )
104
+ copy = partial(copy_utils.copy, pred=tXpX)
105
+
108
106
  if tXcX[0][0] < shape[0]:
109
- cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
110
- tXrX_f32 = cute.make_fragment(tXrX.shape, cutlass.Float32)
111
- tXrX_f32.store(tXrX.load().to(cutlass.Float32))
107
+ copy(tXgX, tXrX)
108
+ tXrX_f32 = cute.make_fragment(tXrX.shape, Float32)
109
+ tXrX_f32.store(tXrX.load().to(Float32))
112
110
 
113
111
  # Encode the indices into the bottom bits of values.
114
112
  log_N = int(math.log2(self.N))
115
113
  idx_mask = (1 << log_N) - 1
116
- vecsize = cutlass.const_expr(tv_layout.shape[1][0])
117
- tXrX_u32 = cute.recast_tensor(tXrX_f32, cutlass.Uint32)
118
- # Encode indices into the last log_N bits of tXrX_u32
119
- for i in cutlass.range(cute.size(tXrX_u32), unroll_full=True):
114
+ vecsize = const_expr(cute.size(tv_layout.shape[1]))
115
+ tXrX_i32 = cute.recast_tensor(tXrX_f32, Int32)
116
+ # Encode indices into the last log_N bits of tXrX_i32
117
+ for i in cutlass.range(cute.size(tXrX_i32), unroll_full=True):
120
118
  # tXcX only keeps track of the indices for every @vecsize elements
121
- col_idx = cutlass.Uint32(tXcX[i // vecsize][1] + i % vecsize)
119
+ col_idx = Int32(tXcX[i // vecsize][1] + i % vecsize)
122
120
  # If positive, invert the bits of the index, so that if there's a tie,
123
121
  # indices coming from a earlier column will win.
124
122
  encoded_idx = ~col_idx if tXrX_f32[i] >= 0 else col_idx
125
123
  # Mask to keep only the last log_N bits of the encoded index
126
124
  encoded_idx = encoded_idx & idx_mask
127
125
  # Clear the last log_N bits and set them to our encoded index
128
- tXrX_u32[i] = (tXrX_u32[i] & ~idx_mask) | encoded_idx
126
+ tXrX_i32[i] = (tXrX_i32[i] & ~idx_mask) | encoded_idx
129
127
 
130
128
  # Fill OOB values with -inf for top-k
131
129
  if const_expr(not is_even_N):
132
130
  utils.fill_oob(tXrX_f32, tXpX, -tXrX_f32.element_type.inf)
133
131
 
134
- threads_per_row = tv_layout.shape[0][0]
135
132
  topk_vals = bitonic_topk(tXrX_f32, self.k, warp_width=threads_per_row)
136
133
 
134
+ # Thread 0 in each row contains all the top-k values, so we split those into multiple threads
135
+ vecsize_out = const_expr(min(self.k, vecsize, 128 // mIndices.element_type.width))
136
+ assert self.k % vecsize_out == 0
137
+ nvec_per_thread = const_expr(cute.ceil_div(self.k, vecsize_out * threads_per_row))
138
+ # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
139
+ mask = cute.arch.WARP_SIZE - threads_per_row
140
+ mask_and_clamp = mask << 8 | (cute.arch.WARP_SIZE - 1)
141
+ topk_vals_split = cute.make_fragment((vecsize_out, nvec_per_thread), Float32)
142
+ for i in cutlass.range(cute.ceil_div(self.k, vecsize_out), unroll_full=True):
143
+ should_receive = tidx % threads_per_row == i % threads_per_row
144
+ for v in cutlass.range(vecsize_out, unroll_full=True):
145
+ if const_expr(threads_per_row > 1):
146
+ if i * vecsize_out + v < self.k:
147
+ val = cute.arch.shuffle_sync(
148
+ topk_vals[i * vecsize_out + v], offset=0, mask_and_clamp=mask_and_clamp
149
+ )
150
+ if should_receive:
151
+ topk_vals_split[v, i // threads_per_row] = val
152
+ else:
153
+ topk_vals_split[v, i // threads_per_row] = topk_vals[i * vecsize_out + v]
154
+
137
155
  # Extract indices and clean values
138
- topk_vals_u32 = cute.recast_tensor(topk_vals, cutlass.Uint32)
139
- topk_indices = cute.make_fragment(self.k, cutlass.Int32)
140
- for i in cutlass.range(self.k):
156
+ topk_vals_i32 = cute.recast_tensor(topk_vals_split, Int32)
157
+ topk_indices = cute.make_fragment(topk_vals_i32.shape, Int32)
158
+ for i in cutlass.range(cute.size(topk_vals_i32), unroll_full=True):
141
159
  # Extract the encoded index from the last log_N bits
142
- encoded_idx = topk_vals_u32[i] & idx_mask
160
+ encoded_idx = topk_vals_i32[i] & idx_mask
143
161
  # Check if original value was positive by looking at the cleaned value
144
- topk_vals_u32[i] = topk_vals_u32[i] & ~idx_mask # Clear last log_N bits
162
+ topk_vals_i32[i] = topk_vals_i32[i] & ~idx_mask # Clear last log_N bits
145
163
  # If positive, we need to invert the bits back to get original index
146
164
  col_idx = ~encoded_idx if topk_vals[i] >= 0 else encoded_idx
147
- topk_indices[i] = cutlass.Int32(col_idx & idx_mask)
165
+ topk_indices[i] = Int32(col_idx & idx_mask)
166
+
167
+ # Compute softmax if requested
168
+ if const_expr(self.softmax):
169
+ # Need masking as some elements may be OOB
170
+ for i in cutlass.range(cute.size(topk_vals_split, mode=[1]), unroll_full=True):
171
+ col = i * threads_per_row + tidx % threads_per_row
172
+ if col >= self.k // vecsize_out:
173
+ for v in cutlass.range(vecsize_out, unroll_full=True):
174
+ topk_vals_split[v, i] = -Float32.inf
175
+ # Get max from thread 0 (topk_vals[0] is the max since sorted descending)
176
+ max_val = cute.arch.shuffle_sync(topk_vals[0], offset=0, mask_and_clamp=mask_and_clamp)
177
+ log2_e = math.log2(math.e)
178
+ exp_x = cute.math.exp2(
179
+ topk_vals_split.load() * log2_e - (max_val * log2_e), fastmath=True
180
+ )
181
+ denom = cute.arch.warp_reduction_sum(
182
+ exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
183
+ threads_in_group=threads_per_row,
184
+ )
185
+ topk_vals_split.store(exp_x * cute.arch.rcp_approx(denom))
148
186
 
149
187
  # Convert cleaned values to output type
150
- topk_vals_out = cute.make_fragment_like(topk_vals, mValues.element_type)
151
- topk_vals_out.store(topk_vals.load().to(mValues.element_type))
188
+ topk_vals_out = cute.make_fragment_like(topk_vals_split, mValues.element_type)
189
+ topk_vals_out.store(topk_vals_split.load().to(mValues.element_type))
152
190
 
153
191
  row = tXcX[0][0]
154
- # Only the 1st thread in this row writes the top-k values and indices
155
- if row < shape[0] and tXcX[0][1] == 0:
156
- # for i in cutlass.range(self.k):
157
- # mValues[row, i] = topk_vals_out[i]
158
- # mIndices[row, i] = topk_indices[i]
192
+ # # Only the 1st thread in this row writes the top-k values and indices
193
+ # if row < shape[0] and tXcX[0][1] == 0:
194
+ # # for i in cutlass.range(self.k):
195
+ # # mValues[row, i] = topk_vals_out[i]
196
+ # # mIndices[row, i] = topk_indices[i]
197
+ # # Vectorized write
198
+ # elems_per_store = const_expr(math.gcd(vecsize, self.k))
199
+ # mValues_store = cute.tiled_divide(mValues[row, None], (elems_per_store,))
200
+ # mIndices_store = cute.tiled_divide(mIndices[row, None], (elems_per_store,))
201
+ # topk_vals_out_store = cute.tiled_divide(topk_vals_out, (elems_per_store,))
202
+ # topk_indices_store = cute.tiled_divide(topk_indices, (elems_per_store,))
203
+ # for i in cutlass.range(cute.size(topk_vals_out_store.shape, [1]), unroll_full=True):
204
+ # cute.autovec_copy(topk_vals_out_store[None, i], mValues_store[None, i])
205
+ # cute.autovec_copy(topk_indices_store[None, i], mIndices_store[None, i])
206
+ if tiler_mn[0] == 0 or row < shape[0]:
159
207
  # Vectorized write
160
- elems_per_store = const_expr(math.gcd(vecsize, self.k))
161
- mValues_store = cute.tiled_divide(mValues[row, None], (elems_per_store,))
162
- mIndices_store = cute.tiled_divide(mIndices[row, None], (elems_per_store,))
163
- topk_vals_out_store = cute.tiled_divide(topk_vals_out, (elems_per_store,))
164
- topk_indices_store = cute.tiled_divide(topk_indices, (elems_per_store,))
165
- for i in cutlass.range(cute.size(topk_vals_out_store.shape, [1]), unroll_full=True):
166
- cute.autovec_copy(topk_vals_out_store[None, i], mValues_store[None, i])
167
- cute.autovec_copy(topk_indices_store[None, i], mIndices_store[None, i])
208
+ mValues_store = cute.tiled_divide(mValues[row, None], (vecsize_out,))
209
+ mIndices_store = cute.tiled_divide(mIndices[row, None], (vecsize_out,))
210
+ for i in cutlass.range(cute.size(topk_vals_out.shape, [1]), unroll_full=True):
211
+ col = i * threads_per_row + tidx % threads_per_row
212
+ if col < self.k // vecsize_out:
213
+ cute.autovec_copy(topk_vals_out[None, i], mValues_store[None, col])
214
+ cute.autovec_copy(topk_indices[None, i], mIndices_store[None, col])
168
215
 
169
216
 
170
217
  @torch.library.custom_op("quack::_topk_fwd", mutates_args={"values", "indices"})
171
- def _topk_fwd(x: torch.Tensor, k: int, values: torch.Tensor, indices: torch.Tensor) -> None:
218
+ def _topk_fwd(
219
+ x: torch.Tensor, k: int, softmax: bool, values: torch.Tensor, indices: torch.Tensor
220
+ ) -> None:
172
221
  """Top-k forward pass.
173
222
  Args:
174
223
  x: Input tensor of shape (M, N)
175
224
  k: Number of top elements to return
225
+ softmax: Whether to apply softmax to the top-k values
176
226
  Returns:
177
227
  Tuple of (values tensor of shape (M, k), indices tensor of shape (M, k))
178
228
  """
@@ -182,46 +232,320 @@ def _topk_fwd(x: torch.Tensor, k: int, values: torch.Tensor, indices: torch.Tens
182
232
  assert k > 0 and k <= x.shape[1], "k must be positive and <= N"
183
233
 
184
234
  N = x.size(1)
185
-
186
235
  dtype = torch2cute_dtype_map[x.dtype]
187
- convert_from_dlpack = lambda tensor: (
188
- from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
189
- mode=0, stride_order=(0, 1)
190
- )
191
- )
192
-
193
- x_tensor, values_tensor, indices_tensor = [
194
- convert_from_dlpack(tensor) for tensor in (x, values, indices)
195
- ]
196
- current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
197
- compile_key = (dtype, N, k)
236
+ compile_key = (dtype, N, k, softmax)
198
237
  if compile_key not in _topk_fwd.compile_cache:
199
- topk_op = TopK(dtype, N, k)
238
+ batch_sym = cute.sym_int()
239
+ div = math.gcd(128 // dtype.width, N)
240
+ x_cute = fake_tensor(dtype, (batch_sym, N), div)
241
+ values_cute = fake_tensor(dtype, (batch_sym, k), div)
242
+ indices_cute = fake_tensor(Int32, (batch_sym, k), div)
243
+ topk_op = TopK(dtype, N, k, softmax=softmax)
200
244
  _topk_fwd.compile_cache[compile_key] = cute.compile(
201
- topk_op, x_tensor, values_tensor, indices_tensor, current_stream
245
+ topk_op,
246
+ x_cute,
247
+ values_cute,
248
+ indices_cute,
249
+ cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
250
+ options="--enable-tvm-ffi",
202
251
  )
203
- _topk_fwd.compile_cache[compile_key](x_tensor, values_tensor, indices_tensor, current_stream)
252
+ _topk_fwd.compile_cache[compile_key](x, values, indices)
204
253
 
205
254
 
206
255
  _topk_fwd.compile_cache = {}
207
256
 
208
257
 
209
- def topk(x: torch.Tensor, k: int):
258
+ def topk_fwd(x: torch.Tensor, k: int, softmax: bool = False):
210
259
  """Top-k operation.
211
260
 
212
261
  Args:
213
262
  x: Input tensor of shape (M, N)
214
263
  k: Number of top elements to return
264
+ softmax: Whether to apply softmax to the top-k values
215
265
 
216
266
  Returns:
217
267
  Tuple of (values tensor of shape (M, k), indices tensor of shape (M, k))
218
268
  """
219
-
220
269
  M = x.size(0)
221
-
222
270
  values = torch.empty((M, k), dtype=x.dtype, device=x.device)
223
271
  indices = torch.empty((M, k), dtype=torch.int32, device=x.device)
272
+ _topk_fwd(x, k, softmax, values, indices)
273
+ return values, indices
224
274
 
225
- _topk_fwd(x, k, values, indices)
226
275
 
227
- return values, indices
276
+ class TopKBackward(ReductionBase):
277
+ def __init__(self, dtype: Type[cutlass.Numeric], N: int, k: int, softmax: bool = False):
278
+ super().__init__(dtype, N, stage=1, reduction_dtype=Float32)
279
+ self.dtype = dtype
280
+ self.N = N
281
+ self.k = k
282
+ self.softmax = softmax
283
+ assert k <= N
284
+ assert k <= 32768
285
+
286
+ def _num_threads(self):
287
+ return 128 if self.N <= 16384 else 256
288
+
289
+ def _get_tiled_copy(self, N: int, vecsize: Optional[int] = None):
290
+ if vecsize is None:
291
+ vecsize = min(N, 128 // self.dtype.width)
292
+ assert N % vecsize == 0, f"Input N {N} is not divisible by vector size {vecsize}"
293
+ num_threads = self._num_threads()
294
+ threads_per_row = min(N // vecsize, num_threads)
295
+ cols_per_block = num_threads // threads_per_row
296
+ num_blocks_N = cute.ceil_div(N // vecsize, threads_per_row)
297
+ tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
298
+ tiled_copy = copy_utils.tiled_copy_2d(
299
+ self.dtype, threads_per_row, num_threads, num_copy_elems=vecsize
300
+ )
301
+ return tiled_copy, tiler_mn, threads_per_row
302
+
303
+ @cute.jit
304
+ def __call__(
305
+ self,
306
+ mdValues: cute.Tensor, # (M, k)
307
+ mValues: Optional[cute.Tensor], # (M, k)
308
+ mIndices: cute.Tensor, # (M, k)
309
+ mdX: cute.Tensor, # (M, N)
310
+ stream: cuda.CUstream,
311
+ ):
312
+ assert mdValues.element_type == self.dtype
313
+ if const_expr(mValues is not None):
314
+ assert mValues.element_type == self.dtype
315
+ assert mIndices.element_type == Int32
316
+ self._set_cluster_n()
317
+ largest_dtype_width = const_expr(
318
+ max(
319
+ *(t.element_type.width for t in [mdValues, mValues, mIndices, mdX] if t is not None)
320
+ )
321
+ )
322
+ vecsize = math.gcd(self.N, 128 // largest_dtype_width)
323
+ tiled_copy, tiler_mn, threads_per_row = self._get_tiled_copy(self.N, vecsize=vecsize)
324
+ num_threads = tiled_copy.size
325
+ self.kernel(
326
+ mdValues,
327
+ mValues,
328
+ mIndices,
329
+ mdX,
330
+ tiler_mn,
331
+ tiled_copy,
332
+ threads_per_row,
333
+ ).launch(
334
+ grid=[cute.ceil_div(mdX.shape[0], tiler_mn[0]), 1, 1],
335
+ block=[num_threads, 1, 1],
336
+ stream=stream,
337
+ )
338
+
339
+ @cute.kernel
340
+ def kernel(
341
+ self,
342
+ mdValues: cute.Tensor, # (M, k)
343
+ mValues: Optional[cute.Tensor], # (M, k)
344
+ mIndices: cute.Tensor, # (M, k)
345
+ mdX: cute.Tensor, # (M, N)
346
+ tiler_mn: cute.Shape,
347
+ tiled_copy: cute.TiledCopy,
348
+ threads_per_row: cutlass.Constexpr[int],
349
+ ):
350
+ tidx, _, _ = cute.arch.thread_idx()
351
+ bidx, _, _ = cute.arch.block_idx()
352
+
353
+ tv_layout = tiled_copy.layout_tv_tiled
354
+ shape = mdX.shape
355
+ idX = cute.make_identity_tensor(shape)
356
+ idTopK = cute.make_identity_tensor(mdValues.shape)
357
+ # slice for CTAs
358
+ gdX, cX = [cute.local_tile(mT, tiler_mn, (bidx, 0)) for mT in (mdX, idX)]
359
+ gdVals, gVals, gIdx, cTopK = [
360
+ cute.local_tile(mT, tiler_mn, (bidx, 0)) if mT is not None else None
361
+ for mT in (mdValues, mValues, mIndices, idTopK)
362
+ ]
363
+
364
+ # Allocate smem for output gradients
365
+ smem = cutlass.utils.SmemAllocator()
366
+ sdX = smem.allocate_tensor(
367
+ mdX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
368
+ )
369
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
370
+
371
+ thr_copy = tiled_copy.get_slice(tidx)
372
+
373
+ tXgdV = thr_copy.partition_S(gdVals)
374
+ tXgV = thr_copy.partition_S(gVals) if const_expr(gVals is not None) else None
375
+ tXgI = thr_copy.partition_S(gIdx)
376
+ tXrdV = cute.make_fragment_like(tXgdV)
377
+ tXrV = cute.make_fragment_like(tXgV) if const_expr(tXgV is not None) else None
378
+ tXrI = cute.make_fragment_like(tXgI)
379
+ tXrdV.fill(tXrdV.element_type.zero)
380
+ if const_expr(mValues is not None):
381
+ tXrV.fill(tXrV.element_type.zero)
382
+ tXrI.fill(0)
383
+
384
+ tXsdX = thr_copy.partition_D(sdX)
385
+ tXgdX = thr_copy.partition_D(gdX)
386
+ tXcX = thr_copy.partition_S(cX)[(0, None), None, None]
387
+ tXrdX = cute.make_fragment_like(tXgdX)
388
+
389
+ is_even_N = const_expr(shape[1] == tiler_mn[1])
390
+ tXpV = copy_utils.predicate_k(thr_copy.partition_S(cTopK), limit=mdValues.shape[1])
391
+ tXpX = (
392
+ None if is_even_N else copy_utils.predicate_k(thr_copy.partition_S(cX), limit=shape[1])
393
+ )
394
+ copy_k = partial(copy_utils.copy, pred=tXpV)
395
+ copy_dx = partial(copy_utils.copy, pred=tXpX)
396
+
397
+ row = tXcX[0][0]
398
+ tile_row_start = Int32(cute.arch.block_idx()[0] * tiler_mn[0])
399
+
400
+ # Zero out smem
401
+ utils.fill_oob(tXsdX, None, fill_value=mdX.element_type.zero)
402
+
403
+ if row < shape[0]:
404
+ copy_k(tXgdV, tXrdV)
405
+ if const_expr(mValues is not None):
406
+ copy_k(tXgV, tXrV)
407
+ copy_k(tXgI, tXrI)
408
+
409
+ cute.arch.barrier()
410
+
411
+ dvals_f32 = tXrdV.load().to(Float32)
412
+ if const_expr(self.softmax):
413
+ vals_f32 = tXrV.load().to(Float32)
414
+ dot = row_reduce(
415
+ dvals_f32 * vals_f32,
416
+ cute.ReductionOp.ADD,
417
+ threads_per_row,
418
+ reduction_buffer[None, None, 0],
419
+ )
420
+ grads = vals_f32 * (dvals_f32 - dot)
421
+ else:
422
+ grads = dvals_f32
423
+ grad_cvt = cute.make_fragment(tXrdV.shape, mdX.element_type)
424
+ grad_cvt.store(grads.to(mdX.element_type))
425
+
426
+ # Scatter values to smem
427
+ if row < shape[0]:
428
+ for rest_v in cutlass.range(tXrdV.shape[0][1], unroll_full=True):
429
+ for n in cutlass.range(tXrdV.shape[2], unroll_full=True):
430
+ if tXpV[rest_v, 0, n]:
431
+ for v in cutlass.range(tXrdV.shape[0][0], unroll_full=True):
432
+ sdX[row - tile_row_start, tXrI[(v, rest_v), 0, n]] = grad_cvt[
433
+ (v, rest_v), 0, n
434
+ ]
435
+ cute.arch.barrier()
436
+
437
+ # Read from smem to rmem, then write to gmem
438
+ cute.autovec_copy(tXsdX, tXrdX)
439
+ if row < shape[0]:
440
+ copy_dx(tXrdX, tXgdX)
441
+
442
+
443
+ @torch.library.custom_op("quack::_topk_bwd", mutates_args={"dx"})
444
+ def _topk_bwd(
445
+ dvalues: torch.Tensor,
446
+ values: Optional[torch.Tensor],
447
+ indices: torch.Tensor,
448
+ k: int,
449
+ softmax: bool,
450
+ dx: torch.Tensor,
451
+ ) -> None:
452
+ """Top-k backward pass.
453
+ Args:
454
+ dvalues: Upstream gradients tensor of shape (M, k)
455
+ values: Forward top-k values tensor of shape (M, k)
456
+ indices: Indices tensor of shape (M, k) from forward pass
457
+ k: Number of top elements
458
+ softmax: Whether softmax was applied in forward
459
+ dx: Output gradient tensor of shape (M, N)
460
+ """
461
+ assert dvalues.dim() == 2, "dvalues must be 2D"
462
+ if values is not None:
463
+ assert values.dim() == 2, "values must be 2D"
464
+ assert indices.dim() == 2, "indices must be 2D"
465
+ assert dvalues.is_cuda and indices.is_cuda, "Tensors must be on CUDA device"
466
+ assert dvalues.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
467
+
468
+ N = dx.size(1)
469
+ dtype = torch2cute_dtype_map[dvalues.dtype]
470
+ val_dtype = torch2cute_dtype_map[values.dtype] if values is not None else dtype
471
+ dx_dtype = torch2cute_dtype_map[dx.dtype]
472
+ compile_key = (dtype, val_dtype, dx_dtype, N, k, softmax)
473
+ if compile_key not in _topk_bwd.compile_cache:
474
+ batch_sym = cute.sym_int()
475
+ div = math.gcd(128 // dtype.width, N)
476
+ dvalues_cute = fake_tensor(dtype, (batch_sym, k), div)
477
+ values_cute = fake_tensor(val_dtype, (batch_sym, k), div) if values is not None else None
478
+ indices_cute = fake_tensor(Int32, (batch_sym, k), div)
479
+ dx_cute = fake_tensor(dx_dtype, (batch_sym, N), div)
480
+ topk_bwd_op = TopKBackward(dtype, N, k, softmax=softmax)
481
+ _topk_bwd.compile_cache[compile_key] = cute.compile(
482
+ topk_bwd_op,
483
+ dvalues_cute,
484
+ values_cute,
485
+ indices_cute,
486
+ dx_cute,
487
+ cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
488
+ options="--enable-tvm-ffi",
489
+ )
490
+ _topk_bwd.compile_cache[compile_key](dvalues, values, indices, dx)
491
+
492
+
493
+ _topk_bwd.compile_cache = {}
494
+
495
+
496
+ def topk_bwd(
497
+ dvalues: torch.Tensor,
498
+ values: Optional[torch.Tensor],
499
+ indices: torch.Tensor,
500
+ N: int,
501
+ softmax: bool = False,
502
+ ) -> torch.Tensor:
503
+ """Top-k backward pass.
504
+
505
+ Args:
506
+ dvalues: Upstream gradients tensor of shape (M, k)
507
+ values: Forward top-k values tensor of shape (M, k), required if softmax=True
508
+ indices: Indices tensor of shape (M, k) from forward pass
509
+ N: Size of the original input dimension
510
+ softmax: Whether softmax was applied in forward
511
+
512
+ Returns:
513
+ Input gradients tensor of shape (M, N)
514
+ """
515
+ M, k = dvalues.shape
516
+ dx = torch.zeros((M, N), dtype=dvalues.dtype, device=dvalues.device)
517
+ _topk_bwd(dvalues, values, indices, k, softmax, dx)
518
+ return dx
519
+
520
+
521
+ class TopKFunction(torch.autograd.Function):
522
+ @staticmethod
523
+ def forward(ctx, x: torch.Tensor, k: int, softmax: bool = False):
524
+ values, indices = topk_fwd(x, k, softmax=softmax)
525
+ ctx.save_for_backward(values if softmax else None, indices)
526
+ ctx.k = k
527
+ ctx.N = x.shape[1]
528
+ ctx.softmax = softmax
529
+ ctx.mark_non_differentiable(indices)
530
+ ctx.set_materialize_grads(False)
531
+ return values, indices
532
+
533
+ @staticmethod
534
+ def backward(ctx, dvalues: torch.Tensor, dindices_: Optional[torch.Tensor] = None):
535
+ values, indices = ctx.saved_tensors
536
+ dx = topk_bwd(dvalues, values, indices, N=ctx.N, softmax=ctx.softmax)
537
+ return dx, None, None
538
+
539
+
540
+ def topk(x: torch.Tensor, k: int, softmax: bool = False):
541
+ """Top-k operation.
542
+
543
+ Args:
544
+ x: Input tensor of shape (M, N)
545
+ k: Number of top elements to return
546
+ softmax: Whether to apply softmax to the top-k values
547
+
548
+ Returns:
549
+ Tuple of (values tensor of shape (M, k), indices tensor of shape (M, k))
550
+ """
551
+ return TopKFunction.apply(x, k, softmax)