quack-kernels 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +4 -1
- quack/autotuner.py +309 -0
- quack/cross_entropy.py +2 -5
- quack/cute_dsl_utils.py +40 -0
- quack/dense_gemm_sm100.py +2562 -0
- quack/dense_gemm_sm90.py +2474 -0
- quack/fast_math.py +97 -0
- quack/gemm_config.py +61 -0
- quack/gemm_interface.py +321 -0
- quack/linear.py +176 -0
- quack/lse.py +62 -0
- quack/mlp.py +204 -0
- quack/pipeline.py +166 -0
- quack/sort/bitonic_sort.py +126 -0
- quack/sort/generate_sorting_networks.py +326 -0
- quack/sort/sorting_networks.py +120 -0
- quack/sort/utils.py +31 -0
- quack/symmetric_dense_gemm_sm90.py +2088 -0
- quack/tensormap_manager.py +114 -0
- quack/tile_scheduler.py +935 -0
- quack/topk.py +221 -0
- quack/utils.py +237 -19
- {quack_kernels-0.1.9.dist-info → quack_kernels-0.1.11.dist-info}/METADATA +3 -3
- quack_kernels-0.1.11.dist-info/RECORD +31 -0
- quack_kernels-0.1.9.dist-info/RECORD +0 -12
- {quack_kernels-0.1.9.dist-info → quack_kernels-0.1.11.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.9.dist-info → quack_kernels-0.1.11.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.9.dist-info → quack_kernels-0.1.11.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
# Copyright (c) 2025, Tri Dao.
|
|
2
|
+
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
import cutlass
|
|
7
|
+
import cutlass.cute as cute
|
|
8
|
+
from cutlass.cutlass_dsl import Boolean, const_expr, Int32
|
|
9
|
+
from cutlass.utils import TensorMapUpdateMode, TensorMapManager
|
|
10
|
+
from cutlass._mlir.dialects import llvm
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass(frozen=True)
|
|
14
|
+
class TensorMapManagerSm90(TensorMapManager):
|
|
15
|
+
"""
|
|
16
|
+
We have to subclass cutlass.utils.TensorMapManager bc it takes in warp_id and only
|
|
17
|
+
perform the operation if warp_id matches the current warp.
|
|
18
|
+
But for Hopper pingpong gemm we want to call it with warp_id 0 and 4.
|
|
19
|
+
So we take in a boolean `is_manager_warp` to determine whether to perform the operation or not.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
@cute.jit
|
|
23
|
+
def init_tensormap_from_atom(
|
|
24
|
+
self, copy_atom: cute.CopyAtom, dst_ptr: cute.Pointer, is_manager_warp: Boolean
|
|
25
|
+
) -> None:
|
|
26
|
+
if is_manager_warp:
|
|
27
|
+
with cute.arch.elect_one():
|
|
28
|
+
cute.nvgpu.cpasync.copy_tensormap(copy_atom, dst_ptr)
|
|
29
|
+
cute.arch.sync_warp()
|
|
30
|
+
return
|
|
31
|
+
|
|
32
|
+
@cute.jit
|
|
33
|
+
def update_tensormap(
|
|
34
|
+
self,
|
|
35
|
+
tensor_gmem: Tuple[cute.Tensor, ...],
|
|
36
|
+
tma_copy_atom: Tuple[cute.CopyAtom, ...],
|
|
37
|
+
tensormap_gmem_ptr: Tuple[cute.Pointer, ...],
|
|
38
|
+
is_manager_warp: Boolean,
|
|
39
|
+
tensormap_smem_ptr: Tuple[cute.Pointer, ...],
|
|
40
|
+
) -> None:
|
|
41
|
+
# updates before touching tensormap in global memory
|
|
42
|
+
if is_manager_warp:
|
|
43
|
+
if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM):
|
|
44
|
+
for copy_atom, tensor, smem_ptr in zip(
|
|
45
|
+
tma_copy_atom, tensor_gmem, tensormap_smem_ptr
|
|
46
|
+
):
|
|
47
|
+
cute.nvgpu.cpasync.update_tma_descriptor(copy_atom, tensor, smem_ptr)
|
|
48
|
+
# wait until it's safe to update tensormap in global memory
|
|
49
|
+
with cute.arch.elect_one():
|
|
50
|
+
cute.arch.cp_async_bulk_commit_group()
|
|
51
|
+
cute.arch.cp_async_bulk_wait_group(0, read=True)
|
|
52
|
+
cute.arch.sync_warp()
|
|
53
|
+
# updates to tensormap in global memory
|
|
54
|
+
if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM):
|
|
55
|
+
for gmem_ptr, smem_ptr in zip(tensormap_gmem_ptr, tensormap_smem_ptr):
|
|
56
|
+
cute.nvgpu.cpasync.cp_fence_tma_desc_release(gmem_ptr, smem_ptr)
|
|
57
|
+
else:
|
|
58
|
+
for copy_atom, tensor, gmem_ptr in zip(
|
|
59
|
+
tma_copy_atom, tensor_gmem, tensormap_gmem_ptr
|
|
60
|
+
):
|
|
61
|
+
cute.nvgpu.cpasync.update_tma_descriptor(copy_atom, tensor, gmem_ptr)
|
|
62
|
+
cute.arch.sync_warp()
|
|
63
|
+
cute.nvgpu.cpasync.fence_tma_desc_release()
|
|
64
|
+
|
|
65
|
+
@cute.jit
|
|
66
|
+
def update_tensormap_shape(
|
|
67
|
+
self,
|
|
68
|
+
tensormap_gmem_ptr: Tuple[cute.Pointer, ...],
|
|
69
|
+
is_manager_warp: Boolean,
|
|
70
|
+
tensormap_smem_ptr: Tuple[cute.Pointer, ...],
|
|
71
|
+
shapes: Tuple[Int32, ...],
|
|
72
|
+
orders: cutlass.Constexpr[Tuple[int, ...]],
|
|
73
|
+
) -> None:
|
|
74
|
+
# updates before touching tensormap in global memory
|
|
75
|
+
if is_manager_warp:
|
|
76
|
+
if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM):
|
|
77
|
+
for smem_ptr, shape, order in zip(tensormap_smem_ptr, shapes, orders):
|
|
78
|
+
smem_ptr_i32 = smem_ptr.toint().ir_value()
|
|
79
|
+
llvm.inline_asm(
|
|
80
|
+
None,
|
|
81
|
+
[smem_ptr_i32, Int32(shape).ir_value(), Int32(order).ir_value()],
|
|
82
|
+
"{\n\t"
|
|
83
|
+
".reg .b64 smem_ptr_i64;\n\t"
|
|
84
|
+
"cvt.u64.u32 smem_ptr_i64, $0;\n\t"
|
|
85
|
+
f"tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [smem_ptr_i64], {order}, $1;\n\t"
|
|
86
|
+
"}\n",
|
|
87
|
+
"r,r",
|
|
88
|
+
has_side_effects=True,
|
|
89
|
+
is_align_stack=False,
|
|
90
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
91
|
+
)
|
|
92
|
+
# wait until it's safe to update tensormap in global memory
|
|
93
|
+
with cute.arch.elect_one():
|
|
94
|
+
cute.arch.cp_async_bulk_commit_group()
|
|
95
|
+
cute.arch.cp_async_bulk_wait_group(0, read=True)
|
|
96
|
+
cute.arch.sync_warp()
|
|
97
|
+
# updates to tensormap in global memory
|
|
98
|
+
if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM):
|
|
99
|
+
for gmem_ptr, smem_ptr in zip(tensormap_gmem_ptr, tensormap_smem_ptr):
|
|
100
|
+
cute.nvgpu.cpasync.cp_fence_tma_desc_release(gmem_ptr, smem_ptr)
|
|
101
|
+
else:
|
|
102
|
+
for gmem_ptr, shape, order in zip(tensormap_gmem_ptr, shapes, orders):
|
|
103
|
+
gmem_ptr_i64 = gmem_ptr.toint().ir_value()
|
|
104
|
+
llvm.inline_asm(
|
|
105
|
+
None,
|
|
106
|
+
[gmem_ptr_i64, Int32(shape).ir_value(), Int32(order).ir_value()],
|
|
107
|
+
f"tensormap.replace.tile.global_dim.global.b1024.b32 [$0], {order}, $1;",
|
|
108
|
+
"l,r",
|
|
109
|
+
has_side_effects=True,
|
|
110
|
+
is_align_stack=False,
|
|
111
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
112
|
+
)
|
|
113
|
+
cute.arch.sync_warp()
|
|
114
|
+
cute.nvgpu.cpasync.fence_tma_desc_release()
|