mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,109 @@
1
+ # @nolint # fbcode
2
+ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
+ from typing import Tuple, Optional
4
+ from dataclasses import dataclass
5
+
6
+ import cutlass
7
+ import cutlass.cute as cute
8
+ from cutlass import Int32, const_expr
9
+
10
+ from mslk.attention.flash_attn.seqlen_info import SeqlenInfoQK
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class BlockInfo:
15
+ tile_m: cutlass.Constexpr[int]
16
+ tile_n: cutlass.Constexpr[int]
17
+ is_causal: cutlass.Constexpr[bool]
18
+ is_local: cutlass.Constexpr[bool] = False
19
+ is_split_kv: cutlass.Constexpr[bool] = False
20
+ window_size_left: Optional[Int32] = None
21
+ window_size_right: Optional[Int32] = None
22
+ qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1
23
+
24
+ @cute.jit
25
+ def get_n_block_min_max(
26
+ self,
27
+ seqlen_info: SeqlenInfoQK,
28
+ m_block: Int32,
29
+ split_idx: cutlass.Int32 = 0,
30
+ num_splits: cutlass.Int32 = 1,
31
+ ) -> Tuple[Int32, Int32]:
32
+ n_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.tile_n)
33
+ if const_expr(self.is_causal or (self.is_local and self.window_size_right is not None)):
34
+ m_idx_max = (m_block + 1) * self.tile_m
35
+ if const_expr(self.qhead_per_kvhead_packgqa > 1):
36
+ m_idx_max = cute.ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa)
37
+ n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q
38
+ n_idx_right = n_idx if const_expr(self.is_causal) else n_idx + self.window_size_right
39
+ n_block_max = min(n_block_max, cute.ceil_div(n_idx_right, self.tile_n))
40
+ n_block_min = 0
41
+ if const_expr(self.is_local and self.window_size_left is not None):
42
+ m_idx_min = m_block * self.tile_m
43
+ if const_expr(self.qhead_per_kvhead_packgqa > 1):
44
+ m_idx_min = m_idx_min // self.qhead_per_kvhead_packgqa
45
+ n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q
46
+ n_idx_left = n_idx - self.window_size_left
47
+ n_block_min = cutlass.max(n_idx_left // self.tile_n, 0)
48
+ if cutlass.const_expr(self.is_split_kv):
49
+ num_n_blocks_per_split = (
50
+ cutlass.Int32(0)
51
+ if n_block_max <= n_block_min
52
+ else (n_block_max - n_block_min + num_splits - 1) // num_splits
53
+ )
54
+ n_block_min = n_block_min + split_idx * num_n_blocks_per_split
55
+ n_block_max = cutlass.min(n_block_min + num_n_blocks_per_split, n_block_max)
56
+ return n_block_min, n_block_max
57
+
58
+ @cute.jit
59
+ def get_m_block_min_max(self, seqlen_info: SeqlenInfoQK, n_block: Int32) -> Tuple[Int32, Int32]:
60
+ m_block_max = cute.ceil_div(seqlen_info.seqlen_q, self.tile_m)
61
+ m_block_min = 0
62
+ if const_expr(self.is_causal or (self.is_local and self.window_size_right is not None)):
63
+ n_idx_min = n_block * self.tile_n
64
+ m_idx = n_idx_min + seqlen_info.seqlen_q - seqlen_info.seqlen_k
65
+ m_idx_right = m_idx if const_expr(self.is_causal) else m_idx - self.window_size_right
66
+ m_block_min = max(m_block_min, m_idx_right // self.tile_m)
67
+ if const_expr(self.is_local and self.window_size_left is not None):
68
+ n_idx_max = (n_block + 1) * self.tile_n
69
+ m_idx = n_idx_max + seqlen_info.seqlen_q - seqlen_info.seqlen_k
70
+ m_idx_left = m_idx + self.window_size_left
71
+ m_block_max = min(m_block_max, cute.ceil_div(m_idx_left, self.tile_m))
72
+ return m_block_min, m_block_max
73
+
74
+ @cute.jit
75
+ def get_n_block_min_causal_local_mask(
76
+ self,
77
+ seqlen_info: SeqlenInfoQK,
78
+ m_block: Int32,
79
+ n_block_min: Int32,
80
+ ) -> Int32:
81
+ """If we have separate iterations with causal or local masking at the start, where do we stop"""
82
+ m_idx_min = m_block * self.tile_m
83
+ if const_expr(self.qhead_per_kvhead_packgqa > 1):
84
+ m_idx_min = m_idx_min // self.qhead_per_kvhead_packgqa
85
+ n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q
86
+ n_idx_right = (
87
+ n_idx
88
+ if const_expr(not self.is_local or self.window_size_right is None)
89
+ else n_idx + self.window_size_right
90
+ )
91
+ return cutlass.max(n_block_min, n_idx_right // self.tile_n)
92
+
93
+ @cute.jit
94
+ def get_n_block_min_before_local_mask(
95
+ self,
96
+ seqlen_info: SeqlenInfoQK,
97
+ m_block: Int32,
98
+ n_block_min: Int32,
99
+ ) -> Int32:
100
+ """If we have separate iterations with local masking at the end, where do we stop the non-masked iterations"""
101
+ if const_expr(not self.is_local or self.window_size_left is None):
102
+ return n_block_min
103
+ else:
104
+ m_idx_max = (m_block + 1) * self.tile_m
105
+ if const_expr(self.qhead_per_kvhead_packgqa > 1):
106
+ m_idx_max = cute.ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa)
107
+ n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q
108
+ n_idx_left = n_idx - self.window_size_left
109
+ return cutlass.max(n_block_min, cute.ceil_div(n_idx_left, self.tile_n))