mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,102 @@
1
+ # @nolint # fbcode
2
+ # Copyright (c) 2025, Tri Dao.
3
+ from typing import Type, Union, Optional
4
+ import cutlass
5
+ import cutlass.cute as cute
6
+ from cutlass import Int32, Float32, Boolean, const_expr
7
+ from cutlass.cute.nvgpu import warpgroup
8
+ from cutlass.cutlass_dsl import Numeric, dsl_user_op
9
+ from cutlass.utils import LayoutEnum
10
+ import cutlass.utils.hopper_helpers as sm90_utils_og
11
+
12
+
13
+ @cute.jit
14
+ def gemm(
15
+ tiled_mma: cute.TiledMma,
16
+ acc: cute.Tensor,
17
+ tCrA: cute.Tensor,
18
+ tCrB: cute.Tensor,
19
+ zero_init: cutlass.Constexpr[bool] = False,
20
+ wg_wait: cutlass.Constexpr[int] = 0,
21
+ # A_in_regs: cutlass.Constexpr[bool] = False,
22
+ swap_AB: cutlass.Constexpr[bool] = False,
23
+ ) -> None:
24
+ if const_expr(swap_AB):
25
+ gemm(tiled_mma, acc, tCrB, tCrA, zero_init=zero_init, wg_wait=wg_wait, swap_AB=False)
26
+ else:
27
+ warpgroup.fence()
28
+ # We make a new mma_atom since we'll be modifying its attribute (accumulate).
29
+ # Otherwise the compiler complains "operand #0 does not dominate this use"
30
+ mma_atom = cute.make_mma_atom(tiled_mma.op)
31
+ mma_atom.set(warpgroup.Field.ACCUMULATE, not zero_init)
32
+ for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])):
33
+ cute.gemm(mma_atom, acc, tCrA[None, None, k], tCrB[None, None, k], acc)
34
+ mma_atom.set(warpgroup.Field.ACCUMULATE, True)
35
+ warpgroup.commit_group()
36
+ if const_expr(wg_wait >= 0):
37
+ warpgroup.wait_group(wg_wait)
38
+
39
+
40
+ def gemm_zero_init(
41
+ tiled_mma: cute.TiledMma,
42
+ shape: cute.Shape,
43
+ tCrA: cute.Tensor,
44
+ tCrB: cute.Tensor,
45
+ A_idx: Optional[Int32] = None,
46
+ B_idx: Optional[Int32] = None,
47
+ wg_wait: int = -1,
48
+ swap_AB: bool = False,
49
+ ) -> cute.Tensor:
50
+ if const_expr(swap_AB):
51
+ return gemm_zero_init(
52
+ tiled_mma, shape[::-1], tCrB, tCrA, B_idx, A_idx, wg_wait, swap_AB=False
53
+ )
54
+ else:
55
+ acc = cute.make_fragment(tiled_mma.partition_shape_C(shape), Float32)
56
+ rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
57
+ rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
58
+ gemm(tiled_mma, acc, rA, rB, zero_init=True, wg_wait=wg_wait)
59
+ return acc
60
+
61
+
62
+ def gemm_w_idx(
63
+ tiled_mma: cute.TiledMma,
64
+ acc: cute.Tensor,
65
+ tCrA: cute.Tensor,
66
+ tCrB: cute.Tensor,
67
+ zero_init: Boolean,
68
+ A_idx: Optional[Int32] = None,
69
+ B_idx: Optional[Int32] = None,
70
+ wg_wait: int = -1,
71
+ swap_AB: bool = False,
72
+ ) -> None:
73
+ if const_expr(swap_AB):
74
+ gemm_w_idx(tiled_mma, acc, tCrB, tCrA, zero_init, B_idx, A_idx, wg_wait, swap_AB=False)
75
+ else:
76
+ rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
77
+ rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
78
+ gemm(tiled_mma, acc, rA, rB, zero_init=zero_init, wg_wait=wg_wait)
79
+
80
+
81
+ @dsl_user_op
82
+ def make_smem_layout(
83
+ dtype: Type[Numeric],
84
+ layout: LayoutEnum,
85
+ shape: cute.Shape,
86
+ stage: Optional[int] = None,
87
+ *,
88
+ loc=None,
89
+ ip=None,
90
+ ) -> Union[cute.Layout, cute.ComposedLayout]:
91
+ major_mode_size = shape[1] if layout.is_n_major_c() else shape[0]
92
+ smem_layout_atom = warpgroup.make_smem_layout_atom(
93
+ sm90_utils_og.get_smem_layout_atom(layout, dtype, major_mode_size),
94
+ dtype,
95
+ )
96
+ order = (1, 0, 2) if const_expr(layout.is_m_major_c()) else (0, 1, 2)
97
+ smem_layout_staged = cute.tile_to_shape(
98
+ smem_layout_atom,
99
+ cute.append(shape, stage) if const_expr(stage is not None) else shape,
100
+ order=order if const_expr(stage is not None) else order[:2],
101
+ )
102
+ return smem_layout_staged