mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,262 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # pyre-unsafe
7
+
8
+ import ast
9
+ import copy
10
+ import functools
11
+ import linecache
12
+ import os
13
+ import sys
14
+ import tempfile
15
+ from enum import Enum
16
+ from typing import Any, Dict, List
17
+
18
+ import triton
19
+
20
+
21
+ class _ForLoopUnroller(ast.NodeTransformer):
22
+ def __init__(self, target, inline_variables, loop_iter):
23
+ self.loop_iter = loop_iter
24
+ self.target = target
25
+ self.inline_variables = inline_variables
26
+
27
+ def visit_Name(self, node):
28
+ if node.id != self.target:
29
+ return node
30
+ return ast.Name(str(self.loop_iter))
31
+
32
+ def visit_Subscript(self, node):
33
+ # Pattern-matching `value[slice]`
34
+ if (
35
+ isinstance(node.slice, ast.Name)
36
+ and node.slice.id == self.target
37
+ and isinstance(node.value, ast.Name)
38
+ and node.value.id in self.inline_variables
39
+ ):
40
+ return ast.Name(f"{node.value.id}{self.loop_iter}")
41
+ return node
42
+
43
+
44
+ class _VisitorVarargKernel(ast.NodeTransformer):
45
+ def __init__(self, N):
46
+ self.inline_variables = set()
47
+ self.N = N
48
+
49
+ def visit_AnnAssign(self, node):
50
+ # Pattern-matching:
51
+ # var_name: "VAR_ARGS_ARRAY"
52
+ if (
53
+ node.value is None
54
+ and node.simple == 1
55
+ and isinstance(node.target, ast.Name)
56
+ and isinstance(node.annotation, ast.Constant)
57
+ and node.annotation.value == "VAR_ARGS_ARRAY"
58
+ ):
59
+ self.inline_variables.add(node.target.id)
60
+ return []
61
+ if node.value is not None:
62
+ node.value = self.visit(node.value)
63
+ if node.annotation is not None:
64
+ node.annotation = self.visit(node.annotation)
65
+ if node.target is not None:
66
+ node.target = self.visit(node.target)
67
+ return node
68
+
69
+ def visit_arguments(self, node):
70
+ # Replace `args` annotated with `VAR_ARGS_ARRAY`
71
+ new_args = []
72
+ for arg in node.args:
73
+ if (
74
+ arg.annotation is not None
75
+ and isinstance(arg.annotation, ast.Constant)
76
+ and arg.annotation.value == "VAR_ARGS_ARRAY"
77
+ ):
78
+ self.inline_variables.add(arg.arg)
79
+ new_args += [ast.arg(f"{arg.arg}{i}") for i in range(self.N)]
80
+ continue
81
+ new_args.append(arg)
82
+ if node.vararg is not None:
83
+ self.inline_variables.add(node.vararg.arg)
84
+ new_args += [ast.arg(f"{node.vararg.arg}{i}") for i in range(self.N)]
85
+ node.vararg = None
86
+ new_args += node.kwonlyargs
87
+ node.kwonlyargs = []
88
+ node.args = new_args
89
+ return node
90
+
91
+
92
+ class _VisitorUnrollKernel(_VisitorVarargKernel):
93
+ def visit_For(self, node):
94
+ if (
95
+ not isinstance(node.iter, ast.Call)
96
+ or node.iter.func.id != "range"
97
+ or len(node.iter.args) != 1
98
+ or not isinstance(node.iter.args[0], ast.Call)
99
+ or node.iter.args[0].func.id != "len"
100
+ or len(node.iter.args[0].args) != 1
101
+ or node.iter.args[0].args[0].id not in self.inline_variables
102
+ ):
103
+ node.body = [self.visit(x) for x in node.body]
104
+ return node
105
+ # We know we have to modify this loop
106
+ new_nodes = []
107
+ for i in range(self.N):
108
+ unroller = _ForLoopUnroller(
109
+ target=node.target.id,
110
+ inline_variables=self.inline_variables,
111
+ loop_iter=i,
112
+ )
113
+ for body in node.body:
114
+ body = copy.deepcopy(body)
115
+ new_node = ast.fix_missing_locations(unroller.visit(body))
116
+ new_node = self.visit(new_node)
117
+ new_nodes.append(new_node)
118
+ return new_nodes
119
+
120
+
121
+ class _VisitorConditionalKernel(_VisitorVarargKernel):
122
+ def __init__(self, *args, **kwargs):
123
+ super().__init__(*args, **kwargs)
124
+ self.extra_nodes = None
125
+
126
+ def visit_Subscript(self, node):
127
+ if isinstance(node.value, ast.Subscript):
128
+ node.value = self.visit_Subscript(node.value)
129
+ return node
130
+ if not isinstance(node.value, ast.Name):
131
+ return node
132
+ if node.value.id in self.inline_variables and isinstance(node.slice, ast.Name):
133
+ # given `a[i]`, replace with `res`, where `res` is:
134
+ # a0 if i == 0 else a1 if i== 1 else a2 if i == 2 ...
135
+ if_statements = [None] * self.N
136
+ if_statements[-1] = ast.Name(f"{node.value.id}{self.N - 1}")
137
+
138
+ for i in reversed(range(self.N - 1)):
139
+ test = ast.Compare(node.slice, [ast.Eq()], [ast.Constant(i)])
140
+ body = ast.Name(f"{node.value.id}{i}")
141
+ if_statements[i] = ast.IfExp(
142
+ test=test,
143
+ body=body,
144
+ orelse=if_statements[i + 1],
145
+ )
146
+
147
+ return if_statements[0]
148
+ return node
149
+
150
+ def visit_Call(self, node):
151
+ if (
152
+ isinstance(node.func, ast.Name)
153
+ and node.func.id == "len"
154
+ and len(node.args) == 1
155
+ and isinstance(node.args[0], ast.Name)
156
+ and node.args[0].id in self.inline_variables
157
+ ):
158
+ return ast.Constant(self.N)
159
+ self.generic_visit(node)
160
+ return node
161
+
162
+
163
+ # Hackfix to get access to get source-code for
164
+ # `exec`-created functions - see https://stackoverflow.com/a/69668999
165
+ _getlines_orig = None
166
+ _FILENAME_TO_SRC: Dict[str, List[str]] = {}
167
+
168
+ # Materializing the codegen to disk can be useful for external tools, e.g. ncu
169
+ # Disabled by default because writing to disk at module import time is unexpected and error-prone.
170
+ _should_materialize_codegen = os.environ.get("XFORMERS_MATERIALIZE_CODEGEN") == "1"
171
+ _should_keep_materialized_source = os.environ.get("XFORMERS_KEEP_CODEGEN") == "1"
172
+ _tmp_dir = None
173
+
174
+
175
+ def _monkey_patched_getlines(filename, module_globals=None):
176
+ if filename in _FILENAME_TO_SRC:
177
+ return _FILENAME_TO_SRC[filename]
178
+ else:
179
+ return _getlines_orig(filename, module_globals) # type: ignore
180
+
181
+
182
+ class VarargMode(Enum):
183
+ UNROLL = "unroll"
184
+ CONDITIONAL = "conditional"
185
+
186
+
187
+ @functools.lru_cache(None)
188
+ def unroll_varargs(kernel, N: int, mode: VarargMode = VarargMode.UNROLL):
189
+ """
190
+ Specializes a triton kernel with variable number of inputs
191
+ to a specific number of inputs `N`.
192
+
193
+ `mode` can either be `UNROLL` or `CONDITIONAL`. Both options
194
+ implement the same functionality, but have different implementations
195
+ and can have different performance. In `UNROLL` mode, any loops that
196
+ loop over the varargs will be unrolled. In `CONDITIONAL` mode,
197
+ indexing into the list of varargs is replaced with conditional
198
+ statements like `a0 if i==0 else a1 if i==1 else a2...`.
199
+ `CONDITIONAL` mode is generally better if `N` is large, because it
200
+ generates a smaller triton kernel that should fit in the
201
+ instruction cache and will compile faster.
202
+
203
+ NOTE: Because it's quite costly to call `triton.jit`,
204
+ we cache the returned value with `lru_cache`
205
+ """
206
+ global _getlines_orig, _tmp_dir
207
+
208
+ k = triton.JITFunction(kernel.fn)
209
+ parsed = ast.parse(k.src) # type: ignore
210
+ if mode == VarargMode.UNROLL:
211
+ nodeVisitor: _VisitorVarargKernel = _VisitorUnrollKernel(N=N)
212
+ elif mode == VarargMode.CONDITIONAL:
213
+ nodeVisitor = _VisitorConditionalKernel(N=N)
214
+ parsed = nodeVisitor.visit(parsed)
215
+ parsed = ast.fix_missing_locations(parsed)
216
+
217
+ # NOTE: `ast.unparse` requires python 3.9+
218
+ if (sys.version_info.major, sys.version_info.minor) <= (3, 8):
219
+ raise RuntimeError("Error: This functionality requires python 3.9 or above")
220
+ new_src = ast.unparse(parsed) # type: ignore
221
+
222
+ # Now we want to `eval` the function, but we need all this
223
+ # boilerplate code to make sure triton can run `inspect.getsource`
224
+
225
+ fn_basename = f"unroll_varargs-{kernel.fn.__name__}-{mode.value}-{N}"
226
+ if _should_materialize_codegen:
227
+ if not _tmp_dir:
228
+ _tmp_dir = tempfile.TemporaryDirectory()
229
+ fn_filename = os.path.join(_tmp_dir.name, f"{fn_basename}.py")
230
+ if _should_keep_materialized_source:
231
+ # destroy the TemporaryDirectory object
232
+ _tmp_dir = None
233
+ # create path if not exists
234
+ os.makedirs(os.path.dirname(fn_filename), exist_ok=True)
235
+ with open(fn_filename, "w") as f:
236
+ f.write(new_src)
237
+ else:
238
+ # Patch `getlines` only the first time
239
+ if not _FILENAME_TO_SRC:
240
+ _getlines_orig = linecache.getlines
241
+ linecache.getlines = _monkey_patched_getlines
242
+ fn_filename = f"<{fn_basename}>"
243
+ _FILENAME_TO_SRC[fn_filename] = new_src.splitlines(keepends=True)
244
+
245
+ # Create function given source
246
+ code = compile(new_src, fn_filename, "exec")
247
+
248
+ _locals: Dict[str, Any] = {}
249
+ exec(code, kernel.fn.__globals__, _locals)
250
+ assert len(_locals) == 1, len(_locals)
251
+ fn = next(iter(_locals.values()))
252
+
253
+ jitted_fn = triton.jit(fn)
254
+ if not hasattr(jitted_fn, "_unsafe_update_src"):
255
+ # Triton older than 3.2
256
+ jitted_fn.src = new_src
257
+ return jitted_fn
258
+
259
+
260
+ # Note: just import this to make mypy happy
261
+ # when annotating variables with `VAR_ARGS_ARRAY`
262
+ VAR_ARGS_ARRAY = List[Any]