quack-kernels 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,114 @@
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ from typing import Tuple
4
+ from dataclasses import dataclass
5
+
6
+ import cutlass
7
+ import cutlass.cute as cute
8
+ from cutlass.cutlass_dsl import Boolean, const_expr, Int32
9
+ from cutlass.utils import TensorMapUpdateMode, TensorMapManager
10
+ from cutlass._mlir.dialects import llvm
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class TensorMapManagerSm90(TensorMapManager):
15
+ """
16
+ We have to subclass cutlass.utils.TensorMapManager bc it takes in warp_id and only
17
+ perform the operation if warp_id matches the current warp.
18
+ But for Hopper pingpong gemm we want to call it with warp_id 0 and 4.
19
+ So we take in a boolean `is_manager_warp` to determine whether to perform the operation or not.
20
+ """
21
+
22
+ @cute.jit
23
+ def init_tensormap_from_atom(
24
+ self, copy_atom: cute.CopyAtom, dst_ptr: cute.Pointer, is_manager_warp: Boolean
25
+ ) -> None:
26
+ if is_manager_warp:
27
+ with cute.arch.elect_one():
28
+ cute.nvgpu.cpasync.copy_tensormap(copy_atom, dst_ptr)
29
+ cute.arch.sync_warp()
30
+ return
31
+
32
+ @cute.jit
33
+ def update_tensormap(
34
+ self,
35
+ tensor_gmem: Tuple[cute.Tensor, ...],
36
+ tma_copy_atom: Tuple[cute.CopyAtom, ...],
37
+ tensormap_gmem_ptr: Tuple[cute.Pointer, ...],
38
+ is_manager_warp: Boolean,
39
+ tensormap_smem_ptr: Tuple[cute.Pointer, ...],
40
+ ) -> None:
41
+ # updates before touching tensormap in global memory
42
+ if is_manager_warp:
43
+ if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM):
44
+ for copy_atom, tensor, smem_ptr in zip(
45
+ tma_copy_atom, tensor_gmem, tensormap_smem_ptr
46
+ ):
47
+ cute.nvgpu.cpasync.update_tma_descriptor(copy_atom, tensor, smem_ptr)
48
+ # wait until it's safe to update tensormap in global memory
49
+ with cute.arch.elect_one():
50
+ cute.arch.cp_async_bulk_commit_group()
51
+ cute.arch.cp_async_bulk_wait_group(0, read=True)
52
+ cute.arch.sync_warp()
53
+ # updates to tensormap in global memory
54
+ if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM):
55
+ for gmem_ptr, smem_ptr in zip(tensormap_gmem_ptr, tensormap_smem_ptr):
56
+ cute.nvgpu.cpasync.cp_fence_tma_desc_release(gmem_ptr, smem_ptr)
57
+ else:
58
+ for copy_atom, tensor, gmem_ptr in zip(
59
+ tma_copy_atom, tensor_gmem, tensormap_gmem_ptr
60
+ ):
61
+ cute.nvgpu.cpasync.update_tma_descriptor(copy_atom, tensor, gmem_ptr)
62
+ cute.arch.sync_warp()
63
+ cute.nvgpu.cpasync.fence_tma_desc_release()
64
+
65
+ @cute.jit
66
+ def update_tensormap_shape(
67
+ self,
68
+ tensormap_gmem_ptr: Tuple[cute.Pointer, ...],
69
+ is_manager_warp: Boolean,
70
+ tensormap_smem_ptr: Tuple[cute.Pointer, ...],
71
+ shapes: Tuple[Int32, ...],
72
+ orders: cutlass.Constexpr[Tuple[int, ...]],
73
+ ) -> None:
74
+ # updates before touching tensormap in global memory
75
+ if is_manager_warp:
76
+ if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM):
77
+ for smem_ptr, shape, order in zip(tensormap_smem_ptr, shapes, orders):
78
+ smem_ptr_i32 = smem_ptr.toint().ir_value()
79
+ llvm.inline_asm(
80
+ None,
81
+ [smem_ptr_i32, Int32(shape).ir_value(), Int32(order).ir_value()],
82
+ "{\n\t"
83
+ ".reg .b64 smem_ptr_i64;\n\t"
84
+ "cvt.u64.u32 smem_ptr_i64, $0;\n\t"
85
+ f"tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [smem_ptr_i64], {order}, $1;\n\t"
86
+ "}\n",
87
+ "r,r",
88
+ has_side_effects=True,
89
+ is_align_stack=False,
90
+ asm_dialect=llvm.AsmDialect.AD_ATT,
91
+ )
92
+ # wait until it's safe to update tensormap in global memory
93
+ with cute.arch.elect_one():
94
+ cute.arch.cp_async_bulk_commit_group()
95
+ cute.arch.cp_async_bulk_wait_group(0, read=True)
96
+ cute.arch.sync_warp()
97
+ # updates to tensormap in global memory
98
+ if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM):
99
+ for gmem_ptr, smem_ptr in zip(tensormap_gmem_ptr, tensormap_smem_ptr):
100
+ cute.nvgpu.cpasync.cp_fence_tma_desc_release(gmem_ptr, smem_ptr)
101
+ else:
102
+ for gmem_ptr, shape, order in zip(tensormap_gmem_ptr, shapes, orders):
103
+ gmem_ptr_i64 = gmem_ptr.toint().ir_value()
104
+ llvm.inline_asm(
105
+ None,
106
+ [gmem_ptr_i64, Int32(shape).ir_value(), Int32(order).ir_value()],
107
+ f"tensormap.replace.tile.global_dim.global.b1024.b32 [$0], {order}, $1;",
108
+ "l,r",
109
+ has_side_effects=True,
110
+ is_align_stack=False,
111
+ asm_dialect=llvm.AsmDialect.AD_ATT,
112
+ )
113
+ cute.arch.sync_warp()
114
+ cute.nvgpu.cpasync.fence_tma_desc_release()