quack-kernels 0.2.3__tar.gz → 0.2.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. {quack_kernels-0.2.3/quack_kernels.egg-info → quack_kernels-0.2.5}/PKG-INFO +3 -3
  2. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/pyproject.toml +4 -3
  3. quack_kernels-0.2.5/quack/__init__.py +21 -0
  4. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack/copy_utils.py +133 -6
  5. quack_kernels-0.2.5/quack/cute_dsl_ptxas.py +151 -0
  6. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack/layout_utils.py +8 -0
  7. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack/pipeline.py +31 -13
  8. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack/sm90_utils.py +31 -1
  9. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack/sort/bitonic_sort.py +1 -1
  10. {quack_kernels-0.2.3 → quack_kernels-0.2.5/quack_kernels.egg-info}/PKG-INFO +3 -3
  11. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack_kernels.egg-info/SOURCES.txt +1 -0
  12. quack_kernels-0.2.5/quack_kernels.egg-info/requires.txt +8 -0
  13. quack_kernels-0.2.5/quack_kernels.egg-info/top_level.txt +1 -0
  14. quack_kernels-0.2.3/quack/__init__.py +0 -11
  15. quack_kernels-0.2.3/quack_kernels.egg-info/requires.txt +0 -8
  16. quack_kernels-0.2.3/quack_kernels.egg-info/top_level.txt +0 -5
  17. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/LICENSE +0 -0
  18. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/README.md +0 -0
  19. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack/activation.py +0 -0
  20. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack/autotuner.py +0 -0
  21. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack/broadcast_utils.py +0 -0
  22. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack/compile_utils.py +0 -0
  23. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack/cross_entropy.py +0 -0
  24. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack/cute_dsl_utils.py +0 -0
  25. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack/fast_math.py +0 -0
  26. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack/gemm.py +0 -0
  27. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack/gemm_act.py +0 -0
  28. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack/gemm_config.py +0 -0
  29. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack/gemm_dact.py +0 -0
  30. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack/gemm_default_epi.py +0 -0
  31. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack/gemm_interface.py +0 -0
  32. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack/gemm_sm100.py +0 -0
  33. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack/gemm_sm90.py +0 -0
  34. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack/gemm_symmetric.py +0 -0
  35. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack/gemm_wrapper_utils.py +0 -0
  36. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack/linear.py +0 -0
  37. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack/linear_cross_entropy.py +0 -0
  38. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack/mlp.py +0 -0
  39. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack/reduce.py +0 -0
  40. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack/reduction_base.py +0 -0
  41. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack/rmsnorm.py +0 -0
  42. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack/sm100_utils.py +0 -0
  43. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack/softmax.py +0 -0
  44. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack/sort/generate_sorting_networks.py +0 -0
  45. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack/sort/sorting_networks.py +0 -0
  46. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack/sort/utils.py +0 -0
  47. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack/tensormap_manager.py +0 -0
  48. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack/tile_scheduler.py +0 -0
  49. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack/topk.py +0 -0
  50. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack/utils.py +0 -0
  51. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack/varlen_utils.py +0 -0
  52. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/quack_kernels.egg-info/dependency_links.txt +0 -0
  53. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/setup.cfg +0 -0
  54. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/tests/test_cross_entropy.py +0 -0
  55. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/tests/test_layernorm.py +0 -0
  56. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/tests/test_linear.py +0 -0
  57. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/tests/test_linear_cross_entropy.py +0 -0
  58. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/tests/test_linear_varlen_k.py +0 -0
  59. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/tests/test_linear_varlen_m.py +0 -0
  60. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/tests/test_rmsnorm.py +0 -0
  61. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/tests/test_softmax.py +0 -0
  62. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/tests/test_symmetric_gemm.py +0 -0
  63. {quack_kernels-0.2.3 → quack_kernels-0.2.5}/tests/test_topk.py +0 -0
@@ -1,11 +1,11 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.2.3
3
+ Version: 0.2.5
4
4
  Requires-Python: >=3.10
5
5
  License-File: LICENSE
6
- Requires-Dist: nvidia-cutlass-dsl==4.3.3
6
+ Requires-Dist: nvidia-cutlass-dsl>=4.4.0.dev0
7
7
  Requires-Dist: torch
8
- Requires-Dist: apache-tvm-ffi<0.2,>=0.1.5
8
+ Requires-Dist: apache-tvm-ffi<0.2,>=0.1.6
9
9
  Requires-Dist: torch-c-dlpack-ext
10
10
  Provides-Extra: dev
11
11
  Requires-Dist: pre-commit; extra == "dev"
@@ -7,9 +7,9 @@ name = "quack-kernels"
7
7
  dynamic = ["version"]
8
8
  requires-python = ">=3.10"
9
9
  dependencies = [
10
- "nvidia-cutlass-dsl==4.3.3",
10
+ "nvidia-cutlass-dsl>=4.4.0.dev0",
11
11
  "torch",
12
- "apache-tvm-ffi>=0.1.5,<0.2",
12
+ "apache-tvm-ffi>=0.1.6,<0.2",
13
13
  "torch-c-dlpack-ext",
14
14
  ]
15
15
 
@@ -20,7 +20,8 @@ dev = [
20
20
  ]
21
21
 
22
22
  [tool.setuptools.packages.find]
23
- exclude = ["tests", "benchmarks"]
23
+ where = ["."]
24
+ include = ["quack*"]
24
25
 
25
26
  [tool.setuptools.dynamic]
26
27
  version = {attr = "quack.__version__"}
@@ -0,0 +1,21 @@
1
+ __version__ = "0.2.5"
2
+
3
+ import os
4
+
5
+ from quack.rmsnorm import rmsnorm
6
+ from quack.softmax import softmax
7
+ from quack.cross_entropy import cross_entropy
8
+
9
+
10
+ if os.environ.get("CUTE_DSL_PTXAS_PATH", None) is not None:
11
+ import quack.cute_dsl_ptxas # noqa: F401
12
+
13
+ # Patch to dump ptx and then use system ptxas to compile to cubin
14
+ quack.cute_dsl_ptxas.patch()
15
+
16
+
17
+ __all__ = [
18
+ "rmsnorm",
19
+ "softmax",
20
+ "cross_entropy",
21
+ ]
@@ -7,18 +7,19 @@ import cutlass
7
7
  import cutlass.cute as cute
8
8
 
9
9
  from cutlass import Int32, Boolean, const_expr
10
- from cutlass.cute.nvgpu import cpasync
10
+ from cutlass.cute.nvgpu import cpasync, warpgroup
11
11
  from cutlass.cutlass_dsl import dsl_user_op
12
12
  import cutlass.pipeline
13
13
 
14
14
 
15
15
  @dsl_user_op
16
16
  def cvt_copy(
17
- atom: cute.CopyAtom,
17
+ tiled_copy: cute.TiledCopy,
18
18
  src: cute.Tensor,
19
19
  dst: cute.Tensor,
20
20
  *,
21
21
  pred: Optional[cute.Tensor] = None,
22
+ retile: bool = False,
22
23
  loc=None,
23
24
  ip=None,
24
25
  **kwargs,
@@ -28,7 +29,9 @@ def cvt_copy(
28
29
  src_cvt = cute.make_fragment_like(src, dst.element_type)
29
30
  src_cvt.store(src.load().to(dst.element_type))
30
31
  src = src_cvt
31
- cute.copy(atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
32
+ if const_expr(retile):
33
+ src = tiled_copy.retile(src)
34
+ cute.copy(tiled_copy, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
32
35
 
33
36
 
34
37
  @dsl_user_op
@@ -262,6 +265,124 @@ def get_smem_store_atom(
262
265
  )
263
266
 
264
267
 
268
+ def get_smem_load_atom(
269
+ arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False
270
+ ) -> cute.CopyAtom:
271
+ if const_expr(arch < 90 or element_type.width != 16):
272
+ return cute.make_copy_atom(
273
+ cute.nvgpu.CopyUniversalOp(),
274
+ element_type,
275
+ num_bits_per_copy=(2 if not transpose else 1) * element_type.width,
276
+ )
277
+ else:
278
+ return cute.make_copy_atom(
279
+ cute.nvgpu.warp.LdMatrix8x8x16bOp(transpose=transpose, num_matrices=4),
280
+ element_type,
281
+ )
282
+
283
+
284
+ def get_smem_store_C(
285
+ tiled_mma: cute.TiledMma,
286
+ sC: cute.Tensor,
287
+ tidx: Int32,
288
+ arch: int,
289
+ transpose: bool = False,
290
+ position_independent=False,
291
+ ) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
292
+ dtype = sC.element_type
293
+ copy_atom = get_smem_store_atom(arch, dtype, transpose)
294
+ tiled_copy = cute.make_tiled_copy_C(copy_atom, tiled_mma)
295
+ thr_copy = tiled_copy.get_slice(tidx)
296
+ if const_expr(not position_independent):
297
+ tRS_sC = thr_copy.partition_D(sC)
298
+ else:
299
+ tRS_sC = partition_D_position_independent(thr_copy, sC)
300
+
301
+ def copy_fn(src: cute.Tensor, dst_idx: Int32, **new_kwargs):
302
+ cvt_copy(tiled_copy, src, tRS_sC[None, None, None, dst_idx], retile=True, **new_kwargs)
303
+
304
+ return copy_fn, thr_copy, tRS_sC
305
+
306
+
307
+ def get_smem_load_C(
308
+ tiled_mma: cute.TiledMma,
309
+ sC: cute.Tensor,
310
+ tidx: Int32,
311
+ arch: int,
312
+ transpose: bool = False,
313
+ position_independent=False,
314
+ ) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
315
+ dtype = sC.element_type
316
+ copy_atom = get_smem_load_atom(arch, dtype, transpose)
317
+ tiled_copy = cute.make_tiled_copy_C(copy_atom, tiled_mma)
318
+ thr_copy = tiled_copy.get_slice(tidx)
319
+ if const_expr(not position_independent):
320
+ tSR_sC = thr_copy.partition_S(sC)
321
+ else:
322
+ tSR_sC = partition_S_position_independent(thr_copy, sC)
323
+ copy_atom_RS = get_smem_store_atom(arch, dtype, transpose)
324
+ thr_copy_RS = cute.make_tiled_copy_C(copy_atom_RS, tiled_mma).get_slice(tidx)
325
+ tRS_shape = thr_copy_RS.partition_S(cute.make_identity_tensor(sC.shape[:2])).shape
326
+
327
+ def copy_fn(src_idx: Int32, **new_kwargs):
328
+ return load_s2r_retile(
329
+ tiled_copy, tSR_sC[None, None, None, src_idx], dst_shape=tRS_shape, **new_kwargs
330
+ )
331
+
332
+ return copy_fn, thr_copy, tSR_sC
333
+
334
+
335
+ def get_smem_store_A(
336
+ tiled_mma: cute.TiledMma, sA: cute.Tensor, tidx: Int32, arch: int, position_independent=False
337
+ ) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
338
+ dtype = sA.element_type
339
+ transpose = tiled_mma.op.a_major_mode == warpgroup.OperandMajorMode.MN
340
+ copy_atom = get_smem_store_atom(arch, dtype, transpose)
341
+ tiled_copy = cute.make_tiled_copy_A(copy_atom, tiled_mma)
342
+ thr_copy = tiled_copy.get_slice(tidx)
343
+ if const_expr(not position_independent):
344
+ tRS_sA = thr_copy.partition_D(sA)
345
+ else:
346
+ tRS_sA = partition_D_position_independent(thr_copy, sA)
347
+
348
+ def copy_fn(src: cute.Tensor, dst_idx: Int32, **new_kwargs):
349
+ cvt_copy(tiled_copy, src, tRS_sA[None, None, None, dst_idx], retile=True, **new_kwargs)
350
+
351
+ return copy_fn, thr_copy, tRS_sA
352
+
353
+
354
+ def get_smem_load_A(
355
+ tiled_mma: cute.TiledMma,
356
+ sA: cute.Tensor,
357
+ tidx: Int32,
358
+ arch: int,
359
+ with_dst_tensor: bool = False,
360
+ position_independent=False,
361
+ ) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
362
+ dtype = sA.element_type
363
+ transpose = tiled_mma.op.a_major_mode == warpgroup.OperandMajorMode.MN
364
+ copy_atom = get_smem_load_atom(arch, dtype, transpose)
365
+ tiled_copy = cute.make_tiled_copy_A(copy_atom, tiled_mma)
366
+ thr_copy = tiled_copy.get_slice(tidx)
367
+ if const_expr(not position_independent):
368
+ tSR_sA = thr_copy.partition_S(sA)
369
+ else:
370
+ tSR_sA = partition_S_position_independent(thr_copy, sA)
371
+ copy_atom_RS = get_smem_store_atom(arch, dtype, transpose)
372
+ thr_copy_RS = cute.make_tiled_copy_C(copy_atom_RS, tiled_mma).get_slice(tidx)
373
+ tRS_shape = tiled_mma.partition_shape_A(sA.shape[:2])
374
+
375
+ def copy_fn(src_idx: Int32, **new_kwargs):
376
+ return load_s2r_retile(
377
+ tiled_copy, tSR_sA[None, None, None, src_idx], dst_shape=tRS_shape, **new_kwargs
378
+ )
379
+
380
+ def copy_fn_w_dst_tensor(src_idx: Int32, dst: cute.Tensor, **new_kwargs):
381
+ return load_s2r_retile(tiled_copy, tSR_sA[None, None, None, src_idx], dst, **new_kwargs)
382
+
383
+ return copy_fn if not with_dst_tensor else copy_fn_w_dst_tensor, thr_copy, tSR_sA
384
+
385
+
265
386
  def tma_get_copy_fn(
266
387
  atom: cute.CopyAtom,
267
388
  cta_coord: cute.Coord,
@@ -269,6 +390,7 @@ def tma_get_copy_fn(
269
390
  src_tensor: cute.Tensor,
270
391
  dst_tensor: cute.Tensor,
271
392
  filter_zeros: bool = False,
393
+ single_stage: bool = False,
272
394
  **kwargs,
273
395
  ) -> Callable:
274
396
  src_is_smem = const_expr(
@@ -276,13 +398,15 @@ def tma_get_copy_fn(
276
398
  and src_tensor.memspace == cute.AddressSpace.smem
277
399
  )
278
400
  smem_tensor, gmem_tensor = (src_tensor, dst_tensor) if src_is_smem else (dst_tensor, src_tensor)
401
+ group_rank_smem = const_expr(cute.rank(smem_tensor) - (1 if not single_stage else 0))
402
+ group_rank_gmem = const_expr(cute.rank(gmem_tensor) - (1 if not single_stage else 0))
279
403
  # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
280
404
  s, g = cpasync.tma_partition(
281
405
  atom,
282
406
  cta_coord,
283
407
  cta_layout,
284
- cute.group_modes(smem_tensor, 0, cute.rank(smem_tensor) - 1),
285
- cute.group_modes(gmem_tensor, 0, cute.rank(gmem_tensor) - 1),
408
+ cute.group_modes(smem_tensor, 0, group_rank_smem),
409
+ cute.group_modes(gmem_tensor, 0, group_rank_gmem),
286
410
  )
287
411
  if const_expr(filter_zeros):
288
412
  s = cute.filter_zeros(s)
@@ -292,7 +416,10 @@ def tma_get_copy_fn(
292
416
  def copy_tma(src_idx, dst_idx, **new_kwargs):
293
417
  cute.copy(atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs)
294
418
 
295
- return copy_tma, s, g
419
+ def copy_tma_single_stage(**new_kwargs):
420
+ cute.copy(atom, src, dst, **new_kwargs, **kwargs)
421
+
422
+ return (copy_tma if const_expr(not single_stage) else copy_tma_single_stage), s, g
296
423
 
297
424
 
298
425
  def tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.PipelineAsync):
@@ -0,0 +1,151 @@
1
+ """
2
+ System ptxas replacement for CUTLASS DSL.
3
+ Environment variables:
4
+ CUTE_DSL_PTXAS_PATH - Path to ptxas (e.g., /usr/local/cuda/bin/ptxas)
5
+ CUTE_DSL_PTXAS_VERBOSE - Set to 1 for verbose output
6
+ """
7
+
8
+ import os
9
+ import sys
10
+ import re
11
+ import ctypes
12
+ import subprocess
13
+ from pathlib import Path
14
+
15
+ import cutlass
16
+
17
+
18
+ CUTE_DSL_PTXAS_PATH = os.environ.get("CUTE_DSL_PTXAS_PATH", None)
19
+ VERBOSE = os.environ.get("CUTE_DSL_PTXAS_VERBOSE", "0") == "1"
20
+
21
+ _original_load_cuda_library = None
22
+ _user_wanted_ptx = False # True if user originally set CUTE_DSL_KEEP_PTX=1
23
+
24
+
25
+ def _log(msg):
26
+ if VERBOSE:
27
+ print(f"[ptxas] {msg}", file=sys.stderr)
28
+
29
+
30
+ def _get_ptx(compiled_func) -> tuple[str, Path] | None:
31
+ """Find and read PTX file, stripping null bytes."""
32
+ func_name = getattr(compiled_func, "function_name", None)
33
+ if not func_name:
34
+ return None
35
+
36
+ dump_dir = os.environ.get("CUTE_DSL_DUMP_DIR", Path.cwd())
37
+ for ptx_path in Path(dump_dir).glob(f"*{func_name}*.ptx"):
38
+ content = ptx_path.read_text().rstrip("\x00")
39
+ if ".entry " in content and content.rstrip().endswith("}"):
40
+ _log(f"Found PTX: {ptx_path}")
41
+ return content, ptx_path
42
+ return None
43
+
44
+
45
+ def _compile_ptx(ptx_path: Path, ptx_content: str) -> bytes:
46
+ """Compile PTX to cubin using system ptxas."""
47
+ # Extract arch from PTX
48
+ match = re.search(r"\.target\s+(sm_\d+[a-z]?)", ptx_content)
49
+ arch = match.group(1) if match else "sm_90a"
50
+
51
+ # Write stripped content back if needed
52
+ if ptx_path.read_text() != ptx_content:
53
+ ptx_path.write_text(ptx_content)
54
+
55
+ # Compile
56
+ cubin_tmp = ptx_path.with_suffix(".cubin.tmp")
57
+ try:
58
+ assert CUTE_DSL_PTXAS_PATH is not None
59
+ result = subprocess.run(
60
+ [CUTE_DSL_PTXAS_PATH, f"-arch={arch}", "-O3", "-o", str(cubin_tmp), str(ptx_path)],
61
+ capture_output=True,
62
+ text=True,
63
+ )
64
+ if result.returncode != 0:
65
+ raise RuntimeError(f"ptxas failed: {result.stderr}")
66
+
67
+ cubin_data = cubin_tmp.read_bytes()
68
+ _log(f"Compiled {ptx_path.name} -> {len(cubin_data)} bytes ({arch})")
69
+
70
+ # Save cubin if CUTE_DSL_KEEP_CUBIN is set
71
+ if os.environ.get("CUTE_DSL_KEEP_CUBIN", "0") == "1":
72
+ cubin_out = ptx_path.with_suffix(".cubin")
73
+ cubin_out.write_bytes(cubin_data)
74
+ _log(f"Saved: {cubin_out}")
75
+
76
+ return cubin_data
77
+ finally:
78
+ cubin_tmp.unlink(missing_ok=True)
79
+
80
+
81
+ def _patched_load_cuda_library(self):
82
+ """Replacement for _load_cuda_library that uses system ptxas."""
83
+
84
+ result = _get_ptx(self)
85
+ if not result:
86
+ _log("PTX not found, falling back to embedded ptxas")
87
+ return _original_load_cuda_library(self)
88
+
89
+ ptx_content, ptx_path = result
90
+
91
+ try:
92
+ cubin = _compile_ptx(ptx_path, ptx_content)
93
+ except Exception as e:
94
+ _log(f"Compilation failed ({e}), falling back to embedded ptxas")
95
+ return _original_load_cuda_library(self)
96
+
97
+ # Load cubin
98
+ import cuda.bindings.runtime as cuda_runtime
99
+
100
+ err, library = cuda_runtime.cudaLibraryLoadData(cubin, None, None, 0, None, None, 0)
101
+ if err != cuda_runtime.cudaError_t.cudaSuccess:
102
+ _log(f"cudaLibraryLoadData failed ({err}), falling back to embedded ptxas")
103
+ return _original_load_cuda_library(self)
104
+
105
+ # Register kernels on all devices
106
+ _, cuda_load_to_device = self._get_cuda_init_and_load()
107
+ lib_ptr = ctypes.c_void_p(int(library))
108
+ dev_id = ctypes.c_int32(0)
109
+ err_val = ctypes.c_int32(0)
110
+ args = (ctypes.c_void_p * 3)(
111
+ ctypes.cast(ctypes.pointer(lib_ptr), ctypes.c_void_p),
112
+ ctypes.cast(ctypes.pointer(dev_id), ctypes.c_void_p),
113
+ ctypes.cast(ctypes.pointer(err_val), ctypes.c_void_p),
114
+ )
115
+
116
+ for dev in range(self.num_devices):
117
+ dev_id.value = dev
118
+ cuda_load_to_device(args)
119
+ if err_val.value != 0:
120
+ _log("cuda_load_to_device failed, falling back to embedded ptxas")
121
+ return _original_load_cuda_library(self)
122
+
123
+ _log(f"Loaded kernel from {ptx_path.name}")
124
+
125
+ # Delete PTX if user didn't originally want it kept
126
+ if not _user_wanted_ptx:
127
+ ptx_path.unlink(missing_ok=True)
128
+
129
+ return [cuda_runtime.cudaLibrary_t(lib_ptr.value)]
130
+
131
+
132
+ def patch():
133
+ """Install system ptxas hook. Call before importing cutlass."""
134
+ global _original_load_cuda_library, _user_wanted_ptx
135
+
136
+ assert CUTE_DSL_PTXAS_PATH is not None
137
+ if not os.path.isfile(CUTE_DSL_PTXAS_PATH) or not os.access(CUTE_DSL_PTXAS_PATH, os.X_OK):
138
+ raise RuntimeError(f"ptxas not found: {CUTE_DSL_PTXAS_PATH}")
139
+
140
+ # Track if user originally wanted PTX kept
141
+ _user_wanted_ptx = os.environ.get("CUTE_DSL_KEEP_PTX", "0") == "1"
142
+ # os.environ['CUTE_DSL_KEEP_PTX'] = '1'
143
+ assert os.environ.get("CUTE_DSL_KEEP_PTX", "0") == "1", (
144
+ "Require CUTE_DSL_KEEP_PTX=1 to use system's ptxas"
145
+ )
146
+
147
+ cls = cutlass.cutlass_dsl.cuda_jit_executor.CudaDialectJitCompiledFunction
148
+ _original_load_cuda_library = cls._load_cuda_library
149
+ cls._load_cuda_library = _patched_load_cuda_library
150
+ _log("Patch applied")
151
+ return
@@ -187,6 +187,10 @@ def make_acc_tensor_mn_view(acc: cute.Tensor) -> cute.Tensor:
187
187
  return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout))
188
188
 
189
189
 
190
+ def reshape_acc_to_mn(acc: cute.Tensor) -> cute.Tensor:
191
+ return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout))
192
+
193
+
190
194
  @cute.jit
191
195
  def convert_layout_acc_frgA(acc_layout: cute.Layout) -> cute.Layout:
192
196
  # For back to back gemm, convert layout of acc0 to gemm 1 accept layout.
@@ -227,6 +231,10 @@ def convert_layout_acc_frgA(acc_layout: cute.Layout) -> cute.Layout:
227
231
  return rA_mma_view
228
232
 
229
233
 
234
+ def reshape_acc_to_frgA(acc: cute.Tensor) -> cute.Tensor:
235
+ return cute.make_tensor(acc.iterator, convert_layout_acc_frgA(acc.layout))
236
+
237
+
230
238
  def convert_layout_zero_stride(
231
239
  input: cute.Tensor | cute.Layout, ref_layout: cute.Layout
232
240
  ) -> cute.Layout:
@@ -5,14 +5,15 @@ from dataclasses import dataclass
5
5
 
6
6
  import cutlass.cute as cute
7
7
  from cutlass import Boolean, Int32, const_expr
8
- from cutlass.cutlass_dsl import if_generate, and_
8
+ from cutlass.cutlass_dsl import if_generate, and_, dsl_user_op
9
9
  from cutlass.pipeline import MbarrierArray, CooperativeGroup, PipelineOp, pipeline_init_wait
10
10
  from cutlass.pipeline import PipelineAsync, PipelineTmaAsync, PipelineState, PipelineUserType
11
11
  from cutlass.pipeline import PipelineTmaUmma
12
12
 
13
13
 
14
14
  class PipelineStateWAdvance(PipelineState):
15
- def advance_iters(self, num_iterations: Int32):
15
+ @dsl_user_op
16
+ def advance_iters(self, num_iterations: Int32, *, loc=None, ip=None):
16
17
  self._count += Int32(num_iterations)
17
18
  new_index = self._index + Int32(num_iterations)
18
19
  # How many times did we cross the stages boundary
@@ -126,34 +127,40 @@ class PipelineTmaCpAsync(PipelineTmaAsync):
126
127
  is_signalling_thread,
127
128
  )
128
129
 
130
+ @dsl_user_op
129
131
  def producer_acquire(
130
132
  self,
131
133
  state: PipelineState,
132
134
  try_acquire_token: Optional[Boolean] = None,
133
135
  is_tma_warp: Optional[Boolean] = True,
136
+ *,
137
+ loc=None,
138
+ ip=None,
134
139
  ):
135
140
  """
136
141
  TMA producer commit conditionally waits on buffer empty and sets the transaction barrier.
137
142
  """
138
143
  if_generate(
139
144
  try_acquire_token is None or try_acquire_token == 0,
140
- lambda: self.sync_object_empty.wait(state.index, state.phase),
145
+ lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
141
146
  )
142
147
  # This is the difference between this and PipelineTmaAsync: we could have multiple
143
148
  # warps calling this, but only 1 warp should do the arrive on the full barrier
144
149
  if_generate(
145
150
  is_tma_warp,
146
- lambda: self.sync_object_full.arrive(state.index, self.producer_mask),
151
+ lambda: self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip),
147
152
  )
148
153
 
149
- def producer_cpasync_commit(self, state: PipelineState):
154
+ @dsl_user_op
155
+ def producer_cpasync_commit(self, state: PipelineState, *, loc=None, ip=None):
150
156
  """
151
157
  We need the mbarrier to track the completion of cp.async
152
158
  """
153
- cute.arch.cp_async_mbarrier_arrive_noinc(self.producer_get_barrier(state))
159
+ cute.arch.cp_async_mbarrier_arrive_noinc(self.producer_get_barrier(state, loc=loc, ip=ip), loc=loc, ip=ip)
154
160
 
155
161
 
156
162
  class MbarrierArrayWDropCount(MbarrierArray):
163
+ @dsl_user_op
157
164
  def __init__(
158
165
  self,
159
166
  barrier_storage: cute.Pointer,
@@ -161,6 +168,9 @@ class MbarrierArrayWDropCount(MbarrierArray):
161
168
  agent: tuple[PipelineOp, CooperativeGroup],
162
169
  tx_count: int = 0,
163
170
  drop_count: Optional[Int32] = None,
171
+ *,
172
+ loc=None,
173
+ ip=None,
164
174
  ) -> None:
165
175
  self.barrier_storage = barrier_storage
166
176
  self.tx_count = tx_count
@@ -183,7 +193,7 @@ class MbarrierArrayWDropCount(MbarrierArray):
183
193
  self.mbarrier_base = self.barrier_storage
184
194
 
185
195
  # Mbarrier initialization in constructor
186
- self.mbarrier_init()
196
+ self.mbarrier_init(loc=loc, ip=ip)
187
197
 
188
198
  def __extract_mlir_values__(self):
189
199
  return [self.barrier_storage, self.drop_count]
@@ -211,6 +221,7 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
211
221
  barrier_storage: cute.Pointer = None,
212
222
  cta_layout_vmnk: Optional[cute.Layout] = None,
213
223
  producer_drop_count: Optional[Int32] = None,
224
+ mcast_mode_mn: tuple[int, int] = (1, 1),
214
225
  ):
215
226
  """
216
227
  This helper function computes any necessary attributes and returns an instance of PipelineTmaUmma.
@@ -226,6 +237,8 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
226
237
  :type tx_count: int
227
238
  :param cta_layout_vmnk: Layout of the cluster shape
228
239
  :type cta_layout_vmnk: cute.Layout | None
240
+ :param mcast_mode_mn: Tuple specifying multicast modes for m and n dimensions (each 0 or 1)
241
+ :type mcast_mode_mn: tuple[int, int], optional
229
242
  """
230
243
  if not isinstance(barrier_storage, cute.Pointer):
231
244
  raise ValueError(
@@ -245,7 +258,7 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
245
258
  tx_count,
246
259
  drop_count=producer_drop_count,
247
260
  )
248
- sync_object_empty = PipelineAsync._make_sync_object(
261
+ sync_object_empty = PipelineTmaUmma._make_sync_object(
249
262
  barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
250
263
  )
251
264
 
@@ -255,7 +268,7 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
255
268
  # All threadblocks are leaders if not using clusters
256
269
  is_leader_cta = True
257
270
  else:
258
- producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(cta_layout_vmnk)
271
+ producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(cta_layout_vmnk, mcast_mode_mn)
259
272
  is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk)
260
273
 
261
274
  cta_group = (
@@ -278,11 +291,15 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
278
291
  cta_group,
279
292
  )
280
293
 
294
+ @dsl_user_op
281
295
  def producer_acquire(
282
296
  self,
283
297
  state: PipelineState,
284
298
  try_acquire_token: Optional[Boolean] = None,
285
299
  is_tma_warp: Optional[Boolean] = True,
300
+ *,
301
+ loc=None,
302
+ ip=None,
286
303
  ):
287
304
  """
288
305
  TMA producer commit conditionally waits on buffer empty and sets the
@@ -290,17 +307,18 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
290
307
  """
291
308
  if_generate(
292
309
  try_acquire_token is None or try_acquire_token == 0,
293
- lambda: self.sync_object_empty.wait(state.index, state.phase),
310
+ lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
294
311
  )
295
312
  # This is the difference between this and PipelineTmaAsync: we could have multiple
296
313
  # warps calling this, but only 1 warp should do the arrive on the full barrier
297
314
  if_generate(
298
315
  and_(self.is_leader_cta, is_tma_warp),
299
- lambda: self.sync_object_full.arrive(state.index, self.producer_mask),
316
+ lambda: self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip),
300
317
  )
301
318
 
302
- def producer_cpasync_commit(self, state: PipelineState):
319
+ @dsl_user_op
320
+ def producer_cpasync_commit(self, state: PipelineState, *, loc=None, ip=None):
303
321
  """
304
322
  We need the mbarrier to track the completion of cp.async
305
323
  """
306
- cute.arch.cp_async_mbarrier_arrive_noinc(self.producer_get_barrier(state))
324
+ cute.arch.cp_async_mbarrier_arrive_noinc(self.producer_get_barrier(state, loc=loc, ip=ip), loc=loc, ip=ip)
@@ -27,10 +27,11 @@ def make_smem_layout(
27
27
  sm90_utils_og.get_smem_layout_atom(layout, dtype, major_mode_size),
28
28
  dtype,
29
29
  )
30
+ order = (1, 0, 2) if const_expr(layout.is_m_major_c()) else (0, 1, 2)
30
31
  smem_layout_staged = cute.tile_to_shape(
31
32
  smem_layout_atom,
32
33
  cute.append(shape, stage) if const_expr(stage is not None) else shape,
33
- order=(1, 0, 2) if layout.is_m_major_c() else (0, 1, 2),
34
+ order=order if const_expr(stage is not None) else order[:2],
34
35
  )
35
36
  return smem_layout_staged
36
37
 
@@ -125,3 +126,32 @@ def gemm_w_idx(
125
126
  rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
126
127
  rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
127
128
  gemm(tiled_mma, acc, rA, rB, zero_init=zero_init, wg_wait=wg_wait)
129
+
130
+
131
+ def partition_fragment_ABC(
132
+ thr_mma: cute.ThrMma,
133
+ shape_mnk: cute.Shape,
134
+ sA: Optional[cute.Tensor],
135
+ sB: Optional[cute.Tensor],
136
+ swap_AB: bool = False,
137
+ ):
138
+ is_rs = thr_mma.op.a_src == warpgroup.OperandSource.RMEM
139
+ if const_expr(not swap_AB):
140
+ acc = cute.make_fragment(thr_mma.partition_shape_C(shape_mnk[:2]), Float32)
141
+ if const_expr(not is_rs):
142
+ assert sA is not None
143
+ tCrA = thr_mma.make_fragment_A(thr_mma.partition_A(sA))
144
+ else:
145
+ tCrA = thr_mma.make_fragment_A(thr_mma.partition_shape_A((shape_mnk[0], shape_mnk[2])))
146
+ assert sB is not None
147
+ tCrB = thr_mma.make_fragment_B(thr_mma.partition_B(sB))
148
+ else:
149
+ acc = cute.make_fragment(thr_mma.partition_shape_C((shape_mnk[1], shape_mnk[0])), Float32)
150
+ if const_expr(not is_rs):
151
+ assert sB is not None
152
+ tCrB = thr_mma.make_fragment_A(thr_mma.partition_A(sB))
153
+ else: # B in rmem
154
+ tCrB = thr_mma.make_fragment_A(thr_mma.partition_shape_A((shape_mnk[1], shape_mnk[2])))
155
+ assert sA is not None
156
+ tCrA = thr_mma.make_fragment_B(thr_mma.partition_B(sA))
157
+ return acc, tCrA, tCrB
@@ -83,7 +83,7 @@ def bitonic_topk_merge(
83
83
  else:
84
84
  minmax_fn = min if ascending else max
85
85
  # Write the top k elements to the first half of the array
86
- for i in cutlass.range(k, unfoll_full=True):
86
+ for i in cutlass.range(k, unroll_full=True):
87
87
  arr0[start0 + i] = minmax_fn(arr0[start0 + i], arr1[start1 + k - 1 - i])
88
88
  # Now the 1st half is bitonic, we just need to merge it
89
89
  bitonic_merge(arr0, k, start0, ascending)
@@ -1,11 +1,11 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.2.3
3
+ Version: 0.2.5
4
4
  Requires-Python: >=3.10
5
5
  License-File: LICENSE
6
- Requires-Dist: nvidia-cutlass-dsl==4.3.3
6
+ Requires-Dist: nvidia-cutlass-dsl>=4.4.0.dev0
7
7
  Requires-Dist: torch
8
- Requires-Dist: apache-tvm-ffi<0.2,>=0.1.5
8
+ Requires-Dist: apache-tvm-ffi<0.2,>=0.1.6
9
9
  Requires-Dist: torch-c-dlpack-ext
10
10
  Provides-Extra: dev
11
11
  Requires-Dist: pre-commit; extra == "dev"
@@ -8,6 +8,7 @@ quack/broadcast_utils.py
8
8
  quack/compile_utils.py
9
9
  quack/copy_utils.py
10
10
  quack/cross_entropy.py
11
+ quack/cute_dsl_ptxas.py
11
12
  quack/cute_dsl_utils.py
12
13
  quack/fast_math.py
13
14
  quack/gemm.py
@@ -0,0 +1,8 @@
1
+ nvidia-cutlass-dsl>=4.4.0.dev0
2
+ torch
3
+ apache-tvm-ffi<0.2,>=0.1.6
4
+ torch-c-dlpack-ext
5
+
6
+ [dev]
7
+ pre-commit
8
+ ruff
@@ -1,11 +0,0 @@
1
- __version__ = "0.2.3"
2
-
3
- from quack.rmsnorm import rmsnorm
4
- from quack.softmax import softmax
5
- from quack.cross_entropy import cross_entropy
6
-
7
- __all__ = [
8
- "rmsnorm",
9
- "softmax",
10
- "cross_entropy",
11
- ]
@@ -1,8 +0,0 @@
1
- nvidia-cutlass-dsl==4.3.3
2
- torch
3
- apache-tvm-ffi<0.2,>=0.1.5
4
- torch-c-dlpack-ext
5
-
6
- [dev]
7
- pre-commit
8
- ruff
@@ -1,5 +0,0 @@
1
- benchmarks
2
- dist
3
- docs
4
- media
5
- quack
File without changes
File without changes
File without changes