quack-kernels 0.1.11__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +7 -3
- quack/activation.py +279 -0
- quack/autotuner.py +2 -1
- quack/cross_entropy.py +330 -184
- quack/cute_dsl_utils.py +83 -4
- quack/dense_gemm_sm100.py +1 -1
- quack/dense_gemm_sm90.py +911 -1140
- quack/fast_math.py +10 -27
- quack/gemm_act_sm90.py +368 -0
- quack/gemm_config.py +43 -35
- quack/gemm_dact_sm90.py +150 -0
- quack/gemm_interface.py +491 -243
- quack/gemm_wrapper_utils.py +158 -0
- quack/layernorm.py +6 -4
- quack/linear.py +128 -64
- quack/linear_cross_entropy.py +275 -0
- quack/mlp.py +30 -160
- quack/pipeline.py +2 -17
- quack/reduce.py +240 -0
- quack/reduction_base.py +2 -11
- quack/rmsnorm.py +614 -228
- quack/softmax.py +28 -16
- quack/symmetric_dense_gemm_sm90.py +6 -3
- quack/tensormap_manager.py +1 -0
- quack/tile_scheduler.py +64 -61
- quack/topk.py +14 -8
- quack/utils.py +14 -322
- quack/varlen_utils.py +22 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/METADATA +3 -3
- quack_kernels-0.2.1.dist-info/RECORD +37 -0
- quack/lse.py +0 -62
- quack_kernels-0.1.11.dist-info/RECORD +0 -31
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/top_level.txt +0 -0
quack/fast_math.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
# Copyright (c) 2025, Tri Dao.
|
|
2
2
|
|
|
3
3
|
from typing import Tuple
|
|
4
|
+
from dataclasses import dataclass
|
|
4
5
|
|
|
5
6
|
import cutlass
|
|
6
7
|
import cutlass.cute as cute
|
|
@@ -8,6 +9,8 @@ from cutlass import Int32, Uint32
|
|
|
8
9
|
from cutlass.cutlass_dsl import T, dsl_user_op
|
|
9
10
|
from cutlass._mlir.dialects import llvm
|
|
10
11
|
|
|
12
|
+
from quack.cute_dsl_utils import ParamsBase
|
|
13
|
+
|
|
11
14
|
|
|
12
15
|
@cute.jit
|
|
13
16
|
def clz(x: Int32) -> Int32:
|
|
@@ -45,18 +48,15 @@ def umulhi(a: Int32, b: Int32, *, loc=None, ip=None) -> Uint32:
|
|
|
45
48
|
)
|
|
46
49
|
|
|
47
50
|
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
self.multiplier = multipler
|
|
54
|
-
self.shift_right = shift_right
|
|
55
|
-
self._loc = loc
|
|
51
|
+
@dataclass
|
|
52
|
+
class FastDivmod(ParamsBase):
|
|
53
|
+
divisor: Int32
|
|
54
|
+
multiplier: Uint32
|
|
55
|
+
shift_right: Uint32
|
|
56
56
|
|
|
57
57
|
# called by host
|
|
58
58
|
@staticmethod
|
|
59
|
-
def create(divisor: Int32
|
|
59
|
+
def create(divisor: Int32) -> "FastDivmod":
|
|
60
60
|
"""Construct the FastDivmod object, in host code.
|
|
61
61
|
This precomputes some values based on the divisor and is computationally expensive.
|
|
62
62
|
"""
|
|
@@ -64,7 +64,7 @@ class FastDivmod:
|
|
|
64
64
|
divisor_u32 = Uint32(divisor)
|
|
65
65
|
multiplier = Uint32(((cutlass.Uint64(1) << p) + divisor_u32 - 1) // divisor_u32)
|
|
66
66
|
shift_right = Uint32(p - 32)
|
|
67
|
-
return FastDivmod(divisor, multiplier, shift_right
|
|
67
|
+
return FastDivmod(divisor, multiplier, shift_right)
|
|
68
68
|
|
|
69
69
|
@cute.jit
|
|
70
70
|
def div(self, dividend: Int32) -> Int32:
|
|
@@ -78,20 +78,3 @@ class FastDivmod:
|
|
|
78
78
|
quotient = self.div(dividend)
|
|
79
79
|
remainder = dividend - quotient * self.divisor
|
|
80
80
|
return quotient, remainder
|
|
81
|
-
|
|
82
|
-
def __extract_mlir_values__(self):
|
|
83
|
-
values, self._values_pos = [], []
|
|
84
|
-
for obj in [self.divisor, self.multiplier, self.shift_right]:
|
|
85
|
-
obj_values = cutlass.extract_mlir_values(obj)
|
|
86
|
-
values += obj_values
|
|
87
|
-
self._values_pos.append(len(obj_values))
|
|
88
|
-
return values
|
|
89
|
-
|
|
90
|
-
def __new_from_mlir_values__(self, values):
|
|
91
|
-
obj_list = []
|
|
92
|
-
for obj, n_items in zip(
|
|
93
|
-
[self.divisor, self.multiplier, self.shift_right], self._values_pos
|
|
94
|
-
):
|
|
95
|
-
obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
|
|
96
|
-
values = values[n_items:]
|
|
97
|
-
return FastDivmod(*(tuple(obj_list)), loc=self._loc)
|
quack/gemm_act_sm90.py
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
1
|
+
# Copyright (c) 2025, Tri Dao.
|
|
2
|
+
from typing import Tuple, Optional, Callable
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
import cutlass
|
|
8
|
+
import cutlass.cute as cute
|
|
9
|
+
from cutlass.cute.nvgpu import warpgroup
|
|
10
|
+
import cutlass.utils.hopper_helpers as sm90_utils
|
|
11
|
+
from cutlass import Int32, Float32, Boolean, const_expr
|
|
12
|
+
import cutlass.torch as cutlass_torch
|
|
13
|
+
|
|
14
|
+
from quack.cute_dsl_utils import ArgumentsBase, ParamsBase
|
|
15
|
+
from quack.dense_gemm_sm90 import GemmSm90
|
|
16
|
+
from quack.cute_dsl_utils import get_max_active_clusters
|
|
17
|
+
from quack.gemm_wrapper_utils import GemmWrapperBase
|
|
18
|
+
import quack.activation
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class GemmActSm90(GemmSm90):
|
|
22
|
+
@dataclass
|
|
23
|
+
class EpilogueArguments(ArgumentsBase):
|
|
24
|
+
mPostAct: cute.Tensor
|
|
25
|
+
act_fn: cutlass.Constexpr[Optional[Callable]] = None
|
|
26
|
+
alpha: Optional[Float32] = None
|
|
27
|
+
beta: Optional[Float32] = None
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class EpilogueParams(ParamsBase):
|
|
31
|
+
tma_atom_postact: cute.CopyAtom
|
|
32
|
+
mPostAct_mnl: cute.Tensor
|
|
33
|
+
epi_postact_smem_layout_staged: cute.ComposedLayout
|
|
34
|
+
act_fn: cutlass.Constexpr[Optional[Callable]] = None
|
|
35
|
+
alpha: Optional[Float32] = None
|
|
36
|
+
beta: Optional[Float32] = None
|
|
37
|
+
|
|
38
|
+
def epi_to_underlying_arguments(
|
|
39
|
+
self, args: EpilogueArguments, *, loc=None, ip=None
|
|
40
|
+
) -> EpilogueParams:
|
|
41
|
+
self.postact_dtype = args.mPostAct.element_type
|
|
42
|
+
self.postact_layout = cutlass.utils.LayoutEnum.from_tensor(args.mPostAct)
|
|
43
|
+
|
|
44
|
+
self.tile_shape_postact_mn = self.tile_shape_mnk[:2]
|
|
45
|
+
self.epi_tile_postact = self.epi_tile
|
|
46
|
+
postact_major_mode_size = (
|
|
47
|
+
self.epi_tile_postact[1]
|
|
48
|
+
if self.postact_layout.is_n_major_c()
|
|
49
|
+
else self.epi_tile_postact[0]
|
|
50
|
+
)
|
|
51
|
+
postact_smem_layout_atom = warpgroup.make_smem_layout_atom(
|
|
52
|
+
sm90_utils.get_smem_layout_atom(
|
|
53
|
+
self.postact_layout, self.postact_dtype, postact_major_mode_size
|
|
54
|
+
),
|
|
55
|
+
self.postact_dtype,
|
|
56
|
+
)
|
|
57
|
+
epi_postact_smem_layout_staged = cute.tile_to_shape(
|
|
58
|
+
postact_smem_layout_atom,
|
|
59
|
+
cute.append(self.epi_tile_postact, self.epi_stage),
|
|
60
|
+
order=(0, 1, 2),
|
|
61
|
+
)
|
|
62
|
+
tma_atom_postact, tma_tensor_postact = self._make_tma_epi_atoms_and_tensors(
|
|
63
|
+
args.mPostAct,
|
|
64
|
+
epi_postact_smem_layout_staged,
|
|
65
|
+
self.epi_tile_postact,
|
|
66
|
+
store_or_load="store",
|
|
67
|
+
)
|
|
68
|
+
return GemmActSm90.EpilogueParams(
|
|
69
|
+
tma_atom_postact,
|
|
70
|
+
tma_tensor_postact,
|
|
71
|
+
epi_postact_smem_layout_staged,
|
|
72
|
+
args.act_fn,
|
|
73
|
+
args.alpha,
|
|
74
|
+
args.beta,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
@staticmethod
|
|
78
|
+
def epi_smem_bytes_per_stage(
|
|
79
|
+
args: EpilogueArguments,
|
|
80
|
+
tile_shape_mnk: Tuple[int, int, int],
|
|
81
|
+
epi_tile: Tuple[int, int],
|
|
82
|
+
) -> int:
|
|
83
|
+
postact_dtype = args.mPostAct.element_type
|
|
84
|
+
postact_bytes_per_stage = cute.size(epi_tile) * (postact_dtype.width // 8)
|
|
85
|
+
return postact_bytes_per_stage
|
|
86
|
+
|
|
87
|
+
def epi_get_smem_struct(self, params: EpilogueParams):
|
|
88
|
+
@cute.struct
|
|
89
|
+
class EpiSharedStorage:
|
|
90
|
+
sPostAct: cute.struct.Align[
|
|
91
|
+
cute.struct.MemRange[
|
|
92
|
+
self.postact_dtype, cute.cosize(params.epi_postact_smem_layout_staged)
|
|
93
|
+
],
|
|
94
|
+
self.buffer_align_bytes,
|
|
95
|
+
]
|
|
96
|
+
|
|
97
|
+
return EpiSharedStorage
|
|
98
|
+
|
|
99
|
+
def epi_get_smem_tensors(self, params: EpilogueParams, storage) -> Tuple[cute.Tensor, ...]:
|
|
100
|
+
sPostAct = storage.epi.sPostAct.get_tensor(
|
|
101
|
+
params.epi_postact_smem_layout_staged.outer,
|
|
102
|
+
swizzle=params.epi_postact_smem_layout_staged.inner,
|
|
103
|
+
)
|
|
104
|
+
return (sPostAct,)
|
|
105
|
+
|
|
106
|
+
@cute.jit
|
|
107
|
+
def epilogue(
|
|
108
|
+
self,
|
|
109
|
+
params: EpilogueParams,
|
|
110
|
+
epi_smem_tensors: Tuple[cute.Tensor, ...],
|
|
111
|
+
epi_pipeline: cutlass.pipeline.PipelineAsync,
|
|
112
|
+
epi_read_state: cutlass.pipeline.PipelineState,
|
|
113
|
+
epi_producer_state: cutlass.pipeline.PipelineState,
|
|
114
|
+
tiled_mma: cute.TiledMma,
|
|
115
|
+
tRS_rAcc: cute.Tensor,
|
|
116
|
+
tRS_rD: cute.Tensor,
|
|
117
|
+
tRS_rC: Optional[cute.Tensor],
|
|
118
|
+
tiled_copy_r2s: cute.core.ThrCopy,
|
|
119
|
+
tRS_sD: cute.Tensor,
|
|
120
|
+
tiled_copy_s2r: Optional[cute.core.ThrCopy],
|
|
121
|
+
tSR_rC: Optional[cute.Tensor],
|
|
122
|
+
tSR_sC: Optional[cute.Tensor],
|
|
123
|
+
copy_D: Optional[Callable],
|
|
124
|
+
bSG_sD: cute.Tensor,
|
|
125
|
+
bSG_gD: cute.Tensor,
|
|
126
|
+
epi_load_g2s: Optional[Callable],
|
|
127
|
+
tile_coord_mnkl: cute.Coord,
|
|
128
|
+
cu_seqlens_m: Optional[cute.Tensor],
|
|
129
|
+
epilogue_barrier: cutlass.pipeline.NamedBarrier,
|
|
130
|
+
tile_scheduler,
|
|
131
|
+
tidx: Int32,
|
|
132
|
+
is_tma_warp: Boolean,
|
|
133
|
+
) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]:
|
|
134
|
+
has_C = const_expr(tRS_rC is not None)
|
|
135
|
+
has_D = const_expr(copy_D is not None)
|
|
136
|
+
assert cu_seqlens_m is None, "GemmActSm90 doesn't support varlen_m for now"
|
|
137
|
+
|
|
138
|
+
tma_atom_postact = params.tma_atom_postact
|
|
139
|
+
mPostAct_mnl = params.mPostAct_mnl
|
|
140
|
+
(sPostAct,) = epi_smem_tensors
|
|
141
|
+
tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma)
|
|
142
|
+
copy_atom_postact_r2s = sm90_utils.sm90_get_smem_store_op(
|
|
143
|
+
self.postact_layout, elem_ty_d=self.postact_dtype, elem_ty_acc=self.acc_dtype
|
|
144
|
+
)
|
|
145
|
+
tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_C_atom)
|
|
146
|
+
thr_copy_postact_r2s = tiled_copy_postact_r2s.get_slice(tidx)
|
|
147
|
+
tRS_sPostAct = thr_copy_postact_r2s.partition_D(sPostAct)
|
|
148
|
+
bSG_sPostAct, bSG_gPostAct = self.epilog_gmem_copy_and_partition(
|
|
149
|
+
tma_atom_postact,
|
|
150
|
+
mPostAct_mnl,
|
|
151
|
+
self.tile_shape_postact_mn,
|
|
152
|
+
self.epi_tile_postact,
|
|
153
|
+
sPostAct,
|
|
154
|
+
tile_coord_mnkl,
|
|
155
|
+
cu_seqlens_m,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# We iterate over epi tiles in the N dimension first before the M dimension
|
|
159
|
+
epi_tile_shape = cute.zipped_divide(
|
|
160
|
+
cute.make_layout(self.tile_shape_mnk[:2]), self.epi_tile
|
|
161
|
+
).shape[1]
|
|
162
|
+
epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1))
|
|
163
|
+
epi_tile_num = cute.size(epi_tile_shape)
|
|
164
|
+
num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num
|
|
165
|
+
|
|
166
|
+
if const_expr(epi_load_g2s is not None):
|
|
167
|
+
for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1):
|
|
168
|
+
epi_producer_state = epi_load_g2s(epi_producer_state, epi_idx, is_tma_warp)
|
|
169
|
+
|
|
170
|
+
for epi_idx in cutlass.range_constexpr(epi_tile_num):
|
|
171
|
+
# Copy from acc to D registers
|
|
172
|
+
for epi_v in cutlass.range_constexpr(cute.size(tRS_rD)):
|
|
173
|
+
tRS_rD[epi_v] = tRS_rAcc[epi_idx * cute.size(tRS_rD) + epi_v]
|
|
174
|
+
if const_expr(has_C):
|
|
175
|
+
epi_pipeline.consumer_wait(epi_read_state)
|
|
176
|
+
cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
|
|
177
|
+
# Fence to make sure shared memory read is visible to TMA load
|
|
178
|
+
cute.arch.fence_proxy(
|
|
179
|
+
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
180
|
+
)
|
|
181
|
+
cute.arch.sync_warp()
|
|
182
|
+
with cute.arch.elect_one():
|
|
183
|
+
epi_pipeline.consumer_release(epi_read_state)
|
|
184
|
+
epi_read_state.advance()
|
|
185
|
+
if const_expr(epi_load_g2s is not None and epi_idx + self.epi_c_stage < epi_tile_num):
|
|
186
|
+
epi_producer_state = epi_load_g2s(
|
|
187
|
+
epi_producer_state, epi_idx + self.epi_c_stage, is_tma_warp
|
|
188
|
+
)
|
|
189
|
+
tRS_rPostAct = self.epi_visit_acc_subtile(params, tRS_rD, tRS_rC)
|
|
190
|
+
epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage
|
|
191
|
+
# Copy from D registers to shared memory
|
|
192
|
+
if const_expr(has_D):
|
|
193
|
+
# Type conversion
|
|
194
|
+
tRS_rD_out = cute.make_fragment_like(tRS_rD, self.d_dtype)
|
|
195
|
+
tRS_rD_out.store(tRS_rD.load().to(self.d_dtype))
|
|
196
|
+
cute.copy(tiled_copy_r2s, tRS_rD_out, tRS_sD[None, None, None, epi_buffer])
|
|
197
|
+
cute.copy(
|
|
198
|
+
tiled_copy_postact_r2s,
|
|
199
|
+
tiled_copy_postact_r2s.retile(tRS_rPostAct),
|
|
200
|
+
tRS_sPostAct[None, None, None, epi_buffer],
|
|
201
|
+
)
|
|
202
|
+
# Fence and barrier to make sure shared memory store is visible to TMA store
|
|
203
|
+
cute.arch.fence_proxy(
|
|
204
|
+
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
205
|
+
)
|
|
206
|
+
epilogue_barrier.arrive_and_wait()
|
|
207
|
+
# Get the global memory coordinate for the current epi tile
|
|
208
|
+
gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
|
|
209
|
+
# Copy from shared memory to global memory
|
|
210
|
+
if is_tma_warp:
|
|
211
|
+
if const_expr(has_D):
|
|
212
|
+
copy_D(bSG_sD[None, epi_buffer], bSG_gD[None, gmem_coord])
|
|
213
|
+
cute.copy(
|
|
214
|
+
tma_atom_postact,
|
|
215
|
+
bSG_sPostAct[None, epi_buffer],
|
|
216
|
+
bSG_gPostAct[None, gmem_coord],
|
|
217
|
+
)
|
|
218
|
+
cute.arch.cp_async_bulk_commit_group()
|
|
219
|
+
cute.arch.cp_async_bulk_wait_group(self.epi_stage - 1, read=True)
|
|
220
|
+
epilogue_barrier.arrive_and_wait()
|
|
221
|
+
|
|
222
|
+
return epi_read_state, epi_producer_state
|
|
223
|
+
|
|
224
|
+
@cute.jit
|
|
225
|
+
def epi_visit_acc_subtile(
|
|
226
|
+
self,
|
|
227
|
+
params: EpilogueParams,
|
|
228
|
+
tRS_rD: cute.Tensor,
|
|
229
|
+
tRS_rC: Optional[cute.Tensor] = None,
|
|
230
|
+
) -> Optional[cute.Tensor]:
|
|
231
|
+
# Apply alpha scaling to accumulator if alpha is provided (not None)
|
|
232
|
+
if const_expr(params.alpha is not None):
|
|
233
|
+
tRS_rD.store(tRS_rD.load() * params.alpha)
|
|
234
|
+
# Apply C with beta scaling
|
|
235
|
+
if const_expr(tRS_rC is not None):
|
|
236
|
+
if const_expr(params.beta is None):
|
|
237
|
+
# beta is None, default behavior: add C (beta=1.0)
|
|
238
|
+
tRS_rD.store(tRS_rD.load() + tRS_rC.load().to(tRS_rD.element_type))
|
|
239
|
+
else:
|
|
240
|
+
tRS_rD.store(tRS_rD.load() + params.beta * tRS_rC.load().to(tRS_rD.element_type))
|
|
241
|
+
# Apply activation function if provided
|
|
242
|
+
# If we don't have .shape here, the compiler generates local stores and loads
|
|
243
|
+
if const_expr(params.act_fn is not None):
|
|
244
|
+
tRS_rPostAct = cute.make_fragment(tRS_rD.layout.shape, self.acc_dtype)
|
|
245
|
+
for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
|
|
246
|
+
tRS_rPostAct[i] = params.act_fn(tRS_rD[i])
|
|
247
|
+
else:
|
|
248
|
+
tRS_rPostAct = tRS_rD
|
|
249
|
+
# Type conversion
|
|
250
|
+
tRS_rPostAct_out = cute.make_fragment_like(tRS_rPostAct, self.postact_dtype)
|
|
251
|
+
tRS_rPostAct_out.store(tRS_rPostAct.load().to(self.postact_dtype))
|
|
252
|
+
return tRS_rPostAct_out
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
act_fn_map = {
|
|
256
|
+
None: None,
|
|
257
|
+
"relu": quack.activation.relu,
|
|
258
|
+
"relu_sq": quack.activation.relu_sq,
|
|
259
|
+
"gelu_tanh_approx": quack.activation.gelu_tanh_approx,
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def gemm_act_sm90(
|
|
264
|
+
A: Tensor, # (l, m, k)
|
|
265
|
+
B: Tensor, # (l, n, k)
|
|
266
|
+
D: Optional[Tensor], # (l, m, n)
|
|
267
|
+
C: Optional[Tensor], # (l, m, n)
|
|
268
|
+
PostAct: Tensor, # (l, m, n)
|
|
269
|
+
activation: Optional[str],
|
|
270
|
+
tile_M: int,
|
|
271
|
+
tile_N: int,
|
|
272
|
+
cluster_M: int,
|
|
273
|
+
cluster_N: int,
|
|
274
|
+
pingpong: bool = False,
|
|
275
|
+
persistent: bool = True,
|
|
276
|
+
alpha: float = 1.0,
|
|
277
|
+
beta: float = 1.0,
|
|
278
|
+
) -> None:
|
|
279
|
+
tile_count_semaphore = None
|
|
280
|
+
assert activation in act_fn_map, f"Unsupported activation {activation}"
|
|
281
|
+
L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
|
|
282
|
+
A, B, D, C, additional_tensors={"PostAct": PostAct}
|
|
283
|
+
)
|
|
284
|
+
GemmWrapperBase.permute_tensors(tensor_infos)
|
|
285
|
+
GemmWrapperBase.extract_dtypes(tensor_infos)
|
|
286
|
+
major_configs = {
|
|
287
|
+
"A": ("m", "k", "l"),
|
|
288
|
+
"B": ("n", "k", "l"),
|
|
289
|
+
"D": ("m", "n", "l"),
|
|
290
|
+
"C": ("m", "n", "l"),
|
|
291
|
+
"PostAct": ("m", "n", "l"),
|
|
292
|
+
}
|
|
293
|
+
GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
|
|
294
|
+
|
|
295
|
+
acc_dtype = cutlass.Float32
|
|
296
|
+
tile_shape_mn = (tile_M, tile_N)
|
|
297
|
+
cluster_shape_mnk = (cluster_M, cluster_N, 1)
|
|
298
|
+
if not GemmActSm90.is_valid_dtypes(
|
|
299
|
+
tensor_infos["A"].dtype,
|
|
300
|
+
tensor_infos["B"].dtype,
|
|
301
|
+
acc_dtype,
|
|
302
|
+
tensor_infos["D"].dtype,
|
|
303
|
+
tensor_infos["A"].major,
|
|
304
|
+
tensor_infos["B"].major,
|
|
305
|
+
):
|
|
306
|
+
raise TypeError("Skipping due to unsupported combination of types and majors")
|
|
307
|
+
|
|
308
|
+
max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
|
|
309
|
+
GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
|
|
310
|
+
act_fn = act_fn_map[activation]
|
|
311
|
+
epi_args = GemmActSm90.EpilogueArguments(
|
|
312
|
+
tensor_infos["PostAct"].cute_tensor,
|
|
313
|
+
act_fn,
|
|
314
|
+
alpha=Float32(alpha) if alpha != 1.0 else None,
|
|
315
|
+
beta=Float32(beta) if beta != 1.0 else None,
|
|
316
|
+
)
|
|
317
|
+
scheduler_args = GemmWrapperBase.create_scheduler_args(
|
|
318
|
+
max_active_clusters, tile_count_semaphore
|
|
319
|
+
)
|
|
320
|
+
current_stream = cutlass_torch.current_stream()
|
|
321
|
+
compile_key = GemmWrapperBase.get_compile_key(
|
|
322
|
+
tensor_infos,
|
|
323
|
+
activation,
|
|
324
|
+
tile_shape_mn,
|
|
325
|
+
cluster_shape_mnk,
|
|
326
|
+
pingpong,
|
|
327
|
+
persistent,
|
|
328
|
+
tile_count_semaphore is not None,
|
|
329
|
+
alpha != 1.0,
|
|
330
|
+
beta != 1.0,
|
|
331
|
+
key_tensor_names=("A", "B", "D", "PostAct", "C"),
|
|
332
|
+
)
|
|
333
|
+
cache = gemm_act_sm90.compile_cache
|
|
334
|
+
if compile_key not in cache:
|
|
335
|
+
gemm = GemmActSm90(
|
|
336
|
+
acc_dtype,
|
|
337
|
+
tensor_infos["A"].dtype,
|
|
338
|
+
tile_shape_mn,
|
|
339
|
+
cluster_shape_mnk,
|
|
340
|
+
pingpong=pingpong,
|
|
341
|
+
is_persistent=persistent,
|
|
342
|
+
)
|
|
343
|
+
cache[compile_key] = cute.compile(
|
|
344
|
+
gemm,
|
|
345
|
+
tensor_infos["A"].cute_tensor,
|
|
346
|
+
tensor_infos["B"].cute_tensor,
|
|
347
|
+
tensor_infos["D"].cute_tensor,
|
|
348
|
+
tensor_infos["C"].cute_tensor,
|
|
349
|
+
epi_args,
|
|
350
|
+
scheduler_args,
|
|
351
|
+
None, # varlen_args
|
|
352
|
+
None, # mAIdx
|
|
353
|
+
current_stream,
|
|
354
|
+
)
|
|
355
|
+
cache[compile_key](
|
|
356
|
+
tensor_infos["A"].cute_tensor,
|
|
357
|
+
tensor_infos["B"].cute_tensor,
|
|
358
|
+
tensor_infos["D"].cute_tensor,
|
|
359
|
+
tensor_infos["C"].cute_tensor,
|
|
360
|
+
epi_args,
|
|
361
|
+
scheduler_args,
|
|
362
|
+
None,
|
|
363
|
+
None,
|
|
364
|
+
current_stream,
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
gemm_act_sm90.compile_cache = {}
|
quack/gemm_config.py
CHANGED
|
@@ -1,61 +1,69 @@
|
|
|
1
|
-
# Copyright (C) 2025,
|
|
1
|
+
# Copyright (C) 2025, Fri Dao.
|
|
2
2
|
import itertools
|
|
3
|
-
from typing import Optional
|
|
4
|
-
from
|
|
3
|
+
from typing import Optional, List
|
|
4
|
+
from dataclasses import dataclass
|
|
5
5
|
|
|
6
6
|
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
7
|
+
@dataclass(frozen=True)
|
|
8
|
+
class GemmConfig:
|
|
9
|
+
tile_m: int = 128
|
|
10
|
+
tile_n: int = 192
|
|
11
|
+
pingpong: bool = True
|
|
10
12
|
cluster_m: int = 2
|
|
11
13
|
cluster_n: int = 1
|
|
12
14
|
swap_ab: bool = False
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
max_swizzle_size: int = 1
|
|
15
|
+
# raster_order: int = 1
|
|
16
|
+
# max_swizzle_size: int = 8
|
|
16
17
|
|
|
17
18
|
|
|
18
19
|
def get_all_configs(
|
|
19
|
-
epilogue: Optional[str],
|
|
20
|
-
|
|
21
|
-
tune_raster_order=True,
|
|
22
|
-
) ->
|
|
20
|
+
epilogue: Optional[str] = None,
|
|
21
|
+
tune_coop: bool = True,
|
|
22
|
+
# tune_raster_order=True,
|
|
23
|
+
) -> List[GemmConfig]:
|
|
23
24
|
tile_n_vals = [128, 144, 160, 176, 192, 208]
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
25
|
+
tile_mn_coop_vals = [(256, tile_n) for tile_n in tile_n_vals] + [
|
|
26
|
+
(128, 224),
|
|
27
|
+
(128, 256),
|
|
28
|
+
# (192, 256), # Getting IOT instruction (core dumped) in the bwd
|
|
29
|
+
]
|
|
30
|
+
tile_mn_pingpong_vals = [(128, tile_n) for tile_n in tile_n_vals] + [(192, 128)]
|
|
31
|
+
if epilogue in ["gated"]:
|
|
32
|
+
tile_mn_coop_vals = [(m, n) for m, n in tile_mn_coop_vals if n % 32 == 0 and m != 192]
|
|
33
|
+
tile_mn_pingpong_vals = [(m, n) for m, n in tile_mn_pingpong_vals if n % 32 == 0]
|
|
34
|
+
elif epilogue in ["lse"]:
|
|
35
|
+
tile_mn_coop_vals = [(m, n) for m, n in tile_mn_coop_vals if m != 192]
|
|
36
|
+
tile_mn_vals = []
|
|
37
|
+
if tune_coop:
|
|
38
|
+
tile_mn_vals += [(m, n, False) for m, n in tile_mn_coop_vals]
|
|
39
|
+
tile_mn_vals += [(m, n, True) for m, n in tile_mn_pingpong_vals]
|
|
40
|
+
cluster = [(1, 2), (2, 1)]
|
|
41
|
+
# cluster = [(1, 1), (1, 2), (2, 1)]
|
|
29
42
|
if epilogue in ["lse"]:
|
|
30
43
|
cluster = [(1, 2), (2, 1)]
|
|
31
44
|
swap_ab_vals = [False, True]
|
|
32
|
-
if epilogue in ["lse", "
|
|
45
|
+
if epilogue in ["lse", "gated"]:
|
|
33
46
|
swap_ab_vals = [False]
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
)
|
|
47
|
+
# raster_swizzle = (
|
|
48
|
+
# [(0, 1)]
|
|
49
|
+
# if not tune_raster_order
|
|
50
|
+
# else [(1, 1), (1, 2), (1, 4), (1, 8), (2, 1), (2, 2), (2, 4), (2, 8)]
|
|
51
|
+
# )
|
|
40
52
|
return [
|
|
41
53
|
GemmConfig(
|
|
42
|
-
tile_m=tile_m
|
|
54
|
+
tile_m=tile_m,
|
|
43
55
|
tile_n=tile_n,
|
|
56
|
+
pingpong=pingpong,
|
|
44
57
|
cluster_m=cluster_m,
|
|
45
58
|
cluster_n=cluster_n,
|
|
46
59
|
swap_ab=swap_ab,
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
max_swizzle_size=max_swizzle_size,
|
|
60
|
+
# raster_order=raster_order,
|
|
61
|
+
# max_swizzle_size=max_swizzle_size,
|
|
50
62
|
)
|
|
51
|
-
for (tile_m, tile_n), (cluster_m, cluster_n), swap_ab
|
|
52
|
-
raster_order,
|
|
53
|
-
max_swizzle_size,
|
|
54
|
-
) in itertools.product(
|
|
63
|
+
for (tile_m, tile_n, pingpong), (cluster_m, cluster_n), swap_ab in itertools.product(
|
|
55
64
|
tile_mn_vals,
|
|
56
65
|
cluster,
|
|
57
66
|
swap_ab_vals,
|
|
58
|
-
|
|
59
|
-
raster_swizzle,
|
|
67
|
+
# raster_swizzle,
|
|
60
68
|
)
|
|
61
69
|
]
|
quack/gemm_dact_sm90.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
# Copyright (c) 2025, Tri Dao.
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
import cutlass
|
|
7
|
+
import cutlass.cute as cute
|
|
8
|
+
from cutlass import const_expr
|
|
9
|
+
import cutlass.torch as cutlass_torch
|
|
10
|
+
|
|
11
|
+
from quack.gemm_act_sm90 import GemmActSm90
|
|
12
|
+
from quack.cute_dsl_utils import get_max_active_clusters
|
|
13
|
+
from quack.gemm_wrapper_utils import GemmWrapperBase
|
|
14
|
+
import quack.activation
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class GemmDActSm90(GemmActSm90):
|
|
18
|
+
# Different from GemmActSm90, here act_bwd_fn must take in 2 arguments (x, dout)
|
|
19
|
+
# and return 2 arguments (dx, out)
|
|
20
|
+
EpilogueArguments = GemmActSm90.EpilogueArguments
|
|
21
|
+
EpilogueParams = GemmActSm90.EpilogueParams
|
|
22
|
+
|
|
23
|
+
@cute.jit
|
|
24
|
+
def epi_visit_acc_subtile(
|
|
25
|
+
self,
|
|
26
|
+
params: EpilogueParams,
|
|
27
|
+
tRS_rD: cute.Tensor,
|
|
28
|
+
tRS_rC: Optional[cute.Tensor] = None,
|
|
29
|
+
) -> Optional[cute.Tensor]:
|
|
30
|
+
assert tRS_rC is not None
|
|
31
|
+
tRS_rC_acc = cute.make_fragment_like(tRS_rC, self.acc_dtype)
|
|
32
|
+
tRS_rC_acc.store(tRS_rC.load().to(self.acc_dtype))
|
|
33
|
+
# If we don't have .shape here, the compiler generates local stores and loads
|
|
34
|
+
if const_expr(params.act_fn is not None):
|
|
35
|
+
tRS_rPostAct = cute.make_fragment(tRS_rD.layout.shape, self.acc_dtype)
|
|
36
|
+
for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
|
|
37
|
+
tRS_rD[i], tRS_rPostAct[i] = params.act_fn(tRS_rC_acc[i], tRS_rD[i])
|
|
38
|
+
else:
|
|
39
|
+
tRS_rPostAct = tRS_rC_acc
|
|
40
|
+
# Type conversion
|
|
41
|
+
tRS_rPostAct_out = cute.make_fragment_like(tRS_rPostAct, self.postact_dtype)
|
|
42
|
+
tRS_rPostAct_out.store(tRS_rPostAct.load().to(self.postact_dtype))
|
|
43
|
+
return tRS_rPostAct_out
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
dact_fn_map = {
|
|
47
|
+
None: None,
|
|
48
|
+
"relu": quack.activation.drelu,
|
|
49
|
+
"relu_sq": quack.activation.drelu_sq,
|
|
50
|
+
"gelu_tanh_approx": quack.activation.dgelu_tanh_approx,
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def gemm_dact_sm90(
|
|
55
|
+
A: Tensor, # (l, m, k)
|
|
56
|
+
B: Tensor, # (l, n, k)
|
|
57
|
+
Out: Tensor, # (l, m, n)
|
|
58
|
+
PreAct: Tensor, # (l, m, n)
|
|
59
|
+
PostAct: Tensor, # (l, m, n)
|
|
60
|
+
tile_count_semaphore: Optional[Tensor], # (1,)
|
|
61
|
+
activation: Optional[str],
|
|
62
|
+
tile_M: int,
|
|
63
|
+
tile_N: int,
|
|
64
|
+
cluster_M: int,
|
|
65
|
+
cluster_N: int,
|
|
66
|
+
pingpong: bool = True,
|
|
67
|
+
persistent: bool = True,
|
|
68
|
+
) -> None:
|
|
69
|
+
assert activation in dact_fn_map, f"Unsupported activation {activation}"
|
|
70
|
+
L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
|
|
71
|
+
A, B, Out, PreAct, additional_tensors={"PostAct": PostAct}
|
|
72
|
+
)
|
|
73
|
+
GemmWrapperBase.permute_tensors(tensor_infos)
|
|
74
|
+
GemmWrapperBase.extract_dtypes(tensor_infos)
|
|
75
|
+
major_configs = {
|
|
76
|
+
"A": ("m", "k", "l"),
|
|
77
|
+
"B": ("n", "k", "l"),
|
|
78
|
+
"D": ("m", "n", "l"),
|
|
79
|
+
"C": ("m", "n", "l"),
|
|
80
|
+
"PostAct": ("m", "n", "l"),
|
|
81
|
+
}
|
|
82
|
+
GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
|
|
83
|
+
|
|
84
|
+
acc_dtype = cutlass.Float32
|
|
85
|
+
tile_shape_mn = (tile_M, tile_N)
|
|
86
|
+
cluster_shape_mnk = (cluster_M, cluster_N, 1)
|
|
87
|
+
if not GemmDActSm90.is_valid_dtypes(
|
|
88
|
+
tensor_infos["A"].dtype,
|
|
89
|
+
tensor_infos["B"].dtype,
|
|
90
|
+
acc_dtype,
|
|
91
|
+
tensor_infos["D"].dtype,
|
|
92
|
+
tensor_infos["A"].major,
|
|
93
|
+
tensor_infos["B"].major,
|
|
94
|
+
):
|
|
95
|
+
raise TypeError("Skipping due to unsupported combination of types and majors")
|
|
96
|
+
|
|
97
|
+
max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
|
|
98
|
+
GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
|
|
99
|
+
act_fn = dact_fn_map[activation]
|
|
100
|
+
epi_args = GemmDActSm90.EpilogueArguments(tensor_infos["PostAct"].cute_tensor, act_fn)
|
|
101
|
+
scheduler_args = GemmWrapperBase.create_scheduler_args(
|
|
102
|
+
max_active_clusters, tile_count_semaphore
|
|
103
|
+
)
|
|
104
|
+
current_stream = cutlass_torch.current_stream()
|
|
105
|
+
compile_key = GemmWrapperBase.get_compile_key(
|
|
106
|
+
tensor_infos,
|
|
107
|
+
activation,
|
|
108
|
+
tile_shape_mn,
|
|
109
|
+
cluster_shape_mnk,
|
|
110
|
+
pingpong,
|
|
111
|
+
persistent,
|
|
112
|
+
tile_count_semaphore is not None,
|
|
113
|
+
key_tensor_names=("A", "B", "D", "PostAct", "C"),
|
|
114
|
+
)
|
|
115
|
+
cache = gemm_dact_sm90.compile_cache
|
|
116
|
+
if compile_key not in cache:
|
|
117
|
+
gemm = GemmDActSm90(
|
|
118
|
+
acc_dtype,
|
|
119
|
+
tensor_infos["A"].dtype,
|
|
120
|
+
tile_shape_mn,
|
|
121
|
+
cluster_shape_mnk,
|
|
122
|
+
pingpong=pingpong,
|
|
123
|
+
is_persistent=persistent,
|
|
124
|
+
)
|
|
125
|
+
cache[compile_key] = cute.compile(
|
|
126
|
+
gemm,
|
|
127
|
+
tensor_infos["A"].cute_tensor,
|
|
128
|
+
tensor_infos["B"].cute_tensor,
|
|
129
|
+
tensor_infos["D"].cute_tensor,
|
|
130
|
+
tensor_infos["C"].cute_tensor,
|
|
131
|
+
epi_args,
|
|
132
|
+
scheduler_args,
|
|
133
|
+
None, # varlen_args
|
|
134
|
+
None, # mAIdx
|
|
135
|
+
current_stream,
|
|
136
|
+
)
|
|
137
|
+
cache[compile_key](
|
|
138
|
+
tensor_infos["A"].cute_tensor,
|
|
139
|
+
tensor_infos["B"].cute_tensor,
|
|
140
|
+
tensor_infos["D"].cute_tensor,
|
|
141
|
+
tensor_infos["C"].cute_tensor,
|
|
142
|
+
epi_args,
|
|
143
|
+
scheduler_args,
|
|
144
|
+
None,
|
|
145
|
+
None,
|
|
146
|
+
current_stream,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
gemm_dact_sm90.compile_cache = {}
|