quack-kernels 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. quack/__init__.py +1 -8
  2. quack/activation.py +366 -121
  3. quack/autotuner.py +64 -5
  4. quack/broadcast_utils.py +29 -0
  5. quack/compile_utils.py +19 -0
  6. quack/copy_utils.py +487 -0
  7. quack/cross_entropy.py +157 -233
  8. quack/cute_dsl_utils.py +20 -35
  9. quack/gemm.py +194 -0
  10. quack/gemm_act.py +510 -0
  11. quack/gemm_config.py +72 -46
  12. quack/gemm_dact.py +215 -0
  13. quack/gemm_default_epi.py +259 -0
  14. quack/gemm_interface.py +615 -146
  15. quack/{dense_gemm_sm100.py → gemm_sm100.py} +1034 -787
  16. quack/{dense_gemm_sm90.py → gemm_sm90.py} +552 -727
  17. quack/gemm_symmetric.py +330 -0
  18. quack/gemm_wrapper_utils.py +182 -23
  19. quack/layout_utils.py +287 -0
  20. quack/linear.py +24 -16
  21. quack/pipeline.py +158 -3
  22. quack/reduce.py +88 -49
  23. quack/reduction_base.py +25 -36
  24. quack/rmsnorm.py +508 -624
  25. quack/sm100_utils.py +62 -0
  26. quack/sm90_utils.py +127 -0
  27. quack/softmax.py +135 -203
  28. quack/sort/bitonic_sort.py +13 -10
  29. quack/sort/utils.py +6 -6
  30. quack/tile_scheduler.py +55 -61
  31. quack/topk.py +409 -85
  32. quack/utils.py +37 -172
  33. quack/varlen_utils.py +370 -6
  34. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
  35. quack_kernels-0.2.3.dist-info/RECORD +44 -0
  36. quack/gemm_act_sm90.py +0 -368
  37. quack/gemm_dact_sm90.py +0 -150
  38. quack/layernorm.py +0 -353
  39. quack/symmetric_dense_gemm_sm90.py +0 -2091
  40. quack_kernels-0.2.1.dist-info/RECORD +0 -37
  41. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
  42. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
  43. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
quack/autotuner.py CHANGED
@@ -11,7 +11,7 @@ import hashlib
11
11
  import json
12
12
  from pathlib import Path
13
13
  from functools import cached_property, partial
14
- from typing import Dict, Tuple
14
+ from typing import Dict, Tuple, List, Optional, Any
15
15
 
16
16
  import torch
17
17
  from torch import Tensor
@@ -53,7 +53,22 @@ def _base32(key):
53
53
 
54
54
 
55
55
  class Autotuner:
56
- def __init__(self, fn, key, configs, restore_value=None, do_bench=None, cache_results=False):
56
+ def __init__(
57
+ self,
58
+ fn,
59
+ key,
60
+ configs,
61
+ restore_value=None,
62
+ prune_configs_by: Optional[Dict] = None,
63
+ do_bench=None,
64
+ cache_results=False,
65
+ ):
66
+ """
67
+ :param prune_configs_by: a dict of functions that are used to prune configs, fields:
68
+ 'perf_model': performance model used to predicate running time with different configs, returns running time
69
+ 'top_k': number of configs to bench
70
+ 'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs.
71
+ """
57
72
  if not configs:
58
73
  self.configs = [AutotuneConfig()]
59
74
  else:
@@ -90,6 +105,16 @@ class Autotuner:
90
105
  else:
91
106
  self.post_hook = None
92
107
 
108
+ self.perf_model = None
109
+ self.configs_top_k = 1.0
110
+ self.early_config_prune = None
111
+ if prune_configs_by:
112
+ self.perf_model = prune_configs_by.get("perf_model", self.perf_model)
113
+ self.configs_top_k = prune_configs_by.get("top_k", self.configs_top_k)
114
+ self.early_config_prune = prune_configs_by.get(
115
+ "early_config_prune", self.early_config_prune
116
+ )
117
+
93
118
  self.fn = fn
94
119
  self._do_bench = do_bench
95
120
 
@@ -198,13 +223,14 @@ class Autotuner:
198
223
  key = tuple(key)
199
224
  if key not in self.cache:
200
225
  used_cached_result = False
226
+ pruned_configs = self.prune_configs(kwargs)
201
227
 
202
228
  @torch.compiler.disable # Don't want any tracing here
203
229
  def benchmark():
204
230
  bench_start = time.time()
205
231
  timings = {
206
232
  config: self._bench(*args, config=config, **kwargs)
207
- for config in self.configs
233
+ for config in pruned_configs
208
234
  }
209
235
  bench_end = time.time()
210
236
  if os.getenv(f"{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING", None) == "1":
@@ -215,7 +241,7 @@ class Autotuner:
215
241
  self.configs_timings = timings
216
242
 
217
243
  if self.cache_results:
218
- self.check_disk_cache(key, self.configs, benchmark)
244
+ self.check_disk_cache(key, pruned_configs, benchmark)
219
245
  else:
220
246
  benchmark()
221
247
 
@@ -239,6 +265,32 @@ class Autotuner:
239
265
  self.nargs = None
240
266
  return ret
241
267
 
268
+ def prune_configs(self, kwargs: Dict) -> List[Any]:
269
+ pruned_configs = self.configs
270
+ if self.early_config_prune:
271
+ pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs)
272
+ if self.perf_model:
273
+ top_k = self.configs_top_k
274
+ if isinstance(top_k, float) and top_k <= 1.0:
275
+ top_k = int(len(self.configs) * top_k)
276
+ elif not isinstance(top_k, int):
277
+ # Slice index must be an integer
278
+ raise TypeError(
279
+ "Error while pruning configs, top_k must be either 1) a float <= 1.0 or 2) an int"
280
+ )
281
+
282
+ if len(pruned_configs) > top_k:
283
+ est_timing = {
284
+ config: self.perf_model(
285
+ **self.nargs,
286
+ **kwargs,
287
+ **config.all_kwargs(),
288
+ )
289
+ for config in pruned_configs
290
+ }
291
+ pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
292
+ return pruned_configs
293
+
242
294
 
243
295
  class AutotuneConfig:
244
296
  """
@@ -272,7 +324,9 @@ class AutotuneConfig:
272
324
  return self_tuple == other_tuple
273
325
 
274
326
 
275
- def autotune(configs, key=None, restore_value=None, do_bench=None, cache_results=True):
327
+ def autotune(
328
+ configs, key=None, prune_configs_by=None, restore_value=None, do_bench=None, cache_results=True
329
+ ):
276
330
  f"""
277
331
  Decorator for auto-tuning a function function.
278
332
 
@@ -286,6 +340,10 @@ def autotune(configs, key=None, restore_value=None, do_bench=None, cache_results
286
340
  :type configs: list[AutotuneConfig]
287
341
  :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
288
342
  :type key: list[str]
343
+ :param prune_configs_by: a dict of functions that are used to prune configs, fields:
344
+ 'perf_model': performance model used to predicate running time with different configs, returns running time
345
+ 'top_k': number of configs to bench
346
+ 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs.
289
347
  :param restore_value: a list of argument names whose value will be restored after evaluating any configs.
290
348
  :type restore_value: list[str]
291
349
  :param do_bench: a benchmark function to measure the time of each run.
@@ -303,6 +361,7 @@ def autotune(configs, key=None, restore_value=None, do_bench=None, cache_results
303
361
  key,
304
362
  configs,
305
363
  restore_value=restore_value,
364
+ prune_configs_by=prune_configs_by,
306
365
  do_bench=do_bench,
307
366
  cache_results=cache_results,
308
367
  )
@@ -0,0 +1,29 @@
1
+ # Copyright (c) 2025, Tri Dao.
2
+ from typing import Callable
3
+
4
+ import cutlass
5
+ import cutlass.cute as cute
6
+ from cutlass import Float32, const_expr
7
+
8
+ from quack.layout_utils import make_acc_tensor_mn_view
9
+
10
+
11
+ @cute.jit
12
+ def vec_op(tCrC: cute.Tensor, tCrVec: cute.Tensor, op: Callable, is_colvec: bool) -> None:
13
+ if const_expr(tCrC.element_type != Float32): # Convert to f32
14
+ tCrC_f32 = cute.make_fragment(tCrC.shape, Float32)
15
+ tCrC_f32.store(tCrC.load().to(Float32))
16
+ else:
17
+ tCrC_f32 = tCrC
18
+ # this happens to work for frgA layout too, not just acc layout
19
+ tCrC_f32_mn = make_acc_tensor_mn_view(tCrC_f32)
20
+ if const_expr(is_colvec):
21
+ assert cute.size(tCrC_f32_mn, mode=[0]) == cute.size(tCrVec)
22
+ for r in cutlass.range(cute.size(tCrC_f32_mn, mode=[0]), unroll_full=True):
23
+ tCrC_f32_mn[r, None].store(op(tCrC_f32_mn[r, None].load(), tCrVec[r]))
24
+ else:
25
+ assert cute.size(tCrC_f32_mn, mode=[1]) == cute.size(tCrVec)
26
+ for c in cutlass.range(cute.size(tCrC_f32_mn, mode=[1]), unroll_full=True):
27
+ tCrC_f32_mn[None, c].store(op(tCrC_f32_mn[None, c].load(), tCrVec[c]))
28
+ if const_expr(tCrC.element_type != Float32): # Convert back to original dtype
29
+ tCrC.store(tCrC_f32.load().to(tCrC.element_type))
quack/compile_utils.py ADDED
@@ -0,0 +1,19 @@
1
+ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
+
3
+ from typing import Optional
4
+
5
+ import cutlass.cute as cute
6
+
7
+
8
+ def make_fake_tensor(dtype, shape, divisibility=1, leading_dim=-1) -> Optional[cute.Tensor]:
9
+ if leading_dim < 0:
10
+ leading_dim = len(shape) + leading_dim
11
+ if dtype is None:
12
+ return None
13
+ stride = tuple(
14
+ cute.sym_int64(divisibility=divisibility) if i != leading_dim else 1
15
+ for i in range(len(shape))
16
+ )
17
+ return cute.runtime.make_fake_tensor(
18
+ dtype, shape, stride=stride, assumed_align=divisibility * dtype.width // 8
19
+ )
quack/copy_utils.py ADDED
@@ -0,0 +1,487 @@
1
+ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
+
3
+ import re
4
+ from typing import Optional, Type, Tuple, Callable
5
+
6
+ import cutlass
7
+ import cutlass.cute as cute
8
+
9
+ from cutlass import Int32, Boolean, const_expr
10
+ from cutlass.cute.nvgpu import cpasync
11
+ from cutlass.cutlass_dsl import dsl_user_op
12
+ import cutlass.pipeline
13
+
14
+
15
+ @dsl_user_op
16
+ def cvt_copy(
17
+ atom: cute.CopyAtom,
18
+ src: cute.Tensor,
19
+ dst: cute.Tensor,
20
+ *,
21
+ pred: Optional[cute.Tensor] = None,
22
+ loc=None,
23
+ ip=None,
24
+ **kwargs,
25
+ ) -> None:
26
+ assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem
27
+ if const_expr(src.element_type != dst.element_type):
28
+ src_cvt = cute.make_fragment_like(src, dst.element_type)
29
+ src_cvt.store(src.load().to(dst.element_type))
30
+ src = src_cvt
31
+ cute.copy(atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
32
+
33
+
34
+ @dsl_user_op
35
+ def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
36
+ dst = cute.make_fragment_like(src, src.element_type, loc=loc, ip=ip)
37
+ cute.autovec_copy(src, dst, loc=loc, ip=ip)
38
+ return dst
39
+
40
+
41
+ @dsl_user_op
42
+ def load_s2r_retile(
43
+ tiled_copy: cute.TiledCopy,
44
+ src: cute.Tensor,
45
+ dst_shape: cute.Tensor | cute.Shape,
46
+ *,
47
+ loc=None,
48
+ ip=None,
49
+ ) -> cute.Tensor:
50
+ # Will also accept dst_shape being a tensor, in which case we write into that tensor
51
+ if const_expr(not isinstance(dst_shape, cute.Tensor)):
52
+ dst = cute.make_fragment(dst_shape, src.element_type, loc=loc, ip=ip)
53
+ else:
54
+ dst = dst_shape
55
+ cute.copy(tiled_copy, src, tiled_copy.retile(dst), loc=loc, ip=ip)
56
+ return dst
57
+
58
+
59
+ @dsl_user_op
60
+ def get_copy_atom(
61
+ dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None
62
+ ) -> cute.CopyAtom:
63
+ num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width))
64
+ copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
65
+ return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
66
+
67
+
68
+ @dsl_user_op
69
+ def copy(
70
+ src: cute.Tensor,
71
+ dst: cute.Tensor,
72
+ *,
73
+ pred: Optional[cute.Tensor] = None,
74
+ is_async: bool = False,
75
+ loc=None,
76
+ ip=None,
77
+ **kwargs,
78
+ ) -> None:
79
+ num_copy_elems = src.shape[0][0]
80
+ copy_atom = get_copy_atom(src.element_type, num_copy_elems, is_async)
81
+ cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
82
+
83
+
84
+ def tiled_copy_1d(
85
+ dtype: Type[cutlass.Numeric], num_threads: int, num_copy_elems: int = 1, is_async: bool = False
86
+ ) -> cute.TiledCopy:
87
+ num_copy_bits = num_copy_elems * dtype.width
88
+ copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
89
+ copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
90
+ thr_layout = cute.make_layout(num_threads)
91
+ val_layout = cute.make_layout(num_copy_elems)
92
+ return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
93
+
94
+
95
+ def tiled_copy_2d(
96
+ dtype: Type[cutlass.Numeric],
97
+ threads_per_row: int,
98
+ num_threads: int,
99
+ num_copy_elems: int = 1,
100
+ is_async: bool = False,
101
+ ) -> cute.TiledCopy:
102
+ num_copy_bits = num_copy_elems * dtype.width
103
+ copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
104
+ copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
105
+ assert num_threads % threads_per_row == 0
106
+ thr_layout = cute.make_ordered_layout(
107
+ (num_threads // threads_per_row, threads_per_row),
108
+ order=(1, 0),
109
+ )
110
+ val_layout = cute.make_layout((1, num_copy_elems))
111
+ return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
112
+
113
+
114
+ @cute.jit
115
+ def predicate_k(tAcA: cute.Tensor, limit: Int32) -> cute.Tensor:
116
+ # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
117
+ tApA = cute.make_fragment(
118
+ cute.make_layout(
119
+ (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
120
+ stride=(cute.size(tAcA, mode=[2]), 0, 1),
121
+ ),
122
+ Boolean,
123
+ )
124
+ for rest_v in cutlass.range_constexpr(tApA.shape[0]):
125
+ for rest_k in cutlass.range_constexpr(tApA.shape[2]):
126
+ tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)
127
+ return tApA
128
+
129
+
130
+ # def tiled_copy_2d(
131
+ # dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = False
132
+ # ) -> cute.TiledCopy:
133
+ # num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width
134
+ # copy_elems = num_copy_bits // dtype.width
135
+ # copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
136
+ # copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
137
+ # gmem_threads_per_row = major_mode_size // copy_elems
138
+ # assert num_threads % gmem_threads_per_row == 0
139
+ # thr_layout = cute.make_ordered_layout(
140
+ # (num_threads // gmem_threads_per_row, gmem_threads_per_row),
141
+ # order=(1, 0),
142
+ # )
143
+ # val_layout = cute.make_layout((1, copy_elems))
144
+ # return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
145
+
146
+
147
+ def parse_swizzle_from_pointer(ptr: cute.Pointer) -> Tuple[int, int, int]:
148
+ """Extract swizzle parameters from a pointer's swizzle_type.
149
+
150
+ The swizzle_type string has the form '!cute.swizzle<"S<b,m,s>">' where
151
+ b, m, s are the swizzle parameters (bits, base, shift).
152
+
153
+ Returns:
154
+ A cute.Swizzle object constructed from the extracted parameters
155
+
156
+ Raises:
157
+ ValueError: If the swizzle_type string cannot be parsed
158
+ """
159
+ # Ideally there should be a better API to get swizzle parameters, but we'll just parse
160
+ # the string here.
161
+ swizzle_str = str(ptr.type.swizzle_type)
162
+ # Extract the inner part "S<b,m,s>"
163
+ match = re.search(r"S<(\d+),(\d+),(\d+)>", swizzle_str)
164
+ if match:
165
+ b, m, s = int(match.group(1)), int(match.group(2)), int(match.group(3))
166
+ return b, m, s
167
+ else:
168
+ raise ValueError(f"Could not parse swizzle_type: {swizzle_str}")
169
+
170
+
171
+ def swizzle_int(ptr_int: Int32, b: int, m: int, s: int) -> Int32:
172
+ bit_msk = (1 << b) - 1
173
+ yyy_msk = bit_msk << (m + s)
174
+ return ptr_int ^ ((ptr_int & yyy_msk) >> s)
175
+
176
+
177
+ def swizzle_ptr(ptr: cute.Pointer):
178
+ b, m, s = parse_swizzle_from_pointer(ptr)
179
+ ptr_int = swizzle_int(ptr.toint(), b, m, s)
180
+ return cute.make_ptr(ptr.dtype, ptr_int, ptr.memspace, assumed_align=ptr.alignment)
181
+
182
+
183
+ def as_position_independent_swizzle_tensor(tensor: cute.Tensor) -> cute.Tensor:
184
+ outer = tensor.layout
185
+ width = tensor.element_type.width
186
+ inner = cute.make_swizzle(*parse_swizzle_from_pointer(tensor.iterator))
187
+ # Need to recast the swizzle from byte (e.g. <3, 4, 3> to element units (e.g. <3, 3, 3> for
188
+ # for 16 bits and <3, 2, 3> for 32 bits)
189
+ new_layout = cute.recast_layout(
190
+ width, 8, cute.make_composed_layout(inner, 0, cute.recast_layout(8, width, outer))
191
+ )
192
+ # recast_ptr to remove the pointer swizzle
193
+ return cute.make_tensor(cute.recast_ptr(tensor.iterator, dtype=tensor.element_type), new_layout)
194
+
195
+
196
+ def partition_D_position_independent(
197
+ thr_copy: cute.core.ThrCopy, tensor: cute.Tensor
198
+ ) -> cute.Tensor:
199
+ return cute.make_tensor(
200
+ swizzle_ptr(thr_copy.partition_D(tensor).iterator),
201
+ thr_copy.partition_D(as_position_independent_swizzle_tensor(tensor)).layout,
202
+ )
203
+
204
+
205
+ def partition_S_position_independent(
206
+ thr_copy: cute.core.ThrCopy, tensor: cute.Tensor
207
+ ) -> cute.Tensor:
208
+ return cute.make_tensor(
209
+ swizzle_ptr(thr_copy.partition_S(tensor).iterator),
210
+ thr_copy.partition_S(as_position_independent_swizzle_tensor(tensor)).layout,
211
+ )
212
+
213
+
214
+ @dsl_user_op
215
+ def sm90_get_smem_load_op(
216
+ layout_c: cutlass.utils.LayoutEnum,
217
+ elem_ty_c: Type[cutlass.Numeric],
218
+ *,
219
+ loc=None,
220
+ ip=None,
221
+ ) -> cute.CopyAtom:
222
+ """
223
+ Selects the largest vectorized smem load atom available subject to constraint of gmem layout.
224
+
225
+ Parameters:
226
+ -----------
227
+ layout_c : LayoutEnum
228
+ The layout enum of the output tensor D.
229
+
230
+ elem_ty_c : Type[Numeric]
231
+ The element type for output tensor D.
232
+
233
+ Returns:
234
+ --------
235
+ Either SmemLoadMatrix or SimtSyncCopy, based on the input parameters.
236
+ """
237
+
238
+ if not isinstance(elem_ty_c, cutlass.cutlass_dsl.NumericMeta):
239
+ raise TypeError(f"elem_ty_c must be a Numeric, but got {elem_ty_c}")
240
+ is_m_major = layout_c.is_m_major_c()
241
+ if elem_ty_c.width == 16:
242
+ return cute.make_copy_atom(
243
+ cute.nvgpu.warp.LdMatrix8x8x16bOp(is_m_major, 4), elem_ty_c, loc=loc, ip=ip
244
+ )
245
+ else:
246
+ return cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), elem_ty_c, loc=loc, ip=ip)
247
+
248
+
249
+ def get_smem_store_atom(
250
+ arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False
251
+ ) -> cute.CopyAtom:
252
+ if const_expr(arch < 90 or element_type.width != 16):
253
+ return cute.make_copy_atom(
254
+ cute.nvgpu.CopyUniversalOp(),
255
+ element_type,
256
+ num_bits_per_copy=(2 if not transpose else 1) * element_type.width,
257
+ )
258
+ else:
259
+ return cute.make_copy_atom(
260
+ cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4),
261
+ element_type,
262
+ )
263
+
264
+
265
+ def tma_get_copy_fn(
266
+ atom: cute.CopyAtom,
267
+ cta_coord: cute.Coord,
268
+ cta_layout: cute.Layout,
269
+ src_tensor: cute.Tensor,
270
+ dst_tensor: cute.Tensor,
271
+ filter_zeros: bool = False,
272
+ **kwargs,
273
+ ) -> Callable:
274
+ src_is_smem = const_expr(
275
+ isinstance(src_tensor.iterator, cute.Pointer)
276
+ and src_tensor.memspace == cute.AddressSpace.smem
277
+ )
278
+ smem_tensor, gmem_tensor = (src_tensor, dst_tensor) if src_is_smem else (dst_tensor, src_tensor)
279
+ # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
280
+ s, g = cpasync.tma_partition(
281
+ atom,
282
+ cta_coord,
283
+ cta_layout,
284
+ cute.group_modes(smem_tensor, 0, cute.rank(smem_tensor) - 1),
285
+ cute.group_modes(gmem_tensor, 0, cute.rank(gmem_tensor) - 1),
286
+ )
287
+ if const_expr(filter_zeros):
288
+ s = cute.filter_zeros(s)
289
+ g = cute.filter_zeros(g)
290
+ src, dst = (s, g) if src_is_smem else (g, s)
291
+
292
+ def copy_tma(src_idx, dst_idx, **new_kwargs):
293
+ cute.copy(atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs)
294
+
295
+ return copy_tma, s, g
296
+
297
+
298
+ def tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.PipelineAsync):
299
+ def copy_fn(src_idx, producer_state: cutlass.pipeline.PipelineState, **new_kwargs):
300
+ copy(
301
+ src_idx=src_idx,
302
+ dst_idx=producer_state.index,
303
+ tma_bar_ptr=pipeline.producer_get_barrier(producer_state),
304
+ **new_kwargs,
305
+ )
306
+
307
+ return copy_fn
308
+
309
+
310
+ @cute.jit
311
+ def gather_m_get_copy_fn(
312
+ thr_copy_A: cute.ThrCopy,
313
+ mA: cute.Tensor, # (whatever, K)
314
+ sA: cute.Tensor, # (tile_M, tile_N, STAGE)
315
+ gsAIdx: cute.Tensor, # (tile_M), either gmem or smem
316
+ limit_m: Int32,
317
+ limit_k: Int32,
318
+ ) -> Callable:
319
+ tile_shape_mk = (cute.size(sA, mode=[0]), cute.size(sA, mode=[1]))
320
+ tAsA = thr_copy_A.partition_D(sA)
321
+ # k-major
322
+ assert tAsA.shape[2] == 1
323
+ tAsA = cute.group_modes(cute.slice_(tAsA, (None, None, 0, None)), 0, 2)
324
+
325
+ is_even_m_smem = tile_shape_mk[0] % thr_copy_A.tiler_mn[0].shape == 0
326
+ if const_expr(not is_even_m_smem):
327
+ limit_m = min(limit_m, tile_shape_mk[0])
328
+ elems_per_load = cute.size(tAsA.shape[0][0])
329
+ cA = cute.make_identity_tensor(tile_shape_mk)
330
+ tAcA = thr_copy_A.partition_S(cA)
331
+ t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
332
+ # Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
333
+ # since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0].
334
+ # This is so that when we do the comparison, t0AcA is known at compile time.
335
+ limit_m = limit_m - tAcA[0][0]
336
+ limit_k = limit_k - tAcA[0][1]
337
+ # Read and cache indices for A
338
+ rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
339
+ cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
340
+ tApA_m = cute.make_fragment(rows_per_thread, Boolean)
341
+ for m in cutlass.range(rows_per_thread, unroll_full=True):
342
+ tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
343
+ m_idx = cute.make_fragment(rows_per_thread, Int32)
344
+ for m in cutlass.range(rows_per_thread, unroll_full=True):
345
+ row_idx = tAcA[0, m, 0][0]
346
+ if tApA_m[m]:
347
+ m_idx[m] = gsAIdx[row_idx]
348
+ else:
349
+ m_idx[m] = 0 # It's ok to load row 0 in the case of OOB
350
+
351
+ mA_k = cute.logical_divide(mA, (None, tile_shape_mk[1]))
352
+
353
+ def copy_fn(src_idx, dst_idx, pred: bool = False):
354
+ tApA_k = None
355
+ if const_expr(pred):
356
+ tApA_k = cute.make_fragment(cols_per_thread, Boolean)
357
+ limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
358
+ for k in cutlass.range(cols_per_thread, unroll_full=True):
359
+ tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
360
+ mA_cur = mA_k[None, (None, src_idx)]
361
+ for m in cutlass.range_constexpr(tAcA.shape[1]):
362
+ # cute.tiled_divide(mA_cur[m_idx[m], None], (elems_per_load,)) would give shape
363
+ # ((elems_per_load), thread_per_row)
364
+ # But we actually want shape ((elems_per_load, 1), thread_per_row) to match tAsA
365
+ # So we append 1s to the last dimension and then do tiled_divide, then slice.
366
+ mA_row = cute.tiled_divide(
367
+ cute.append_ones(mA_cur[m_idx[m], None], up_to_rank=2), (elems_per_load, 1)
368
+ )[None, None, 0]
369
+ if const_expr(is_even_m_smem) or tApA_m[m]:
370
+ # There's only 1 load per row
371
+ assert cute.size(tAcA.shape, mode=[2]) == 1
372
+ ki = tAcA[0, 0, 0][1] // elems_per_load
373
+ cute.copy(thr_copy_A, mA_row[None, ki], tAsA[(None, m), dst_idx], pred=tApA_k)
374
+
375
+ return copy_fn
376
+
377
+
378
+ @cute.jit
379
+ def gather_k_get_copy_fn(
380
+ thr_copy_A: cute.ThrCopy,
381
+ mA: cute.Tensor, # (tile_M, whatever)
382
+ sA: cute.Tensor, # (tile_M, tile_N, STAGE)
383
+ gsAIdx: cute.Tensor, # (tile_K, RestK), either gmem or smem
384
+ limit_m: Int32,
385
+ limit_k: Int32,
386
+ ) -> Callable:
387
+ gAIdx, sAIdx = None, None
388
+ if const_expr(gsAIdx.memspace == cute.AddressSpace.gmem):
389
+ gAIdx = gsAIdx
390
+ else:
391
+ assert gsAIdx.memspace == cute.AddressSpace.smem
392
+ sAIdx = gsAIdx
393
+ tile_shape_mk = (cute.size(sA, mode=[0]), cute.size(sA, mode=[1]))
394
+ # (atom_v, CPY_M, 1, STAGE)
395
+ tAsA = thr_copy_A.partition_D(sA)
396
+ # m-major
397
+ tAsA = cute.group_modes(tAsA, 0, 3)
398
+
399
+ is_even_m_smem = tile_shape_mk[0] % thr_copy_A.tiler_mn[0].shape == 0
400
+ if const_expr(not is_even_m_smem):
401
+ limit_m = min(limit_m, tile_shape_mk[0])
402
+ elems_per_load = cute.size(tAsA.shape[0][0])
403
+ cA = cute.make_identity_tensor(tile_shape_mk)
404
+ tAcA = thr_copy_A.partition_S(cA)
405
+ t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
406
+ # Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
407
+ # since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0].
408
+ # This is so that when we do the comparison, t0AcA is known at compile time.
409
+ limit_m = limit_m - tAcA[0][0]
410
+ limit_k = limit_k - tAcA[0][1]
411
+ # Read and cache indices for A
412
+ rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
413
+ cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
414
+ tApA_m = cute.make_fragment(rows_per_thread, Boolean)
415
+ for m in cutlass.range(rows_per_thread, unroll_full=True):
416
+ tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
417
+ threads_per_col = const_expr(thr_copy_A.tiler_mn[0].shape // elems_per_load)
418
+ # This is very convoluted but idk a better way
419
+ # for tile_M=128, flat_divide gives (8, 16, K),
420
+ # then logical_divide gives ((8, 1), (8, 2), K).
421
+ tidx = thr_copy_A.thr_idx
422
+ tAmA = cute.logical_divide(
423
+ cute.flat_divide(mA, (elems_per_load,)), (elems_per_load, threads_per_col)
424
+ )[None, (tidx % threads_per_col, None), None] # ((8, 1), 2, K)
425
+
426
+ def prefetch_from_gmem_fn(src_idx, pred: bool = False) -> Tuple[cute.Tensor, cute.Tensor]:
427
+ # Prefetch mAIdx early, even before smem is free
428
+ tApA_k = None
429
+ if const_expr(pred):
430
+ tApA_k = cute.make_fragment(cols_per_thread, Boolean)
431
+ limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
432
+ for k in cutlass.range(cols_per_thread, unroll_full=True):
433
+ tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
434
+ gAIdx_cur = gAIdx[None, src_idx]
435
+ k_idx = cute.make_fragment(cols_per_thread, Int32)
436
+ for k in cutlass.range(cols_per_thread):
437
+ col_idx = tAcA[0, 0, k][1]
438
+ if const_expr(not pred):
439
+ k_idx[k] = gAIdx_cur[col_idx]
440
+ else:
441
+ if tApA_k[k]:
442
+ k_idx[k] = gAIdx_cur[col_idx]
443
+ else:
444
+ k_idx[k] = -1
445
+ return k_idx, tApA_k
446
+
447
+ def prefetch_from_smem_fn(
448
+ a_prefetch_pipeline, src_idx, dst_idx, a_prefetch_consumer_state, pred: bool = False
449
+ ) -> Tuple[cute.Tensor, cute.Tensor]:
450
+ tApA_k = None
451
+ if const_expr(pred):
452
+ tApA_k = cute.make_fragment(cols_per_thread, Boolean)
453
+ limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
454
+ for k in cutlass.range(cols_per_thread, unroll_full=True):
455
+ tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
456
+ a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state)
457
+ sAIdx_cur = sAIdx[None, dst_idx]
458
+ k_idx = cute.make_fragment(cols_per_thread, Int32)
459
+ for k in cutlass.range(cols_per_thread):
460
+ col_idx = tAcA[0, 0, k][1]
461
+ k_idx[k] = sAIdx_cur[col_idx]
462
+ cute.arch.sync_warp()
463
+ with cute.arch.elect_one():
464
+ a_prefetch_pipeline.consumer_release(a_prefetch_consumer_state)
465
+ return k_idx, tApA_k
466
+
467
+ def copy_fn(
468
+ src_idx, dst_idx, k_idx_tApA_k: Tuple[cute.Tensor, cute.Tensor], pred: bool = False
469
+ ):
470
+ k_idx, tApA_k = k_idx_tApA_k
471
+ tApA_k_pred = None
472
+ if const_expr(pred):
473
+ tApA_k_pred = cute.prepend_ones(tApA_k, up_to_rank=2) # (1, cols_per_thread)
474
+ for k in cutlass.range_constexpr(tAcA.shape[2]):
475
+ # copy_A(tAmA[None, None, k_idx[k]], tAsA[(None, None, k), smem_idx], pred=cute.prepend_ones(tApA_m, up_to_rank=2))
476
+ for m in cutlass.range_constexpr(tAcA.shape[1]):
477
+ if tApA_m[m]:
478
+ cute.copy(
479
+ thr_copy_A,
480
+ tAmA[None, m, k_idx[k]],
481
+ tAsA[(None, m, k), dst_idx],
482
+ pred=None if const_expr(tApA_k_pred is None) else tApA_k_pred[None, k],
483
+ )
484
+
485
+ return copy_fn, prefetch_from_gmem_fn if const_expr(
486
+ gAIdx is not None
487
+ ) else prefetch_from_smem_fn