quack-kernels 0.2.4__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +11 -1
- quack/copy_utils.py +133 -6
- quack/cute_dsl_ptxas.py +151 -0
- quack/layout_utils.py +8 -0
- quack/pipeline.py +31 -13
- quack/sm90_utils.py +31 -1
- quack/sort/bitonic_sort.py +1 -1
- {quack_kernels-0.2.4.dist-info → quack_kernels-0.2.5.dist-info}/METADATA +2 -2
- {quack_kernels-0.2.4.dist-info → quack_kernels-0.2.5.dist-info}/RECORD +12 -11
- {quack_kernels-0.2.4.dist-info → quack_kernels-0.2.5.dist-info}/WHEEL +1 -1
- {quack_kernels-0.2.4.dist-info → quack_kernels-0.2.5.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.4.dist-info → quack_kernels-0.2.5.dist-info}/top_level.txt +0 -0
quack/__init__.py
CHANGED
|
@@ -1,9 +1,19 @@
|
|
|
1
|
-
__version__ = "0.2.
|
|
1
|
+
__version__ = "0.2.5"
|
|
2
|
+
|
|
3
|
+
import os
|
|
2
4
|
|
|
3
5
|
from quack.rmsnorm import rmsnorm
|
|
4
6
|
from quack.softmax import softmax
|
|
5
7
|
from quack.cross_entropy import cross_entropy
|
|
6
8
|
|
|
9
|
+
|
|
10
|
+
if os.environ.get("CUTE_DSL_PTXAS_PATH", None) is not None:
|
|
11
|
+
import quack.cute_dsl_ptxas # noqa: F401
|
|
12
|
+
|
|
13
|
+
# Patch to dump ptx and then use system ptxas to compile to cubin
|
|
14
|
+
quack.cute_dsl_ptxas.patch()
|
|
15
|
+
|
|
16
|
+
|
|
7
17
|
__all__ = [
|
|
8
18
|
"rmsnorm",
|
|
9
19
|
"softmax",
|
quack/copy_utils.py
CHANGED
|
@@ -7,18 +7,19 @@ import cutlass
|
|
|
7
7
|
import cutlass.cute as cute
|
|
8
8
|
|
|
9
9
|
from cutlass import Int32, Boolean, const_expr
|
|
10
|
-
from cutlass.cute.nvgpu import cpasync
|
|
10
|
+
from cutlass.cute.nvgpu import cpasync, warpgroup
|
|
11
11
|
from cutlass.cutlass_dsl import dsl_user_op
|
|
12
12
|
import cutlass.pipeline
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
@dsl_user_op
|
|
16
16
|
def cvt_copy(
|
|
17
|
-
|
|
17
|
+
tiled_copy: cute.TiledCopy,
|
|
18
18
|
src: cute.Tensor,
|
|
19
19
|
dst: cute.Tensor,
|
|
20
20
|
*,
|
|
21
21
|
pred: Optional[cute.Tensor] = None,
|
|
22
|
+
retile: bool = False,
|
|
22
23
|
loc=None,
|
|
23
24
|
ip=None,
|
|
24
25
|
**kwargs,
|
|
@@ -28,7 +29,9 @@ def cvt_copy(
|
|
|
28
29
|
src_cvt = cute.make_fragment_like(src, dst.element_type)
|
|
29
30
|
src_cvt.store(src.load().to(dst.element_type))
|
|
30
31
|
src = src_cvt
|
|
31
|
-
|
|
32
|
+
if const_expr(retile):
|
|
33
|
+
src = tiled_copy.retile(src)
|
|
34
|
+
cute.copy(tiled_copy, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
|
|
32
35
|
|
|
33
36
|
|
|
34
37
|
@dsl_user_op
|
|
@@ -262,6 +265,124 @@ def get_smem_store_atom(
|
|
|
262
265
|
)
|
|
263
266
|
|
|
264
267
|
|
|
268
|
+
def get_smem_load_atom(
|
|
269
|
+
arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False
|
|
270
|
+
) -> cute.CopyAtom:
|
|
271
|
+
if const_expr(arch < 90 or element_type.width != 16):
|
|
272
|
+
return cute.make_copy_atom(
|
|
273
|
+
cute.nvgpu.CopyUniversalOp(),
|
|
274
|
+
element_type,
|
|
275
|
+
num_bits_per_copy=(2 if not transpose else 1) * element_type.width,
|
|
276
|
+
)
|
|
277
|
+
else:
|
|
278
|
+
return cute.make_copy_atom(
|
|
279
|
+
cute.nvgpu.warp.LdMatrix8x8x16bOp(transpose=transpose, num_matrices=4),
|
|
280
|
+
element_type,
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def get_smem_store_C(
|
|
285
|
+
tiled_mma: cute.TiledMma,
|
|
286
|
+
sC: cute.Tensor,
|
|
287
|
+
tidx: Int32,
|
|
288
|
+
arch: int,
|
|
289
|
+
transpose: bool = False,
|
|
290
|
+
position_independent=False,
|
|
291
|
+
) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
|
|
292
|
+
dtype = sC.element_type
|
|
293
|
+
copy_atom = get_smem_store_atom(arch, dtype, transpose)
|
|
294
|
+
tiled_copy = cute.make_tiled_copy_C(copy_atom, tiled_mma)
|
|
295
|
+
thr_copy = tiled_copy.get_slice(tidx)
|
|
296
|
+
if const_expr(not position_independent):
|
|
297
|
+
tRS_sC = thr_copy.partition_D(sC)
|
|
298
|
+
else:
|
|
299
|
+
tRS_sC = partition_D_position_independent(thr_copy, sC)
|
|
300
|
+
|
|
301
|
+
def copy_fn(src: cute.Tensor, dst_idx: Int32, **new_kwargs):
|
|
302
|
+
cvt_copy(tiled_copy, src, tRS_sC[None, None, None, dst_idx], retile=True, **new_kwargs)
|
|
303
|
+
|
|
304
|
+
return copy_fn, thr_copy, tRS_sC
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def get_smem_load_C(
|
|
308
|
+
tiled_mma: cute.TiledMma,
|
|
309
|
+
sC: cute.Tensor,
|
|
310
|
+
tidx: Int32,
|
|
311
|
+
arch: int,
|
|
312
|
+
transpose: bool = False,
|
|
313
|
+
position_independent=False,
|
|
314
|
+
) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
|
|
315
|
+
dtype = sC.element_type
|
|
316
|
+
copy_atom = get_smem_load_atom(arch, dtype, transpose)
|
|
317
|
+
tiled_copy = cute.make_tiled_copy_C(copy_atom, tiled_mma)
|
|
318
|
+
thr_copy = tiled_copy.get_slice(tidx)
|
|
319
|
+
if const_expr(not position_independent):
|
|
320
|
+
tSR_sC = thr_copy.partition_S(sC)
|
|
321
|
+
else:
|
|
322
|
+
tSR_sC = partition_S_position_independent(thr_copy, sC)
|
|
323
|
+
copy_atom_RS = get_smem_store_atom(arch, dtype, transpose)
|
|
324
|
+
thr_copy_RS = cute.make_tiled_copy_C(copy_atom_RS, tiled_mma).get_slice(tidx)
|
|
325
|
+
tRS_shape = thr_copy_RS.partition_S(cute.make_identity_tensor(sC.shape[:2])).shape
|
|
326
|
+
|
|
327
|
+
def copy_fn(src_idx: Int32, **new_kwargs):
|
|
328
|
+
return load_s2r_retile(
|
|
329
|
+
tiled_copy, tSR_sC[None, None, None, src_idx], dst_shape=tRS_shape, **new_kwargs
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
return copy_fn, thr_copy, tSR_sC
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
def get_smem_store_A(
|
|
336
|
+
tiled_mma: cute.TiledMma, sA: cute.Tensor, tidx: Int32, arch: int, position_independent=False
|
|
337
|
+
) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
|
|
338
|
+
dtype = sA.element_type
|
|
339
|
+
transpose = tiled_mma.op.a_major_mode == warpgroup.OperandMajorMode.MN
|
|
340
|
+
copy_atom = get_smem_store_atom(arch, dtype, transpose)
|
|
341
|
+
tiled_copy = cute.make_tiled_copy_A(copy_atom, tiled_mma)
|
|
342
|
+
thr_copy = tiled_copy.get_slice(tidx)
|
|
343
|
+
if const_expr(not position_independent):
|
|
344
|
+
tRS_sA = thr_copy.partition_D(sA)
|
|
345
|
+
else:
|
|
346
|
+
tRS_sA = partition_D_position_independent(thr_copy, sA)
|
|
347
|
+
|
|
348
|
+
def copy_fn(src: cute.Tensor, dst_idx: Int32, **new_kwargs):
|
|
349
|
+
cvt_copy(tiled_copy, src, tRS_sA[None, None, None, dst_idx], retile=True, **new_kwargs)
|
|
350
|
+
|
|
351
|
+
return copy_fn, thr_copy, tRS_sA
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
def get_smem_load_A(
|
|
355
|
+
tiled_mma: cute.TiledMma,
|
|
356
|
+
sA: cute.Tensor,
|
|
357
|
+
tidx: Int32,
|
|
358
|
+
arch: int,
|
|
359
|
+
with_dst_tensor: bool = False,
|
|
360
|
+
position_independent=False,
|
|
361
|
+
) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
|
|
362
|
+
dtype = sA.element_type
|
|
363
|
+
transpose = tiled_mma.op.a_major_mode == warpgroup.OperandMajorMode.MN
|
|
364
|
+
copy_atom = get_smem_load_atom(arch, dtype, transpose)
|
|
365
|
+
tiled_copy = cute.make_tiled_copy_A(copy_atom, tiled_mma)
|
|
366
|
+
thr_copy = tiled_copy.get_slice(tidx)
|
|
367
|
+
if const_expr(not position_independent):
|
|
368
|
+
tSR_sA = thr_copy.partition_S(sA)
|
|
369
|
+
else:
|
|
370
|
+
tSR_sA = partition_S_position_independent(thr_copy, sA)
|
|
371
|
+
copy_atom_RS = get_smem_store_atom(arch, dtype, transpose)
|
|
372
|
+
thr_copy_RS = cute.make_tiled_copy_C(copy_atom_RS, tiled_mma).get_slice(tidx)
|
|
373
|
+
tRS_shape = tiled_mma.partition_shape_A(sA.shape[:2])
|
|
374
|
+
|
|
375
|
+
def copy_fn(src_idx: Int32, **new_kwargs):
|
|
376
|
+
return load_s2r_retile(
|
|
377
|
+
tiled_copy, tSR_sA[None, None, None, src_idx], dst_shape=tRS_shape, **new_kwargs
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
def copy_fn_w_dst_tensor(src_idx: Int32, dst: cute.Tensor, **new_kwargs):
|
|
381
|
+
return load_s2r_retile(tiled_copy, tSR_sA[None, None, None, src_idx], dst, **new_kwargs)
|
|
382
|
+
|
|
383
|
+
return copy_fn if not with_dst_tensor else copy_fn_w_dst_tensor, thr_copy, tSR_sA
|
|
384
|
+
|
|
385
|
+
|
|
265
386
|
def tma_get_copy_fn(
|
|
266
387
|
atom: cute.CopyAtom,
|
|
267
388
|
cta_coord: cute.Coord,
|
|
@@ -269,6 +390,7 @@ def tma_get_copy_fn(
|
|
|
269
390
|
src_tensor: cute.Tensor,
|
|
270
391
|
dst_tensor: cute.Tensor,
|
|
271
392
|
filter_zeros: bool = False,
|
|
393
|
+
single_stage: bool = False,
|
|
272
394
|
**kwargs,
|
|
273
395
|
) -> Callable:
|
|
274
396
|
src_is_smem = const_expr(
|
|
@@ -276,13 +398,15 @@ def tma_get_copy_fn(
|
|
|
276
398
|
and src_tensor.memspace == cute.AddressSpace.smem
|
|
277
399
|
)
|
|
278
400
|
smem_tensor, gmem_tensor = (src_tensor, dst_tensor) if src_is_smem else (dst_tensor, src_tensor)
|
|
401
|
+
group_rank_smem = const_expr(cute.rank(smem_tensor) - (1 if not single_stage else 0))
|
|
402
|
+
group_rank_gmem = const_expr(cute.rank(gmem_tensor) - (1 if not single_stage else 0))
|
|
279
403
|
# ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
|
|
280
404
|
s, g = cpasync.tma_partition(
|
|
281
405
|
atom,
|
|
282
406
|
cta_coord,
|
|
283
407
|
cta_layout,
|
|
284
|
-
cute.group_modes(smem_tensor, 0,
|
|
285
|
-
cute.group_modes(gmem_tensor, 0,
|
|
408
|
+
cute.group_modes(smem_tensor, 0, group_rank_smem),
|
|
409
|
+
cute.group_modes(gmem_tensor, 0, group_rank_gmem),
|
|
286
410
|
)
|
|
287
411
|
if const_expr(filter_zeros):
|
|
288
412
|
s = cute.filter_zeros(s)
|
|
@@ -292,7 +416,10 @@ def tma_get_copy_fn(
|
|
|
292
416
|
def copy_tma(src_idx, dst_idx, **new_kwargs):
|
|
293
417
|
cute.copy(atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs)
|
|
294
418
|
|
|
295
|
-
|
|
419
|
+
def copy_tma_single_stage(**new_kwargs):
|
|
420
|
+
cute.copy(atom, src, dst, **new_kwargs, **kwargs)
|
|
421
|
+
|
|
422
|
+
return (copy_tma if const_expr(not single_stage) else copy_tma_single_stage), s, g
|
|
296
423
|
|
|
297
424
|
|
|
298
425
|
def tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.PipelineAsync):
|
quack/cute_dsl_ptxas.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
"""
|
|
2
|
+
System ptxas replacement for CUTLASS DSL.
|
|
3
|
+
Environment variables:
|
|
4
|
+
CUTE_DSL_PTXAS_PATH - Path to ptxas (e.g., /usr/local/cuda/bin/ptxas)
|
|
5
|
+
CUTE_DSL_PTXAS_VERBOSE - Set to 1 for verbose output
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
import sys
|
|
10
|
+
import re
|
|
11
|
+
import ctypes
|
|
12
|
+
import subprocess
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
|
|
15
|
+
import cutlass
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
CUTE_DSL_PTXAS_PATH = os.environ.get("CUTE_DSL_PTXAS_PATH", None)
|
|
19
|
+
VERBOSE = os.environ.get("CUTE_DSL_PTXAS_VERBOSE", "0") == "1"
|
|
20
|
+
|
|
21
|
+
_original_load_cuda_library = None
|
|
22
|
+
_user_wanted_ptx = False # True if user originally set CUTE_DSL_KEEP_PTX=1
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _log(msg):
|
|
26
|
+
if VERBOSE:
|
|
27
|
+
print(f"[ptxas] {msg}", file=sys.stderr)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _get_ptx(compiled_func) -> tuple[str, Path] | None:
|
|
31
|
+
"""Find and read PTX file, stripping null bytes."""
|
|
32
|
+
func_name = getattr(compiled_func, "function_name", None)
|
|
33
|
+
if not func_name:
|
|
34
|
+
return None
|
|
35
|
+
|
|
36
|
+
dump_dir = os.environ.get("CUTE_DSL_DUMP_DIR", Path.cwd())
|
|
37
|
+
for ptx_path in Path(dump_dir).glob(f"*{func_name}*.ptx"):
|
|
38
|
+
content = ptx_path.read_text().rstrip("\x00")
|
|
39
|
+
if ".entry " in content and content.rstrip().endswith("}"):
|
|
40
|
+
_log(f"Found PTX: {ptx_path}")
|
|
41
|
+
return content, ptx_path
|
|
42
|
+
return None
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _compile_ptx(ptx_path: Path, ptx_content: str) -> bytes:
|
|
46
|
+
"""Compile PTX to cubin using system ptxas."""
|
|
47
|
+
# Extract arch from PTX
|
|
48
|
+
match = re.search(r"\.target\s+(sm_\d+[a-z]?)", ptx_content)
|
|
49
|
+
arch = match.group(1) if match else "sm_90a"
|
|
50
|
+
|
|
51
|
+
# Write stripped content back if needed
|
|
52
|
+
if ptx_path.read_text() != ptx_content:
|
|
53
|
+
ptx_path.write_text(ptx_content)
|
|
54
|
+
|
|
55
|
+
# Compile
|
|
56
|
+
cubin_tmp = ptx_path.with_suffix(".cubin.tmp")
|
|
57
|
+
try:
|
|
58
|
+
assert CUTE_DSL_PTXAS_PATH is not None
|
|
59
|
+
result = subprocess.run(
|
|
60
|
+
[CUTE_DSL_PTXAS_PATH, f"-arch={arch}", "-O3", "-o", str(cubin_tmp), str(ptx_path)],
|
|
61
|
+
capture_output=True,
|
|
62
|
+
text=True,
|
|
63
|
+
)
|
|
64
|
+
if result.returncode != 0:
|
|
65
|
+
raise RuntimeError(f"ptxas failed: {result.stderr}")
|
|
66
|
+
|
|
67
|
+
cubin_data = cubin_tmp.read_bytes()
|
|
68
|
+
_log(f"Compiled {ptx_path.name} -> {len(cubin_data)} bytes ({arch})")
|
|
69
|
+
|
|
70
|
+
# Save cubin if CUTE_DSL_KEEP_CUBIN is set
|
|
71
|
+
if os.environ.get("CUTE_DSL_KEEP_CUBIN", "0") == "1":
|
|
72
|
+
cubin_out = ptx_path.with_suffix(".cubin")
|
|
73
|
+
cubin_out.write_bytes(cubin_data)
|
|
74
|
+
_log(f"Saved: {cubin_out}")
|
|
75
|
+
|
|
76
|
+
return cubin_data
|
|
77
|
+
finally:
|
|
78
|
+
cubin_tmp.unlink(missing_ok=True)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _patched_load_cuda_library(self):
|
|
82
|
+
"""Replacement for _load_cuda_library that uses system ptxas."""
|
|
83
|
+
|
|
84
|
+
result = _get_ptx(self)
|
|
85
|
+
if not result:
|
|
86
|
+
_log("PTX not found, falling back to embedded ptxas")
|
|
87
|
+
return _original_load_cuda_library(self)
|
|
88
|
+
|
|
89
|
+
ptx_content, ptx_path = result
|
|
90
|
+
|
|
91
|
+
try:
|
|
92
|
+
cubin = _compile_ptx(ptx_path, ptx_content)
|
|
93
|
+
except Exception as e:
|
|
94
|
+
_log(f"Compilation failed ({e}), falling back to embedded ptxas")
|
|
95
|
+
return _original_load_cuda_library(self)
|
|
96
|
+
|
|
97
|
+
# Load cubin
|
|
98
|
+
import cuda.bindings.runtime as cuda_runtime
|
|
99
|
+
|
|
100
|
+
err, library = cuda_runtime.cudaLibraryLoadData(cubin, None, None, 0, None, None, 0)
|
|
101
|
+
if err != cuda_runtime.cudaError_t.cudaSuccess:
|
|
102
|
+
_log(f"cudaLibraryLoadData failed ({err}), falling back to embedded ptxas")
|
|
103
|
+
return _original_load_cuda_library(self)
|
|
104
|
+
|
|
105
|
+
# Register kernels on all devices
|
|
106
|
+
_, cuda_load_to_device = self._get_cuda_init_and_load()
|
|
107
|
+
lib_ptr = ctypes.c_void_p(int(library))
|
|
108
|
+
dev_id = ctypes.c_int32(0)
|
|
109
|
+
err_val = ctypes.c_int32(0)
|
|
110
|
+
args = (ctypes.c_void_p * 3)(
|
|
111
|
+
ctypes.cast(ctypes.pointer(lib_ptr), ctypes.c_void_p),
|
|
112
|
+
ctypes.cast(ctypes.pointer(dev_id), ctypes.c_void_p),
|
|
113
|
+
ctypes.cast(ctypes.pointer(err_val), ctypes.c_void_p),
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
for dev in range(self.num_devices):
|
|
117
|
+
dev_id.value = dev
|
|
118
|
+
cuda_load_to_device(args)
|
|
119
|
+
if err_val.value != 0:
|
|
120
|
+
_log("cuda_load_to_device failed, falling back to embedded ptxas")
|
|
121
|
+
return _original_load_cuda_library(self)
|
|
122
|
+
|
|
123
|
+
_log(f"Loaded kernel from {ptx_path.name}")
|
|
124
|
+
|
|
125
|
+
# Delete PTX if user didn't originally want it kept
|
|
126
|
+
if not _user_wanted_ptx:
|
|
127
|
+
ptx_path.unlink(missing_ok=True)
|
|
128
|
+
|
|
129
|
+
return [cuda_runtime.cudaLibrary_t(lib_ptr.value)]
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def patch():
|
|
133
|
+
"""Install system ptxas hook. Call before importing cutlass."""
|
|
134
|
+
global _original_load_cuda_library, _user_wanted_ptx
|
|
135
|
+
|
|
136
|
+
assert CUTE_DSL_PTXAS_PATH is not None
|
|
137
|
+
if not os.path.isfile(CUTE_DSL_PTXAS_PATH) or not os.access(CUTE_DSL_PTXAS_PATH, os.X_OK):
|
|
138
|
+
raise RuntimeError(f"ptxas not found: {CUTE_DSL_PTXAS_PATH}")
|
|
139
|
+
|
|
140
|
+
# Track if user originally wanted PTX kept
|
|
141
|
+
_user_wanted_ptx = os.environ.get("CUTE_DSL_KEEP_PTX", "0") == "1"
|
|
142
|
+
# os.environ['CUTE_DSL_KEEP_PTX'] = '1'
|
|
143
|
+
assert os.environ.get("CUTE_DSL_KEEP_PTX", "0") == "1", (
|
|
144
|
+
"Require CUTE_DSL_KEEP_PTX=1 to use system's ptxas"
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
cls = cutlass.cutlass_dsl.cuda_jit_executor.CudaDialectJitCompiledFunction
|
|
148
|
+
_original_load_cuda_library = cls._load_cuda_library
|
|
149
|
+
cls._load_cuda_library = _patched_load_cuda_library
|
|
150
|
+
_log("Patch applied")
|
|
151
|
+
return
|
quack/layout_utils.py
CHANGED
|
@@ -187,6 +187,10 @@ def make_acc_tensor_mn_view(acc: cute.Tensor) -> cute.Tensor:
|
|
|
187
187
|
return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout))
|
|
188
188
|
|
|
189
189
|
|
|
190
|
+
def reshape_acc_to_mn(acc: cute.Tensor) -> cute.Tensor:
|
|
191
|
+
return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout))
|
|
192
|
+
|
|
193
|
+
|
|
190
194
|
@cute.jit
|
|
191
195
|
def convert_layout_acc_frgA(acc_layout: cute.Layout) -> cute.Layout:
|
|
192
196
|
# For back to back gemm, convert layout of acc0 to gemm 1 accept layout.
|
|
@@ -227,6 +231,10 @@ def convert_layout_acc_frgA(acc_layout: cute.Layout) -> cute.Layout:
|
|
|
227
231
|
return rA_mma_view
|
|
228
232
|
|
|
229
233
|
|
|
234
|
+
def reshape_acc_to_frgA(acc: cute.Tensor) -> cute.Tensor:
|
|
235
|
+
return cute.make_tensor(acc.iterator, convert_layout_acc_frgA(acc.layout))
|
|
236
|
+
|
|
237
|
+
|
|
230
238
|
def convert_layout_zero_stride(
|
|
231
239
|
input: cute.Tensor | cute.Layout, ref_layout: cute.Layout
|
|
232
240
|
) -> cute.Layout:
|
quack/pipeline.py
CHANGED
|
@@ -5,14 +5,15 @@ from dataclasses import dataclass
|
|
|
5
5
|
|
|
6
6
|
import cutlass.cute as cute
|
|
7
7
|
from cutlass import Boolean, Int32, const_expr
|
|
8
|
-
from cutlass.cutlass_dsl import if_generate, and_
|
|
8
|
+
from cutlass.cutlass_dsl import if_generate, and_, dsl_user_op
|
|
9
9
|
from cutlass.pipeline import MbarrierArray, CooperativeGroup, PipelineOp, pipeline_init_wait
|
|
10
10
|
from cutlass.pipeline import PipelineAsync, PipelineTmaAsync, PipelineState, PipelineUserType
|
|
11
11
|
from cutlass.pipeline import PipelineTmaUmma
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class PipelineStateWAdvance(PipelineState):
|
|
15
|
-
|
|
15
|
+
@dsl_user_op
|
|
16
|
+
def advance_iters(self, num_iterations: Int32, *, loc=None, ip=None):
|
|
16
17
|
self._count += Int32(num_iterations)
|
|
17
18
|
new_index = self._index + Int32(num_iterations)
|
|
18
19
|
# How many times did we cross the stages boundary
|
|
@@ -126,34 +127,40 @@ class PipelineTmaCpAsync(PipelineTmaAsync):
|
|
|
126
127
|
is_signalling_thread,
|
|
127
128
|
)
|
|
128
129
|
|
|
130
|
+
@dsl_user_op
|
|
129
131
|
def producer_acquire(
|
|
130
132
|
self,
|
|
131
133
|
state: PipelineState,
|
|
132
134
|
try_acquire_token: Optional[Boolean] = None,
|
|
133
135
|
is_tma_warp: Optional[Boolean] = True,
|
|
136
|
+
*,
|
|
137
|
+
loc=None,
|
|
138
|
+
ip=None,
|
|
134
139
|
):
|
|
135
140
|
"""
|
|
136
141
|
TMA producer commit conditionally waits on buffer empty and sets the transaction barrier.
|
|
137
142
|
"""
|
|
138
143
|
if_generate(
|
|
139
144
|
try_acquire_token is None or try_acquire_token == 0,
|
|
140
|
-
lambda: self.sync_object_empty.wait(state.index, state.phase),
|
|
145
|
+
lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
|
|
141
146
|
)
|
|
142
147
|
# This is the difference between this and PipelineTmaAsync: we could have multiple
|
|
143
148
|
# warps calling this, but only 1 warp should do the arrive on the full barrier
|
|
144
149
|
if_generate(
|
|
145
150
|
is_tma_warp,
|
|
146
|
-
lambda: self.sync_object_full.arrive(state.index, self.producer_mask),
|
|
151
|
+
lambda: self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip),
|
|
147
152
|
)
|
|
148
153
|
|
|
149
|
-
|
|
154
|
+
@dsl_user_op
|
|
155
|
+
def producer_cpasync_commit(self, state: PipelineState, *, loc=None, ip=None):
|
|
150
156
|
"""
|
|
151
157
|
We need the mbarrier to track the completion of cp.async
|
|
152
158
|
"""
|
|
153
|
-
cute.arch.cp_async_mbarrier_arrive_noinc(self.producer_get_barrier(state))
|
|
159
|
+
cute.arch.cp_async_mbarrier_arrive_noinc(self.producer_get_barrier(state, loc=loc, ip=ip), loc=loc, ip=ip)
|
|
154
160
|
|
|
155
161
|
|
|
156
162
|
class MbarrierArrayWDropCount(MbarrierArray):
|
|
163
|
+
@dsl_user_op
|
|
157
164
|
def __init__(
|
|
158
165
|
self,
|
|
159
166
|
barrier_storage: cute.Pointer,
|
|
@@ -161,6 +168,9 @@ class MbarrierArrayWDropCount(MbarrierArray):
|
|
|
161
168
|
agent: tuple[PipelineOp, CooperativeGroup],
|
|
162
169
|
tx_count: int = 0,
|
|
163
170
|
drop_count: Optional[Int32] = None,
|
|
171
|
+
*,
|
|
172
|
+
loc=None,
|
|
173
|
+
ip=None,
|
|
164
174
|
) -> None:
|
|
165
175
|
self.barrier_storage = barrier_storage
|
|
166
176
|
self.tx_count = tx_count
|
|
@@ -183,7 +193,7 @@ class MbarrierArrayWDropCount(MbarrierArray):
|
|
|
183
193
|
self.mbarrier_base = self.barrier_storage
|
|
184
194
|
|
|
185
195
|
# Mbarrier initialization in constructor
|
|
186
|
-
self.mbarrier_init()
|
|
196
|
+
self.mbarrier_init(loc=loc, ip=ip)
|
|
187
197
|
|
|
188
198
|
def __extract_mlir_values__(self):
|
|
189
199
|
return [self.barrier_storage, self.drop_count]
|
|
@@ -211,6 +221,7 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
|
|
|
211
221
|
barrier_storage: cute.Pointer = None,
|
|
212
222
|
cta_layout_vmnk: Optional[cute.Layout] = None,
|
|
213
223
|
producer_drop_count: Optional[Int32] = None,
|
|
224
|
+
mcast_mode_mn: tuple[int, int] = (1, 1),
|
|
214
225
|
):
|
|
215
226
|
"""
|
|
216
227
|
This helper function computes any necessary attributes and returns an instance of PipelineTmaUmma.
|
|
@@ -226,6 +237,8 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
|
|
|
226
237
|
:type tx_count: int
|
|
227
238
|
:param cta_layout_vmnk: Layout of the cluster shape
|
|
228
239
|
:type cta_layout_vmnk: cute.Layout | None
|
|
240
|
+
:param mcast_mode_mn: Tuple specifying multicast modes for m and n dimensions (each 0 or 1)
|
|
241
|
+
:type mcast_mode_mn: tuple[int, int], optional
|
|
229
242
|
"""
|
|
230
243
|
if not isinstance(barrier_storage, cute.Pointer):
|
|
231
244
|
raise ValueError(
|
|
@@ -245,7 +258,7 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
|
|
|
245
258
|
tx_count,
|
|
246
259
|
drop_count=producer_drop_count,
|
|
247
260
|
)
|
|
248
|
-
sync_object_empty =
|
|
261
|
+
sync_object_empty = PipelineTmaUmma._make_sync_object(
|
|
249
262
|
barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
|
|
250
263
|
)
|
|
251
264
|
|
|
@@ -255,7 +268,7 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
|
|
|
255
268
|
# All threadblocks are leaders if not using clusters
|
|
256
269
|
is_leader_cta = True
|
|
257
270
|
else:
|
|
258
|
-
producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(cta_layout_vmnk)
|
|
271
|
+
producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(cta_layout_vmnk, mcast_mode_mn)
|
|
259
272
|
is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk)
|
|
260
273
|
|
|
261
274
|
cta_group = (
|
|
@@ -278,11 +291,15 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
|
|
|
278
291
|
cta_group,
|
|
279
292
|
)
|
|
280
293
|
|
|
294
|
+
@dsl_user_op
|
|
281
295
|
def producer_acquire(
|
|
282
296
|
self,
|
|
283
297
|
state: PipelineState,
|
|
284
298
|
try_acquire_token: Optional[Boolean] = None,
|
|
285
299
|
is_tma_warp: Optional[Boolean] = True,
|
|
300
|
+
*,
|
|
301
|
+
loc=None,
|
|
302
|
+
ip=None,
|
|
286
303
|
):
|
|
287
304
|
"""
|
|
288
305
|
TMA producer commit conditionally waits on buffer empty and sets the
|
|
@@ -290,17 +307,18 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
|
|
|
290
307
|
"""
|
|
291
308
|
if_generate(
|
|
292
309
|
try_acquire_token is None or try_acquire_token == 0,
|
|
293
|
-
lambda: self.sync_object_empty.wait(state.index, state.phase),
|
|
310
|
+
lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
|
|
294
311
|
)
|
|
295
312
|
# This is the difference between this and PipelineTmaAsync: we could have multiple
|
|
296
313
|
# warps calling this, but only 1 warp should do the arrive on the full barrier
|
|
297
314
|
if_generate(
|
|
298
315
|
and_(self.is_leader_cta, is_tma_warp),
|
|
299
|
-
lambda: self.sync_object_full.arrive(state.index, self.producer_mask),
|
|
316
|
+
lambda: self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip),
|
|
300
317
|
)
|
|
301
318
|
|
|
302
|
-
|
|
319
|
+
@dsl_user_op
|
|
320
|
+
def producer_cpasync_commit(self, state: PipelineState, *, loc=None, ip=None):
|
|
303
321
|
"""
|
|
304
322
|
We need the mbarrier to track the completion of cp.async
|
|
305
323
|
"""
|
|
306
|
-
cute.arch.cp_async_mbarrier_arrive_noinc(self.producer_get_barrier(state))
|
|
324
|
+
cute.arch.cp_async_mbarrier_arrive_noinc(self.producer_get_barrier(state, loc=loc, ip=ip), loc=loc, ip=ip)
|
quack/sm90_utils.py
CHANGED
|
@@ -27,10 +27,11 @@ def make_smem_layout(
|
|
|
27
27
|
sm90_utils_og.get_smem_layout_atom(layout, dtype, major_mode_size),
|
|
28
28
|
dtype,
|
|
29
29
|
)
|
|
30
|
+
order = (1, 0, 2) if const_expr(layout.is_m_major_c()) else (0, 1, 2)
|
|
30
31
|
smem_layout_staged = cute.tile_to_shape(
|
|
31
32
|
smem_layout_atom,
|
|
32
33
|
cute.append(shape, stage) if const_expr(stage is not None) else shape,
|
|
33
|
-
order=(
|
|
34
|
+
order=order if const_expr(stage is not None) else order[:2],
|
|
34
35
|
)
|
|
35
36
|
return smem_layout_staged
|
|
36
37
|
|
|
@@ -125,3 +126,32 @@ def gemm_w_idx(
|
|
|
125
126
|
rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
|
|
126
127
|
rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
|
|
127
128
|
gemm(tiled_mma, acc, rA, rB, zero_init=zero_init, wg_wait=wg_wait)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def partition_fragment_ABC(
|
|
132
|
+
thr_mma: cute.ThrMma,
|
|
133
|
+
shape_mnk: cute.Shape,
|
|
134
|
+
sA: Optional[cute.Tensor],
|
|
135
|
+
sB: Optional[cute.Tensor],
|
|
136
|
+
swap_AB: bool = False,
|
|
137
|
+
):
|
|
138
|
+
is_rs = thr_mma.op.a_src == warpgroup.OperandSource.RMEM
|
|
139
|
+
if const_expr(not swap_AB):
|
|
140
|
+
acc = cute.make_fragment(thr_mma.partition_shape_C(shape_mnk[:2]), Float32)
|
|
141
|
+
if const_expr(not is_rs):
|
|
142
|
+
assert sA is not None
|
|
143
|
+
tCrA = thr_mma.make_fragment_A(thr_mma.partition_A(sA))
|
|
144
|
+
else:
|
|
145
|
+
tCrA = thr_mma.make_fragment_A(thr_mma.partition_shape_A((shape_mnk[0], shape_mnk[2])))
|
|
146
|
+
assert sB is not None
|
|
147
|
+
tCrB = thr_mma.make_fragment_B(thr_mma.partition_B(sB))
|
|
148
|
+
else:
|
|
149
|
+
acc = cute.make_fragment(thr_mma.partition_shape_C((shape_mnk[1], shape_mnk[0])), Float32)
|
|
150
|
+
if const_expr(not is_rs):
|
|
151
|
+
assert sB is not None
|
|
152
|
+
tCrB = thr_mma.make_fragment_A(thr_mma.partition_A(sB))
|
|
153
|
+
else: # B in rmem
|
|
154
|
+
tCrB = thr_mma.make_fragment_A(thr_mma.partition_shape_A((shape_mnk[1], shape_mnk[2])))
|
|
155
|
+
assert sA is not None
|
|
156
|
+
tCrA = thr_mma.make_fragment_B(thr_mma.partition_B(sA))
|
|
157
|
+
return acc, tCrA, tCrB
|
quack/sort/bitonic_sort.py
CHANGED
|
@@ -83,7 +83,7 @@ def bitonic_topk_merge(
|
|
|
83
83
|
else:
|
|
84
84
|
minmax_fn = min if ascending else max
|
|
85
85
|
# Write the top k elements to the first half of the array
|
|
86
|
-
for i in cutlass.range(k,
|
|
86
|
+
for i in cutlass.range(k, unroll_full=True):
|
|
87
87
|
arr0[start0 + i] = minmax_fn(arr0[start0 + i], arr1[start1 + k - 1 - i])
|
|
88
88
|
# Now the 1st half is bitonic, we just need to merge it
|
|
89
89
|
bitonic_merge(arr0, k, start0, ascending)
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: quack-kernels
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.5
|
|
4
4
|
Requires-Python: >=3.10
|
|
5
5
|
License-File: LICENSE
|
|
6
|
-
Requires-Dist: nvidia-cutlass-dsl
|
|
6
|
+
Requires-Dist: nvidia-cutlass-dsl>=4.4.0.dev0
|
|
7
7
|
Requires-Dist: torch
|
|
8
8
|
Requires-Dist: apache-tvm-ffi<0.2,>=0.1.6
|
|
9
9
|
Requires-Dist: torch-c-dlpack-ext
|
|
@@ -1,10 +1,11 @@
|
|
|
1
|
-
quack/__init__.py,sha256=
|
|
1
|
+
quack/__init__.py,sha256=0MnyCHBHxG4k12KHNzd-JlElf1M0qRrzhs89PJZJUHA,427
|
|
2
2
|
quack/activation.py,sha256=-lZgojraqdyLjOzgOXBehoVeRBhBq30UX7kOkXsCpGI,20855
|
|
3
3
|
quack/autotuner.py,sha256=atw0ntedi22RPwSdjWOoge4S56S8VFvRocJQcYhpAlo,13454
|
|
4
4
|
quack/broadcast_utils.py,sha256=X5vWg2RtIIWU9Z7nEUW6m0EP0Cfd9XtCKxp4tSyp4Mg,1283
|
|
5
5
|
quack/compile_utils.py,sha256=qJ3oTsDlbAiddrJHtEO7LPYVqn_s-neNfiw-_KvfXZU,591
|
|
6
|
-
quack/copy_utils.py,sha256=
|
|
6
|
+
quack/copy_utils.py,sha256=IIXtLJv0wQSKfinjIJwG10xQScRvAZvKw1yBV2MXckw,23682
|
|
7
7
|
quack/cross_entropy.py,sha256=w6fjHC_vXt5ji2KfoLrSOdAvpLrQszrYU9rmRij2yY8,24899
|
|
8
|
+
quack/cute_dsl_ptxas.py,sha256=IfBnTJ9amdfDOQkuSdWCLTh7CkZziIvs_xrAc8taxhk,5122
|
|
8
9
|
quack/cute_dsl_utils.py,sha256=4uQx5aYDG9UvVzbWwJTjjJLrnoympz70_CD8b37FQWo,3854
|
|
9
10
|
quack/fast_math.py,sha256=E1XUqfUt0_n9BPZNggF-UDzZ6anso9bYUrwqafemWvQ,2297
|
|
10
11
|
quack/gemm.py,sha256=8V23MPq49QbV3csv-_AxjfE9qf8R3NIqFK9Q9db6t2c,7417
|
|
@@ -17,28 +18,28 @@ quack/gemm_sm100.py,sha256=U9jmzpST_d1W6CBFf1ZHhTtr0K8hENCsUz7dXvHaMZc,122344
|
|
|
17
18
|
quack/gemm_sm90.py,sha256=u-Q3fN6DPm1fEdz0LcMecMbGTBcRunUCWopufwO8cHU,92015
|
|
18
19
|
quack/gemm_symmetric.py,sha256=mqx7wgOCY6Dh9hjL6gR9PBstMD476GhpA_NkGeaEtik,13349
|
|
19
20
|
quack/gemm_wrapper_utils.py,sha256=EaPyR3Lq19z_RkdB2_xxRj0IPSJMgyfpkrTXyvY3B6M,12775
|
|
20
|
-
quack/layout_utils.py,sha256=
|
|
21
|
+
quack/layout_utils.py,sha256=qar8x_6VPKOdrz_lAGH_c4W_HKfYLk3Lhtd3Rv1OBBE,12197
|
|
21
22
|
quack/linear.py,sha256=mhN2A98w7H7X4MS63XCCK3gpOm1eS8H7a4WO9ovkt5U,9791
|
|
22
23
|
quack/linear_cross_entropy.py,sha256=Zhy_gdMsKHOie-jntBaqIuiDJtkiq6qEBwnyuWwIRw4,10092
|
|
23
24
|
quack/mlp.py,sha256=YjdwQRwEePA9KyidFXp5H1-lxiJc8dZ41vl8Fv8pgss,2259
|
|
24
|
-
quack/pipeline.py,sha256=
|
|
25
|
+
quack/pipeline.py,sha256=3d4D8CPHw7ytZfdH9HFkfDng12YTnGf3pAe2DYxHjK4,11993
|
|
25
26
|
quack/reduce.py,sha256=ySKT2xh1_pIlbJX29BPmwH6yJ7MxIrRZyxHIPPYVpm0,12698
|
|
26
27
|
quack/reduction_base.py,sha256=QqlPs5L2VCxwDrO4CHPq-KY6f_BAYRbvsR6k81LPzTU,3180
|
|
27
28
|
quack/rmsnorm.py,sha256=esy18s5JtT7KBPRPhWf_anLRTrtromwqeJmg2yzOm60,44678
|
|
28
29
|
quack/sm100_utils.py,sha256=-p5qj3Wi9n4WDLy2sl-fApYpGp5rH3JvZQb712OTxPs,1901
|
|
29
|
-
quack/sm90_utils.py,sha256=
|
|
30
|
+
quack/sm90_utils.py,sha256=RLfIZFPhx7Mb9gXwilJ-QSULaj_Q4unaQJA2tFjGIJ4,5545
|
|
30
31
|
quack/softmax.py,sha256=ZqeVbnGfzwkro1LfWBHagbS7B7ug7b9SLZWuGx_Y3Kc,14367
|
|
31
32
|
quack/tensormap_manager.py,sha256=Ts3Mxp0_es2RNA0ffvUjWMXN79lsfWEBZ0DQYhtbcnw,5338
|
|
32
33
|
quack/tile_scheduler.py,sha256=vbKq0xp94eII0uJ63yY_3sgvJkQI7Irc8y1OttO6cRA,42514
|
|
33
34
|
quack/topk.py,sha256=43xHpRGbwZCSRsulmfrG4WA_r2eLHc3sniaUFU7wn-o,22522
|
|
34
35
|
quack/utils.py,sha256=WIttE1iiwyPIwR1NpaeO26Pn9YkZb361TDxFTUDH-IE,7354
|
|
35
36
|
quack/varlen_utils.py,sha256=SOYkomxX2FoqjYlybg99CqNhS9IARM6F9ba2AkIVvT4,15811
|
|
36
|
-
quack/sort/bitonic_sort.py,sha256
|
|
37
|
+
quack/sort/bitonic_sort.py,sha256=-4VmHGmnqRLaVF-IrNhbJqNEJcz-FJT5GuzSWTFeIfI,4831
|
|
37
38
|
quack/sort/generate_sorting_networks.py,sha256=vkJBOjTVEinQkWT4OtFqOWxFVdTIPoNAQocneKc9-rM,14477
|
|
38
39
|
quack/sort/sorting_networks.py,sha256=l_26zi3gXD_z-tnm2eAczRrmE-mbaz00KmqH6ONivL8,9686
|
|
39
40
|
quack/sort/utils.py,sha256=RbubEY1GcEpsjiz_6o5o2WB47IeMOzaajW6Jis0s444,1059
|
|
40
|
-
quack_kernels-0.2.
|
|
41
|
-
quack_kernels-0.2.
|
|
42
|
-
quack_kernels-0.2.
|
|
43
|
-
quack_kernels-0.2.
|
|
44
|
-
quack_kernels-0.2.
|
|
41
|
+
quack_kernels-0.2.5.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
42
|
+
quack_kernels-0.2.5.dist-info/METADATA,sha256=5FnKfn7JrhBVjFUILnccs-OL2I8UN6Lqo7QR0i4tAlA,366
|
|
43
|
+
quack_kernels-0.2.5.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
44
|
+
quack_kernels-0.2.5.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
|
|
45
|
+
quack_kernels-0.2.5.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|