quack-kernels 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -8
- quack/activation.py +366 -121
- quack/broadcast_utils.py +29 -0
- quack/compile_utils.py +19 -0
- quack/copy_utils.py +487 -0
- quack/cross_entropy.py +157 -233
- quack/cute_dsl_utils.py +20 -34
- quack/gemm.py +194 -0
- quack/{gemm_act_sm90.py → gemm_act.py} +218 -117
- quack/gemm_config.py +72 -46
- quack/{gemm_dact_sm90.py → gemm_dact.py} +53 -21
- quack/gemm_default_epi.py +259 -0
- quack/gemm_interface.py +177 -31
- quack/gemm_sm100.py +729 -506
- quack/{dense_gemm_sm90.py → gemm_sm90.py} +344 -814
- quack/gemm_symmetric.py +330 -0
- quack/gemm_wrapper_utils.py +3 -1
- quack/layout_utils.py +287 -0
- quack/linear.py +24 -16
- quack/pipeline.py +158 -3
- quack/reduce.py +88 -49
- quack/reduction_base.py +25 -36
- quack/rmsnorm.py +476 -526
- quack/sm100_utils.py +62 -0
- quack/sm90_utils.py +127 -0
- quack/softmax.py +135 -203
- quack/sort/bitonic_sort.py +13 -10
- quack/sort/utils.py +6 -6
- quack/tile_scheduler.py +23 -16
- quack/topk.py +409 -85
- quack/utils.py +32 -220
- quack/varlen_utils.py +370 -1
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/METADATA +4 -2
- quack_kernels-0.2.4.dist-info/RECORD +44 -0
- quack/layernorm.py +0 -353
- quack/symmetric_dense_gemm_sm90.py +0 -2091
- quack_kernels-0.2.2.dist-info/RECORD +0 -37
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/top_level.txt +0 -0
|
@@ -1,41 +1,54 @@
|
|
|
1
|
-
# Copyright (c) 2025, Tri Dao.
|
|
1
|
+
# Copyright (c) 2025, Wentao Guo, Tri Dao.
|
|
2
2
|
from typing import Tuple, Optional, Callable
|
|
3
|
+
from functools import partial
|
|
3
4
|
from dataclasses import dataclass
|
|
4
5
|
|
|
5
6
|
from torch import Tensor
|
|
6
7
|
|
|
7
8
|
import cutlass
|
|
8
9
|
import cutlass.cute as cute
|
|
9
|
-
|
|
10
|
-
import cutlass.utils.
|
|
10
|
+
import cutlass.utils.hopper_helpers as sm90_utils_og
|
|
11
|
+
import cutlass.utils.blackwell_helpers as sm100_utils
|
|
11
12
|
from cutlass import Int32, Float32, Boolean, const_expr
|
|
13
|
+
from cutlass.cutlass_dsl import if_generate
|
|
12
14
|
import cutlass.torch as cutlass_torch
|
|
15
|
+
from cutlass.cute.runtime import from_dlpack
|
|
13
16
|
|
|
14
17
|
from quack.cute_dsl_utils import ArgumentsBase, ParamsBase
|
|
15
|
-
from quack.
|
|
16
|
-
from quack.
|
|
18
|
+
from quack.varlen_utils import VarlenManager
|
|
19
|
+
from quack.gemm_sm90 import GemmSm90
|
|
20
|
+
from quack.gemm_sm100 import GemmSm100
|
|
21
|
+
from quack.gemm_default_epi import GemmDefaultEpiMixin
|
|
22
|
+
from quack.cute_dsl_utils import get_device_capacity, get_max_active_clusters
|
|
17
23
|
from quack.gemm_wrapper_utils import GemmWrapperBase
|
|
24
|
+
import quack.sm90_utils as sm90_utils
|
|
25
|
+
import quack.copy_utils as copy_utils
|
|
18
26
|
import quack.activation
|
|
19
27
|
|
|
20
28
|
|
|
21
|
-
class
|
|
29
|
+
class GemmActMixin(GemmDefaultEpiMixin):
|
|
22
30
|
num_epi_tensormaps: int = 1
|
|
23
31
|
|
|
24
32
|
@dataclass
|
|
25
33
|
class EpilogueArguments(ArgumentsBase):
|
|
26
34
|
mPostAct: cute.Tensor
|
|
27
35
|
act_fn: cutlass.Constexpr[Optional[Callable]] = None
|
|
28
|
-
alpha: Optional[Float32] = None
|
|
29
|
-
beta: Optional[Float32] = None
|
|
36
|
+
alpha: Optional[Float32 | cute.Tensor] = None
|
|
37
|
+
beta: Optional[Float32 | cute.Tensor] = None
|
|
38
|
+
mRowVecBroadcast: Optional[cute.Tensor] = None
|
|
39
|
+
mColVecBroadcast: Optional[cute.Tensor] = None
|
|
30
40
|
|
|
31
41
|
@dataclass
|
|
32
42
|
class EpilogueParams(ParamsBase):
|
|
33
43
|
tma_atom_postact: cute.CopyAtom
|
|
34
44
|
mPostAct_mnl: cute.Tensor
|
|
35
45
|
epi_postact_smem_layout_staged: cute.ComposedLayout
|
|
46
|
+
epi_tile_postact: cute.Tile
|
|
36
47
|
act_fn: cutlass.Constexpr[Optional[Callable]] = None
|
|
37
|
-
alpha: Optional[Float32] = None
|
|
38
|
-
beta: Optional[Float32] = None
|
|
48
|
+
alpha: Optional[Float32 | cute.Tensor] = None
|
|
49
|
+
beta: Optional[Float32 | cute.Tensor] = None
|
|
50
|
+
mRowVecBroadcast: Optional[cute.Tensor] = None
|
|
51
|
+
mColVecBroadcast: Optional[cute.Tensor] = None
|
|
39
52
|
|
|
40
53
|
def epi_to_underlying_arguments(
|
|
41
54
|
self, args: EpilogueArguments, *, loc=None, ip=None
|
|
@@ -44,36 +57,38 @@ class GemmActSm90(GemmSm90):
|
|
|
44
57
|
self.postact_layout = cutlass.utils.LayoutEnum.from_tensor(args.mPostAct)
|
|
45
58
|
|
|
46
59
|
self.cta_tile_shape_postact_mn = self.cta_tile_shape_mnk[:2]
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
else self.epi_tile_postact[0]
|
|
52
|
-
)
|
|
53
|
-
postact_smem_layout_atom = warpgroup.make_smem_layout_atom(
|
|
54
|
-
sm90_utils.get_smem_layout_atom(
|
|
55
|
-
self.postact_layout, self.postact_dtype, postact_major_mode_size
|
|
56
|
-
),
|
|
57
|
-
self.postact_dtype,
|
|
58
|
-
)
|
|
59
|
-
epi_postact_smem_layout_staged = cute.tile_to_shape(
|
|
60
|
-
postact_smem_layout_atom,
|
|
61
|
-
cute.append(self.epi_tile_postact, self.epi_stage),
|
|
62
|
-
order=(0, 1, 2),
|
|
60
|
+
epi_tile_postact = self.epi_tile
|
|
61
|
+
utils_cls = sm100_utils if self.arch == 100 else sm90_utils
|
|
62
|
+
epi_postact_smem_layout_staged = utils_cls.make_smem_layout_epi(
|
|
63
|
+
self.postact_dtype, self.postact_layout, epi_tile_postact, self.epi_stage
|
|
63
64
|
)
|
|
64
65
|
tma_atom_postact, tma_tensor_postact = self._make_tma_epi_atoms_and_tensors(
|
|
65
66
|
args.mPostAct,
|
|
66
67
|
epi_postact_smem_layout_staged,
|
|
67
|
-
|
|
68
|
+
epi_tile_postact,
|
|
68
69
|
op_type="store",
|
|
69
70
|
)
|
|
70
|
-
|
|
71
|
+
# Assume all strides are divisible by 32 bits except the last stride
|
|
72
|
+
new_stride = lambda t: tuple(
|
|
73
|
+
cute.assume(s, divby=32 // t.element_type.width) if not cute.is_static(s) else s
|
|
74
|
+
for s in t.stride
|
|
75
|
+
)
|
|
76
|
+
mRowVecBroadcast, mColVecBroadcast = [
|
|
77
|
+
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
|
|
78
|
+
if t is not None
|
|
79
|
+
else None
|
|
80
|
+
for t in (args.mRowVecBroadcast, args.mColVecBroadcast)
|
|
81
|
+
]
|
|
82
|
+
return self.EpilogueParams(
|
|
71
83
|
tma_atom_postact,
|
|
72
84
|
tma_tensor_postact,
|
|
73
85
|
epi_postact_smem_layout_staged,
|
|
86
|
+
epi_tile_postact,
|
|
74
87
|
args.act_fn,
|
|
75
|
-
args.alpha,
|
|
76
|
-
args.beta,
|
|
88
|
+
alpha=args.alpha,
|
|
89
|
+
beta=args.beta,
|
|
90
|
+
mRowVecBroadcast=mRowVecBroadcast,
|
|
91
|
+
mColVecBroadcast=mColVecBroadcast,
|
|
77
92
|
)
|
|
78
93
|
|
|
79
94
|
def epi_get_tma_atoms(
|
|
@@ -84,29 +99,41 @@ class GemmActSm90(GemmSm90):
|
|
|
84
99
|
def epi_get_tensormap_update_shapes_orders(
|
|
85
100
|
self,
|
|
86
101
|
params: EpilogueParams,
|
|
87
|
-
cu_seqlens_m: cute.Tensor,
|
|
102
|
+
cu_seqlens_m: Optional[cute.Tensor],
|
|
88
103
|
batch_idx: Int32,
|
|
89
104
|
*,
|
|
90
105
|
loc=None,
|
|
91
106
|
ip=None,
|
|
92
107
|
) -> tuple[list[Int32], list[int]]:
|
|
93
|
-
shapes = [cu_seqlens_m[batch_idx + 1]]
|
|
108
|
+
shapes = [cu_seqlens_m[batch_idx + 1] if cu_seqlens_m is not None else None]
|
|
94
109
|
orders = [0 if const_expr(self.postact_layout.is_m_major_c()) else 1]
|
|
95
110
|
return shapes, orders
|
|
96
111
|
|
|
97
112
|
@staticmethod
|
|
98
113
|
def epi_smem_bytes_per_stage(
|
|
99
|
-
args: EpilogueArguments,
|
|
100
|
-
cta_tile_shape_mnk: Tuple[int, int, int],
|
|
101
|
-
epi_tile: Tuple[int, int],
|
|
114
|
+
args: EpilogueArguments, cta_tile_shape_mnk: Tuple[int, int, int], epi_tile: cute.Tile
|
|
102
115
|
) -> int:
|
|
103
116
|
postact_dtype = args.mPostAct.element_type
|
|
104
|
-
postact_bytes_per_stage = cute.size(epi_tile) * (postact_dtype.width // 8)
|
|
105
|
-
|
|
117
|
+
postact_bytes_per_stage = cute.size(cute.shape(epi_tile)) * (postact_dtype.width // 8)
|
|
118
|
+
rowvec_colvec_bytes = GemmDefaultEpiMixin.epi_smem_bytes_per_stage(
|
|
119
|
+
args, cta_tile_shape_mnk, epi_tile
|
|
120
|
+
)
|
|
121
|
+
return postact_bytes_per_stage + rowvec_colvec_bytes
|
|
106
122
|
|
|
107
123
|
def epi_get_smem_struct(self, params: EpilogueParams):
|
|
124
|
+
row_vec_smem_size = 0 if params.mRowVecBroadcast is None else self.cta_tile_shape_mnk[1]
|
|
125
|
+
col_vec_smem_size = 0 if params.mColVecBroadcast is None else self.cta_tile_shape_mnk[0]
|
|
126
|
+
row_vec_dtype = (
|
|
127
|
+
params.mRowVecBroadcast.element_type if params.mRowVecBroadcast is not None else Float32
|
|
128
|
+
)
|
|
129
|
+
col_vec_dtype = (
|
|
130
|
+
params.mColVecBroadcast.element_type if params.mColVecBroadcast is not None else Float32
|
|
131
|
+
)
|
|
132
|
+
|
|
108
133
|
@cute.struct
|
|
109
134
|
class EpiSharedStorage:
|
|
135
|
+
sRowVec: cute.struct.Align[cute.struct.MemRange[row_vec_dtype, row_vec_smem_size], 16]
|
|
136
|
+
sColVec: cute.struct.Align[cute.struct.MemRange[col_vec_dtype, col_vec_smem_size], 16]
|
|
110
137
|
sPostAct: cute.struct.Align[
|
|
111
138
|
cute.struct.MemRange[
|
|
112
139
|
self.postact_dtype, cute.cosize(params.epi_postact_smem_layout_staged)
|
|
@@ -117,11 +144,12 @@ class GemmActSm90(GemmSm90):
|
|
|
117
144
|
return EpiSharedStorage
|
|
118
145
|
|
|
119
146
|
def epi_get_smem_tensors(self, params: EpilogueParams, storage) -> Tuple[cute.Tensor, ...]:
|
|
147
|
+
sRowVec, sColVec = super().epi_get_smem_tensors(params, storage)
|
|
120
148
|
sPostAct = storage.epi.sPostAct.get_tensor(
|
|
121
149
|
params.epi_postact_smem_layout_staged.outer,
|
|
122
150
|
swizzle=params.epi_postact_smem_layout_staged.inner,
|
|
123
151
|
)
|
|
124
|
-
return (sPostAct
|
|
152
|
+
return (sRowVec, sColVec, sPostAct)
|
|
125
153
|
|
|
126
154
|
@cute.jit
|
|
127
155
|
def epilogue(
|
|
@@ -133,21 +161,20 @@ class GemmActSm90(GemmSm90):
|
|
|
133
161
|
epi_store_pipeline: cutlass.pipeline.PipelineAsync,
|
|
134
162
|
epi_read_state: cutlass.pipeline.PipelineState,
|
|
135
163
|
epi_producer_state: cutlass.pipeline.PipelineState,
|
|
136
|
-
|
|
137
|
-
|
|
164
|
+
epi_tile: cute.Tile,
|
|
165
|
+
load_acc_subtile: Callable,
|
|
138
166
|
tRS_rD: cute.Tensor,
|
|
139
167
|
tRS_rC: Optional[cute.Tensor],
|
|
140
|
-
|
|
168
|
+
tiled_copy_t2r: Optional[cute.TiledCopy], # Only for Sm100
|
|
169
|
+
tiled_copy_r2s: cute.TiledCopy,
|
|
141
170
|
tRS_sD: cute.Tensor,
|
|
142
|
-
tiled_copy_s2r: Optional[cute.
|
|
171
|
+
tiled_copy_s2r: Optional[cute.TiledCopy],
|
|
143
172
|
tSR_rC: Optional[cute.Tensor],
|
|
144
173
|
tSR_sC: Optional[cute.Tensor],
|
|
145
174
|
copy_D: Optional[Callable],
|
|
146
|
-
|
|
147
|
-
bSG_gD: cute.Tensor,
|
|
148
|
-
epi_load_g2s: Optional[Callable],
|
|
175
|
+
copy_C: Optional[Callable],
|
|
149
176
|
tile_coord_mnkl: cute.Coord,
|
|
150
|
-
|
|
177
|
+
varlen_manager: VarlenManager,
|
|
151
178
|
epilogue_barrier: cutlass.pipeline.NamedBarrier,
|
|
152
179
|
tile_scheduler,
|
|
153
180
|
tidx: Int32,
|
|
@@ -158,41 +185,85 @@ class GemmActSm90(GemmSm90):
|
|
|
158
185
|
|
|
159
186
|
tma_atom_postact = params.tma_atom_postact
|
|
160
187
|
mPostAct_mnl = params.mPostAct_mnl
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
188
|
+
sRowVec, sColVec, sPostAct = epi_smem_tensors
|
|
189
|
+
get_smem_store_op = (
|
|
190
|
+
partial(sm100_utils.get_smem_store_op, tiled_tmem_load=tiled_copy_t2r)
|
|
191
|
+
if self.arch == 100
|
|
192
|
+
else sm90_utils_og.sm90_get_smem_store_op
|
|
193
|
+
)
|
|
194
|
+
copy_atom_postact_r2s = get_smem_store_op(
|
|
195
|
+
self.postact_layout, self.postact_dtype, self.acc_dtype
|
|
165
196
|
)
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
197
|
+
# tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma)
|
|
198
|
+
# tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_C_atom)
|
|
199
|
+
tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_r2s)
|
|
200
|
+
tRS_sPostAct = tiled_copy_postact_r2s.get_slice(tidx).partition_D(sPostAct)
|
|
201
|
+
(tma_desc_postact_ptr,) = tma_desc_epi_ptrs
|
|
202
|
+
batch_idx = tile_coord_mnkl[3]
|
|
203
|
+
copy_postact, _, _ = self.epilog_gmem_copy_and_partition(
|
|
170
204
|
tma_atom_postact,
|
|
171
|
-
mPostAct_mnl,
|
|
205
|
+
varlen_manager.offset_batch_epi(mPostAct_mnl, batch_idx),
|
|
172
206
|
self.cta_tile_shape_postact_mn,
|
|
173
|
-
|
|
207
|
+
params.epi_tile_postact,
|
|
174
208
|
sPostAct,
|
|
175
209
|
tile_coord_mnkl,
|
|
176
|
-
|
|
210
|
+
tma_desc_ptr=tma_desc_postact_ptr,
|
|
177
211
|
)
|
|
178
|
-
(tma_desc_postact_ptr,) = tma_desc_epi_ptrs
|
|
179
212
|
|
|
180
213
|
# We iterate over epi tiles in the N dimension first before the M dimension
|
|
181
214
|
epi_tile_shape = cute.zipped_divide(
|
|
182
|
-
cute.make_layout(self.cta_tile_shape_mnk[:2]),
|
|
215
|
+
cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile
|
|
183
216
|
).shape[1]
|
|
184
217
|
epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1))
|
|
185
218
|
epi_tile_num = cute.size(epi_tile_shape)
|
|
186
219
|
num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num
|
|
187
220
|
|
|
188
|
-
|
|
221
|
+
epi_tensors = self.epi_begin(
|
|
222
|
+
params,
|
|
223
|
+
epi_smem_tensors,
|
|
224
|
+
epi_tile,
|
|
225
|
+
tiled_copy_t2r,
|
|
226
|
+
tiled_copy_r2s,
|
|
227
|
+
tile_coord_mnkl,
|
|
228
|
+
varlen_manager,
|
|
229
|
+
epilogue_barrier,
|
|
230
|
+
tidx,
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
if const_expr(copy_C is not None):
|
|
189
234
|
for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1):
|
|
190
|
-
|
|
235
|
+
gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx)
|
|
236
|
+
if is_tma_warp:
|
|
237
|
+
epi_pipeline.producer_acquire(epi_producer_state)
|
|
238
|
+
copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
|
|
239
|
+
epi_pipeline.producer_commit(epi_producer_state)
|
|
240
|
+
epi_producer_state.advance()
|
|
241
|
+
|
|
242
|
+
def tma_store_fn(src_idx, dst_idx):
|
|
243
|
+
# Fence and barrier to make sure shared memory store is visible to TMA store
|
|
244
|
+
cute.arch.fence_proxy(
|
|
245
|
+
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
246
|
+
)
|
|
247
|
+
epilogue_barrier.arrive_and_wait()
|
|
248
|
+
# Copy from shared memory to global memory
|
|
249
|
+
if is_tma_warp:
|
|
250
|
+
if const_expr(has_D):
|
|
251
|
+
copy_D(src_idx=src_idx, dst_idx=dst_idx)
|
|
252
|
+
copy_postact(src_idx=src_idx, dst_idx=dst_idx)
|
|
253
|
+
# Can't use if statement here, epi_store_pipeline object isn't captured somehow
|
|
254
|
+
if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_commit())
|
|
255
|
+
if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_acquire())
|
|
256
|
+
epilogue_barrier.arrive_and_wait()
|
|
257
|
+
|
|
258
|
+
delay_tma_store = True
|
|
191
259
|
|
|
260
|
+
src_idx_prev, dst_idx_prev = None, None
|
|
192
261
|
for epi_idx in cutlass.range_constexpr(epi_tile_num):
|
|
262
|
+
# The global memory coordinate for the current epi tile
|
|
263
|
+
gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
|
|
193
264
|
# Copy from acc to D registers
|
|
194
|
-
|
|
195
|
-
|
|
265
|
+
load_acc_subtile(tRS_rD, epi_idx)
|
|
266
|
+
epi_loop_tensors = self.epi_begin_loop(params, epi_tensors, gmem_coord)
|
|
196
267
|
if const_expr(has_C):
|
|
197
268
|
epi_pipeline.consumer_wait(epi_read_state)
|
|
198
269
|
cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
|
|
@@ -204,69 +275,67 @@ class GemmActSm90(GemmSm90):
|
|
|
204
275
|
with cute.arch.elect_one():
|
|
205
276
|
epi_pipeline.consumer_release(epi_read_state)
|
|
206
277
|
epi_read_state.advance()
|
|
207
|
-
if const_expr(
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
278
|
+
if const_expr(copy_C is not None and epi_idx + self.epi_c_stage < epi_tile_num):
|
|
279
|
+
gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx + self.epi_c_stage)
|
|
280
|
+
if is_tma_warp:
|
|
281
|
+
epi_pipeline.producer_acquire(epi_producer_state)
|
|
282
|
+
copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
|
|
283
|
+
epi_pipeline.producer_commit(epi_producer_state)
|
|
284
|
+
epi_producer_state.advance()
|
|
285
|
+
tRS_rPostAct = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC)
|
|
212
286
|
epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage
|
|
287
|
+
if const_expr(delay_tma_store):
|
|
288
|
+
if const_expr(epi_idx > 0):
|
|
289
|
+
tma_store_fn(src_idx=src_idx_prev, dst_idx=dst_idx_prev)
|
|
290
|
+
src_idx_prev, dst_idx_prev = epi_buffer, gmem_coord
|
|
213
291
|
# Copy from D registers to shared memory
|
|
214
292
|
if const_expr(has_D):
|
|
215
|
-
|
|
216
|
-
tRS_rD_out = cute.make_fragment_like(tRS_rD, self.d_dtype)
|
|
217
|
-
tRS_rD_out.store(tRS_rD.load().to(self.d_dtype))
|
|
218
|
-
cute.copy(tiled_copy_r2s, tRS_rD_out, tRS_sD[None, None, None, epi_buffer])
|
|
293
|
+
copy_utils.cvt_copy(tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer])
|
|
219
294
|
cute.copy(
|
|
220
295
|
tiled_copy_postact_r2s,
|
|
221
296
|
tiled_copy_postact_r2s.retile(tRS_rPostAct),
|
|
222
297
|
tRS_sPostAct[None, None, None, epi_buffer],
|
|
223
298
|
)
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
)
|
|
241
|
-
epi_store_pipeline.producer_commit()
|
|
242
|
-
epi_store_pipeline.producer_acquire()
|
|
243
|
-
epilogue_barrier.arrive_and_wait()
|
|
299
|
+
if const_expr(not delay_tma_store):
|
|
300
|
+
tma_store_fn(src_idx=epi_buffer, dst_idx=gmem_coord)
|
|
301
|
+
|
|
302
|
+
if const_expr(delay_tma_store):
|
|
303
|
+
tma_store_fn(src_idx=src_idx_prev, dst_idx=dst_idx_prev)
|
|
304
|
+
|
|
305
|
+
self.epi_end(
|
|
306
|
+
params,
|
|
307
|
+
epi_tensors,
|
|
308
|
+
epi_tile,
|
|
309
|
+
tiled_copy_t2r,
|
|
310
|
+
tiled_copy_r2s,
|
|
311
|
+
tile_coord_mnkl,
|
|
312
|
+
varlen_manager,
|
|
313
|
+
tidx,
|
|
314
|
+
)
|
|
244
315
|
|
|
245
316
|
return epi_read_state, epi_producer_state
|
|
246
317
|
|
|
247
318
|
@cute.jit
|
|
248
|
-
def
|
|
319
|
+
def epi_visit_subtile(
|
|
249
320
|
self,
|
|
250
321
|
params: EpilogueParams,
|
|
322
|
+
epi_loop_tensors: Tuple[cute.Tensor, ...],
|
|
251
323
|
tRS_rD: cute.Tensor,
|
|
252
324
|
tRS_rC: Optional[cute.Tensor] = None,
|
|
253
325
|
) -> Optional[cute.Tensor]:
|
|
254
|
-
|
|
255
|
-
if const_expr(params.alpha is not None):
|
|
256
|
-
tRS_rD.store(tRS_rD.load() * params.alpha)
|
|
257
|
-
# Apply C with beta scaling
|
|
258
|
-
if const_expr(tRS_rC is not None):
|
|
259
|
-
if const_expr(params.beta is None):
|
|
260
|
-
# beta is None, default behavior: add C (beta=1.0)
|
|
261
|
-
tRS_rD.store(tRS_rD.load() + tRS_rC.load().to(tRS_rD.element_type))
|
|
262
|
-
else:
|
|
263
|
-
tRS_rD.store(tRS_rD.load() + params.beta * tRS_rC.load().to(tRS_rD.element_type))
|
|
326
|
+
GemmDefaultEpiMixin.epi_visit_subtile(self, params, epi_loop_tensors, tRS_rD, tRS_rC)
|
|
264
327
|
# Apply activation function if provided
|
|
265
328
|
# If we don't have .shape here, the compiler generates local stores and loads
|
|
266
329
|
if const_expr(params.act_fn is not None):
|
|
267
330
|
tRS_rPostAct = cute.make_fragment(tRS_rD.layout.shape, self.acc_dtype)
|
|
268
|
-
|
|
269
|
-
|
|
331
|
+
if const_expr(self.arch < 100):
|
|
332
|
+
for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
|
|
333
|
+
tRS_rPostAct[i] = params.act_fn(tRS_rD[i])
|
|
334
|
+
else:
|
|
335
|
+
for i in cutlass.range(cute.size(tRS_rPostAct) // 2, unroll_full=True):
|
|
336
|
+
tRS_rPostAct[2 * i], tRS_rPostAct[2 * i + 1] = params.act_fn(
|
|
337
|
+
(tRS_rD[2 * i], tRS_rD[2 * i + 1])
|
|
338
|
+
)
|
|
270
339
|
else:
|
|
271
340
|
tRS_rPostAct = tRS_rD
|
|
272
341
|
# Type conversion
|
|
@@ -275,6 +344,14 @@ class GemmActSm90(GemmSm90):
|
|
|
275
344
|
return tRS_rPostAct_out
|
|
276
345
|
|
|
277
346
|
|
|
347
|
+
class GemmActSm90(GemmActMixin, GemmSm90):
|
|
348
|
+
pass
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
class GemmActSm100(GemmActMixin, GemmSm100):
|
|
352
|
+
pass
|
|
353
|
+
|
|
354
|
+
|
|
278
355
|
act_fn_map = {
|
|
279
356
|
None: None,
|
|
280
357
|
"relu": quack.activation.relu,
|
|
@@ -283,7 +360,7 @@ act_fn_map = {
|
|
|
283
360
|
}
|
|
284
361
|
|
|
285
362
|
|
|
286
|
-
def
|
|
363
|
+
def gemm_act(
|
|
287
364
|
A: Tensor, # (l, m, k) or (total_m, k) if varlen_m or (whatever, k) if gather_A with varlen_m
|
|
288
365
|
B: Tensor, # (l, n, k)
|
|
289
366
|
D: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
|
|
@@ -297,6 +374,9 @@ def gemm_act_sm90(
|
|
|
297
374
|
cluster_N: int,
|
|
298
375
|
pingpong: bool = False,
|
|
299
376
|
persistent: bool = True,
|
|
377
|
+
max_swizzle_size: int = 8,
|
|
378
|
+
rowvec_bias: Optional[Tensor] = None, # (l, n)
|
|
379
|
+
colvec_bias: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m
|
|
300
380
|
cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
|
|
301
381
|
A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m
|
|
302
382
|
) -> None:
|
|
@@ -326,10 +406,14 @@ def gemm_act_sm90(
|
|
|
326
406
|
}
|
|
327
407
|
GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
|
|
328
408
|
|
|
329
|
-
|
|
409
|
+
device_capacity = get_device_capacity(A.device)
|
|
410
|
+
assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported"
|
|
411
|
+
GemmCls = GemmActSm100 if device_capacity[0] > 9 else GemmActSm90
|
|
412
|
+
|
|
413
|
+
acc_dtype = Float32
|
|
330
414
|
tile_shape_mn = (tile_M, tile_N)
|
|
331
415
|
cluster_shape_mnk = (cluster_M, cluster_N, 1)
|
|
332
|
-
if not
|
|
416
|
+
if not GemmCls.is_valid_dtypes(
|
|
333
417
|
tensor_infos["A"].dtype,
|
|
334
418
|
tensor_infos["B"].dtype,
|
|
335
419
|
acc_dtype,
|
|
@@ -342,9 +426,22 @@ def gemm_act_sm90(
|
|
|
342
426
|
max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
|
|
343
427
|
GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
|
|
344
428
|
act_fn = act_fn_map[activation]
|
|
345
|
-
epi_args =
|
|
429
|
+
epi_args = GemmCls.EpilogueArguments(
|
|
430
|
+
tensor_infos["PostAct"].cute_tensor,
|
|
431
|
+
act_fn,
|
|
432
|
+
mRowVecBroadcast=from_dlpack(rowvec_bias.detach(), assumed_align=4).mark_layout_dynamic(
|
|
433
|
+
leading_dim=1
|
|
434
|
+
)
|
|
435
|
+
if rowvec_bias is not None
|
|
436
|
+
else None,
|
|
437
|
+
mColVecBroadcast=from_dlpack(colvec_bias.detach(), assumed_align=4).mark_layout_dynamic(
|
|
438
|
+
leading_dim=1 if cu_seqlens_m is None else 0
|
|
439
|
+
)
|
|
440
|
+
if colvec_bias is not None
|
|
441
|
+
else None,
|
|
442
|
+
)
|
|
346
443
|
scheduler_args = GemmWrapperBase.create_scheduler_args(
|
|
347
|
-
max_active_clusters, tile_count_semaphore
|
|
444
|
+
max_active_clusters, tile_count_semaphore, max_swizzle_size=max_swizzle_size
|
|
348
445
|
)
|
|
349
446
|
|
|
350
447
|
# Create varlen arguments if needed (assumes persistent=True when varlen_m)
|
|
@@ -355,7 +452,7 @@ def gemm_act_sm90(
|
|
|
355
452
|
max_active_clusters,
|
|
356
453
|
cluster_shape_mnk,
|
|
357
454
|
tensor_infos,
|
|
358
|
-
|
|
455
|
+
GemmCls.num_epi_tensormaps,
|
|
359
456
|
pingpong,
|
|
360
457
|
)
|
|
361
458
|
|
|
@@ -368,23 +465,27 @@ def gemm_act_sm90(
|
|
|
368
465
|
pingpong,
|
|
369
466
|
persistent,
|
|
370
467
|
tile_count_semaphore is not None,
|
|
468
|
+
device_capacity,
|
|
469
|
+
max_swizzle_size,
|
|
470
|
+
rowvec_bias.dtype if rowvec_bias is not None else None,
|
|
471
|
+
colvec_bias.dtype if colvec_bias is not None else None,
|
|
371
472
|
cu_seqlens_m is not None,
|
|
372
473
|
A_idx is not None,
|
|
373
474
|
key_tensor_names=("A", "B", "D", "PostAct", "C"),
|
|
374
475
|
)
|
|
375
|
-
cache =
|
|
476
|
+
cache = gemm_act.compile_cache
|
|
376
477
|
if compile_key not in cache:
|
|
377
|
-
|
|
478
|
+
if device_capacity[0] == 9:
|
|
479
|
+
GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent)
|
|
480
|
+
gemm_obj = GemmCls(
|
|
378
481
|
acc_dtype,
|
|
379
482
|
tensor_infos["A"].dtype,
|
|
380
483
|
tile_shape_mn,
|
|
381
484
|
cluster_shape_mnk,
|
|
382
|
-
pingpong=pingpong,
|
|
383
|
-
is_persistent=persistent,
|
|
384
485
|
gather_A=gather_A,
|
|
385
486
|
)
|
|
386
487
|
cache[compile_key] = cute.compile(
|
|
387
|
-
|
|
488
|
+
gemm_obj,
|
|
388
489
|
tensor_infos["A"].cute_tensor,
|
|
389
490
|
tensor_infos["B"].cute_tensor,
|
|
390
491
|
tensor_infos["D"].cute_tensor,
|
|
@@ -406,4 +507,4 @@ def gemm_act_sm90(
|
|
|
406
507
|
)
|
|
407
508
|
|
|
408
509
|
|
|
409
|
-
|
|
510
|
+
gemm_act.compile_cache = {}
|