quack-kernels 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. quack/__init__.py +1 -8
  2. quack/activation.py +366 -121
  3. quack/broadcast_utils.py +29 -0
  4. quack/compile_utils.py +19 -0
  5. quack/copy_utils.py +487 -0
  6. quack/cross_entropy.py +157 -233
  7. quack/cute_dsl_utils.py +20 -34
  8. quack/gemm.py +194 -0
  9. quack/{gemm_act_sm90.py → gemm_act.py} +218 -117
  10. quack/gemm_config.py +72 -46
  11. quack/{gemm_dact_sm90.py → gemm_dact.py} +53 -21
  12. quack/gemm_default_epi.py +259 -0
  13. quack/gemm_interface.py +177 -31
  14. quack/gemm_sm100.py +729 -506
  15. quack/{dense_gemm_sm90.py → gemm_sm90.py} +344 -814
  16. quack/gemm_symmetric.py +330 -0
  17. quack/gemm_wrapper_utils.py +3 -1
  18. quack/layout_utils.py +287 -0
  19. quack/linear.py +24 -16
  20. quack/pipeline.py +158 -3
  21. quack/reduce.py +88 -49
  22. quack/reduction_base.py +25 -36
  23. quack/rmsnorm.py +476 -526
  24. quack/sm100_utils.py +62 -0
  25. quack/sm90_utils.py +127 -0
  26. quack/softmax.py +135 -203
  27. quack/sort/bitonic_sort.py +13 -10
  28. quack/sort/utils.py +6 -6
  29. quack/tile_scheduler.py +23 -16
  30. quack/topk.py +409 -85
  31. quack/utils.py +32 -220
  32. quack/varlen_utils.py +370 -1
  33. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/METADATA +4 -2
  34. quack_kernels-0.2.4.dist-info/RECORD +44 -0
  35. quack/layernorm.py +0 -353
  36. quack/symmetric_dense_gemm_sm90.py +0 -2091
  37. quack_kernels-0.2.2.dist-info/RECORD +0 -37
  38. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/WHEEL +0 -0
  39. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/licenses/LICENSE +0 -0
  40. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,29 @@
1
+ # Copyright (c) 2025, Tri Dao.
2
+ from typing import Callable
3
+
4
+ import cutlass
5
+ import cutlass.cute as cute
6
+ from cutlass import Float32, const_expr
7
+
8
+ from quack.layout_utils import make_acc_tensor_mn_view
9
+
10
+
11
+ @cute.jit
12
+ def vec_op(tCrC: cute.Tensor, tCrVec: cute.Tensor, op: Callable, is_colvec: bool) -> None:
13
+ if const_expr(tCrC.element_type != Float32): # Convert to f32
14
+ tCrC_f32 = cute.make_fragment(tCrC.shape, Float32)
15
+ tCrC_f32.store(tCrC.load().to(Float32))
16
+ else:
17
+ tCrC_f32 = tCrC
18
+ # this happens to work for frgA layout too, not just acc layout
19
+ tCrC_f32_mn = make_acc_tensor_mn_view(tCrC_f32)
20
+ if const_expr(is_colvec):
21
+ assert cute.size(tCrC_f32_mn, mode=[0]) == cute.size(tCrVec)
22
+ for r in cutlass.range(cute.size(tCrC_f32_mn, mode=[0]), unroll_full=True):
23
+ tCrC_f32_mn[r, None].store(op(tCrC_f32_mn[r, None].load(), tCrVec[r]))
24
+ else:
25
+ assert cute.size(tCrC_f32_mn, mode=[1]) == cute.size(tCrVec)
26
+ for c in cutlass.range(cute.size(tCrC_f32_mn, mode=[1]), unroll_full=True):
27
+ tCrC_f32_mn[None, c].store(op(tCrC_f32_mn[None, c].load(), tCrVec[c]))
28
+ if const_expr(tCrC.element_type != Float32): # Convert back to original dtype
29
+ tCrC.store(tCrC_f32.load().to(tCrC.element_type))
quack/compile_utils.py ADDED
@@ -0,0 +1,19 @@
1
+ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
+
3
+ from typing import Optional
4
+
5
+ import cutlass.cute as cute
6
+
7
+
8
+ def make_fake_tensor(dtype, shape, divisibility=1, leading_dim=-1) -> Optional[cute.Tensor]:
9
+ if leading_dim < 0:
10
+ leading_dim = len(shape) + leading_dim
11
+ if dtype is None:
12
+ return None
13
+ stride = tuple(
14
+ cute.sym_int64(divisibility=divisibility) if i != leading_dim else 1
15
+ for i in range(len(shape))
16
+ )
17
+ return cute.runtime.make_fake_tensor(
18
+ dtype, shape, stride=stride, assumed_align=divisibility * dtype.width // 8
19
+ )
quack/copy_utils.py ADDED
@@ -0,0 +1,487 @@
1
+ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
+
3
+ import re
4
+ from typing import Optional, Type, Tuple, Callable
5
+
6
+ import cutlass
7
+ import cutlass.cute as cute
8
+
9
+ from cutlass import Int32, Boolean, const_expr
10
+ from cutlass.cute.nvgpu import cpasync
11
+ from cutlass.cutlass_dsl import dsl_user_op
12
+ import cutlass.pipeline
13
+
14
+
15
+ @dsl_user_op
16
+ def cvt_copy(
17
+ atom: cute.CopyAtom,
18
+ src: cute.Tensor,
19
+ dst: cute.Tensor,
20
+ *,
21
+ pred: Optional[cute.Tensor] = None,
22
+ loc=None,
23
+ ip=None,
24
+ **kwargs,
25
+ ) -> None:
26
+ assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem
27
+ if const_expr(src.element_type != dst.element_type):
28
+ src_cvt = cute.make_fragment_like(src, dst.element_type)
29
+ src_cvt.store(src.load().to(dst.element_type))
30
+ src = src_cvt
31
+ cute.copy(atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
32
+
33
+
34
+ @dsl_user_op
35
+ def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
36
+ dst = cute.make_fragment_like(src, src.element_type, loc=loc, ip=ip)
37
+ cute.autovec_copy(src, dst, loc=loc, ip=ip)
38
+ return dst
39
+
40
+
41
+ @dsl_user_op
42
+ def load_s2r_retile(
43
+ tiled_copy: cute.TiledCopy,
44
+ src: cute.Tensor,
45
+ dst_shape: cute.Tensor | cute.Shape,
46
+ *,
47
+ loc=None,
48
+ ip=None,
49
+ ) -> cute.Tensor:
50
+ # Will also accept dst_shape being a tensor, in which case we write into that tensor
51
+ if const_expr(not isinstance(dst_shape, cute.Tensor)):
52
+ dst = cute.make_fragment(dst_shape, src.element_type, loc=loc, ip=ip)
53
+ else:
54
+ dst = dst_shape
55
+ cute.copy(tiled_copy, src, tiled_copy.retile(dst), loc=loc, ip=ip)
56
+ return dst
57
+
58
+
59
+ @dsl_user_op
60
+ def get_copy_atom(
61
+ dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None
62
+ ) -> cute.CopyAtom:
63
+ num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width))
64
+ copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
65
+ return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
66
+
67
+
68
+ @dsl_user_op
69
+ def copy(
70
+ src: cute.Tensor,
71
+ dst: cute.Tensor,
72
+ *,
73
+ pred: Optional[cute.Tensor] = None,
74
+ is_async: bool = False,
75
+ loc=None,
76
+ ip=None,
77
+ **kwargs,
78
+ ) -> None:
79
+ num_copy_elems = src.shape[0][0]
80
+ copy_atom = get_copy_atom(src.element_type, num_copy_elems, is_async)
81
+ cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
82
+
83
+
84
+ def tiled_copy_1d(
85
+ dtype: Type[cutlass.Numeric], num_threads: int, num_copy_elems: int = 1, is_async: bool = False
86
+ ) -> cute.TiledCopy:
87
+ num_copy_bits = num_copy_elems * dtype.width
88
+ copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
89
+ copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
90
+ thr_layout = cute.make_layout(num_threads)
91
+ val_layout = cute.make_layout(num_copy_elems)
92
+ return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
93
+
94
+
95
+ def tiled_copy_2d(
96
+ dtype: Type[cutlass.Numeric],
97
+ threads_per_row: int,
98
+ num_threads: int,
99
+ num_copy_elems: int = 1,
100
+ is_async: bool = False,
101
+ ) -> cute.TiledCopy:
102
+ num_copy_bits = num_copy_elems * dtype.width
103
+ copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
104
+ copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
105
+ assert num_threads % threads_per_row == 0
106
+ thr_layout = cute.make_ordered_layout(
107
+ (num_threads // threads_per_row, threads_per_row),
108
+ order=(1, 0),
109
+ )
110
+ val_layout = cute.make_layout((1, num_copy_elems))
111
+ return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
112
+
113
+
114
+ @cute.jit
115
+ def predicate_k(tAcA: cute.Tensor, limit: Int32) -> cute.Tensor:
116
+ # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
117
+ tApA = cute.make_fragment(
118
+ cute.make_layout(
119
+ (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
120
+ stride=(cute.size(tAcA, mode=[2]), 0, 1),
121
+ ),
122
+ Boolean,
123
+ )
124
+ for rest_v in cutlass.range_constexpr(tApA.shape[0]):
125
+ for rest_k in cutlass.range_constexpr(tApA.shape[2]):
126
+ tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)
127
+ return tApA
128
+
129
+
130
+ # def tiled_copy_2d(
131
+ # dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = False
132
+ # ) -> cute.TiledCopy:
133
+ # num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width
134
+ # copy_elems = num_copy_bits // dtype.width
135
+ # copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
136
+ # copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
137
+ # gmem_threads_per_row = major_mode_size // copy_elems
138
+ # assert num_threads % gmem_threads_per_row == 0
139
+ # thr_layout = cute.make_ordered_layout(
140
+ # (num_threads // gmem_threads_per_row, gmem_threads_per_row),
141
+ # order=(1, 0),
142
+ # )
143
+ # val_layout = cute.make_layout((1, copy_elems))
144
+ # return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
145
+
146
+
147
+ def parse_swizzle_from_pointer(ptr: cute.Pointer) -> Tuple[int, int, int]:
148
+ """Extract swizzle parameters from a pointer's swizzle_type.
149
+
150
+ The swizzle_type string has the form '!cute.swizzle<"S<b,m,s>">' where
151
+ b, m, s are the swizzle parameters (bits, base, shift).
152
+
153
+ Returns:
154
+ A cute.Swizzle object constructed from the extracted parameters
155
+
156
+ Raises:
157
+ ValueError: If the swizzle_type string cannot be parsed
158
+ """
159
+ # Ideally there should be a better API to get swizzle parameters, but we'll just parse
160
+ # the string here.
161
+ swizzle_str = str(ptr.type.swizzle_type)
162
+ # Extract the inner part "S<b,m,s>"
163
+ match = re.search(r"S<(\d+),(\d+),(\d+)>", swizzle_str)
164
+ if match:
165
+ b, m, s = int(match.group(1)), int(match.group(2)), int(match.group(3))
166
+ return b, m, s
167
+ else:
168
+ raise ValueError(f"Could not parse swizzle_type: {swizzle_str}")
169
+
170
+
171
+ def swizzle_int(ptr_int: Int32, b: int, m: int, s: int) -> Int32:
172
+ bit_msk = (1 << b) - 1
173
+ yyy_msk = bit_msk << (m + s)
174
+ return ptr_int ^ ((ptr_int & yyy_msk) >> s)
175
+
176
+
177
+ def swizzle_ptr(ptr: cute.Pointer):
178
+ b, m, s = parse_swizzle_from_pointer(ptr)
179
+ ptr_int = swizzle_int(ptr.toint(), b, m, s)
180
+ return cute.make_ptr(ptr.dtype, ptr_int, ptr.memspace, assumed_align=ptr.alignment)
181
+
182
+
183
+ def as_position_independent_swizzle_tensor(tensor: cute.Tensor) -> cute.Tensor:
184
+ outer = tensor.layout
185
+ width = tensor.element_type.width
186
+ inner = cute.make_swizzle(*parse_swizzle_from_pointer(tensor.iterator))
187
+ # Need to recast the swizzle from byte (e.g. <3, 4, 3> to element units (e.g. <3, 3, 3> for
188
+ # for 16 bits and <3, 2, 3> for 32 bits)
189
+ new_layout = cute.recast_layout(
190
+ width, 8, cute.make_composed_layout(inner, 0, cute.recast_layout(8, width, outer))
191
+ )
192
+ # recast_ptr to remove the pointer swizzle
193
+ return cute.make_tensor(cute.recast_ptr(tensor.iterator, dtype=tensor.element_type), new_layout)
194
+
195
+
196
+ def partition_D_position_independent(
197
+ thr_copy: cute.core.ThrCopy, tensor: cute.Tensor
198
+ ) -> cute.Tensor:
199
+ return cute.make_tensor(
200
+ swizzle_ptr(thr_copy.partition_D(tensor).iterator),
201
+ thr_copy.partition_D(as_position_independent_swizzle_tensor(tensor)).layout,
202
+ )
203
+
204
+
205
+ def partition_S_position_independent(
206
+ thr_copy: cute.core.ThrCopy, tensor: cute.Tensor
207
+ ) -> cute.Tensor:
208
+ return cute.make_tensor(
209
+ swizzle_ptr(thr_copy.partition_S(tensor).iterator),
210
+ thr_copy.partition_S(as_position_independent_swizzle_tensor(tensor)).layout,
211
+ )
212
+
213
+
214
+ @dsl_user_op
215
+ def sm90_get_smem_load_op(
216
+ layout_c: cutlass.utils.LayoutEnum,
217
+ elem_ty_c: Type[cutlass.Numeric],
218
+ *,
219
+ loc=None,
220
+ ip=None,
221
+ ) -> cute.CopyAtom:
222
+ """
223
+ Selects the largest vectorized smem load atom available subject to constraint of gmem layout.
224
+
225
+ Parameters:
226
+ -----------
227
+ layout_c : LayoutEnum
228
+ The layout enum of the output tensor D.
229
+
230
+ elem_ty_c : Type[Numeric]
231
+ The element type for output tensor D.
232
+
233
+ Returns:
234
+ --------
235
+ Either SmemLoadMatrix or SimtSyncCopy, based on the input parameters.
236
+ """
237
+
238
+ if not isinstance(elem_ty_c, cutlass.cutlass_dsl.NumericMeta):
239
+ raise TypeError(f"elem_ty_c must be a Numeric, but got {elem_ty_c}")
240
+ is_m_major = layout_c.is_m_major_c()
241
+ if elem_ty_c.width == 16:
242
+ return cute.make_copy_atom(
243
+ cute.nvgpu.warp.LdMatrix8x8x16bOp(is_m_major, 4), elem_ty_c, loc=loc, ip=ip
244
+ )
245
+ else:
246
+ return cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), elem_ty_c, loc=loc, ip=ip)
247
+
248
+
249
+ def get_smem_store_atom(
250
+ arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False
251
+ ) -> cute.CopyAtom:
252
+ if const_expr(arch < 90 or element_type.width != 16):
253
+ return cute.make_copy_atom(
254
+ cute.nvgpu.CopyUniversalOp(),
255
+ element_type,
256
+ num_bits_per_copy=(2 if not transpose else 1) * element_type.width,
257
+ )
258
+ else:
259
+ return cute.make_copy_atom(
260
+ cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4),
261
+ element_type,
262
+ )
263
+
264
+
265
+ def tma_get_copy_fn(
266
+ atom: cute.CopyAtom,
267
+ cta_coord: cute.Coord,
268
+ cta_layout: cute.Layout,
269
+ src_tensor: cute.Tensor,
270
+ dst_tensor: cute.Tensor,
271
+ filter_zeros: bool = False,
272
+ **kwargs,
273
+ ) -> Callable:
274
+ src_is_smem = const_expr(
275
+ isinstance(src_tensor.iterator, cute.Pointer)
276
+ and src_tensor.memspace == cute.AddressSpace.smem
277
+ )
278
+ smem_tensor, gmem_tensor = (src_tensor, dst_tensor) if src_is_smem else (dst_tensor, src_tensor)
279
+ # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
280
+ s, g = cpasync.tma_partition(
281
+ atom,
282
+ cta_coord,
283
+ cta_layout,
284
+ cute.group_modes(smem_tensor, 0, cute.rank(smem_tensor) - 1),
285
+ cute.group_modes(gmem_tensor, 0, cute.rank(gmem_tensor) - 1),
286
+ )
287
+ if const_expr(filter_zeros):
288
+ s = cute.filter_zeros(s)
289
+ g = cute.filter_zeros(g)
290
+ src, dst = (s, g) if src_is_smem else (g, s)
291
+
292
+ def copy_tma(src_idx, dst_idx, **new_kwargs):
293
+ cute.copy(atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs)
294
+
295
+ return copy_tma, s, g
296
+
297
+
298
+ def tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.PipelineAsync):
299
+ def copy_fn(src_idx, producer_state: cutlass.pipeline.PipelineState, **new_kwargs):
300
+ copy(
301
+ src_idx=src_idx,
302
+ dst_idx=producer_state.index,
303
+ tma_bar_ptr=pipeline.producer_get_barrier(producer_state),
304
+ **new_kwargs,
305
+ )
306
+
307
+ return copy_fn
308
+
309
+
310
+ @cute.jit
311
+ def gather_m_get_copy_fn(
312
+ thr_copy_A: cute.ThrCopy,
313
+ mA: cute.Tensor, # (whatever, K)
314
+ sA: cute.Tensor, # (tile_M, tile_N, STAGE)
315
+ gsAIdx: cute.Tensor, # (tile_M), either gmem or smem
316
+ limit_m: Int32,
317
+ limit_k: Int32,
318
+ ) -> Callable:
319
+ tile_shape_mk = (cute.size(sA, mode=[0]), cute.size(sA, mode=[1]))
320
+ tAsA = thr_copy_A.partition_D(sA)
321
+ # k-major
322
+ assert tAsA.shape[2] == 1
323
+ tAsA = cute.group_modes(cute.slice_(tAsA, (None, None, 0, None)), 0, 2)
324
+
325
+ is_even_m_smem = tile_shape_mk[0] % thr_copy_A.tiler_mn[0].shape == 0
326
+ if const_expr(not is_even_m_smem):
327
+ limit_m = min(limit_m, tile_shape_mk[0])
328
+ elems_per_load = cute.size(tAsA.shape[0][0])
329
+ cA = cute.make_identity_tensor(tile_shape_mk)
330
+ tAcA = thr_copy_A.partition_S(cA)
331
+ t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
332
+ # Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
333
+ # since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0].
334
+ # This is so that when we do the comparison, t0AcA is known at compile time.
335
+ limit_m = limit_m - tAcA[0][0]
336
+ limit_k = limit_k - tAcA[0][1]
337
+ # Read and cache indices for A
338
+ rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
339
+ cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
340
+ tApA_m = cute.make_fragment(rows_per_thread, Boolean)
341
+ for m in cutlass.range(rows_per_thread, unroll_full=True):
342
+ tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
343
+ m_idx = cute.make_fragment(rows_per_thread, Int32)
344
+ for m in cutlass.range(rows_per_thread, unroll_full=True):
345
+ row_idx = tAcA[0, m, 0][0]
346
+ if tApA_m[m]:
347
+ m_idx[m] = gsAIdx[row_idx]
348
+ else:
349
+ m_idx[m] = 0 # It's ok to load row 0 in the case of OOB
350
+
351
+ mA_k = cute.logical_divide(mA, (None, tile_shape_mk[1]))
352
+
353
+ def copy_fn(src_idx, dst_idx, pred: bool = False):
354
+ tApA_k = None
355
+ if const_expr(pred):
356
+ tApA_k = cute.make_fragment(cols_per_thread, Boolean)
357
+ limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
358
+ for k in cutlass.range(cols_per_thread, unroll_full=True):
359
+ tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
360
+ mA_cur = mA_k[None, (None, src_idx)]
361
+ for m in cutlass.range_constexpr(tAcA.shape[1]):
362
+ # cute.tiled_divide(mA_cur[m_idx[m], None], (elems_per_load,)) would give shape
363
+ # ((elems_per_load), thread_per_row)
364
+ # But we actually want shape ((elems_per_load, 1), thread_per_row) to match tAsA
365
+ # So we append 1s to the last dimension and then do tiled_divide, then slice.
366
+ mA_row = cute.tiled_divide(
367
+ cute.append_ones(mA_cur[m_idx[m], None], up_to_rank=2), (elems_per_load, 1)
368
+ )[None, None, 0]
369
+ if const_expr(is_even_m_smem) or tApA_m[m]:
370
+ # There's only 1 load per row
371
+ assert cute.size(tAcA.shape, mode=[2]) == 1
372
+ ki = tAcA[0, 0, 0][1] // elems_per_load
373
+ cute.copy(thr_copy_A, mA_row[None, ki], tAsA[(None, m), dst_idx], pred=tApA_k)
374
+
375
+ return copy_fn
376
+
377
+
378
+ @cute.jit
379
+ def gather_k_get_copy_fn(
380
+ thr_copy_A: cute.ThrCopy,
381
+ mA: cute.Tensor, # (tile_M, whatever)
382
+ sA: cute.Tensor, # (tile_M, tile_N, STAGE)
383
+ gsAIdx: cute.Tensor, # (tile_K, RestK), either gmem or smem
384
+ limit_m: Int32,
385
+ limit_k: Int32,
386
+ ) -> Callable:
387
+ gAIdx, sAIdx = None, None
388
+ if const_expr(gsAIdx.memspace == cute.AddressSpace.gmem):
389
+ gAIdx = gsAIdx
390
+ else:
391
+ assert gsAIdx.memspace == cute.AddressSpace.smem
392
+ sAIdx = gsAIdx
393
+ tile_shape_mk = (cute.size(sA, mode=[0]), cute.size(sA, mode=[1]))
394
+ # (atom_v, CPY_M, 1, STAGE)
395
+ tAsA = thr_copy_A.partition_D(sA)
396
+ # m-major
397
+ tAsA = cute.group_modes(tAsA, 0, 3)
398
+
399
+ is_even_m_smem = tile_shape_mk[0] % thr_copy_A.tiler_mn[0].shape == 0
400
+ if const_expr(not is_even_m_smem):
401
+ limit_m = min(limit_m, tile_shape_mk[0])
402
+ elems_per_load = cute.size(tAsA.shape[0][0])
403
+ cA = cute.make_identity_tensor(tile_shape_mk)
404
+ tAcA = thr_copy_A.partition_S(cA)
405
+ t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
406
+ # Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
407
+ # since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0].
408
+ # This is so that when we do the comparison, t0AcA is known at compile time.
409
+ limit_m = limit_m - tAcA[0][0]
410
+ limit_k = limit_k - tAcA[0][1]
411
+ # Read and cache indices for A
412
+ rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
413
+ cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
414
+ tApA_m = cute.make_fragment(rows_per_thread, Boolean)
415
+ for m in cutlass.range(rows_per_thread, unroll_full=True):
416
+ tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
417
+ threads_per_col = const_expr(thr_copy_A.tiler_mn[0].shape // elems_per_load)
418
+ # This is very convoluted but idk a better way
419
+ # for tile_M=128, flat_divide gives (8, 16, K),
420
+ # then logical_divide gives ((8, 1), (8, 2), K).
421
+ tidx = thr_copy_A.thr_idx
422
+ tAmA = cute.logical_divide(
423
+ cute.flat_divide(mA, (elems_per_load,)), (elems_per_load, threads_per_col)
424
+ )[None, (tidx % threads_per_col, None), None] # ((8, 1), 2, K)
425
+
426
+ def prefetch_from_gmem_fn(src_idx, pred: bool = False) -> Tuple[cute.Tensor, cute.Tensor]:
427
+ # Prefetch mAIdx early, even before smem is free
428
+ tApA_k = None
429
+ if const_expr(pred):
430
+ tApA_k = cute.make_fragment(cols_per_thread, Boolean)
431
+ limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
432
+ for k in cutlass.range(cols_per_thread, unroll_full=True):
433
+ tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
434
+ gAIdx_cur = gAIdx[None, src_idx]
435
+ k_idx = cute.make_fragment(cols_per_thread, Int32)
436
+ for k in cutlass.range(cols_per_thread):
437
+ col_idx = tAcA[0, 0, k][1]
438
+ if const_expr(not pred):
439
+ k_idx[k] = gAIdx_cur[col_idx]
440
+ else:
441
+ if tApA_k[k]:
442
+ k_idx[k] = gAIdx_cur[col_idx]
443
+ else:
444
+ k_idx[k] = -1
445
+ return k_idx, tApA_k
446
+
447
+ def prefetch_from_smem_fn(
448
+ a_prefetch_pipeline, src_idx, dst_idx, a_prefetch_consumer_state, pred: bool = False
449
+ ) -> Tuple[cute.Tensor, cute.Tensor]:
450
+ tApA_k = None
451
+ if const_expr(pred):
452
+ tApA_k = cute.make_fragment(cols_per_thread, Boolean)
453
+ limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
454
+ for k in cutlass.range(cols_per_thread, unroll_full=True):
455
+ tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
456
+ a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state)
457
+ sAIdx_cur = sAIdx[None, dst_idx]
458
+ k_idx = cute.make_fragment(cols_per_thread, Int32)
459
+ for k in cutlass.range(cols_per_thread):
460
+ col_idx = tAcA[0, 0, k][1]
461
+ k_idx[k] = sAIdx_cur[col_idx]
462
+ cute.arch.sync_warp()
463
+ with cute.arch.elect_one():
464
+ a_prefetch_pipeline.consumer_release(a_prefetch_consumer_state)
465
+ return k_idx, tApA_k
466
+
467
+ def copy_fn(
468
+ src_idx, dst_idx, k_idx_tApA_k: Tuple[cute.Tensor, cute.Tensor], pred: bool = False
469
+ ):
470
+ k_idx, tApA_k = k_idx_tApA_k
471
+ tApA_k_pred = None
472
+ if const_expr(pred):
473
+ tApA_k_pred = cute.prepend_ones(tApA_k, up_to_rank=2) # (1, cols_per_thread)
474
+ for k in cutlass.range_constexpr(tAcA.shape[2]):
475
+ # copy_A(tAmA[None, None, k_idx[k]], tAsA[(None, None, k), smem_idx], pred=cute.prepend_ones(tApA_m, up_to_rank=2))
476
+ for m in cutlass.range_constexpr(tAcA.shape[1]):
477
+ if tApA_m[m]:
478
+ cute.copy(
479
+ thr_copy_A,
480
+ tAmA[None, m, k_idx[k]],
481
+ tAsA[(None, m, k), dst_idx],
482
+ pred=None if const_expr(tApA_k_pred is None) else tApA_k_pred[None, k],
483
+ )
484
+
485
+ return copy_fn, prefetch_from_gmem_fn if const_expr(
486
+ gAIdx is not None
487
+ ) else prefetch_from_smem_fn