quack-kernels 0.1.11__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +7 -3
- quack/activation.py +279 -0
- quack/autotuner.py +2 -1
- quack/cross_entropy.py +330 -184
- quack/cute_dsl_utils.py +83 -4
- quack/dense_gemm_sm100.py +1 -1
- quack/dense_gemm_sm90.py +911 -1140
- quack/fast_math.py +10 -27
- quack/gemm_act_sm90.py +368 -0
- quack/gemm_config.py +43 -35
- quack/gemm_dact_sm90.py +150 -0
- quack/gemm_interface.py +491 -243
- quack/gemm_wrapper_utils.py +158 -0
- quack/layernorm.py +6 -4
- quack/linear.py +128 -64
- quack/linear_cross_entropy.py +275 -0
- quack/mlp.py +30 -160
- quack/pipeline.py +2 -17
- quack/reduce.py +240 -0
- quack/reduction_base.py +2 -11
- quack/rmsnorm.py +614 -228
- quack/softmax.py +28 -16
- quack/symmetric_dense_gemm_sm90.py +6 -3
- quack/tensormap_manager.py +1 -0
- quack/tile_scheduler.py +64 -61
- quack/topk.py +14 -8
- quack/utils.py +14 -322
- quack/varlen_utils.py +22 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/METADATA +3 -3
- quack_kernels-0.2.1.dist-info/RECORD +37 -0
- quack/lse.py +0 -62
- quack_kernels-0.1.11.dist-info/RECORD +0 -31
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/top_level.txt +0 -0
quack/softmax.py
CHANGED
|
@@ -9,7 +9,9 @@ import cutlass.cute as cute
|
|
|
9
9
|
from cutlass.cute.runtime import from_dlpack
|
|
10
10
|
|
|
11
11
|
import quack.utils as utils
|
|
12
|
-
from quack.
|
|
12
|
+
from quack.reduce import row_reduce, online_softmax_reduce
|
|
13
|
+
from quack.reduction_base import ReductionBase
|
|
14
|
+
from quack.cute_dsl_utils import torch2cute_dtype_map
|
|
13
15
|
|
|
14
16
|
|
|
15
17
|
class Softmax(ReductionBase):
|
|
@@ -147,7 +149,7 @@ class Softmax(ReductionBase):
|
|
|
147
149
|
x = tXrX.load().to(cute.Float32)
|
|
148
150
|
threads_per_row = tv_layout.shape[0][0]
|
|
149
151
|
if cutlass.const_expr(not self.online_softmax):
|
|
150
|
-
max_x =
|
|
152
|
+
max_x = row_reduce(
|
|
151
153
|
x,
|
|
152
154
|
cute.ReductionOp.MAX,
|
|
153
155
|
threads_per_row,
|
|
@@ -157,8 +159,8 @@ class Softmax(ReductionBase):
|
|
|
157
159
|
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
158
160
|
)
|
|
159
161
|
log2_e = math.log2(math.e)
|
|
160
|
-
exp_x = cute.math.exp2(
|
|
161
|
-
denom =
|
|
162
|
+
exp_x = cute.math.exp2(x * log2_e - (max_x * log2_e), fastmath=True)
|
|
163
|
+
denom = row_reduce(
|
|
162
164
|
exp_x,
|
|
163
165
|
cute.ReductionOp.ADD,
|
|
164
166
|
threads_per_row,
|
|
@@ -167,7 +169,7 @@ class Softmax(ReductionBase):
|
|
|
167
169
|
init_val=0.0,
|
|
168
170
|
)
|
|
169
171
|
else:
|
|
170
|
-
max_x, denom, exp_x =
|
|
172
|
+
max_x, denom, exp_x = online_softmax_reduce(
|
|
171
173
|
x,
|
|
172
174
|
threads_per_row,
|
|
173
175
|
reduction_buffer[None, None, 0],
|
|
@@ -186,7 +188,8 @@ class Softmax(ReductionBase):
|
|
|
186
188
|
cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
|
|
187
189
|
|
|
188
190
|
|
|
189
|
-
|
|
191
|
+
@torch.library.custom_op("quack::_softmax_fwd", mutates_args={"out"})
|
|
192
|
+
def _softmax_fwd(x: torch.Tensor, out: torch.Tensor) -> None:
|
|
190
193
|
"""Softmax forward pass.
|
|
191
194
|
Args:
|
|
192
195
|
x: Input tensor of shape (M, N)
|
|
@@ -196,8 +199,7 @@ def _softmax_fwd(x: torch.Tensor) -> torch.Tensor:
|
|
|
196
199
|
assert x.dim() == 2, "Input must be 2D"
|
|
197
200
|
assert x.is_cuda, "Tensor must be on CUDA device"
|
|
198
201
|
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
|
|
199
|
-
|
|
200
|
-
out = torch.empty_like(x)
|
|
202
|
+
N = x.size(1)
|
|
201
203
|
dtype = torch2cute_dtype_map[x.dtype]
|
|
202
204
|
convert_from_dlpack = lambda tensor: (
|
|
203
205
|
from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
@@ -213,12 +215,17 @@ def _softmax_fwd(x: torch.Tensor) -> torch.Tensor:
|
|
|
213
215
|
softmax_op, x_tensor, out_tensor, current_stream
|
|
214
216
|
)
|
|
215
217
|
_softmax_fwd.compile_cache[compile_key](x_tensor, out_tensor, current_stream)
|
|
216
|
-
return out
|
|
217
218
|
|
|
218
219
|
|
|
219
220
|
_softmax_fwd.compile_cache = {}
|
|
220
221
|
|
|
221
222
|
|
|
223
|
+
def softmax_fwd(x: torch.Tensor) -> torch.Tensor:
|
|
224
|
+
out = torch.empty_like(x)
|
|
225
|
+
_softmax_fwd(x, out)
|
|
226
|
+
return out
|
|
227
|
+
|
|
228
|
+
|
|
222
229
|
class SoftmaxBackward(ReductionBase):
|
|
223
230
|
def __init__(self, dtype: Type[cutlass.Numeric], N: int):
|
|
224
231
|
# 1 stage for computing dot product
|
|
@@ -372,7 +379,7 @@ class SoftmaxBackward(ReductionBase):
|
|
|
372
379
|
|
|
373
380
|
# Compute dot product: dot = Σⱼ dy_j × y_j
|
|
374
381
|
threads_per_row = tv_layout.shape[0][0]
|
|
375
|
-
dot =
|
|
382
|
+
dot = row_reduce(
|
|
376
383
|
dy * y,
|
|
377
384
|
cute.ReductionOp.ADD,
|
|
378
385
|
threads_per_row,
|
|
@@ -394,7 +401,8 @@ class SoftmaxBackward(ReductionBase):
|
|
|
394
401
|
cute.copy(copy_atom_store, tdXrdX, tdXgdX, pred=tdXpdX)
|
|
395
402
|
|
|
396
403
|
|
|
397
|
-
|
|
404
|
+
@torch.library.custom_op("quack::_softmax_backward", mutates_args={"dx"})
|
|
405
|
+
def _softmax_backward(dy: torch.Tensor, y: torch.Tensor, dx: torch.Tensor) -> None:
|
|
398
406
|
"""Softmax backward pass.
|
|
399
407
|
Args:
|
|
400
408
|
dy: Upstream gradients tensor of shape (M, N)
|
|
@@ -409,8 +417,7 @@ def _softmax_backward(dy: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
|
409
417
|
assert dy.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
|
|
410
418
|
assert y.dtype == dy.dtype, "dy and y must have same dtype"
|
|
411
419
|
|
|
412
|
-
|
|
413
|
-
dx = torch.empty_like(dy)
|
|
420
|
+
N = dy.size(1)
|
|
414
421
|
dtype = torch2cute_dtype_map[dy.dtype]
|
|
415
422
|
convert_from_dlpack = lambda tensor: (
|
|
416
423
|
from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
@@ -427,23 +434,28 @@ def _softmax_backward(dy: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
|
427
434
|
softmax_backward_op, dy_tensor, y_tensor, dx_tensor, current_stream
|
|
428
435
|
)
|
|
429
436
|
_softmax_backward.compile_cache[compile_key](dy_tensor, y_tensor, dx_tensor, current_stream)
|
|
430
|
-
return dx
|
|
431
437
|
|
|
432
438
|
|
|
433
439
|
_softmax_backward.compile_cache = {}
|
|
434
440
|
|
|
435
441
|
|
|
442
|
+
def softmax_bwd(dy: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
443
|
+
dx = torch.empty_like(dy)
|
|
444
|
+
_softmax_backward(dy, y, dx)
|
|
445
|
+
return dx
|
|
446
|
+
|
|
447
|
+
|
|
436
448
|
class SoftmaxFunction(torch.autograd.Function):
|
|
437
449
|
@staticmethod
|
|
438
450
|
def forward(ctx, x):
|
|
439
|
-
y =
|
|
451
|
+
y = softmax_fwd(x)
|
|
440
452
|
ctx.save_for_backward(y)
|
|
441
453
|
return y
|
|
442
454
|
|
|
443
455
|
@staticmethod
|
|
444
456
|
def backward(ctx, dy):
|
|
445
457
|
(y,) = ctx.saved_tensors
|
|
446
|
-
dx =
|
|
458
|
+
dx = softmax_bwd(dy, y)
|
|
447
459
|
return dx
|
|
448
460
|
|
|
449
461
|
|
|
@@ -51,7 +51,7 @@ from quack.tile_scheduler import (
|
|
|
51
51
|
RasterOrderOption,
|
|
52
52
|
TriangularTileScheduler,
|
|
53
53
|
)
|
|
54
|
-
from quack.
|
|
54
|
+
from quack.cute_dsl_utils import torch2cute_dtype_map
|
|
55
55
|
|
|
56
56
|
# return PipelineStateWAdvance instead of PipelineState
|
|
57
57
|
from quack.pipeline import make_pipeline_state
|
|
@@ -907,8 +907,11 @@ class HopperSymmetricGemmKernel:
|
|
|
907
907
|
|
|
908
908
|
acc_shape = tiled_mma.partition_shape_C(cute.select(self.tile_shape_mnk, mode=[0, 1]))
|
|
909
909
|
acc = cute.make_fragment(acc_shape, self.acc_dtype)
|
|
910
|
-
|
|
911
|
-
|
|
910
|
+
acc_slow = (
|
|
911
|
+
cute.make_fragment(acc_shape, self.acc_dtype)
|
|
912
|
+
if const_expr(self.fp8_slow_accum)
|
|
913
|
+
else None
|
|
914
|
+
)
|
|
912
915
|
|
|
913
916
|
if const_expr(self.pingpong):
|
|
914
917
|
if warp_group_idx == 0:
|
quack/tensormap_manager.py
CHANGED
|
@@ -99,6 +99,7 @@ class TensorMapManagerSm90(TensorMapManager):
|
|
|
99
99
|
for gmem_ptr, smem_ptr in zip(tensormap_gmem_ptr, tensormap_smem_ptr):
|
|
100
100
|
cute.nvgpu.cpasync.cp_fence_tma_desc_release(gmem_ptr, smem_ptr)
|
|
101
101
|
else:
|
|
102
|
+
assert len(shapes) == len(orders) == len(tensormap_gmem_ptr)
|
|
102
103
|
for gmem_ptr, shape, order in zip(tensormap_gmem_ptr, shapes, orders):
|
|
103
104
|
gmem_ptr_i64 = gmem_ptr.toint().ir_value()
|
|
104
105
|
llvm.inline_asm(
|
quack/tile_scheduler.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
# Copyright (c) 2025, Tri Dao.
|
|
2
2
|
|
|
3
3
|
from typing import Tuple, Optional
|
|
4
|
-
from dataclasses import dataclass
|
|
4
|
+
from dataclasses import dataclass
|
|
5
5
|
from enum import IntEnum
|
|
6
6
|
|
|
7
7
|
import cutlass
|
|
@@ -11,30 +11,7 @@ from cutlass import Int32, Boolean, const_expr
|
|
|
11
11
|
import quack.utils as utils
|
|
12
12
|
from quack.fast_math import FastDivmod
|
|
13
13
|
from quack.pipeline import PipelineStateWAdvance
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
@dataclass
|
|
17
|
-
class ParamsBase:
|
|
18
|
-
def __extract_mlir_values__(self):
|
|
19
|
-
all_fields = [getattr(self, field.name) for field in fields(self)]
|
|
20
|
-
non_constexpr_fields = [f for f in all_fields if not isinstance(f, cutlass.Constexpr)]
|
|
21
|
-
values, self._values_pos = [], []
|
|
22
|
-
for obj in non_constexpr_fields:
|
|
23
|
-
obj_values = cutlass.extract_mlir_values(obj)
|
|
24
|
-
values += obj_values
|
|
25
|
-
self._values_pos.append(len(obj_values))
|
|
26
|
-
return values
|
|
27
|
-
|
|
28
|
-
def __new_from_mlir_values__(self, values):
|
|
29
|
-
all_fields = {field.name: getattr(self, field.name) for field in fields(self)}
|
|
30
|
-
constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, cutlass.Constexpr)}
|
|
31
|
-
non_constexpr_fields = {
|
|
32
|
-
n: f for n, f in all_fields.items() if not isinstance(f, cutlass.Constexpr)
|
|
33
|
-
}
|
|
34
|
-
for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
|
|
35
|
-
non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
|
|
36
|
-
values = values[n_items:]
|
|
37
|
-
return self.__class__(**non_constexpr_fields, **constexpr_fields)
|
|
14
|
+
from quack.cute_dsl_utils import ArgumentsBase, ParamsBase
|
|
38
15
|
|
|
39
16
|
|
|
40
17
|
class RasterOrderOption(IntEnum):
|
|
@@ -66,13 +43,24 @@ def get_raster_order_from_option(
|
|
|
66
43
|
return raster_order
|
|
67
44
|
|
|
68
45
|
|
|
46
|
+
# Grouping arguments together that should be passed to __call__
|
|
47
|
+
@dataclass
|
|
48
|
+
class TileSchedulerOptions(ArgumentsBase):
|
|
49
|
+
max_active_clusters: Int32
|
|
50
|
+
raster_order: cutlass.Constexpr[RasterOrderOption] = RasterOrderOption.Heuristic
|
|
51
|
+
max_swizzle_size: Int32 = Int32(8)
|
|
52
|
+
tile_count_semaphore: Optional[cute.Pointer] = None
|
|
53
|
+
batch_idx_permute: Optional[cute.Tensor] = None
|
|
54
|
+
|
|
55
|
+
|
|
69
56
|
@dataclass
|
|
70
|
-
class TileSchedulerArguments(
|
|
57
|
+
class TileSchedulerArguments(ArgumentsBase):
|
|
71
58
|
problem_shape_ntile_mnl: cute.Shape
|
|
72
|
-
raster_order: RasterOrderOption
|
|
59
|
+
raster_order: cutlass.Constexpr[RasterOrderOption]
|
|
73
60
|
group_size: Int32
|
|
74
61
|
cluster_shape_mnk: cutlass.Constexpr[cute.Shape]
|
|
75
62
|
tile_count_semaphore: Optional[cute.Pointer] = None
|
|
63
|
+
batch_idx_permute: Optional[cute.Tensor] = None
|
|
76
64
|
is_persistent: cutlass.Constexpr[bool] = False
|
|
77
65
|
|
|
78
66
|
|
|
@@ -87,6 +75,7 @@ class TileScheduler:
|
|
|
87
75
|
group_size_tail_divmod: FastDivmod
|
|
88
76
|
num_clusters_in_group_divmod: FastDivmod
|
|
89
77
|
tile_count_semaphore: Optional[cute.Pointer]
|
|
78
|
+
batch_idx_permute: Optional[cute.Tensor]
|
|
90
79
|
cluster_shape_mn: cutlass.Constexpr[cute.Shape]
|
|
91
80
|
is_persistent: cutlass.Constexpr[bool]
|
|
92
81
|
|
|
@@ -128,6 +117,7 @@ class TileScheduler:
|
|
|
128
117
|
FastDivmod.create(group_size_tail if group_size_tail > 0 else 1),
|
|
129
118
|
FastDivmod.create(num_clusters_in_group),
|
|
130
119
|
args.tile_count_semaphore if const_expr(args.is_persistent) else None,
|
|
120
|
+
args.batch_idx_permute,
|
|
131
121
|
cluster_shape_mn,
|
|
132
122
|
args.is_persistent,
|
|
133
123
|
)
|
|
@@ -256,7 +246,10 @@ class TileScheduler:
|
|
|
256
246
|
bidx_in_cluster = cute.arch.block_in_cluster_idx()
|
|
257
247
|
pid_m = cid_m * params.cluster_shape_mn[0] + bidx_in_cluster[0]
|
|
258
248
|
pid_n = cid_n * params.cluster_shape_mn[1] + bidx_in_cluster[1]
|
|
259
|
-
|
|
249
|
+
batch_idx = (
|
|
250
|
+
bidz if const_expr(params.batch_idx_permute is None) else params.batch_idx_permute[bidz]
|
|
251
|
+
)
|
|
252
|
+
tile_coord_mnkl = (pid_m, pid_n, None, batch_idx)
|
|
260
253
|
if const_expr(not params.is_persistent):
|
|
261
254
|
is_valid = self._num_tiles_executed == 0
|
|
262
255
|
else:
|
|
@@ -267,10 +260,10 @@ class TileScheduler:
|
|
|
267
260
|
return self.get_current_work(loc=loc, ip=ip)
|
|
268
261
|
|
|
269
262
|
@cute.jit
|
|
270
|
-
def fetch_next_work(self, is_scheduler_warp: bool | Boolean, *, loc=None, ip=None):
|
|
263
|
+
def fetch_next_work(self, is_scheduler_warp: bool | Boolean = False, *, loc=None, ip=None):
|
|
271
264
|
"""is_scheduler_warp should only be true for one warp in the whole cluster"""
|
|
272
|
-
|
|
273
|
-
|
|
265
|
+
params = self.params
|
|
266
|
+
if const_expr(params.is_persistent and params.tile_count_semaphore is not None):
|
|
274
267
|
current_work_linear_idx = self._current_work_linear_idx
|
|
275
268
|
if is_scheduler_warp:
|
|
276
269
|
if cute.arch.lane_idx() == 0:
|
|
@@ -283,6 +276,38 @@ class TileScheduler:
|
|
|
283
276
|
current_work_linear_idx = cute.arch.shuffle_sync(current_work_linear_idx, 0)
|
|
284
277
|
self._current_work_linear_idx = current_work_linear_idx
|
|
285
278
|
|
|
279
|
+
# We have to split broadcast_next_work and advance_to_next_work into two functions
|
|
280
|
+
# due to a bug in cute-dsl 4.2: https://github.com/NVIDIA/cutlass/issues/2647
|
|
281
|
+
@cute.jit
|
|
282
|
+
def broadcast_next_work(self, is_scheduler_warp: bool | Boolean = False, *, loc=None, ip=None):
|
|
283
|
+
"""is_scheduler_warp should only be true for one warp in the whole cluster"""
|
|
284
|
+
params = self.params
|
|
285
|
+
if const_expr(params.is_persistent and params.tile_count_semaphore is not None):
|
|
286
|
+
current_work_linear_idx = self._current_work_linear_idx
|
|
287
|
+
if is_scheduler_warp:
|
|
288
|
+
self._scheduler_pipeline.producer_acquire(self._pipeline_state)
|
|
289
|
+
lane_idx = cute.arch.lane_idx()
|
|
290
|
+
if lane_idx < cute.size(params.cluster_shape_mn):
|
|
291
|
+
# cute.printf("Producer bidx = {}, tidx = {}, after empty wait, idx = {}", bidx, tidx, current_work_linear_idx)
|
|
292
|
+
if const_expr(cute.size(params.cluster_shape_mn) == 1):
|
|
293
|
+
self._tile_count[self._pipeline_state.index] = current_work_linear_idx
|
|
294
|
+
self._scheduler_pipeline.producer_commit(self._pipeline_state)
|
|
295
|
+
else:
|
|
296
|
+
peer_cta_rank_in_cluster = lane_idx
|
|
297
|
+
mbar_ptr = self._scheduler_pipeline.producer_get_barrier(
|
|
298
|
+
self._pipeline_state
|
|
299
|
+
)
|
|
300
|
+
cute.arch.mbarrier_arrive_and_expect_tx(
|
|
301
|
+
mbar_ptr, 4, peer_cta_rank_in_cluster
|
|
302
|
+
)
|
|
303
|
+
utils.store_shared_remote(
|
|
304
|
+
val=current_work_linear_idx,
|
|
305
|
+
smem_ptr=self._tile_count.iterator + self._pipeline_state.index,
|
|
306
|
+
mbar_ptr=mbar_ptr,
|
|
307
|
+
peer_cta_rank_in_cluster=peer_cta_rank_in_cluster,
|
|
308
|
+
)
|
|
309
|
+
# cute.printf("Producer bidx = {}, tidx = {}, after full arrive", bidx, tidx)
|
|
310
|
+
|
|
286
311
|
@cute.jit
|
|
287
312
|
def advance_to_next_work(
|
|
288
313
|
self,
|
|
@@ -300,32 +325,10 @@ class TileScheduler:
|
|
|
300
325
|
if const_expr(params.tile_count_semaphore is None): # Static persistent
|
|
301
326
|
self._current_work_linear_idx += advance_count * Int32(num_persistent_clusters)
|
|
302
327
|
else: # Dynamic persistent
|
|
303
|
-
|
|
328
|
+
if const_expr(advance_count > 1):
|
|
329
|
+
self._pipeline_state.advance_iters(advance_count - 1)
|
|
304
330
|
current_work_linear_idx = self._current_work_linear_idx
|
|
305
|
-
if is_scheduler_warp:
|
|
306
|
-
self._scheduler_pipeline.producer_acquire(self._pipeline_state)
|
|
307
|
-
lane_idx = cute.arch.lane_idx()
|
|
308
|
-
if lane_idx < cute.size(params.cluster_shape_mn):
|
|
309
|
-
# cute.printf("Producer bidx = {}, tidx = {}, after empty wait, idx = {}", bidx, tidx, current_work_linear_idx)
|
|
310
|
-
if const_expr(cute.size(params.cluster_shape_mn) == 1):
|
|
311
|
-
self._tile_count[self._pipeline_state.index] = current_work_linear_idx
|
|
312
|
-
self._scheduler_pipeline.producer_commit(self._pipeline_state)
|
|
313
|
-
else:
|
|
314
|
-
peer_cta_rank_in_cluster = lane_idx
|
|
315
|
-
mbar_ptr = self._scheduler_pipeline.producer_get_barrier(
|
|
316
|
-
self._pipeline_state
|
|
317
|
-
)
|
|
318
|
-
cute.arch.mbarrier_arrive_and_expect_tx(
|
|
319
|
-
mbar_ptr, 4, peer_cta_rank_in_cluster
|
|
320
|
-
)
|
|
321
|
-
utils.store_shared_remote(
|
|
322
|
-
val=current_work_linear_idx,
|
|
323
|
-
smem_ptr=self._tile_count.iterator + self._pipeline_state.index,
|
|
324
|
-
mbar_ptr=mbar_ptr,
|
|
325
|
-
peer_cta_rank_in_cluster=peer_cta_rank_in_cluster,
|
|
326
|
-
)
|
|
327
|
-
# cute.printf("Producer bidx = {}, tidx = {}, after full arrive", bidx, tidx)
|
|
328
|
-
else:
|
|
331
|
+
if not is_scheduler_warp:
|
|
329
332
|
# if tidx % 64 == 0: cute.printf("bidx = {},tidx = {}, before full wait, idx = {}", bidx, tidx, current_work_linear_idx)
|
|
330
333
|
self._scheduler_pipeline.consumer_wait(self._pipeline_state)
|
|
331
334
|
# if tidx % 64 == 0: cute.printf("bidx = {}, tidx = {}, after full wait, idx = {}", bidx, tidx, current_work_linear_idx)
|
|
@@ -387,7 +390,7 @@ def triangular_idx_to_coord(idx: Int32) -> Tuple[Int32, Int32]:
|
|
|
387
390
|
Convert a triangular index to 2D coordinates.
|
|
388
391
|
This is used to convert the linear index to 2D coordinates for triangular matrices.
|
|
389
392
|
"""
|
|
390
|
-
row = utils.ceil((
|
|
393
|
+
row = utils.ceil((cute.math.sqrt(2 * idx + 2.25, fastmath=True) - 0.5)) - 1
|
|
391
394
|
col = idx - (row * (row + 1)) // 2
|
|
392
395
|
return row, col
|
|
393
396
|
|
|
@@ -521,7 +524,8 @@ class TriangularTileScheduler(TileScheduler):
|
|
|
521
524
|
group_size = params.group_size_divmod.divisor
|
|
522
525
|
group_id = (
|
|
523
526
|
utils.ceil(
|
|
524
|
-
(
|
|
527
|
+
(cute.math.sqrt(2 * cluster_id_in_problem + 2.25, fastmath=True) - 0.5)
|
|
528
|
+
* params.group_size_inv_f32
|
|
525
529
|
)
|
|
526
530
|
- 1
|
|
527
531
|
)
|
|
@@ -580,7 +584,7 @@ class VarlenMTileSchedulerArguments(ParamsBase):
|
|
|
580
584
|
cu_seqlens_m: cute.Tensor
|
|
581
585
|
raster_order: cutlass.Constexpr[RasterOrderOption]
|
|
582
586
|
group_size: Int32
|
|
583
|
-
|
|
587
|
+
tile_shape_mn: cutlass.Constexpr[cute.Shape]
|
|
584
588
|
cluster_shape_mnk: cutlass.Constexpr[cute.Shape]
|
|
585
589
|
tile_count_semaphore: Optional[cute.Pointer] = None
|
|
586
590
|
is_persistent: cutlass.Constexpr[bool] = False
|
|
@@ -609,7 +613,6 @@ class VarlenMTileScheduler(TileScheduler):
|
|
|
609
613
|
) -> "VarlenMTileScheduler.Params":
|
|
610
614
|
assert args.cluster_shape_mnk[2] == 1
|
|
611
615
|
cluster_shape_mn = const_expr(cute.select(args.cluster_shape_mnk, mode=[0, 1]))
|
|
612
|
-
tile_shape_mn = const_expr(cute.select(args.tile_shape_mnk, mode=[0, 1]))
|
|
613
616
|
# problem_shape_ntile_mnl[0] will be None for VarlenM
|
|
614
617
|
problem_shape_ntile_mn = cute.select(args.problem_shape_ntile_mnl, mode=[0, 1])
|
|
615
618
|
problem_shape_ncluster_mn = (
|
|
@@ -657,7 +660,7 @@ class VarlenMTileScheduler(TileScheduler):
|
|
|
657
660
|
FastDivmod.create(num_clusters_in_group)
|
|
658
661
|
if num_clusters_in_group is not None
|
|
659
662
|
else None,
|
|
660
|
-
tile_shape_mn,
|
|
663
|
+
args.tile_shape_mn,
|
|
661
664
|
args.tile_count_semaphore if const_expr(args.is_persistent) else None,
|
|
662
665
|
cluster_shape_mn,
|
|
663
666
|
args.is_persistent,
|
quack/topk.py
CHANGED
|
@@ -12,7 +12,7 @@ from cutlass.cute.runtime import from_dlpack
|
|
|
12
12
|
from cutlass import const_expr
|
|
13
13
|
|
|
14
14
|
import quack.utils as utils
|
|
15
|
-
from quack.
|
|
15
|
+
from quack.cute_dsl_utils import torch2cute_dtype_map
|
|
16
16
|
from quack.sort.bitonic_sort import bitonic_topk
|
|
17
17
|
|
|
18
18
|
|
|
@@ -133,6 +133,7 @@ class TopK:
|
|
|
133
133
|
|
|
134
134
|
threads_per_row = tv_layout.shape[0][0]
|
|
135
135
|
topk_vals = bitonic_topk(tXrX_f32, self.k, warp_width=threads_per_row)
|
|
136
|
+
|
|
136
137
|
# Extract indices and clean values
|
|
137
138
|
topk_vals_u32 = cute.recast_tensor(topk_vals, cutlass.Uint32)
|
|
138
139
|
topk_indices = cute.make_fragment(self.k, cutlass.Int32)
|
|
@@ -166,7 +167,8 @@ class TopK:
|
|
|
166
167
|
cute.autovec_copy(topk_indices_store[None, i], mIndices_store[None, i])
|
|
167
168
|
|
|
168
169
|
|
|
169
|
-
|
|
170
|
+
@torch.library.custom_op("quack::_topk_fwd", mutates_args={"values", "indices"})
|
|
171
|
+
def _topk_fwd(x: torch.Tensor, k: int, values: torch.Tensor, indices: torch.Tensor) -> None:
|
|
170
172
|
"""Top-k forward pass.
|
|
171
173
|
Args:
|
|
172
174
|
x: Input tensor of shape (M, N)
|
|
@@ -179,9 +181,7 @@ def _topk_fwd(x: torch.Tensor, k: int):
|
|
|
179
181
|
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
|
|
180
182
|
assert k > 0 and k <= x.shape[1], "k must be positive and <= N"
|
|
181
183
|
|
|
182
|
-
|
|
183
|
-
values = torch.empty((M, k), dtype=x.dtype, device=x.device)
|
|
184
|
-
indices = torch.empty((M, k), dtype=torch.int32, device=x.device)
|
|
184
|
+
N = x.size(1)
|
|
185
185
|
|
|
186
186
|
dtype = torch2cute_dtype_map[x.dtype]
|
|
187
187
|
convert_from_dlpack = lambda tensor: (
|
|
@@ -202,8 +202,6 @@ def _topk_fwd(x: torch.Tensor, k: int):
|
|
|
202
202
|
)
|
|
203
203
|
_topk_fwd.compile_cache[compile_key](x_tensor, values_tensor, indices_tensor, current_stream)
|
|
204
204
|
|
|
205
|
-
return values, indices
|
|
206
|
-
|
|
207
205
|
|
|
208
206
|
_topk_fwd.compile_cache = {}
|
|
209
207
|
|
|
@@ -218,4 +216,12 @@ def topk(x: torch.Tensor, k: int):
|
|
|
218
216
|
Returns:
|
|
219
217
|
Tuple of (values tensor of shape (M, k), indices tensor of shape (M, k))
|
|
220
218
|
"""
|
|
221
|
-
|
|
219
|
+
|
|
220
|
+
M = x.size(0)
|
|
221
|
+
|
|
222
|
+
values = torch.empty((M, k), dtype=x.dtype, device=x.device)
|
|
223
|
+
indices = torch.empty((M, k), dtype=torch.int32, device=x.device)
|
|
224
|
+
|
|
225
|
+
_topk_fwd(x, k, values, indices)
|
|
226
|
+
|
|
227
|
+
return values, indices
|