quack-kernels 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -1
- quack/activation.py +16 -25
- quack/autotuner.py +64 -5
- quack/cross_entropy.py +6 -10
- quack/cute_dsl_utils.py +6 -7
- quack/dense_gemm_sm90.py +582 -287
- quack/gemm_act_sm90.py +70 -29
- quack/gemm_dact_sm90.py +43 -10
- quack/gemm_interface.py +453 -130
- quack/{dense_gemm_sm100.py → gemm_sm100.py} +443 -419
- quack/gemm_wrapper_utils.py +179 -22
- quack/layernorm.py +1 -1
- quack/reduce.py +6 -7
- quack/rmsnorm.py +126 -158
- quack/softmax.py +1 -1
- quack/tile_scheduler.py +37 -49
- quack/utils.py +61 -71
- quack/varlen_utils.py +1 -6
- {quack_kernels-0.2.0.dist-info → quack_kernels-0.2.2.dist-info}/METADATA +3 -3
- quack_kernels-0.2.2.dist-info/RECORD +37 -0
- quack_kernels-0.2.0.dist-info/RECORD +0 -37
- {quack_kernels-0.2.0.dist-info → quack_kernels-0.2.2.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.0.dist-info → quack_kernels-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.0.dist-info → quack_kernels-0.2.2.dist-info}/top_level.txt +0 -0
quack/tile_scheduler.py
CHANGED
|
@@ -135,7 +135,7 @@ class TileScheduler:
|
|
|
135
135
|
ip=None,
|
|
136
136
|
):
|
|
137
137
|
self._current_work_linear_idx = current_work_linear_idx
|
|
138
|
-
self.
|
|
138
|
+
self.num_tiles_executed = num_tiles_executed
|
|
139
139
|
self._tile_count = tile_count
|
|
140
140
|
self._scheduler_pipeline = scheduler_pipeline
|
|
141
141
|
self._pipeline_state = pipeline_state
|
|
@@ -251,7 +251,7 @@ class TileScheduler:
|
|
|
251
251
|
)
|
|
252
252
|
tile_coord_mnkl = (pid_m, pid_n, None, batch_idx)
|
|
253
253
|
if const_expr(not params.is_persistent):
|
|
254
|
-
is_valid = self.
|
|
254
|
+
is_valid = self.num_tiles_executed == 0
|
|
255
255
|
else:
|
|
256
256
|
is_valid = self._current_work_linear_idx < cute.size(params.problem_shape_ncluster_mnl)
|
|
257
257
|
return cutlass.utils.WorkTileInfo(tile_coord_mnkl, is_valid)
|
|
@@ -276,38 +276,6 @@ class TileScheduler:
|
|
|
276
276
|
current_work_linear_idx = cute.arch.shuffle_sync(current_work_linear_idx, 0)
|
|
277
277
|
self._current_work_linear_idx = current_work_linear_idx
|
|
278
278
|
|
|
279
|
-
# We have to split broadcast_next_work and advance_to_next_work into two functions
|
|
280
|
-
# due to a bug in cute-dsl 4.2: https://github.com/NVIDIA/cutlass/issues/2647
|
|
281
|
-
@cute.jit
|
|
282
|
-
def broadcast_next_work(self, is_scheduler_warp: bool | Boolean = False, *, loc=None, ip=None):
|
|
283
|
-
"""is_scheduler_warp should only be true for one warp in the whole cluster"""
|
|
284
|
-
params = self.params
|
|
285
|
-
if const_expr(params.is_persistent and params.tile_count_semaphore is not None):
|
|
286
|
-
current_work_linear_idx = self._current_work_linear_idx
|
|
287
|
-
if is_scheduler_warp:
|
|
288
|
-
self._scheduler_pipeline.producer_acquire(self._pipeline_state)
|
|
289
|
-
lane_idx = cute.arch.lane_idx()
|
|
290
|
-
if lane_idx < cute.size(params.cluster_shape_mn):
|
|
291
|
-
# cute.printf("Producer bidx = {}, tidx = {}, after empty wait, idx = {}", bidx, tidx, current_work_linear_idx)
|
|
292
|
-
if const_expr(cute.size(params.cluster_shape_mn) == 1):
|
|
293
|
-
self._tile_count[self._pipeline_state.index] = current_work_linear_idx
|
|
294
|
-
self._scheduler_pipeline.producer_commit(self._pipeline_state)
|
|
295
|
-
else:
|
|
296
|
-
peer_cta_rank_in_cluster = lane_idx
|
|
297
|
-
mbar_ptr = self._scheduler_pipeline.producer_get_barrier(
|
|
298
|
-
self._pipeline_state
|
|
299
|
-
)
|
|
300
|
-
cute.arch.mbarrier_arrive_and_expect_tx(
|
|
301
|
-
mbar_ptr, 4, peer_cta_rank_in_cluster
|
|
302
|
-
)
|
|
303
|
-
utils.store_shared_remote(
|
|
304
|
-
val=current_work_linear_idx,
|
|
305
|
-
smem_ptr=self._tile_count.iterator + self._pipeline_state.index,
|
|
306
|
-
mbar_ptr=mbar_ptr,
|
|
307
|
-
peer_cta_rank_in_cluster=peer_cta_rank_in_cluster,
|
|
308
|
-
)
|
|
309
|
-
# cute.printf("Producer bidx = {}, tidx = {}, after full arrive", bidx, tidx)
|
|
310
|
-
|
|
311
279
|
@cute.jit
|
|
312
280
|
def advance_to_next_work(
|
|
313
281
|
self,
|
|
@@ -328,7 +296,30 @@ class TileScheduler:
|
|
|
328
296
|
if const_expr(advance_count > 1):
|
|
329
297
|
self._pipeline_state.advance_iters(advance_count - 1)
|
|
330
298
|
current_work_linear_idx = self._current_work_linear_idx
|
|
331
|
-
if
|
|
299
|
+
if is_scheduler_warp:
|
|
300
|
+
self._scheduler_pipeline.producer_acquire(self._pipeline_state)
|
|
301
|
+
lane_idx = cute.arch.lane_idx()
|
|
302
|
+
if lane_idx < cute.size(params.cluster_shape_mn):
|
|
303
|
+
# cute.printf("Producer bidx = {}, tidx = {}, after empty wait, idx = {}", bidx, tidx, current_work_linear_idx)
|
|
304
|
+
if const_expr(cute.size(params.cluster_shape_mn) == 1):
|
|
305
|
+
self._tile_count[self._pipeline_state.index] = current_work_linear_idx
|
|
306
|
+
self._scheduler_pipeline.producer_commit(self._pipeline_state)
|
|
307
|
+
else:
|
|
308
|
+
peer_cta_rank_in_cluster = lane_idx
|
|
309
|
+
mbar_ptr = self._scheduler_pipeline.producer_get_barrier(
|
|
310
|
+
self._pipeline_state
|
|
311
|
+
)
|
|
312
|
+
cute.arch.mbarrier_arrive_and_expect_tx(
|
|
313
|
+
mbar_ptr, 4, peer_cta_rank_in_cluster
|
|
314
|
+
)
|
|
315
|
+
utils.store_shared_remote(
|
|
316
|
+
val=current_work_linear_idx,
|
|
317
|
+
smem_ptr=self._tile_count.iterator + self._pipeline_state.index,
|
|
318
|
+
mbar_ptr=mbar_ptr,
|
|
319
|
+
peer_cta_rank_in_cluster=peer_cta_rank_in_cluster,
|
|
320
|
+
)
|
|
321
|
+
# cute.printf("Producer bidx = {}, tidx = {}, after full arrive", bidx, tidx)
|
|
322
|
+
else:
|
|
332
323
|
# if tidx % 64 == 0: cute.printf("bidx = {},tidx = {}, before full wait, idx = {}", bidx, tidx, current_work_linear_idx)
|
|
333
324
|
self._scheduler_pipeline.consumer_wait(self._pipeline_state)
|
|
334
325
|
# if tidx % 64 == 0: cute.printf("bidx = {}, tidx = {}, after full wait, idx = {}", bidx, tidx, current_work_linear_idx)
|
|
@@ -341,21 +332,17 @@ class TileScheduler:
|
|
|
341
332
|
# if tidx % 64 == 0: cute.printf("bidx = {}, tidx = {}, after empty arrive", bidx, tidx)
|
|
342
333
|
self._current_work_linear_idx = current_work_linear_idx
|
|
343
334
|
self._pipeline_state.advance()
|
|
344
|
-
self.
|
|
335
|
+
self.num_tiles_executed += Int32(advance_count)
|
|
345
336
|
|
|
346
337
|
def producer_tail(self):
|
|
347
338
|
if const_expr(self.params.is_persistent and self.params.tile_count_semaphore is not None):
|
|
348
339
|
self._scheduler_pipeline.producer_tail(self._pipeline_state)
|
|
349
340
|
|
|
350
|
-
@property
|
|
351
|
-
def num_tiles_executed(self) -> Int32:
|
|
352
|
-
return self._num_tiles_executed
|
|
353
|
-
|
|
354
341
|
def __extract_mlir_values__(self):
|
|
355
342
|
values, self._values_pos = [], []
|
|
356
343
|
for obj in [
|
|
357
344
|
self._current_work_linear_idx,
|
|
358
|
-
self.
|
|
345
|
+
self.num_tiles_executed,
|
|
359
346
|
self._tile_count,
|
|
360
347
|
self._scheduler_pipeline,
|
|
361
348
|
self._pipeline_state,
|
|
@@ -371,7 +358,7 @@ class TileScheduler:
|
|
|
371
358
|
for obj, n_items in zip(
|
|
372
359
|
[
|
|
373
360
|
self._current_work_linear_idx,
|
|
374
|
-
self.
|
|
361
|
+
self.num_tiles_executed,
|
|
375
362
|
self._tile_count,
|
|
376
363
|
self._scheduler_pipeline,
|
|
377
364
|
self._pipeline_state,
|
|
@@ -390,7 +377,7 @@ def triangular_idx_to_coord(idx: Int32) -> Tuple[Int32, Int32]:
|
|
|
390
377
|
Convert a triangular index to 2D coordinates.
|
|
391
378
|
This is used to convert the linear index to 2D coordinates for triangular matrices.
|
|
392
379
|
"""
|
|
393
|
-
row = utils.ceil((
|
|
380
|
+
row = utils.ceil((cute.math.sqrt(2 * idx + 2.25, fastmath=True) - 0.5)) - 1
|
|
394
381
|
col = idx - (row * (row + 1)) // 2
|
|
395
382
|
return row, col
|
|
396
383
|
|
|
@@ -524,7 +511,8 @@ class TriangularTileScheduler(TileScheduler):
|
|
|
524
511
|
group_size = params.group_size_divmod.divisor
|
|
525
512
|
group_id = (
|
|
526
513
|
utils.ceil(
|
|
527
|
-
(
|
|
514
|
+
(cute.math.sqrt(2 * cluster_id_in_problem + 2.25, fastmath=True) - 0.5)
|
|
515
|
+
* params.group_size_inv_f32
|
|
528
516
|
)
|
|
529
517
|
- 1
|
|
530
518
|
)
|
|
@@ -561,7 +549,7 @@ class TriangularTileScheduler(TileScheduler):
|
|
|
561
549
|
pid_n = cid_n * params.cluster_shape_mn[1] + bidx_in_cluster[1]
|
|
562
550
|
tile_coord_mnkl = (pid_m, pid_n, None, bidz)
|
|
563
551
|
if const_expr(not params.is_persistent):
|
|
564
|
-
is_valid = self.
|
|
552
|
+
is_valid = self.num_tiles_executed == 0
|
|
565
553
|
else:
|
|
566
554
|
is_valid = (
|
|
567
555
|
self._current_work_linear_idx
|
|
@@ -680,7 +668,7 @@ class VarlenMTileScheduler(TileScheduler):
|
|
|
680
668
|
ip=None,
|
|
681
669
|
):
|
|
682
670
|
self._current_work_linear_idx = current_work_linear_idx
|
|
683
|
-
self.
|
|
671
|
+
self.num_tiles_executed = num_tiles_executed
|
|
684
672
|
self._current_batch_idx = current_batch_idx
|
|
685
673
|
self._num_work_idx_before_cur_batch = num_work_idx_before_cur_batch
|
|
686
674
|
self._tile_count = tile_count
|
|
@@ -877,7 +865,7 @@ class VarlenMTileScheduler(TileScheduler):
|
|
|
877
865
|
pid_n = cid_n * params.cluster_shape_mn[1] + bidx_in_cluster[1]
|
|
878
866
|
tile_coord_mnkl = (pid_m, pid_n, None, batch_idx)
|
|
879
867
|
if const_expr(not params.is_persistent):
|
|
880
|
-
is_valid = self.
|
|
868
|
+
is_valid = self.num_tiles_executed == 0 and batch_idx < num_batch
|
|
881
869
|
else:
|
|
882
870
|
is_valid = batch_idx < num_batch
|
|
883
871
|
return cutlass.utils.WorkTileInfo(tile_coord_mnkl, is_valid)
|
|
@@ -904,7 +892,7 @@ class VarlenMTileScheduler(TileScheduler):
|
|
|
904
892
|
values, self._values_pos = [], []
|
|
905
893
|
for obj in [
|
|
906
894
|
self._current_work_linear_idx,
|
|
907
|
-
self.
|
|
895
|
+
self.num_tiles_executed,
|
|
908
896
|
self._current_batch_idx,
|
|
909
897
|
self._num_work_idx_before_cur_batch,
|
|
910
898
|
self._tile_count,
|
|
@@ -922,7 +910,7 @@ class VarlenMTileScheduler(TileScheduler):
|
|
|
922
910
|
for obj, n_items in zip(
|
|
923
911
|
[
|
|
924
912
|
self._current_work_linear_idx,
|
|
925
|
-
self.
|
|
913
|
+
self.num_tiles_executed,
|
|
926
914
|
self._current_batch_idx,
|
|
927
915
|
self._num_work_idx_before_cur_batch,
|
|
928
916
|
self._tile_count,
|
quack/utils.py
CHANGED
|
@@ -6,7 +6,7 @@ from typing import Optional, Tuple, Type, Union
|
|
|
6
6
|
import cutlass
|
|
7
7
|
import cutlass.cute as cute
|
|
8
8
|
|
|
9
|
-
from cutlass import Float32, Int32
|
|
9
|
+
from cutlass import Float32, Int32, const_expr
|
|
10
10
|
from cutlass.cutlass_dsl import T, dsl_user_op
|
|
11
11
|
from cutlass._mlir.dialects import llvm, nvvm, vector
|
|
12
12
|
from cutlass.cute.runtime import from_dlpack
|
|
@@ -22,6 +22,59 @@ def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Te
|
|
|
22
22
|
)
|
|
23
23
|
|
|
24
24
|
|
|
25
|
+
def transpose_view(a: cute.Tensor) -> cute.Tensor:
|
|
26
|
+
"""Transpose the first two dimensions of a tensor on smem."""
|
|
27
|
+
shape = (a.shape[1], a.shape[0], *a.shape[2:])
|
|
28
|
+
order = (1, 0, *range(2, cute.rank(a)))
|
|
29
|
+
return cute.composition(a, cute.make_ordered_layout(shape, order=order))
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def select(a: cute.Tensor, mode: list[int]) -> cute.Tensor:
|
|
33
|
+
return cute.make_tensor(a.iterator, cute.select(a.layout, mode))
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dsl_user_op
|
|
37
|
+
def get_copy_atom(
|
|
38
|
+
dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None
|
|
39
|
+
) -> cute.CopyAtom:
|
|
40
|
+
num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width))
|
|
41
|
+
copy_op = cute.nvgpu.cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
|
|
42
|
+
return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dsl_user_op
|
|
46
|
+
def copy(
|
|
47
|
+
src: cute.Tensor,
|
|
48
|
+
dst: cute.Tensor,
|
|
49
|
+
*,
|
|
50
|
+
pred: Optional[cute.Tensor] = None,
|
|
51
|
+
num_copy_elems: int = 1,
|
|
52
|
+
is_async: bool = False,
|
|
53
|
+
loc=None,
|
|
54
|
+
ip=None,
|
|
55
|
+
**kwargs,
|
|
56
|
+
) -> None:
|
|
57
|
+
copy_atom = get_copy_atom(src.element_type, num_copy_elems, is_async)
|
|
58
|
+
cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def tiled_copy_2d(
|
|
62
|
+
dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = True
|
|
63
|
+
) -> cute.TiledCopy:
|
|
64
|
+
num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width
|
|
65
|
+
copy_elems = num_copy_bits // dtype.width
|
|
66
|
+
copy_op = cute.nvgpu.cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
|
|
67
|
+
copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
|
|
68
|
+
gmem_threads_per_row = major_mode_size // copy_elems
|
|
69
|
+
assert num_threads % gmem_threads_per_row == 0
|
|
70
|
+
thr_layout = cute.make_ordered_layout(
|
|
71
|
+
(num_threads // gmem_threads_per_row, gmem_threads_per_row),
|
|
72
|
+
order=(1, 0),
|
|
73
|
+
)
|
|
74
|
+
val_layout = cute.make_layout((1, copy_elems))
|
|
75
|
+
return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
|
|
76
|
+
|
|
77
|
+
|
|
25
78
|
@dsl_user_op
|
|
26
79
|
def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer:
|
|
27
80
|
return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip)
|
|
@@ -29,7 +82,7 @@ def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cut
|
|
|
29
82
|
|
|
30
83
|
@cute.jit
|
|
31
84
|
def load_scalar_or_pointer(x: Float32 | cute.Pointer) -> Float32:
|
|
32
|
-
if
|
|
85
|
+
if const_expr(isinstance(x, cute.Pointer)):
|
|
33
86
|
return Float32(cute.make_tensor(x, cute.make_layout(1))[0])
|
|
34
87
|
else:
|
|
35
88
|
assert isinstance(x, Float32)
|
|
@@ -71,7 +124,7 @@ def store_shared_remote(
|
|
|
71
124
|
remote_mbar_ptr_i32 = set_block_rank(
|
|
72
125
|
mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
|
|
73
126
|
).ir_value()
|
|
74
|
-
if
|
|
127
|
+
if const_expr(isinstance(val, float)):
|
|
75
128
|
val = Float32(val)
|
|
76
129
|
assert isinstance(val, (Float32, Int32, cutlass.Int64)), "val must be Float32, Int32, or Int64"
|
|
77
130
|
suffix = {Float32: "f32", Int32: "s32", cutlass.Int64: "s64"}[type(val)]
|
|
@@ -100,69 +153,6 @@ def fmin(a: Union[float, Float32], b: Union[float, Float32], *, loc=None, ip=Non
|
|
|
100
153
|
)
|
|
101
154
|
|
|
102
155
|
|
|
103
|
-
@cute.jit
|
|
104
|
-
def exp2f(x: cute.TensorSSA | Float32) -> cute.TensorSSA | Float32:
|
|
105
|
-
"""exp2f calculation for both vector and scalar.
|
|
106
|
-
:param x: input value
|
|
107
|
-
:type x: cute.TensorSSA or Float32
|
|
108
|
-
:return: exp2 value
|
|
109
|
-
:rtype: cute.TensorSSA or Float32
|
|
110
|
-
"""
|
|
111
|
-
if cutlass.const_expr(isinstance(x, cute.TensorSSA)):
|
|
112
|
-
res = cute.make_fragment(x.shape, Float32)
|
|
113
|
-
res.store(x)
|
|
114
|
-
for i in cutlass.range(cute.size(x.shape), unroll_full=True):
|
|
115
|
-
res[i] = cute.arch.exp2(res[i])
|
|
116
|
-
return res.load()
|
|
117
|
-
else:
|
|
118
|
-
return cute.arch.exp2(x)
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
@dsl_user_op
|
|
122
|
-
def log2f(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
123
|
-
return Float32(
|
|
124
|
-
llvm.inline_asm(
|
|
125
|
-
T.f32(),
|
|
126
|
-
[Float32(a).ir_value(loc=loc, ip=ip)],
|
|
127
|
-
"lg2.approx.ftz.f32 $0, $1;",
|
|
128
|
-
"=f,f",
|
|
129
|
-
has_side_effects=False,
|
|
130
|
-
is_align_stack=False,
|
|
131
|
-
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
132
|
-
)
|
|
133
|
-
)
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
@dsl_user_op
|
|
137
|
-
def sqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
138
|
-
return Float32(
|
|
139
|
-
llvm.inline_asm(
|
|
140
|
-
T.f32(),
|
|
141
|
-
[Float32(a).ir_value(loc=loc, ip=ip)],
|
|
142
|
-
"sqrt.approx.ftz.f32 $0, $1;",
|
|
143
|
-
"=f,f",
|
|
144
|
-
has_side_effects=False,
|
|
145
|
-
is_align_stack=False,
|
|
146
|
-
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
147
|
-
)
|
|
148
|
-
)
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
@dsl_user_op
|
|
152
|
-
def rsqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
153
|
-
return Float32(
|
|
154
|
-
llvm.inline_asm(
|
|
155
|
-
T.f32(),
|
|
156
|
-
[Float32(a).ir_value(loc=loc, ip=ip)],
|
|
157
|
-
"rsqrt.approx.ftz.f32 $0, $1;",
|
|
158
|
-
"=f,f",
|
|
159
|
-
has_side_effects=False,
|
|
160
|
-
is_align_stack=False,
|
|
161
|
-
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
162
|
-
)
|
|
163
|
-
)
|
|
164
|
-
|
|
165
|
-
|
|
166
156
|
@dsl_user_op
|
|
167
157
|
def ceil(a: float | Float32, *, loc=None, ip=None) -> Int32:
|
|
168
158
|
return Int32(
|
|
@@ -259,7 +249,7 @@ def fill_oob(tXsX: cute.Tensor, tXpX: Optional[cute.Tensor], fill_value: cute.Nu
|
|
|
259
249
|
tXrX_fill.fill(fill_value)
|
|
260
250
|
for rest_v in cutlass.range_constexpr(tXsX.shape[0][1]):
|
|
261
251
|
for rest_k in cutlass.range_constexpr(tXsX.shape[2]):
|
|
262
|
-
if
|
|
252
|
+
if const_expr(tXpX is not None):
|
|
263
253
|
if not tXpX[rest_v, 0, rest_k]:
|
|
264
254
|
cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k])
|
|
265
255
|
else:
|
|
@@ -295,9 +285,9 @@ def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float
|
|
|
295
285
|
def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
|
|
296
286
|
flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord))
|
|
297
287
|
flat_stride = cute.flatten_to_tuple(tensor.stride)
|
|
298
|
-
assert len(flat_coord_i64) == len(
|
|
299
|
-
|
|
300
|
-
)
|
|
288
|
+
assert len(flat_coord_i64) == len(flat_stride), (
|
|
289
|
+
"Coordinate and stride must have the same length"
|
|
290
|
+
)
|
|
301
291
|
offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride))
|
|
302
292
|
assert isinstance(tensor.iterator, cute.Pointer)
|
|
303
293
|
# HACK: we assume that applying the offset does not change the pointer alignment
|
|
@@ -328,7 +318,7 @@ def coord_offset_i64(
|
|
|
328
318
|
|
|
329
319
|
@cute.jit
|
|
330
320
|
def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) -> cutlass.Int32:
|
|
331
|
-
if
|
|
321
|
+
if const_expr(lane is None):
|
|
332
322
|
lane = cute.arch.lane_idx()
|
|
333
323
|
for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))):
|
|
334
324
|
offset = 1 << i
|
quack/varlen_utils.py
CHANGED
|
@@ -14,9 +14,4 @@ class VarlenArguments(ArgumentsBase):
|
|
|
14
14
|
mCuSeqlensM: Optional[cute.Tensor] = None
|
|
15
15
|
mCuSeqlensK: Optional[cute.Tensor] = None
|
|
16
16
|
mTensormaps: Optional[cute.Tensor] = None
|
|
17
|
-
|
|
18
|
-
def __post_init__(self):
|
|
19
|
-
if self.mCuSeqlensM is not None or self.mCuSeqlensK is not None:
|
|
20
|
-
assert (
|
|
21
|
-
self.mTensormaps is not None
|
|
22
|
-
), "mTensormaps must be provided if mCuSeqlensM or mCuSeqlensK is provided"
|
|
17
|
+
mAIdx: Optional[cute.Tensor] = None
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: quack-kernels
|
|
3
|
-
Version: 0.2.
|
|
4
|
-
Requires-Python: >=3.
|
|
3
|
+
Version: 0.2.2
|
|
4
|
+
Requires-Python: >=3.10
|
|
5
5
|
License-File: LICENSE
|
|
6
|
-
Requires-Dist: nvidia-cutlass-dsl==4.2.
|
|
6
|
+
Requires-Dist: nvidia-cutlass-dsl==4.2.1
|
|
7
7
|
Requires-Dist: torch
|
|
8
8
|
Provides-Extra: dev
|
|
9
9
|
Requires-Dist: pre-commit; extra == "dev"
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
quack/__init__.py,sha256=sJum67V7jEQPUDWz4FKJ5Sk7MqmBtbMXjZPVboQnDdE,364
|
|
2
|
+
quack/activation.py,sha256=SzQDUCB-kccqsy1aYUrHYJ2cGxKMXxxqpjJaJoqBYaE,10017
|
|
3
|
+
quack/autotuner.py,sha256=atw0ntedi22RPwSdjWOoge4S56S8VFvRocJQcYhpAlo,13454
|
|
4
|
+
quack/cross_entropy.py,sha256=TE8j21c-7E4cInKtFjcKsgKXNhKCRFkNfhCJpgpasj8,28409
|
|
5
|
+
quack/cute_dsl_utils.py,sha256=d8xLD17a9EsSQgmgWDO8rUWWCTRM8e1kDq1wzilaYC8,4563
|
|
6
|
+
quack/dense_gemm_sm90.py,sha256=LvcR178zzzWClkEerhIx940Sg-AF_BpQdnjqC8s9W1o,113832
|
|
7
|
+
quack/fast_math.py,sha256=E1XUqfUt0_n9BPZNggF-UDzZ6anso9bYUrwqafemWvQ,2297
|
|
8
|
+
quack/gemm_act_sm90.py,sha256=yJEkwCtKjldxzJYq78CpCV6fxoqoZJSpd7KvnglHqfo,16206
|
|
9
|
+
quack/gemm_config.py,sha256=gbYjPFeyT5wAhVwFQroRHlHoMKEJqAWX9P8wWy04l8Q,2258
|
|
10
|
+
quack/gemm_dact_sm90.py,sha256=QOACq-v9XHfY6p5frKzYCvkCbqGDq69beYcfCfl-5Kc,6458
|
|
11
|
+
quack/gemm_interface.py,sha256=qEbQRsvTrwKdLLlGVCMH76diMCKOsA6GqsC0PaepLow,39636
|
|
12
|
+
quack/gemm_sm100.py,sha256=T-2BUrUBXROxQ9Iz-6pB5T8j9go29Vlw4ZCJQ_oM7yg,110396
|
|
13
|
+
quack/gemm_wrapper_utils.py,sha256=oDCXngJuH-qbDI9DJuXkDHUogXleWZrF1mRpI1DAcI8,12687
|
|
14
|
+
quack/layernorm.py,sha256=AOe95-YqhFPw96x8pJq7FfBe26ROX9ZTvH025lM1ILs,13579
|
|
15
|
+
quack/linear.py,sha256=SrhRiAFjC7ONIMVmiNu-kSPLHNUyaCXt59a1f_5nNXo,9383
|
|
16
|
+
quack/linear_cross_entropy.py,sha256=Zhy_gdMsKHOie-jntBaqIuiDJtkiq6qEBwnyuWwIRw4,10092
|
|
17
|
+
quack/mlp.py,sha256=YjdwQRwEePA9KyidFXp5H1-lxiJc8dZ41vl8Fv8pgss,2259
|
|
18
|
+
quack/pipeline.py,sha256=DyCwZX8WvoUBFcMBz7CeYm9VUM31haEGgBhAzmxu8cE,5519
|
|
19
|
+
quack/reduce.py,sha256=0hRFMFfn6xC5QLk32Qmgc17XVkQ1yKC-3TfksccSBaU,10341
|
|
20
|
+
quack/reduction_base.py,sha256=CT-t_j7z8H1ByD9FkQYDRik_-THMDFv9QoXHmr9Xx9E,3636
|
|
21
|
+
quack/rmsnorm.py,sha256=Ak3EL-qzwgaKGZl7O2upiR3FC93776Cgse_B5PZhTu0,45643
|
|
22
|
+
quack/softmax.py,sha256=WFWtgc40iLPFBpdStBBTC9803Npnv9rZjOzb_nK-RDs,17110
|
|
23
|
+
quack/symmetric_dense_gemm_sm90.py,sha256=2UXooIpClT2izdyGis1XaIgYYlLj-7MrcOMg2yR7YCk,88694
|
|
24
|
+
quack/tensormap_manager.py,sha256=Ts3Mxp0_es2RNA0ffvUjWMXN79lsfWEBZ0DQYhtbcnw,5338
|
|
25
|
+
quack/tile_scheduler.py,sha256=5lcprf3VIXWCNusWHBCveHpCWRzQ0nzcIMhaQbXher8,41727
|
|
26
|
+
quack/topk.py,sha256=RQl-23lIicQ9ry9Njur8i0JGem_WbO_Gchr6jy8EtVM,9185
|
|
27
|
+
quack/utils.py,sha256=DVMSbMngPBnIRrHuGDXKqVueiNv9DFCfGv076hxzJms,14747
|
|
28
|
+
quack/varlen_utils.py,sha256=GwXc8tO6BrYoYszhSeJ0u_KmreJAEodP1EAizLS-jaA,464
|
|
29
|
+
quack/sort/bitonic_sort.py,sha256=8t0SG1a6iEpYIlY8YM_AWvm4aN-4AA4vEzdBuJMJm9g,4768
|
|
30
|
+
quack/sort/generate_sorting_networks.py,sha256=vkJBOjTVEinQkWT4OtFqOWxFVdTIPoNAQocneKc9-rM,14477
|
|
31
|
+
quack/sort/sorting_networks.py,sha256=l_26zi3gXD_z-tnm2eAczRrmE-mbaz00KmqH6ONivL8,9686
|
|
32
|
+
quack/sort/utils.py,sha256=Mkr-l97RMAV-ZoNrwuzA1U3KO0Wjr38CV9Jm7ScyZoI,1090
|
|
33
|
+
quack_kernels-0.2.2.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
34
|
+
quack_kernels-0.2.2.dist-info/METADATA,sha256=ZZofR2edTztufmX_0ExiJ7CpFsT80koJf-pRRUm3ssg,285
|
|
35
|
+
quack_kernels-0.2.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
36
|
+
quack_kernels-0.2.2.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
|
|
37
|
+
quack_kernels-0.2.2.dist-info/RECORD,,
|
|
@@ -1,37 +0,0 @@
|
|
|
1
|
-
quack/__init__.py,sha256=fGBYbb9JlaNT7HdtUTbUnuAkL5G2Dg8XZAA5Ir1R-ow,364
|
|
2
|
-
quack/activation.py,sha256=ysXaVUXX2yGQC5o4ZVeRXw_fDIHOrqnzpHJaIsc0kHc,10271
|
|
3
|
-
quack/autotuner.py,sha256=czO6JrYL0EJpOeJOYDSsVdrJaFuwfL3vTdG8QfL1F34,10792
|
|
4
|
-
quack/cross_entropy.py,sha256=Kc3P83Vsu1nGaCu7llsO3vct3J_t3frRYPxij7JfHMA,28619
|
|
5
|
-
quack/cute_dsl_utils.py,sha256=D2Pw7rzX9jY8u8wikIPvPvinmFLCDeZg95HPBLqGej4,4635
|
|
6
|
-
quack/dense_gemm_sm100.py,sha256=hKBNC34UxdctrTKVP68nvANZl4Dq2rnUjRcweESEq3g,109965
|
|
7
|
-
quack/dense_gemm_sm90.py,sha256=TjnjHnjhAwWH5YQWsFlADq07xSxtsprkw_p2Cy0yw7I,100407
|
|
8
|
-
quack/fast_math.py,sha256=E1XUqfUt0_n9BPZNggF-UDzZ6anso9bYUrwqafemWvQ,2297
|
|
9
|
-
quack/gemm_act_sm90.py,sha256=N5UAFWZvw1na22Vh5JSGgcdqZ2zI6kQMBVOLxYbCAUU,14332
|
|
10
|
-
quack/gemm_config.py,sha256=gbYjPFeyT5wAhVwFQroRHlHoMKEJqAWX9P8wWy04l8Q,2258
|
|
11
|
-
quack/gemm_dact_sm90.py,sha256=KCXgjOzdamSDexwrwf_pX2r-ippPRirbClrlU6BP7b8,4990
|
|
12
|
-
quack/gemm_interface.py,sha256=_JTpE7zQw6NUw-v65Wql_XUOZBfW0oSEgiMnharTJU4,20501
|
|
13
|
-
quack/gemm_wrapper_utils.py,sha256=aMMtu-Ojhtjay_5xJH4AjP-JRVks1AB8jmtNme_DIqU,5960
|
|
14
|
-
quack/layernorm.py,sha256=JkK0sVdUfZ-SmoBmNqLF3wCiszDbdorvcBH2julv0Vg,13560
|
|
15
|
-
quack/linear.py,sha256=SrhRiAFjC7ONIMVmiNu-kSPLHNUyaCXt59a1f_5nNXo,9383
|
|
16
|
-
quack/linear_cross_entropy.py,sha256=Zhy_gdMsKHOie-jntBaqIuiDJtkiq6qEBwnyuWwIRw4,10092
|
|
17
|
-
quack/mlp.py,sha256=YjdwQRwEePA9KyidFXp5H1-lxiJc8dZ41vl8Fv8pgss,2259
|
|
18
|
-
quack/pipeline.py,sha256=DyCwZX8WvoUBFcMBz7CeYm9VUM31haEGgBhAzmxu8cE,5519
|
|
19
|
-
quack/reduce.py,sha256=hsYByu6haCZjLTLB-qpYmKDjqS2UqlwPgfWTup38GNA,10341
|
|
20
|
-
quack/reduction_base.py,sha256=CT-t_j7z8H1ByD9FkQYDRik_-THMDFv9QoXHmr9Xx9E,3636
|
|
21
|
-
quack/rmsnorm.py,sha256=93qlTPjY9JBm3R5M-HeHse1PbAfD9931G3OFs71yo_g,48998
|
|
22
|
-
quack/softmax.py,sha256=Mq3_2Ul8H64zeGUI9wOKEpIISJnrCcHQpZvk2sb10Tg,17101
|
|
23
|
-
quack/symmetric_dense_gemm_sm90.py,sha256=2UXooIpClT2izdyGis1XaIgYYlLj-7MrcOMg2yR7YCk,88694
|
|
24
|
-
quack/tensormap_manager.py,sha256=Ts3Mxp0_es2RNA0ffvUjWMXN79lsfWEBZ0DQYhtbcnw,5338
|
|
25
|
-
quack/tile_scheduler.py,sha256=8qqYmx6GpQzt8XiidcrdLIaWf0TGbJVdwKFfeb1X_us,42265
|
|
26
|
-
quack/topk.py,sha256=RQl-23lIicQ9ry9Njur8i0JGem_WbO_Gchr6jy8EtVM,9185
|
|
27
|
-
quack/utils.py,sha256=tiqeJZiPPFl5irQWCUd7dTPA_OAv4SjHUW5S-u9wO8Y,14526
|
|
28
|
-
quack/varlen_utils.py,sha256=vkduMEpo5bJJvZRNnIcKPb6pp1wD34vaIpMIB0ZGIZA,681
|
|
29
|
-
quack/sort/bitonic_sort.py,sha256=8t0SG1a6iEpYIlY8YM_AWvm4aN-4AA4vEzdBuJMJm9g,4768
|
|
30
|
-
quack/sort/generate_sorting_networks.py,sha256=vkJBOjTVEinQkWT4OtFqOWxFVdTIPoNAQocneKc9-rM,14477
|
|
31
|
-
quack/sort/sorting_networks.py,sha256=l_26zi3gXD_z-tnm2eAczRrmE-mbaz00KmqH6ONivL8,9686
|
|
32
|
-
quack/sort/utils.py,sha256=Mkr-l97RMAV-ZoNrwuzA1U3KO0Wjr38CV9Jm7ScyZoI,1090
|
|
33
|
-
quack_kernels-0.2.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
34
|
-
quack_kernels-0.2.0.dist-info/METADATA,sha256=DAeQymRUqp7lSfSTNyS7TZF3oWcFzCKriGJ2p8JLu6A,285
|
|
35
|
-
quack_kernels-0.2.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
36
|
-
quack_kernels-0.2.0.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
|
|
37
|
-
quack_kernels-0.2.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|