mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,204 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # pyre-unsafe
7
+
8
+ from typing import Any, Iterable, List, Optional, Tuple
9
+
10
+ import torch
11
+
12
+ from .attn_bias import BlockDiagonalCausalWithOffsetPaddedKeysMask
13
+ from .common import AttentionFwOpBase, check_lastdim_alignment_stride1, Context, Inputs
14
+ from .utils.op_common import get_operator, register_operator
15
+
16
+
17
+ @register_operator
18
+ class FwOp(AttentionFwOpBase):
19
+ OPERATOR = get_operator("xformers", "efficient_attention_forward_decoder_splitk_ck")
20
+ SUPPORTED_DEVICES = {"cuda"}
21
+ SUPPORTED_DTYPES = {
22
+ torch.half,
23
+ torch.bfloat16,
24
+ torch.float,
25
+ } # Those are dtypes of Q. In the quantized case K/V has dtype int32
26
+ SUPPORTED_MAX_K = 256
27
+ SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = (
28
+ type(None),
29
+ BlockDiagonalCausalWithOffsetPaddedKeysMask,
30
+ )
31
+ SUPPORTS_DROPOUT = False
32
+ SUPPORTS_CUSTOM_SCALE = True
33
+ SUPPORTS_BMGHK = True
34
+ NAME = "ck_splitKF"
35
+
36
+ SPLIT_K: Optional[int] = None
37
+ BLOCK_M = 16
38
+ BLOCK_N = 64
39
+
40
+ NUM_GROUPS = 1 # Default quantization is row-wise
41
+
42
+ @classmethod
43
+ def shape_not_supported_reasons(
44
+ cls, Mq: int, Mkv: int, K: int, Kv: int
45
+ ) -> List[str]:
46
+ reasons = super().shape_not_supported_reasons(Mq, Mkv, K, Kv)
47
+ # if K not in {16, 32, 64, 128}:
48
+ # reasons.append(f"Embed dim {K} not supported")
49
+ return reasons
50
+
51
+ @classmethod
52
+ def not_supported_reasons(cls, d: Inputs) -> List[str]:
53
+ reasons = super(FwOp, cls).not_supported_reasons(d)
54
+ check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
55
+ if d.key.dtype != torch.int32:
56
+ check_lastdim_alignment_stride1(reasons, "key", d.key, 8)
57
+ check_lastdim_alignment_stride1(reasons, "value", d.value, 8)
58
+ if cls.OPERATOR is None:
59
+ reasons.append("triton is not available")
60
+ if d.device.type == "cuda":
61
+ # Has only been tested on 8.0 / 9.0.
62
+ if torch.cuda.get_device_capability(d.device) < (7, 0):
63
+ reasons.append(
64
+ "requires GPU with sm80 minimum compute capacity, e.g., A100/H100/L4"
65
+ )
66
+
67
+ q_len = d.query.shape[1]
68
+ if isinstance(d.attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask):
69
+ seqinfo = d.attn_bias.q_seqinfo
70
+ if q_len != seqinfo.seqstart_py[-1]:
71
+ reasons.append(
72
+ f"Expected total {seqinfo.seqstart_py[-1]} queries not {q_len}"
73
+ )
74
+ q_len = seqinfo.min_seqlen
75
+ if q_len != seqinfo.max_seqlen:
76
+ reasons.append(
77
+ "Variable query len is not supported in the presence of causal mask."
78
+ )
79
+
80
+ if d.key.ndim in [4, 5] and d.key.shape[-2] != 1:
81
+ if d.key.stride(-2) == 0 and d.value.stride(-2) == 0 and q_len > 1:
82
+ reasons.append("multiquery is only supported with query seqlen=1")
83
+
84
+ if d.attn_bias is not None and q_len > 1:
85
+ reasons.append(
86
+ "query with seqlen > 1 is not supported in the presence of causal mask"
87
+ )
88
+ return reasons
89
+
90
+ @classmethod
91
+ def get_split_k(cls, B: int, H: int, Mk: int) -> int:
92
+ """Heuristic for the number of splits"""
93
+ bh = max(B * H, 1) # NOTE: Handle B*h=0 case
94
+ split_k = max(Mk, 1024) // bh
95
+ max_chunk_size = 64 if Mk <= 512 and bh <= 64 else 128
96
+ while split_k > 0 and Mk / split_k < max_chunk_size:
97
+ split_k = split_k // 2
98
+ split_k = min(split_k, 64)
99
+ split_k = max(split_k, 1)
100
+ return split_k
101
+
102
+ @classmethod
103
+ def apply(
104
+ cls, inp: Inputs, needs_gradient: bool
105
+ ) -> Tuple[torch.Tensor, Optional[Context]]:
106
+ attn_bias = inp.attn_bias
107
+ q, k, v = inp.get_qkv_in_bmghk()
108
+
109
+ if attn_bias is not None:
110
+ assert isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask)
111
+ attn_bias.k_seqinfo.to(k.device)
112
+ attn_bias.q_seqinfo.to(q.device)
113
+ padding = attn_bias.k_seqinfo.padding
114
+ seq_positions_gpu = attn_bias.k_seqinfo.seqlen
115
+ else:
116
+ padding = k.shape[1]
117
+ seq_positions_gpu = None
118
+
119
+ if attn_bias is not None:
120
+ # key: (1, B * padding, G, 1 if multiquery else Hkv, D)
121
+ # value: like key
122
+ # query: (1, B * q_seqlen, G, Hq, D)
123
+ multiquery = k.stride(3) == 0
124
+ if multiquery:
125
+ key = k[0, :, :, :1].unflatten(0, (-1, padding))
126
+ value = v[0, :, :, :1].unflatten(0, (-1, padding))
127
+ else:
128
+ key = k[0].unflatten(0, (-1, padding))
129
+ value = v[0].unflatten(0, (-1, padding))
130
+ query = q[0].unflatten(0, (key.shape[0], -1))
131
+ else:
132
+ # key: (B, padding, G, 1 if multiquery else Hkv, D)
133
+ # value: like key
134
+ # query: (B, q_seqlen, G, Hq, D)
135
+ key = k
136
+ query = q
137
+ value = v
138
+
139
+ B, _, _, H, _ = query.shape
140
+ _, Mk, _, _, _ = key.shape
141
+
142
+ if cls.SPLIT_K is not None:
143
+ split_k = cls.SPLIT_K
144
+ else:
145
+ # Use heuristics
146
+ split_k = cls.get_split_k(B, H, Mk)
147
+
148
+ if inp.scale is not None:
149
+ qk_scale = inp.scale
150
+ else:
151
+ qk_scale = torch.rsqrt(
152
+ torch.tensor(k.shape[-1], dtype=torch.float32)
153
+ ).item()
154
+
155
+ out = cls.OPERATOR(
156
+ query=query,
157
+ key=key,
158
+ value=value,
159
+ seq_positions=seq_positions_gpu,
160
+ scale=qk_scale,
161
+ split_k=split_k,
162
+ )
163
+
164
+ return out, None
165
+
166
+
167
+ class FwOp_S1(FwOp):
168
+ SPLIT_K = 1
169
+ NAME = "ck_splitK1"
170
+
171
+
172
+ class FwOp_S2(FwOp):
173
+ SPLIT_K = 2
174
+ NAME = "ck_splitK2"
175
+
176
+
177
+ class FwOp_S4(FwOp):
178
+ SPLIT_K = 4
179
+ NAME = "ck_splitK4"
180
+
181
+
182
+ class FwOp_S8(FwOp):
183
+ SPLIT_K = 8
184
+ NAME = "ck_splitK8"
185
+
186
+
187
+ class FwOp_S16(FwOp):
188
+ SPLIT_K = 16
189
+ NAME = "ck_splitK16"
190
+
191
+
192
+ class FwOp_S32(FwOp):
193
+ SPLIT_K = 32
194
+ NAME = "ck_splitK32"
195
+
196
+
197
+ class FwOp_S64(FwOp):
198
+ SPLIT_K = 64
199
+ NAME = "ck_splitK64"
200
+
201
+
202
+ class FwOp_S128(FwOp):
203
+ SPLIT_K = 128
204
+ NAME = "ck_splitK128"