quack-kernels 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -8
- quack/activation.py +366 -121
- quack/broadcast_utils.py +29 -0
- quack/compile_utils.py +19 -0
- quack/copy_utils.py +487 -0
- quack/cross_entropy.py +157 -233
- quack/cute_dsl_utils.py +20 -34
- quack/gemm.py +194 -0
- quack/{gemm_act_sm90.py → gemm_act.py} +218 -117
- quack/gemm_config.py +72 -46
- quack/{gemm_dact_sm90.py → gemm_dact.py} +53 -21
- quack/gemm_default_epi.py +259 -0
- quack/gemm_interface.py +177 -31
- quack/gemm_sm100.py +729 -506
- quack/{dense_gemm_sm90.py → gemm_sm90.py} +344 -814
- quack/gemm_symmetric.py +330 -0
- quack/gemm_wrapper_utils.py +3 -1
- quack/layout_utils.py +287 -0
- quack/linear.py +24 -16
- quack/pipeline.py +158 -3
- quack/reduce.py +88 -49
- quack/reduction_base.py +25 -36
- quack/rmsnorm.py +476 -526
- quack/sm100_utils.py +62 -0
- quack/sm90_utils.py +127 -0
- quack/softmax.py +135 -203
- quack/sort/bitonic_sort.py +13 -10
- quack/sort/utils.py +6 -6
- quack/tile_scheduler.py +23 -16
- quack/topk.py +409 -85
- quack/utils.py +32 -220
- quack/varlen_utils.py +370 -1
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
- quack_kernels-0.2.3.dist-info/RECORD +44 -0
- quack/layernorm.py +0 -353
- quack/symmetric_dense_gemm_sm90.py +0 -2091
- quack_kernels-0.2.2.dist-info/RECORD +0 -37
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
quack/broadcast_utils.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# Copyright (c) 2025, Tri Dao.
|
|
2
|
+
from typing import Callable
|
|
3
|
+
|
|
4
|
+
import cutlass
|
|
5
|
+
import cutlass.cute as cute
|
|
6
|
+
from cutlass import Float32, const_expr
|
|
7
|
+
|
|
8
|
+
from quack.layout_utils import make_acc_tensor_mn_view
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@cute.jit
|
|
12
|
+
def vec_op(tCrC: cute.Tensor, tCrVec: cute.Tensor, op: Callable, is_colvec: bool) -> None:
|
|
13
|
+
if const_expr(tCrC.element_type != Float32): # Convert to f32
|
|
14
|
+
tCrC_f32 = cute.make_fragment(tCrC.shape, Float32)
|
|
15
|
+
tCrC_f32.store(tCrC.load().to(Float32))
|
|
16
|
+
else:
|
|
17
|
+
tCrC_f32 = tCrC
|
|
18
|
+
# this happens to work for frgA layout too, not just acc layout
|
|
19
|
+
tCrC_f32_mn = make_acc_tensor_mn_view(tCrC_f32)
|
|
20
|
+
if const_expr(is_colvec):
|
|
21
|
+
assert cute.size(tCrC_f32_mn, mode=[0]) == cute.size(tCrVec)
|
|
22
|
+
for r in cutlass.range(cute.size(tCrC_f32_mn, mode=[0]), unroll_full=True):
|
|
23
|
+
tCrC_f32_mn[r, None].store(op(tCrC_f32_mn[r, None].load(), tCrVec[r]))
|
|
24
|
+
else:
|
|
25
|
+
assert cute.size(tCrC_f32_mn, mode=[1]) == cute.size(tCrVec)
|
|
26
|
+
for c in cutlass.range(cute.size(tCrC_f32_mn, mode=[1]), unroll_full=True):
|
|
27
|
+
tCrC_f32_mn[None, c].store(op(tCrC_f32_mn[None, c].load(), tCrVec[c]))
|
|
28
|
+
if const_expr(tCrC.element_type != Float32): # Convert back to original dtype
|
|
29
|
+
tCrC.store(tCrC_f32.load().to(tCrC.element_type))
|
quack/compile_utils.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import cutlass.cute as cute
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def make_fake_tensor(dtype, shape, divisibility=1, leading_dim=-1) -> Optional[cute.Tensor]:
|
|
9
|
+
if leading_dim < 0:
|
|
10
|
+
leading_dim = len(shape) + leading_dim
|
|
11
|
+
if dtype is None:
|
|
12
|
+
return None
|
|
13
|
+
stride = tuple(
|
|
14
|
+
cute.sym_int64(divisibility=divisibility) if i != leading_dim else 1
|
|
15
|
+
for i in range(len(shape))
|
|
16
|
+
)
|
|
17
|
+
return cute.runtime.make_fake_tensor(
|
|
18
|
+
dtype, shape, stride=stride, assumed_align=divisibility * dtype.width // 8
|
|
19
|
+
)
|
quack/copy_utils.py
ADDED
|
@@ -0,0 +1,487 @@
|
|
|
1
|
+
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from typing import Optional, Type, Tuple, Callable
|
|
5
|
+
|
|
6
|
+
import cutlass
|
|
7
|
+
import cutlass.cute as cute
|
|
8
|
+
|
|
9
|
+
from cutlass import Int32, Boolean, const_expr
|
|
10
|
+
from cutlass.cute.nvgpu import cpasync
|
|
11
|
+
from cutlass.cutlass_dsl import dsl_user_op
|
|
12
|
+
import cutlass.pipeline
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dsl_user_op
|
|
16
|
+
def cvt_copy(
|
|
17
|
+
atom: cute.CopyAtom,
|
|
18
|
+
src: cute.Tensor,
|
|
19
|
+
dst: cute.Tensor,
|
|
20
|
+
*,
|
|
21
|
+
pred: Optional[cute.Tensor] = None,
|
|
22
|
+
loc=None,
|
|
23
|
+
ip=None,
|
|
24
|
+
**kwargs,
|
|
25
|
+
) -> None:
|
|
26
|
+
assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem
|
|
27
|
+
if const_expr(src.element_type != dst.element_type):
|
|
28
|
+
src_cvt = cute.make_fragment_like(src, dst.element_type)
|
|
29
|
+
src_cvt.store(src.load().to(dst.element_type))
|
|
30
|
+
src = src_cvt
|
|
31
|
+
cute.copy(atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dsl_user_op
|
|
35
|
+
def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
|
|
36
|
+
dst = cute.make_fragment_like(src, src.element_type, loc=loc, ip=ip)
|
|
37
|
+
cute.autovec_copy(src, dst, loc=loc, ip=ip)
|
|
38
|
+
return dst
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dsl_user_op
|
|
42
|
+
def load_s2r_retile(
|
|
43
|
+
tiled_copy: cute.TiledCopy,
|
|
44
|
+
src: cute.Tensor,
|
|
45
|
+
dst_shape: cute.Tensor | cute.Shape,
|
|
46
|
+
*,
|
|
47
|
+
loc=None,
|
|
48
|
+
ip=None,
|
|
49
|
+
) -> cute.Tensor:
|
|
50
|
+
# Will also accept dst_shape being a tensor, in which case we write into that tensor
|
|
51
|
+
if const_expr(not isinstance(dst_shape, cute.Tensor)):
|
|
52
|
+
dst = cute.make_fragment(dst_shape, src.element_type, loc=loc, ip=ip)
|
|
53
|
+
else:
|
|
54
|
+
dst = dst_shape
|
|
55
|
+
cute.copy(tiled_copy, src, tiled_copy.retile(dst), loc=loc, ip=ip)
|
|
56
|
+
return dst
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@dsl_user_op
|
|
60
|
+
def get_copy_atom(
|
|
61
|
+
dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None
|
|
62
|
+
) -> cute.CopyAtom:
|
|
63
|
+
num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width))
|
|
64
|
+
copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
|
|
65
|
+
return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@dsl_user_op
|
|
69
|
+
def copy(
|
|
70
|
+
src: cute.Tensor,
|
|
71
|
+
dst: cute.Tensor,
|
|
72
|
+
*,
|
|
73
|
+
pred: Optional[cute.Tensor] = None,
|
|
74
|
+
is_async: bool = False,
|
|
75
|
+
loc=None,
|
|
76
|
+
ip=None,
|
|
77
|
+
**kwargs,
|
|
78
|
+
) -> None:
|
|
79
|
+
num_copy_elems = src.shape[0][0]
|
|
80
|
+
copy_atom = get_copy_atom(src.element_type, num_copy_elems, is_async)
|
|
81
|
+
cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def tiled_copy_1d(
|
|
85
|
+
dtype: Type[cutlass.Numeric], num_threads: int, num_copy_elems: int = 1, is_async: bool = False
|
|
86
|
+
) -> cute.TiledCopy:
|
|
87
|
+
num_copy_bits = num_copy_elems * dtype.width
|
|
88
|
+
copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
|
|
89
|
+
copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
|
|
90
|
+
thr_layout = cute.make_layout(num_threads)
|
|
91
|
+
val_layout = cute.make_layout(num_copy_elems)
|
|
92
|
+
return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def tiled_copy_2d(
|
|
96
|
+
dtype: Type[cutlass.Numeric],
|
|
97
|
+
threads_per_row: int,
|
|
98
|
+
num_threads: int,
|
|
99
|
+
num_copy_elems: int = 1,
|
|
100
|
+
is_async: bool = False,
|
|
101
|
+
) -> cute.TiledCopy:
|
|
102
|
+
num_copy_bits = num_copy_elems * dtype.width
|
|
103
|
+
copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
|
|
104
|
+
copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
|
|
105
|
+
assert num_threads % threads_per_row == 0
|
|
106
|
+
thr_layout = cute.make_ordered_layout(
|
|
107
|
+
(num_threads // threads_per_row, threads_per_row),
|
|
108
|
+
order=(1, 0),
|
|
109
|
+
)
|
|
110
|
+
val_layout = cute.make_layout((1, num_copy_elems))
|
|
111
|
+
return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@cute.jit
|
|
115
|
+
def predicate_k(tAcA: cute.Tensor, limit: Int32) -> cute.Tensor:
|
|
116
|
+
# Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
|
|
117
|
+
tApA = cute.make_fragment(
|
|
118
|
+
cute.make_layout(
|
|
119
|
+
(cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
|
|
120
|
+
stride=(cute.size(tAcA, mode=[2]), 0, 1),
|
|
121
|
+
),
|
|
122
|
+
Boolean,
|
|
123
|
+
)
|
|
124
|
+
for rest_v in cutlass.range_constexpr(tApA.shape[0]):
|
|
125
|
+
for rest_k in cutlass.range_constexpr(tApA.shape[2]):
|
|
126
|
+
tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)
|
|
127
|
+
return tApA
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
# def tiled_copy_2d(
|
|
131
|
+
# dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = False
|
|
132
|
+
# ) -> cute.TiledCopy:
|
|
133
|
+
# num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width
|
|
134
|
+
# copy_elems = num_copy_bits // dtype.width
|
|
135
|
+
# copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
|
|
136
|
+
# copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
|
|
137
|
+
# gmem_threads_per_row = major_mode_size // copy_elems
|
|
138
|
+
# assert num_threads % gmem_threads_per_row == 0
|
|
139
|
+
# thr_layout = cute.make_ordered_layout(
|
|
140
|
+
# (num_threads // gmem_threads_per_row, gmem_threads_per_row),
|
|
141
|
+
# order=(1, 0),
|
|
142
|
+
# )
|
|
143
|
+
# val_layout = cute.make_layout((1, copy_elems))
|
|
144
|
+
# return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def parse_swizzle_from_pointer(ptr: cute.Pointer) -> Tuple[int, int, int]:
|
|
148
|
+
"""Extract swizzle parameters from a pointer's swizzle_type.
|
|
149
|
+
|
|
150
|
+
The swizzle_type string has the form '!cute.swizzle<"S<b,m,s>">' where
|
|
151
|
+
b, m, s are the swizzle parameters (bits, base, shift).
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
A cute.Swizzle object constructed from the extracted parameters
|
|
155
|
+
|
|
156
|
+
Raises:
|
|
157
|
+
ValueError: If the swizzle_type string cannot be parsed
|
|
158
|
+
"""
|
|
159
|
+
# Ideally there should be a better API to get swizzle parameters, but we'll just parse
|
|
160
|
+
# the string here.
|
|
161
|
+
swizzle_str = str(ptr.type.swizzle_type)
|
|
162
|
+
# Extract the inner part "S<b,m,s>"
|
|
163
|
+
match = re.search(r"S<(\d+),(\d+),(\d+)>", swizzle_str)
|
|
164
|
+
if match:
|
|
165
|
+
b, m, s = int(match.group(1)), int(match.group(2)), int(match.group(3))
|
|
166
|
+
return b, m, s
|
|
167
|
+
else:
|
|
168
|
+
raise ValueError(f"Could not parse swizzle_type: {swizzle_str}")
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def swizzle_int(ptr_int: Int32, b: int, m: int, s: int) -> Int32:
|
|
172
|
+
bit_msk = (1 << b) - 1
|
|
173
|
+
yyy_msk = bit_msk << (m + s)
|
|
174
|
+
return ptr_int ^ ((ptr_int & yyy_msk) >> s)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def swizzle_ptr(ptr: cute.Pointer):
|
|
178
|
+
b, m, s = parse_swizzle_from_pointer(ptr)
|
|
179
|
+
ptr_int = swizzle_int(ptr.toint(), b, m, s)
|
|
180
|
+
return cute.make_ptr(ptr.dtype, ptr_int, ptr.memspace, assumed_align=ptr.alignment)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def as_position_independent_swizzle_tensor(tensor: cute.Tensor) -> cute.Tensor:
|
|
184
|
+
outer = tensor.layout
|
|
185
|
+
width = tensor.element_type.width
|
|
186
|
+
inner = cute.make_swizzle(*parse_swizzle_from_pointer(tensor.iterator))
|
|
187
|
+
# Need to recast the swizzle from byte (e.g. <3, 4, 3> to element units (e.g. <3, 3, 3> for
|
|
188
|
+
# for 16 bits and <3, 2, 3> for 32 bits)
|
|
189
|
+
new_layout = cute.recast_layout(
|
|
190
|
+
width, 8, cute.make_composed_layout(inner, 0, cute.recast_layout(8, width, outer))
|
|
191
|
+
)
|
|
192
|
+
# recast_ptr to remove the pointer swizzle
|
|
193
|
+
return cute.make_tensor(cute.recast_ptr(tensor.iterator, dtype=tensor.element_type), new_layout)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def partition_D_position_independent(
|
|
197
|
+
thr_copy: cute.core.ThrCopy, tensor: cute.Tensor
|
|
198
|
+
) -> cute.Tensor:
|
|
199
|
+
return cute.make_tensor(
|
|
200
|
+
swizzle_ptr(thr_copy.partition_D(tensor).iterator),
|
|
201
|
+
thr_copy.partition_D(as_position_independent_swizzle_tensor(tensor)).layout,
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def partition_S_position_independent(
|
|
206
|
+
thr_copy: cute.core.ThrCopy, tensor: cute.Tensor
|
|
207
|
+
) -> cute.Tensor:
|
|
208
|
+
return cute.make_tensor(
|
|
209
|
+
swizzle_ptr(thr_copy.partition_S(tensor).iterator),
|
|
210
|
+
thr_copy.partition_S(as_position_independent_swizzle_tensor(tensor)).layout,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
@dsl_user_op
|
|
215
|
+
def sm90_get_smem_load_op(
|
|
216
|
+
layout_c: cutlass.utils.LayoutEnum,
|
|
217
|
+
elem_ty_c: Type[cutlass.Numeric],
|
|
218
|
+
*,
|
|
219
|
+
loc=None,
|
|
220
|
+
ip=None,
|
|
221
|
+
) -> cute.CopyAtom:
|
|
222
|
+
"""
|
|
223
|
+
Selects the largest vectorized smem load atom available subject to constraint of gmem layout.
|
|
224
|
+
|
|
225
|
+
Parameters:
|
|
226
|
+
-----------
|
|
227
|
+
layout_c : LayoutEnum
|
|
228
|
+
The layout enum of the output tensor D.
|
|
229
|
+
|
|
230
|
+
elem_ty_c : Type[Numeric]
|
|
231
|
+
The element type for output tensor D.
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
--------
|
|
235
|
+
Either SmemLoadMatrix or SimtSyncCopy, based on the input parameters.
|
|
236
|
+
"""
|
|
237
|
+
|
|
238
|
+
if not isinstance(elem_ty_c, cutlass.cutlass_dsl.NumericMeta):
|
|
239
|
+
raise TypeError(f"elem_ty_c must be a Numeric, but got {elem_ty_c}")
|
|
240
|
+
is_m_major = layout_c.is_m_major_c()
|
|
241
|
+
if elem_ty_c.width == 16:
|
|
242
|
+
return cute.make_copy_atom(
|
|
243
|
+
cute.nvgpu.warp.LdMatrix8x8x16bOp(is_m_major, 4), elem_ty_c, loc=loc, ip=ip
|
|
244
|
+
)
|
|
245
|
+
else:
|
|
246
|
+
return cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), elem_ty_c, loc=loc, ip=ip)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def get_smem_store_atom(
|
|
250
|
+
arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False
|
|
251
|
+
) -> cute.CopyAtom:
|
|
252
|
+
if const_expr(arch < 90 or element_type.width != 16):
|
|
253
|
+
return cute.make_copy_atom(
|
|
254
|
+
cute.nvgpu.CopyUniversalOp(),
|
|
255
|
+
element_type,
|
|
256
|
+
num_bits_per_copy=(2 if not transpose else 1) * element_type.width,
|
|
257
|
+
)
|
|
258
|
+
else:
|
|
259
|
+
return cute.make_copy_atom(
|
|
260
|
+
cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4),
|
|
261
|
+
element_type,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def tma_get_copy_fn(
|
|
266
|
+
atom: cute.CopyAtom,
|
|
267
|
+
cta_coord: cute.Coord,
|
|
268
|
+
cta_layout: cute.Layout,
|
|
269
|
+
src_tensor: cute.Tensor,
|
|
270
|
+
dst_tensor: cute.Tensor,
|
|
271
|
+
filter_zeros: bool = False,
|
|
272
|
+
**kwargs,
|
|
273
|
+
) -> Callable:
|
|
274
|
+
src_is_smem = const_expr(
|
|
275
|
+
isinstance(src_tensor.iterator, cute.Pointer)
|
|
276
|
+
and src_tensor.memspace == cute.AddressSpace.smem
|
|
277
|
+
)
|
|
278
|
+
smem_tensor, gmem_tensor = (src_tensor, dst_tensor) if src_is_smem else (dst_tensor, src_tensor)
|
|
279
|
+
# ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
|
|
280
|
+
s, g = cpasync.tma_partition(
|
|
281
|
+
atom,
|
|
282
|
+
cta_coord,
|
|
283
|
+
cta_layout,
|
|
284
|
+
cute.group_modes(smem_tensor, 0, cute.rank(smem_tensor) - 1),
|
|
285
|
+
cute.group_modes(gmem_tensor, 0, cute.rank(gmem_tensor) - 1),
|
|
286
|
+
)
|
|
287
|
+
if const_expr(filter_zeros):
|
|
288
|
+
s = cute.filter_zeros(s)
|
|
289
|
+
g = cute.filter_zeros(g)
|
|
290
|
+
src, dst = (s, g) if src_is_smem else (g, s)
|
|
291
|
+
|
|
292
|
+
def copy_tma(src_idx, dst_idx, **new_kwargs):
|
|
293
|
+
cute.copy(atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs)
|
|
294
|
+
|
|
295
|
+
return copy_tma, s, g
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.PipelineAsync):
|
|
299
|
+
def copy_fn(src_idx, producer_state: cutlass.pipeline.PipelineState, **new_kwargs):
|
|
300
|
+
copy(
|
|
301
|
+
src_idx=src_idx,
|
|
302
|
+
dst_idx=producer_state.index,
|
|
303
|
+
tma_bar_ptr=pipeline.producer_get_barrier(producer_state),
|
|
304
|
+
**new_kwargs,
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
return copy_fn
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
@cute.jit
|
|
311
|
+
def gather_m_get_copy_fn(
|
|
312
|
+
thr_copy_A: cute.ThrCopy,
|
|
313
|
+
mA: cute.Tensor, # (whatever, K)
|
|
314
|
+
sA: cute.Tensor, # (tile_M, tile_N, STAGE)
|
|
315
|
+
gsAIdx: cute.Tensor, # (tile_M), either gmem or smem
|
|
316
|
+
limit_m: Int32,
|
|
317
|
+
limit_k: Int32,
|
|
318
|
+
) -> Callable:
|
|
319
|
+
tile_shape_mk = (cute.size(sA, mode=[0]), cute.size(sA, mode=[1]))
|
|
320
|
+
tAsA = thr_copy_A.partition_D(sA)
|
|
321
|
+
# k-major
|
|
322
|
+
assert tAsA.shape[2] == 1
|
|
323
|
+
tAsA = cute.group_modes(cute.slice_(tAsA, (None, None, 0, None)), 0, 2)
|
|
324
|
+
|
|
325
|
+
is_even_m_smem = tile_shape_mk[0] % thr_copy_A.tiler_mn[0].shape == 0
|
|
326
|
+
if const_expr(not is_even_m_smem):
|
|
327
|
+
limit_m = min(limit_m, tile_shape_mk[0])
|
|
328
|
+
elems_per_load = cute.size(tAsA.shape[0][0])
|
|
329
|
+
cA = cute.make_identity_tensor(tile_shape_mk)
|
|
330
|
+
tAcA = thr_copy_A.partition_S(cA)
|
|
331
|
+
t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
|
|
332
|
+
# Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
|
|
333
|
+
# since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0].
|
|
334
|
+
# This is so that when we do the comparison, t0AcA is known at compile time.
|
|
335
|
+
limit_m = limit_m - tAcA[0][0]
|
|
336
|
+
limit_k = limit_k - tAcA[0][1]
|
|
337
|
+
# Read and cache indices for A
|
|
338
|
+
rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
|
|
339
|
+
cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
|
|
340
|
+
tApA_m = cute.make_fragment(rows_per_thread, Boolean)
|
|
341
|
+
for m in cutlass.range(rows_per_thread, unroll_full=True):
|
|
342
|
+
tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
|
|
343
|
+
m_idx = cute.make_fragment(rows_per_thread, Int32)
|
|
344
|
+
for m in cutlass.range(rows_per_thread, unroll_full=True):
|
|
345
|
+
row_idx = tAcA[0, m, 0][0]
|
|
346
|
+
if tApA_m[m]:
|
|
347
|
+
m_idx[m] = gsAIdx[row_idx]
|
|
348
|
+
else:
|
|
349
|
+
m_idx[m] = 0 # It's ok to load row 0 in the case of OOB
|
|
350
|
+
|
|
351
|
+
mA_k = cute.logical_divide(mA, (None, tile_shape_mk[1]))
|
|
352
|
+
|
|
353
|
+
def copy_fn(src_idx, dst_idx, pred: bool = False):
|
|
354
|
+
tApA_k = None
|
|
355
|
+
if const_expr(pred):
|
|
356
|
+
tApA_k = cute.make_fragment(cols_per_thread, Boolean)
|
|
357
|
+
limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
|
|
358
|
+
for k in cutlass.range(cols_per_thread, unroll_full=True):
|
|
359
|
+
tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
|
|
360
|
+
mA_cur = mA_k[None, (None, src_idx)]
|
|
361
|
+
for m in cutlass.range_constexpr(tAcA.shape[1]):
|
|
362
|
+
# cute.tiled_divide(mA_cur[m_idx[m], None], (elems_per_load,)) would give shape
|
|
363
|
+
# ((elems_per_load), thread_per_row)
|
|
364
|
+
# But we actually want shape ((elems_per_load, 1), thread_per_row) to match tAsA
|
|
365
|
+
# So we append 1s to the last dimension and then do tiled_divide, then slice.
|
|
366
|
+
mA_row = cute.tiled_divide(
|
|
367
|
+
cute.append_ones(mA_cur[m_idx[m], None], up_to_rank=2), (elems_per_load, 1)
|
|
368
|
+
)[None, None, 0]
|
|
369
|
+
if const_expr(is_even_m_smem) or tApA_m[m]:
|
|
370
|
+
# There's only 1 load per row
|
|
371
|
+
assert cute.size(tAcA.shape, mode=[2]) == 1
|
|
372
|
+
ki = tAcA[0, 0, 0][1] // elems_per_load
|
|
373
|
+
cute.copy(thr_copy_A, mA_row[None, ki], tAsA[(None, m), dst_idx], pred=tApA_k)
|
|
374
|
+
|
|
375
|
+
return copy_fn
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
@cute.jit
|
|
379
|
+
def gather_k_get_copy_fn(
|
|
380
|
+
thr_copy_A: cute.ThrCopy,
|
|
381
|
+
mA: cute.Tensor, # (tile_M, whatever)
|
|
382
|
+
sA: cute.Tensor, # (tile_M, tile_N, STAGE)
|
|
383
|
+
gsAIdx: cute.Tensor, # (tile_K, RestK), either gmem or smem
|
|
384
|
+
limit_m: Int32,
|
|
385
|
+
limit_k: Int32,
|
|
386
|
+
) -> Callable:
|
|
387
|
+
gAIdx, sAIdx = None, None
|
|
388
|
+
if const_expr(gsAIdx.memspace == cute.AddressSpace.gmem):
|
|
389
|
+
gAIdx = gsAIdx
|
|
390
|
+
else:
|
|
391
|
+
assert gsAIdx.memspace == cute.AddressSpace.smem
|
|
392
|
+
sAIdx = gsAIdx
|
|
393
|
+
tile_shape_mk = (cute.size(sA, mode=[0]), cute.size(sA, mode=[1]))
|
|
394
|
+
# (atom_v, CPY_M, 1, STAGE)
|
|
395
|
+
tAsA = thr_copy_A.partition_D(sA)
|
|
396
|
+
# m-major
|
|
397
|
+
tAsA = cute.group_modes(tAsA, 0, 3)
|
|
398
|
+
|
|
399
|
+
is_even_m_smem = tile_shape_mk[0] % thr_copy_A.tiler_mn[0].shape == 0
|
|
400
|
+
if const_expr(not is_even_m_smem):
|
|
401
|
+
limit_m = min(limit_m, tile_shape_mk[0])
|
|
402
|
+
elems_per_load = cute.size(tAsA.shape[0][0])
|
|
403
|
+
cA = cute.make_identity_tensor(tile_shape_mk)
|
|
404
|
+
tAcA = thr_copy_A.partition_S(cA)
|
|
405
|
+
t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
|
|
406
|
+
# Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
|
|
407
|
+
# since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0].
|
|
408
|
+
# This is so that when we do the comparison, t0AcA is known at compile time.
|
|
409
|
+
limit_m = limit_m - tAcA[0][0]
|
|
410
|
+
limit_k = limit_k - tAcA[0][1]
|
|
411
|
+
# Read and cache indices for A
|
|
412
|
+
rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
|
|
413
|
+
cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
|
|
414
|
+
tApA_m = cute.make_fragment(rows_per_thread, Boolean)
|
|
415
|
+
for m in cutlass.range(rows_per_thread, unroll_full=True):
|
|
416
|
+
tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
|
|
417
|
+
threads_per_col = const_expr(thr_copy_A.tiler_mn[0].shape // elems_per_load)
|
|
418
|
+
# This is very convoluted but idk a better way
|
|
419
|
+
# for tile_M=128, flat_divide gives (8, 16, K),
|
|
420
|
+
# then logical_divide gives ((8, 1), (8, 2), K).
|
|
421
|
+
tidx = thr_copy_A.thr_idx
|
|
422
|
+
tAmA = cute.logical_divide(
|
|
423
|
+
cute.flat_divide(mA, (elems_per_load,)), (elems_per_load, threads_per_col)
|
|
424
|
+
)[None, (tidx % threads_per_col, None), None] # ((8, 1), 2, K)
|
|
425
|
+
|
|
426
|
+
def prefetch_from_gmem_fn(src_idx, pred: bool = False) -> Tuple[cute.Tensor, cute.Tensor]:
|
|
427
|
+
# Prefetch mAIdx early, even before smem is free
|
|
428
|
+
tApA_k = None
|
|
429
|
+
if const_expr(pred):
|
|
430
|
+
tApA_k = cute.make_fragment(cols_per_thread, Boolean)
|
|
431
|
+
limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
|
|
432
|
+
for k in cutlass.range(cols_per_thread, unroll_full=True):
|
|
433
|
+
tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
|
|
434
|
+
gAIdx_cur = gAIdx[None, src_idx]
|
|
435
|
+
k_idx = cute.make_fragment(cols_per_thread, Int32)
|
|
436
|
+
for k in cutlass.range(cols_per_thread):
|
|
437
|
+
col_idx = tAcA[0, 0, k][1]
|
|
438
|
+
if const_expr(not pred):
|
|
439
|
+
k_idx[k] = gAIdx_cur[col_idx]
|
|
440
|
+
else:
|
|
441
|
+
if tApA_k[k]:
|
|
442
|
+
k_idx[k] = gAIdx_cur[col_idx]
|
|
443
|
+
else:
|
|
444
|
+
k_idx[k] = -1
|
|
445
|
+
return k_idx, tApA_k
|
|
446
|
+
|
|
447
|
+
def prefetch_from_smem_fn(
|
|
448
|
+
a_prefetch_pipeline, src_idx, dst_idx, a_prefetch_consumer_state, pred: bool = False
|
|
449
|
+
) -> Tuple[cute.Tensor, cute.Tensor]:
|
|
450
|
+
tApA_k = None
|
|
451
|
+
if const_expr(pred):
|
|
452
|
+
tApA_k = cute.make_fragment(cols_per_thread, Boolean)
|
|
453
|
+
limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
|
|
454
|
+
for k in cutlass.range(cols_per_thread, unroll_full=True):
|
|
455
|
+
tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
|
|
456
|
+
a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state)
|
|
457
|
+
sAIdx_cur = sAIdx[None, dst_idx]
|
|
458
|
+
k_idx = cute.make_fragment(cols_per_thread, Int32)
|
|
459
|
+
for k in cutlass.range(cols_per_thread):
|
|
460
|
+
col_idx = tAcA[0, 0, k][1]
|
|
461
|
+
k_idx[k] = sAIdx_cur[col_idx]
|
|
462
|
+
cute.arch.sync_warp()
|
|
463
|
+
with cute.arch.elect_one():
|
|
464
|
+
a_prefetch_pipeline.consumer_release(a_prefetch_consumer_state)
|
|
465
|
+
return k_idx, tApA_k
|
|
466
|
+
|
|
467
|
+
def copy_fn(
|
|
468
|
+
src_idx, dst_idx, k_idx_tApA_k: Tuple[cute.Tensor, cute.Tensor], pred: bool = False
|
|
469
|
+
):
|
|
470
|
+
k_idx, tApA_k = k_idx_tApA_k
|
|
471
|
+
tApA_k_pred = None
|
|
472
|
+
if const_expr(pred):
|
|
473
|
+
tApA_k_pred = cute.prepend_ones(tApA_k, up_to_rank=2) # (1, cols_per_thread)
|
|
474
|
+
for k in cutlass.range_constexpr(tAcA.shape[2]):
|
|
475
|
+
# copy_A(tAmA[None, None, k_idx[k]], tAsA[(None, None, k), smem_idx], pred=cute.prepend_ones(tApA_m, up_to_rank=2))
|
|
476
|
+
for m in cutlass.range_constexpr(tAcA.shape[1]):
|
|
477
|
+
if tApA_m[m]:
|
|
478
|
+
cute.copy(
|
|
479
|
+
thr_copy_A,
|
|
480
|
+
tAmA[None, m, k_idx[k]],
|
|
481
|
+
tAsA[(None, m, k), dst_idx],
|
|
482
|
+
pred=None if const_expr(tApA_k_pred is None) else tApA_k_pred[None, k],
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
return copy_fn, prefetch_from_gmem_fn if const_expr(
|
|
486
|
+
gAIdx is not None
|
|
487
|
+
) else prefetch_from_smem_fn
|