mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,224 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # pyre-unsafe
7
+
8
+
9
+ import textwrap
10
+ from collections import deque
11
+ from typing import Any, List, Optional, Sequence, Tuple, Type, TypeVar
12
+
13
+ import torch
14
+
15
+ from . import attn_bias, ck, cutlass, flash, flash3, flash_mtia, triton_splitk
16
+ from .common import AttentionBwOpBase, AttentionFwOpBase, Inputs
17
+
18
+ T = TypeVar("T", Type[AttentionFwOpBase], Type[AttentionBwOpBase])
19
+
20
+
21
+ try:
22
+ import mtia.host_runtime.torch_mtia.dynamic_library # noqa # type: ignore
23
+
24
+ # Use MTIA flash attention if the MTIA libraries are available
25
+ _USE_MTIA_FLASH_ATTENTION = True
26
+ except (ImportError, OSError):
27
+ # Failed to load MTIA libraries, so don't use MTIA flash attention
28
+ _USE_MTIA_FLASH_ATTENTION = False
29
+
30
+
31
+ _USE_FLASH_ATTENTION_3 = False
32
+
33
+
34
+ def _set_use_fa3(use_flash_attention3: bool) -> None:
35
+ global _USE_FLASH_ATTENTION_3
36
+ _USE_FLASH_ATTENTION_3 = use_flash_attention3
37
+
38
+
39
+ def _get_use_fa3() -> bool:
40
+ return _USE_FLASH_ATTENTION_3
41
+
42
+
43
+ def fa3_available() -> bool:
44
+ has_valid_flash3 = flash3._C_flashattention3 is not None # pyre-ignore[16]
45
+ is_90a = torch.version.cuda and torch.cuda.get_device_capability() >= (9, 0)
46
+ return has_valid_flash3 and is_90a
47
+
48
+
49
+ def _format_inputs_description(inp: Inputs) -> str:
50
+ return f"""query : shape={tuple(inp.query.shape)} ({inp.query.dtype})
51
+ key : shape={tuple(inp.key.shape)} ({inp.key.dtype})
52
+ value : shape={tuple(inp.value.shape)} ({inp.value.dtype})
53
+ attn_bias : {type(inp.attn_bias)}
54
+ p : {inp.p}"""
55
+
56
+
57
+ def _ensure_op_supports_or_raise(exc_type, name: str, op, inp: Inputs) -> None:
58
+ reasons = op.not_supported_reasons(inp)
59
+ if not reasons:
60
+ return
61
+ raise exc_type(
62
+ f"""Operator `{name}` does not support inputs:
63
+ {textwrap.indent(_format_inputs_description(inp), " ")}
64
+ {_format_not_supported_reasons(op, reasons)}"""
65
+ )
66
+
67
+
68
+ def _format_not_supported_reasons(op, reasons: List[str]) -> str:
69
+ return f"`{op.NAME}` is not supported because:\n " + "\n ".join(reasons)
70
+
71
+
72
+ def _run_priority_list(
73
+ name: str,
74
+ priority_list: Sequence[T],
75
+ inp: Inputs,
76
+ extra_op_reasons: Optional[List[Tuple[Any, List[str]]]] = None,
77
+ ) -> T:
78
+ not_supported_reasons: List[List[str]] = []
79
+ for op in priority_list:
80
+ not_supported = op.not_supported_reasons(inp)
81
+ if not not_supported:
82
+ return op
83
+ not_supported_reasons.append(not_supported)
84
+
85
+ # Let's write a nice message explaining what we tried and why it's not supported
86
+ msg = f"""No operator found for `{name}` with inputs:
87
+ {textwrap.indent(_format_inputs_description(inp), " ")}"""
88
+ for op, not_supported in zip(priority_list, not_supported_reasons):
89
+ msg += "\n" + _format_not_supported_reasons(op, not_supported)
90
+ if extra_op_reasons is not None:
91
+ for op, not_supported in extra_op_reasons:
92
+ msg += "\n" + _format_not_supported_reasons(op, not_supported)
93
+ raise NotImplementedError(msg)
94
+
95
+
96
+ def _dispatch_fw_priority_list(
97
+ inp: Inputs, needs_gradient: bool
98
+ ) -> Sequence[Type[AttentionFwOpBase]]:
99
+ if torch.version.cuda:
100
+ flash3_op = [flash3.FwOp] if _get_use_fa3() else []
101
+ priority_list_ops = deque(
102
+ flash3_op
103
+ + [
104
+ flash.FwOp,
105
+ cutlass.FwOp,
106
+ ]
107
+ )
108
+ else:
109
+ priority_list_ops = deque(
110
+ [
111
+ ck.FwOp,
112
+ ]
113
+ )
114
+ priority_list_ops.append(triton_splitk.FwOp)
115
+ if not needs_gradient:
116
+ mqa_or_gqa = (
117
+ inp.key.ndim > 3 and inp.key.stride(-2) == 0 and inp.key.shape[-2] > 1
118
+ )
119
+ # Split-KV is useful with MQA
120
+ # for short Q-seqlen / long K-seqlen
121
+ if mqa_or_gqa and inp.query.shape[1] <= 32 and inp.key.shape[1] >= 256:
122
+ parallelism_BH = 0 # BMK
123
+ if inp.query.ndim == 3:
124
+ parallelism_BH = inp.query.shape[0]
125
+ elif inp.query.ndim == 4: # BMHK
126
+ parallelism_BH = inp.query.shape[0] * inp.query.shape[2]
127
+ elif inp.query.ndim == 5: # BMGHK
128
+ parallelism_BH = inp.query.shape[0] * inp.query.shape[2]
129
+ if (
130
+ parallelism_BH > 0
131
+ and parallelism_BH < 64
132
+ and not torch.mtia.is_available()
133
+ ):
134
+ # priority_list_ops.appendleft(ck_splitk.FwOp)
135
+ priority_list_ops.remove(triton_splitk.FwOp)
136
+ priority_list_ops.appendleft(triton_splitk.FwOp)
137
+ # Without variable seqlen flash is fastest
138
+ if torch.version.cuda and not isinstance(
139
+ inp.attn_bias, attn_bias.BlockDiagonalMask
140
+ ):
141
+ if _get_use_fa3():
142
+ priority_list_ops.remove(flash3.FwOp)
143
+ priority_list_ops.remove(flash.FwOp)
144
+ priority_list_ops.appendleft(flash.FwOp)
145
+
146
+ # torch.mtia.is_available() cannot be called here because it isn't supported
147
+ # when tracing with PT2, so we simply add flash_mtia to the end if the MTIA
148
+ # dynamic library can be loaded
149
+ if _USE_MTIA_FLASH_ATTENTION:
150
+ priority_list_ops.append(flash_mtia.FwOp)
151
+
152
+ return priority_list_ops
153
+
154
+
155
+ def _dispatch_fw(inp: Inputs, needs_gradient: bool) -> Type[AttentionFwOpBase]:
156
+ """Computes the best operator for forward
157
+
158
+ Raises:
159
+ NotImplementedError: if not operator was found
160
+
161
+ Returns:
162
+ AttentionOp: The best operator for the configuration
163
+ """
164
+ return _run_priority_list(
165
+ "memory_efficient_attention_forward",
166
+ _dispatch_fw_priority_list(inp, needs_gradient),
167
+ inp,
168
+ )
169
+
170
+
171
+ def _is_cutlassB_faster_than_flash(inp: Inputs) -> bool:
172
+ return False
173
+
174
+
175
+ def _dispatch_bw(
176
+ inp: Inputs, varlen_lse_packed: Optional[bool]
177
+ ) -> Type[AttentionBwOpBase]:
178
+ if torch.version.cuda:
179
+ priority_list_ops: List[Type[AttentionBwOpBase]] = [
180
+ flash.BwOp,
181
+ cutlass.BwOp,
182
+ ]
183
+ if _get_use_fa3():
184
+ priority_list_ops = [flash3.BwOp] + priority_list_ops
185
+ else:
186
+ priority_list_ops: List[Type[AttentionBwOpBase]] = [
187
+ ck.BwOp,
188
+ ]
189
+
190
+ # NOTE: If we have a variable seqlen `attn_bias`, we need to get a BW pass
191
+ # that supports the LSE format
192
+ # *unless* we are in the case where both formats are the same (bs=1)
193
+ extra_op_reasons = []
194
+ if (
195
+ isinstance(inp.attn_bias, attn_bias.VARLEN_BIASES)
196
+ and inp.attn_bias.q_seqinfo.seqstart.shape[0] > 2
197
+ ):
198
+ assert varlen_lse_packed is not None
199
+ for op in priority_list_ops:
200
+ if op.VARLEN_LSE_PACKED != varlen_lse_packed:
201
+ extra_op_reasons.append(
202
+ (
203
+ op,
204
+ [
205
+ f"LSE is in {'packed' if varlen_lse_packed else 'padded'} format"
206
+ ],
207
+ )
208
+ )
209
+ priority_list_ops = [
210
+ op for op in priority_list_ops if op.VARLEN_LSE_PACKED == varlen_lse_packed
211
+ ]
212
+ if torch.version.cuda and _is_cutlassB_faster_than_flash(inp):
213
+ priority_list_ops.remove(cutlass.BwOp)
214
+ priority_list_ops.insert(0, cutlass.BwOp)
215
+
216
+ # torch.mtia.is_available() cannot be called here because it isn't supported
217
+ # when tracing with PT2, so we simply add flash_mtia to the end if the MTIA
218
+ # dynamic library can be loaded
219
+ if _USE_MTIA_FLASH_ATTENTION:
220
+ priority_list_ops.append(flash_mtia.BwOp)
221
+
222
+ return _run_priority_list(
223
+ "memory_efficient_attention_backward", priority_list_ops, inp
224
+ )