quack-kernels 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. quack/__init__.py +1 -8
  2. quack/activation.py +366 -121
  3. quack/broadcast_utils.py +29 -0
  4. quack/compile_utils.py +19 -0
  5. quack/copy_utils.py +487 -0
  6. quack/cross_entropy.py +157 -233
  7. quack/cute_dsl_utils.py +20 -34
  8. quack/gemm.py +194 -0
  9. quack/{gemm_act_sm90.py → gemm_act.py} +218 -117
  10. quack/gemm_config.py +72 -46
  11. quack/{gemm_dact_sm90.py → gemm_dact.py} +53 -21
  12. quack/gemm_default_epi.py +259 -0
  13. quack/gemm_interface.py +177 -31
  14. quack/gemm_sm100.py +729 -506
  15. quack/{dense_gemm_sm90.py → gemm_sm90.py} +344 -814
  16. quack/gemm_symmetric.py +330 -0
  17. quack/gemm_wrapper_utils.py +3 -1
  18. quack/layout_utils.py +287 -0
  19. quack/linear.py +24 -16
  20. quack/pipeline.py +158 -3
  21. quack/reduce.py +88 -49
  22. quack/reduction_base.py +25 -36
  23. quack/rmsnorm.py +476 -526
  24. quack/sm100_utils.py +62 -0
  25. quack/sm90_utils.py +127 -0
  26. quack/softmax.py +135 -203
  27. quack/sort/bitonic_sort.py +13 -10
  28. quack/sort/utils.py +6 -6
  29. quack/tile_scheduler.py +23 -16
  30. quack/topk.py +409 -85
  31. quack/utils.py +32 -220
  32. quack/varlen_utils.py +370 -1
  33. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
  34. quack_kernels-0.2.3.dist-info/RECORD +44 -0
  35. quack/layernorm.py +0 -353
  36. quack/symmetric_dense_gemm_sm90.py +0 -2091
  37. quack_kernels-0.2.2.dist-info/RECORD +0 -37
  38. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
  39. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
  40. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
quack/softmax.py CHANGED
@@ -1,14 +1,20 @@
1
+ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
+
1
3
  import math
2
- import torch
3
4
  from typing import Type
5
+ from functools import partial
6
+
7
+ import torch
4
8
 
5
9
  import cuda.bindings.driver as cuda
6
10
 
7
11
  import cutlass
8
12
  import cutlass.cute as cute
9
- from cutlass.cute.runtime import from_dlpack
13
+ from cutlass import Int64, Float32, const_expr
10
14
 
11
15
  import quack.utils as utils
16
+ import quack.copy_utils as copy_utils
17
+ from quack.compile_utils import make_fake_tensor as fake_tensor
12
18
  from quack.reduce import row_reduce, online_softmax_reduce
13
19
  from quack.reduction_base import ReductionBase
14
20
  from quack.cute_dsl_utils import torch2cute_dtype_map
@@ -21,45 +27,28 @@ class Softmax(ReductionBase):
21
27
  dtype,
22
28
  N,
23
29
  stage=2 if not online_softmax else 1,
24
- reduction_dtype=cutlass.Float32 if not online_softmax else cutlass.Int64,
30
+ reduction_dtype=Float32 if not online_softmax else Int64,
25
31
  )
26
32
  self.online_softmax = online_softmax
27
33
 
28
- def _calculate_threads_per_row(self):
34
+ def _threads_per_row(self):
29
35
  N = self.N
30
- return (
31
- 8
32
- if N <= 64
33
- else (
34
- 16
35
- if N <= 128
36
- else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
37
- )
38
- )
36
+ for limit, threads in [(64, 8), (128, 16), (3072, 32), (6144, 64), (16384, 128)]:
37
+ if N <= limit:
38
+ return threads
39
+ return 256
39
40
 
40
41
  def _set_cluster_n(self):
41
42
  N = self.N
42
- if cutlass.const_expr(self.dtype.width == 16):
43
- cluster_n = (
44
- 1
45
- if N <= 16 * 1024
46
- else (
47
- 2
48
- if N <= 32 * 1024
49
- else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
50
- )
51
- )
52
- else: # fp32
53
- cluster_n = (
54
- 1
55
- if N <= 32 * 1024
56
- else (
57
- 2
58
- if N <= 64 * 1024
59
- else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
60
- )
61
- )
62
- self.cluster_n = cluster_n
43
+ if const_expr(self.dtype.width == 16):
44
+ thresholds = [(16 * 1024, 1), (32 * 1024, 2), (64 * 1024, 4), (128 * 1024, 8)]
45
+ else:
46
+ thresholds = [(32 * 1024, 1), (64 * 1024, 2), (128 * 1024, 4), (256 * 1024, 8)]
47
+ for limit, cluster in thresholds:
48
+ if N <= limit:
49
+ self.cluster_n = cluster
50
+ return
51
+ self.cluster_n = 16
63
52
 
64
53
  @cute.jit
65
54
  def __call__(
@@ -69,16 +58,16 @@ class Softmax(ReductionBase):
69
58
  stream: cuda.CUstream,
70
59
  ):
71
60
  assert mX.element_type == self.dtype
72
- assert mO.element_type == self.dtype
73
61
  self._set_cluster_n()
74
- tiler_mn, tv_layout = self._get_tv_layout()
75
- num_threads = cute.size(tv_layout, mode=[0])
76
- num_warps = num_threads // cute.arch.WARP_SIZE
77
- self.kernel(mX, mO, tv_layout, tiler_mn).launch(
62
+ largest_dtype_width = const_expr(max(t.element_type.width for t in [mX, mO]))
63
+ tiled_copy, tiler_mn, threads_per_row = self._get_tiled_copy(
64
+ vecsize=128 // largest_dtype_width
65
+ )
66
+ num_threads = tiled_copy.size
67
+ self.kernel(mX, mO, tiler_mn, tiled_copy, threads_per_row).launch(
78
68
  grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
79
69
  block=[num_threads, 1, 1],
80
- cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
81
- smem=self._smem_size_in_bytes(tiler_mn, num_warps),
70
+ cluster=[1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None,
82
71
  stream=stream,
83
72
  )
84
73
 
@@ -87,23 +76,20 @@ class Softmax(ReductionBase):
87
76
  self,
88
77
  mX: cute.Tensor,
89
78
  mO: cute.Tensor,
90
- tv_layout: cute.Layout,
91
79
  tiler_mn: cute.Shape,
80
+ tiled_copy: cute.TiledCopy,
81
+ threads_per_row: cutlass.Constexpr[int],
92
82
  ):
83
+ tv_layout = tiled_copy.layout_tv_tiled
84
+
93
85
  tidx, _, _ = cute.arch.thread_idx()
94
86
  bidx, _, _ = cute.arch.block_idx()
95
- if cutlass.const_expr(self.cluster_n > 1):
96
- cluster_y = cute.arch.block_idx()[1]
97
- else:
98
- cluster_y = cutlass.const_expr(0)
87
+ cluster_y = const_expr(0) if const_expr(self.cluster_n == 1) else cute.arch.block_idx()[1]
99
88
 
100
89
  shape = mX.shape
101
90
  idX = cute.make_identity_tensor(shape)
102
91
  # slice for CTAs
103
- # We use domain_offset_i64 to deal with tensors larger than 2^31 elements
104
- mX, mO = [utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mO)]
105
- gX, gO = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mX, mO)]
106
- cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
92
+ gX, gO, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, mO, idX)]
107
93
 
108
94
  smem = cutlass.utils.SmemAllocator()
109
95
  sX = smem.allocate_tensor(
@@ -111,52 +97,45 @@ class Softmax(ReductionBase):
111
97
  )
112
98
  reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
113
99
 
114
- # declare the atoms which will be used later for memory copy
115
- copy_atom_load_X = cute.make_copy_atom(
116
- cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128
117
- )
118
- copy_atom_store_O = cute.make_copy_atom(
119
- cute.nvgpu.CopyUniversalOp(), gO.element_type, num_bits_per_copy=128
120
- )
121
-
122
- thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
123
- thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
100
+ thr_copy_X = tiled_copy.get_slice(tidx)
124
101
 
125
102
  tXgX = thr_copy_X.partition_S(gX)
126
103
  tXsX = thr_copy_X.partition_D(sX)
127
- tXgO = thr_copy_O.partition_D(gO)
104
+ tXgO = thr_copy_X.partition_D(gO)
128
105
  tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
129
-
130
- # allocate fragments for gmem->rmem
131
106
  tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
132
107
 
133
- num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
134
- self._initialize_cluster(tidx, mbar_ptr, num_warps)
135
-
136
- is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
108
+ is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
137
109
  tXpX = (
138
- utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
110
+ None
111
+ if is_even_N
112
+ else copy_utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
139
113
  )
114
+ # Each copy will use the same predicate
115
+ copy = partial(copy_utils.copy, pred=tXpX)
116
+
117
+ num_warps = cute.size(tiled_copy) // cute.arch.WARP_SIZE
118
+ self._initialize_cluster(tidx, mbar_ptr, num_warps)
119
+
140
120
  if tXcX[0][0] < shape[0]:
141
- cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
121
+ copy(tXgX, tXsX, is_async=True)
142
122
  cute.arch.cp_async_commit_group()
143
123
  cute.arch.cp_async_wait_group(0)
144
124
  # Fill OOB values with -inf
145
- if cutlass.const_expr(not is_even_N):
125
+ if const_expr(not is_even_N):
146
126
  utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
147
127
 
148
128
  cute.autovec_copy(tXsX, tXrX)
149
129
  x = tXrX.load().to(cute.Float32)
150
- threads_per_row = tv_layout.shape[0][0]
151
- if cutlass.const_expr(not self.online_softmax):
130
+ if const_expr(not self.online_softmax):
152
131
  max_x = row_reduce(
153
132
  x,
154
133
  cute.ReductionOp.MAX,
155
134
  threads_per_row,
156
135
  reduction_buffer[None, None, 0],
157
- mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
158
- init_val=-cutlass.Float32.inf,
159
- hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
136
+ mbar_ptr + 0 if const_expr(self.cluster_n > 1) else None,
137
+ init_val=-Float32.inf,
138
+ hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None,
160
139
  )
161
140
  log2_e = math.log2(math.e)
162
141
  exp_x = cute.math.exp2(x * log2_e - (max_x * log2_e), fastmath=True)
@@ -165,7 +144,7 @@ class Softmax(ReductionBase):
165
144
  cute.ReductionOp.ADD,
166
145
  threads_per_row,
167
146
  reduction_buffer[None, None, 1],
168
- mbar_ptr + 1 if cutlass.const_expr(self.cluster_n > 1) else None,
147
+ mbar_ptr + 1 if const_expr(self.cluster_n > 1) else None,
169
148
  init_val=0.0,
170
149
  )
171
150
  else:
@@ -174,18 +153,14 @@ class Softmax(ReductionBase):
174
153
  threads_per_row,
175
154
  reduction_buffer[None, None, 0],
176
155
  mbar_ptr,
177
- hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
156
+ hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None,
178
157
  return_exp_x=True,
179
158
  )
180
- y = exp_x * (1.0 / denom)
159
+ # y = exp_x * (1.0 / denom)
160
+ y = exp_x * cute.arch.rcp_approx(denom)
181
161
  tXrO.store(y.to(tXrO.element_type))
182
- tOpO = (
183
- utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1])
184
- if cutlass.const_expr(not is_even_N)
185
- else None
186
- )
187
162
  if tXcX[0][0] < shape[0]:
188
- cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
163
+ copy(tXrO, tXgO)
189
164
 
190
165
 
191
166
  @torch.library.custom_op("quack::_softmax_fwd", mutates_args={"out"})
@@ -200,21 +175,21 @@ def _softmax_fwd(x: torch.Tensor, out: torch.Tensor) -> None:
200
175
  assert x.is_cuda, "Tensor must be on CUDA device"
201
176
  assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
202
177
  N = x.size(1)
203
- dtype = torch2cute_dtype_map[x.dtype]
204
- convert_from_dlpack = lambda tensor: (
205
- from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
206
- mode=0, stride_order=(0, 1)
207
- )
208
- )
209
- x_tensor, out_tensor = [convert_from_dlpack(tensor) for tensor in (x, out)]
210
- current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
211
- compile_key = (dtype, N)
178
+ dtype, out_dtype = [torch2cute_dtype_map[t.dtype] for t in [x, out]]
179
+ compile_key = (dtype, out_dtype, N)
212
180
  if compile_key not in _softmax_fwd.compile_cache:
181
+ batch_sym = cute.sym_int()
182
+ div = math.gcd(128 // dtype.width, N)
183
+ x_cute, out_cute = [fake_tensor(dt, (batch_sym, N), div) for dt in [dtype, out_dtype]]
213
184
  softmax_op = Softmax(dtype, N)
214
185
  _softmax_fwd.compile_cache[compile_key] = cute.compile(
215
- softmax_op, x_tensor, out_tensor, current_stream
186
+ softmax_op,
187
+ x_cute,
188
+ out_cute,
189
+ cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
190
+ options="--enable-tvm-ffi",
216
191
  )
217
- _softmax_fwd.compile_cache[compile_key](x_tensor, out_tensor, current_stream)
192
+ _softmax_fwd.compile_cache[compile_key](x, out)
218
193
 
219
194
 
220
195
  _softmax_fwd.compile_cache = {}
@@ -229,55 +204,30 @@ def softmax_fwd(x: torch.Tensor) -> torch.Tensor:
229
204
  class SoftmaxBackward(ReductionBase):
230
205
  def __init__(self, dtype: Type[cutlass.Numeric], N: int):
231
206
  # 1 stage for computing dot product
232
- super().__init__(dtype, N, stage=1, reduction_dtype=cutlass.Float32)
207
+ super().__init__(dtype, N, stage=1, reduction_dtype=Float32)
233
208
 
234
- def _calculate_threads_per_row(self):
209
+ def _threads_per_row(self):
235
210
  N = self.N
236
- return (
237
- 8
238
- if N <= 64
239
- else (
240
- 16
241
- if N <= 128
242
- else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 8192 else 256)))
243
- )
244
- )
211
+ for limit, threads in [(64, 8), (128, 16), (3072, 32), (6144, 64), (8192, 128)]:
212
+ if N <= limit:
213
+ return threads
214
+ return 256
245
215
 
246
216
  def _set_cluster_n(self):
247
217
  N = self.N
248
- if cutlass.const_expr(self.dtype.width == 16):
249
- cluster_n = (
250
- 1
251
- if N <= 16 * 1024
252
- else (
253
- 2
254
- if N <= 32 * 1024
255
- else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
256
- )
257
- )
258
- else: # fp32
259
- cluster_n = (
260
- 1
261
- if N <= 16 * 1024
262
- else (
263
- 2
264
- if N <= 32 * 1024
265
- else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
266
- )
267
- )
268
- self.cluster_n = cluster_n
269
-
270
- def _get_num_threads(self):
218
+ if const_expr(self.dtype.width == 16):
219
+ thresholds = [(16 * 1024, 1), (32 * 1024, 2), (64 * 1024, 4), (128 * 1024, 8)]
220
+ else:
221
+ thresholds = [(16 * 1024, 1), (32 * 1024, 2), (64 * 1024, 4), (128 * 1024, 8)]
222
+ for limit, cluster in thresholds:
223
+ if N <= limit:
224
+ self.cluster_n = cluster
225
+ return
226
+ self.cluster_n = 16
227
+
228
+ def _num_threads(self):
271
229
  return 128 if self.N <= 8192 else 256
272
230
 
273
- def _smem_size_in_bytes(self, tiler_mn, num_warps):
274
- return (
275
- # Multiply by 2 since we need space for Y and dY
276
- cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * 2
277
- + self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
278
- + self.stage * (cutlass.Int64.width // 8)
279
- )
280
-
281
231
  @cute.jit
282
232
  def __call__(
283
233
  self,
@@ -287,17 +237,16 @@ class SoftmaxBackward(ReductionBase):
287
237
  stream: cuda.CUstream,
288
238
  ):
289
239
  assert mdY.element_type == self.dtype
290
- assert mY.element_type == self.dtype
291
- assert mdX.element_type == self.dtype
292
240
  self._set_cluster_n()
293
- tiler_mn, tv_layout = self._get_tv_layout()
294
- num_threads = cute.size(tv_layout, mode=[0])
295
- num_warps = num_threads // cute.arch.WARP_SIZE
296
- self.kernel(mdY, mY, mdX, tv_layout, tiler_mn).launch(
241
+ largest_dtype_width = const_expr(max(t.element_type.width for t in [mdY, mY, mdX]))
242
+ tiled_copy, tiler_mn, threads_per_row = self._get_tiled_copy(
243
+ vecsize=128 // largest_dtype_width
244
+ )
245
+ num_threads = tiled_copy.size
246
+ self.kernel(mdY, mY, mdX, tiler_mn, tiled_copy, threads_per_row).launch(
297
247
  grid=[cute.ceil_div(mdY.shape[0], tiler_mn[0]), self.cluster_n, 1],
298
248
  block=[num_threads, 1, 1],
299
- cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
300
- smem=self._smem_size_in_bytes(tiler_mn, num_warps),
249
+ cluster=[1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None,
301
250
  stream=stream,
302
251
  )
303
252
 
@@ -307,24 +256,21 @@ class SoftmaxBackward(ReductionBase):
307
256
  mdY: cute.Tensor,
308
257
  mY: cute.Tensor,
309
258
  mdX: cute.Tensor,
310
- tv_layout: cute.Layout,
311
259
  tiler_mn: cute.Shape,
260
+ tiled_copy: cute.TiledCopy,
261
+ threads_per_row: cutlass.Constexpr[int],
312
262
  ):
313
263
  tidx, _, _ = cute.arch.thread_idx()
314
264
  bidx, _, _ = cute.arch.block_idx()
315
- if cutlass.const_expr(self.cluster_n > 1):
316
- cluster_y = cute.arch.block_idx()[1]
317
- else:
318
- cluster_y = cutlass.const_expr(0)
265
+ cluster_y = const_expr(0) if const_expr(self.cluster_n == 1) else cute.arch.block_idx()[1]
266
+ tv_layout = tiled_copy.layout_tv_tiled
319
267
 
320
268
  shape = mdY.shape
321
269
  idX = cute.make_identity_tensor(shape)
322
270
  # slice for CTAs
323
- mdY, mY, mdX = [
324
- utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mdY, mY, mdX)
271
+ gdY, gY, gdX, cX = [
272
+ cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mdY, mY, mdX, idX)
325
273
  ]
326
- gdY, gY, gdX = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mdY, mY, mdX)]
327
- cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
328
274
 
329
275
  smem = cutlass.utils.SmemAllocator()
330
276
  sdY = smem.allocate_tensor(
@@ -335,42 +281,32 @@ class SoftmaxBackward(ReductionBase):
335
281
  )
336
282
  reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
337
283
 
338
- # declare the atoms which will be used later for memory copy
339
- copy_atom_load = cute.make_copy_atom(
340
- cute.nvgpu.cpasync.CopyG2SOp(), mdY.element_type, num_bits_per_copy=128
341
- )
342
- copy_atom_store = cute.make_copy_atom(
343
- cute.nvgpu.CopyUniversalOp(), gdX.element_type, num_bits_per_copy=128
344
- )
284
+ thr_copy = tiled_copy.get_slice(tidx)
345
285
 
346
- thr_copy_load = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn).get_slice(tidx)
347
- thr_copy_store = cute.make_tiled_copy(copy_atom_store, tv_layout, tiler_mn).get_slice(tidx)
348
-
349
- tdYgdY = thr_copy_load.partition_S(gdY)
350
- tdYsdY = thr_copy_load.partition_D(sdY)
351
- tYgY = thr_copy_load.partition_S(gY)
352
- tYsY = thr_copy_load.partition_D(sY)
353
- tdXgdX = thr_copy_store.partition_D(gdX)
354
- tXcX = thr_copy_load.partition_S(cX)[(0, None), None, None]
355
-
356
- # allocate fragments for gmem->rmem
286
+ tdYgdY = thr_copy.partition_S(gdY)
287
+ tdYsdY = thr_copy.partition_D(sdY)
288
+ tYgY = thr_copy.partition_S(gY)
289
+ tYsY = thr_copy.partition_D(sY)
290
+ tdXgdX = thr_copy.partition_D(gdX)
291
+ tXcX = thr_copy.partition_S(cX)[(0, None), None, None]
357
292
  tdYrdY, tYrY, tdXrdX = [cute.make_fragment_like(thr) for thr in (tdYgdY, tYgY, tdXgdX)]
358
293
 
359
- num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
360
- self._initialize_cluster(tidx, mbar_ptr, num_warps)
361
-
362
- is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
363
- tdYpdY = (
364
- utils.predicate_k(thr_copy_load.partition_S(cX), limit=shape[1])
365
- if cutlass.const_expr(not is_even_N)
366
- else None
294
+ is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
295
+ tXpX = (
296
+ None if is_even_N else copy_utils.predicate_k(thr_copy.partition_S(cX), limit=shape[1])
367
297
  )
298
+ # Each copy will use the same predicate
299
+ copy = partial(copy_utils.copy, pred=tXpX)
300
+
301
+ num_warps = cute.size(tiled_copy) // cute.arch.WARP_SIZE
302
+ self._initialize_cluster(tidx, mbar_ptr, num_warps)
368
303
 
369
304
  if tXcX[0][0] < shape[0]:
370
- cute.copy(copy_atom_load, tdYgdY, tdYsdY, pred=tdYpdY)
371
- cute.copy(copy_atom_load, tYgY, tYsY, pred=tdYpdY)
305
+ copy(tdYgdY, tdYsdY, is_async=True)
306
+ copy(tYgY, tYsY, is_async=True)
372
307
  cute.arch.cp_async_commit_group()
373
308
  cute.arch.cp_async_wait_group(0)
309
+ # Don't need fill_oob since cp.async will automatically fills OOB elements with zeros
374
310
 
375
311
  cute.autovec_copy(tdYsdY, tdYrdY)
376
312
  cute.autovec_copy(tYsY, tYrY)
@@ -378,27 +314,21 @@ class SoftmaxBackward(ReductionBase):
378
314
  y = tYrY.load().to(cute.Float32)
379
315
 
380
316
  # Compute dot product: dot = Σⱼ dy_j × y_j
381
- threads_per_row = tv_layout.shape[0][0]
382
317
  dot = row_reduce(
383
318
  dy * y,
384
319
  cute.ReductionOp.ADD,
385
320
  threads_per_row,
386
321
  reduction_buffer[None, None, 0],
387
- mbar_ptr if cutlass.const_expr(self.cluster_n > 1) else None,
322
+ mbar_ptr if const_expr(self.cluster_n > 1) else None,
388
323
  init_val=0.0,
389
- hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
324
+ hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None,
390
325
  )
391
326
 
392
327
  # Compute gradient: dx_i = y_i × (dy_i - dot)
393
328
  dx = y * (dy - dot)
394
329
  tdXrdX.store(dx.to(tdXrdX.element_type))
395
- tdXpdX = (
396
- utils.predicate_k(thr_copy_store.partition_S(cX), limit=shape[1])
397
- if cutlass.const_expr(not is_even_N)
398
- else None
399
- )
400
330
  if tXcX[0][0] < shape[0]:
401
- cute.copy(copy_atom_store, tdXrdX, tdXgdX, pred=tdXpdX)
331
+ copy(tdXrdX, tdXgdX)
402
332
 
403
333
 
404
334
  @torch.library.custom_op("quack::_softmax_backward", mutates_args={"dx"})
@@ -418,22 +348,24 @@ def _softmax_backward(dy: torch.Tensor, y: torch.Tensor, dx: torch.Tensor) -> No
418
348
  assert y.dtype == dy.dtype, "dy and y must have same dtype"
419
349
 
420
350
  N = dy.size(1)
421
- dtype = torch2cute_dtype_map[dy.dtype]
422
- convert_from_dlpack = lambda tensor: (
423
- from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
424
- mode=0, stride_order=(0, 1)
425
- )
426
- )
427
- dy_tensor, y_tensor, dx_tensor = [convert_from_dlpack(tensor) for tensor in (dy, y, dx)]
428
- current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
429
-
430
- compile_key = (dtype, N)
351
+ dtype, y_dtype, dx_dtype = [torch2cute_dtype_map[t.dtype] for t in [dy, y, dx]]
352
+ compile_key = (dtype, y_dtype, dx_dtype, N)
431
353
  if compile_key not in _softmax_backward.compile_cache:
354
+ batch_sym = cute.sym_int()
355
+ div = math.gcd(128 // dtype.width, N)
356
+ dy_cute, y_cute, dx_cute = [
357
+ fake_tensor(dt, (batch_sym, N), div) for dt in [dtype, y_dtype, dx_dtype]
358
+ ]
432
359
  softmax_backward_op = SoftmaxBackward(dtype, N)
433
360
  _softmax_backward.compile_cache[compile_key] = cute.compile(
434
- softmax_backward_op, dy_tensor, y_tensor, dx_tensor, current_stream
361
+ softmax_backward_op,
362
+ dy_cute,
363
+ y_cute,
364
+ dx_cute,
365
+ cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
366
+ options="--enable-tvm-ffi",
435
367
  )
436
- _softmax_backward.compile_cache[compile_key](dy_tensor, y_tensor, dx_tensor, current_stream)
368
+ _softmax_backward.compile_cache[compile_key](dy, y, dx)
437
369
 
438
370
 
439
371
  _softmax_backward.compile_cache = {}
@@ -5,6 +5,7 @@ from typing import Optional
5
5
 
6
6
  import cutlass
7
7
  import cutlass.cute as cute
8
+ from cutlass import Int32, Float32, const_expr
8
9
 
9
10
  import quack.utils as utils
10
11
  from quack.sort.utils import compare_and_swap
@@ -14,12 +15,14 @@ from quack.sort.sorting_networks import optimal_sort
14
15
  @cute.jit
15
16
  def bitonic_merge(
16
17
  arr: cute.Tensor,
17
- n: cutlass.Constexpr[int],
18
- start: cutlass.Constexpr[int],
18
+ n: Optional[cutlass.Constexpr[int]] = None,
19
+ start: cutlass.Constexpr[int] = 0,
19
20
  ascending: cutlass.Constexpr[bool] = True,
20
21
  ) -> None:
21
22
  """Merge a bitonic sequence into a sorted sequence using iterative approach."""
22
- if cutlass.const_expr(n > 1):
23
+ if const_expr(n is None):
24
+ n = cute.size(arr.shape)
25
+ if const_expr(n > 1):
23
26
  num_levels = int(math.log2(n))
24
27
  assert n == 2**num_levels, "n must be a power of 2"
25
28
  # This one must be range_constexpr otherwise it's very slow for n = 128
@@ -48,11 +51,11 @@ def bitonic_sort(
48
51
  start: Starting index (default 0)
49
52
  ascending: Sort in ascending order (default True)
50
53
  """
51
- if cutlass.const_expr(n is None):
54
+ if const_expr(n is None):
52
55
  n = cute.size(arr.shape)
53
56
  assert n <= 128
54
- if cutlass.const_expr(n > 1):
55
- if cutlass.const_expr(n in [2, 4, 8, 16, 32, 64]):
57
+ if const_expr(n > 1):
58
+ if const_expr(n in [2, 4, 8, 16, 32, 64]):
56
59
  optimal_sort(arr, n, start, ascending)
57
60
  else: # Fall back to bitonic sort
58
61
  assert n % 2 == 0
@@ -73,9 +76,9 @@ def bitonic_topk_merge(
73
76
  start1: cutlass.Constexpr[int] = 0,
74
77
  ascending: cutlass.Constexpr[bool] = False,
75
78
  ) -> None:
76
- if cutlass.const_expr(k is None):
79
+ if const_expr(k is None):
77
80
  k = cute.size(arr0.shape)
78
- if cutlass.const_expr(arr0.element_type == cutlass.Float32):
81
+ if const_expr(arr0.element_type == Float32):
79
82
  minmax_fn = utils.fmin if ascending else cute.arch.fmax
80
83
  else:
81
84
  minmax_fn = min if ascending else max
@@ -101,7 +104,7 @@ def bitonic_topk(
101
104
  k: must be power of 2 and <= 128
102
105
  ascending: Sort in ascending order (default False)
103
106
  """
104
- assert arr.element_type in [cutlass.Float32, cutlass.Int32]
107
+ assert arr.element_type in [Float32, Int32]
105
108
  n = cute.size(arr.shape)
106
109
  assert k == 1 << int(math.log2(k)), "k must be a power of 2"
107
110
  assert n % k == 0, "n must be divisible by k"
@@ -109,8 +112,8 @@ def bitonic_topk(
109
112
  for v in cutlass.range(k, unroll_full=True):
110
113
  topk_vals[v] = arr[v]
111
114
  bitonic_sort(topk_vals, ascending=ascending)
112
- other_vals = cute.make_fragment(k, arr.element_type)
113
115
  for i in cutlass.range(1, n // k, unroll_full=True):
116
+ other_vals = cute.make_fragment(k, arr.element_type)
114
117
  for v in cutlass.range(k, unroll_full=True):
115
118
  other_vals[v] = arr[i * k + v]
116
119
  bitonic_sort(other_vals, ascending=ascending)
quack/sort/utils.py CHANGED
@@ -1,5 +1,5 @@
1
- import cutlass
2
1
  import cutlass.cute as cute
2
+ from cutlass import Float32, const_expr
3
3
 
4
4
  import quack.utils as utils
5
5
 
@@ -9,12 +9,12 @@ def compare_and_swap(
9
9
  arr: cute.Tensor, i: int, j: int, ascending: bool = True, use_selection: bool = False
10
10
  ) -> None:
11
11
  """Compare and swap elements at indices i and j in ascending or descending order."""
12
- if cutlass.const_expr(use_selection):
12
+ if const_expr(use_selection):
13
13
  a, b = arr[i], arr[j]
14
14
  if (a > b) ^ (not ascending):
15
15
  arr[i] = b
16
16
  arr[j] = a
17
- # if cutlass.const_expr(ascending):
17
+ # if const_expr(ascending):
18
18
  # if a > b:
19
19
  # arr[i] = b
20
20
  # arr[j] = a
@@ -23,9 +23,9 @@ def compare_and_swap(
23
23
  # arr[i] = b
24
24
  # arr[j] = a
25
25
  else:
26
- min_fn = min if cutlass.const_expr(arr.element_type != cutlass.Float32) else utils.fmin
27
- max_fn = max if cutlass.const_expr(arr.element_type != cutlass.Float32) else cute.arch.fmax
28
- if cutlass.const_expr(ascending):
26
+ min_fn = min if const_expr(arr.element_type != Float32) else utils.fmin
27
+ max_fn = max if const_expr(arr.element_type != Float32) else cute.arch.fmax
28
+ if const_expr(ascending):
29
29
  arr[i], arr[j] = min_fn(arr[i], arr[j]), max_fn(arr[i], arr[j])
30
30
  else:
31
31
  arr[i], arr[j] = max_fn(arr[i], arr[j]), min_fn(arr[i], arr[j])