mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
# pyre-unsafe
|
|
7
|
+
|
|
8
|
+
import ast
|
|
9
|
+
import copy
|
|
10
|
+
import functools
|
|
11
|
+
import linecache
|
|
12
|
+
import os
|
|
13
|
+
import sys
|
|
14
|
+
import tempfile
|
|
15
|
+
from enum import Enum
|
|
16
|
+
from typing import Any, Dict, List
|
|
17
|
+
|
|
18
|
+
import triton
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class _ForLoopUnroller(ast.NodeTransformer):
|
|
22
|
+
def __init__(self, target, inline_variables, loop_iter):
|
|
23
|
+
self.loop_iter = loop_iter
|
|
24
|
+
self.target = target
|
|
25
|
+
self.inline_variables = inline_variables
|
|
26
|
+
|
|
27
|
+
def visit_Name(self, node):
|
|
28
|
+
if node.id != self.target:
|
|
29
|
+
return node
|
|
30
|
+
return ast.Name(str(self.loop_iter))
|
|
31
|
+
|
|
32
|
+
def visit_Subscript(self, node):
|
|
33
|
+
# Pattern-matching `value[slice]`
|
|
34
|
+
if (
|
|
35
|
+
isinstance(node.slice, ast.Name)
|
|
36
|
+
and node.slice.id == self.target
|
|
37
|
+
and isinstance(node.value, ast.Name)
|
|
38
|
+
and node.value.id in self.inline_variables
|
|
39
|
+
):
|
|
40
|
+
return ast.Name(f"{node.value.id}{self.loop_iter}")
|
|
41
|
+
return node
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class _VisitorVarargKernel(ast.NodeTransformer):
|
|
45
|
+
def __init__(self, N):
|
|
46
|
+
self.inline_variables = set()
|
|
47
|
+
self.N = N
|
|
48
|
+
|
|
49
|
+
def visit_AnnAssign(self, node):
|
|
50
|
+
# Pattern-matching:
|
|
51
|
+
# var_name: "VAR_ARGS_ARRAY"
|
|
52
|
+
if (
|
|
53
|
+
node.value is None
|
|
54
|
+
and node.simple == 1
|
|
55
|
+
and isinstance(node.target, ast.Name)
|
|
56
|
+
and isinstance(node.annotation, ast.Constant)
|
|
57
|
+
and node.annotation.value == "VAR_ARGS_ARRAY"
|
|
58
|
+
):
|
|
59
|
+
self.inline_variables.add(node.target.id)
|
|
60
|
+
return []
|
|
61
|
+
if node.value is not None:
|
|
62
|
+
node.value = self.visit(node.value)
|
|
63
|
+
if node.annotation is not None:
|
|
64
|
+
node.annotation = self.visit(node.annotation)
|
|
65
|
+
if node.target is not None:
|
|
66
|
+
node.target = self.visit(node.target)
|
|
67
|
+
return node
|
|
68
|
+
|
|
69
|
+
def visit_arguments(self, node):
|
|
70
|
+
# Replace `args` annotated with `VAR_ARGS_ARRAY`
|
|
71
|
+
new_args = []
|
|
72
|
+
for arg in node.args:
|
|
73
|
+
if (
|
|
74
|
+
arg.annotation is not None
|
|
75
|
+
and isinstance(arg.annotation, ast.Constant)
|
|
76
|
+
and arg.annotation.value == "VAR_ARGS_ARRAY"
|
|
77
|
+
):
|
|
78
|
+
self.inline_variables.add(arg.arg)
|
|
79
|
+
new_args += [ast.arg(f"{arg.arg}{i}") for i in range(self.N)]
|
|
80
|
+
continue
|
|
81
|
+
new_args.append(arg)
|
|
82
|
+
if node.vararg is not None:
|
|
83
|
+
self.inline_variables.add(node.vararg.arg)
|
|
84
|
+
new_args += [ast.arg(f"{node.vararg.arg}{i}") for i in range(self.N)]
|
|
85
|
+
node.vararg = None
|
|
86
|
+
new_args += node.kwonlyargs
|
|
87
|
+
node.kwonlyargs = []
|
|
88
|
+
node.args = new_args
|
|
89
|
+
return node
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class _VisitorUnrollKernel(_VisitorVarargKernel):
|
|
93
|
+
def visit_For(self, node):
|
|
94
|
+
if (
|
|
95
|
+
not isinstance(node.iter, ast.Call)
|
|
96
|
+
or node.iter.func.id != "range"
|
|
97
|
+
or len(node.iter.args) != 1
|
|
98
|
+
or not isinstance(node.iter.args[0], ast.Call)
|
|
99
|
+
or node.iter.args[0].func.id != "len"
|
|
100
|
+
or len(node.iter.args[0].args) != 1
|
|
101
|
+
or node.iter.args[0].args[0].id not in self.inline_variables
|
|
102
|
+
):
|
|
103
|
+
node.body = [self.visit(x) for x in node.body]
|
|
104
|
+
return node
|
|
105
|
+
# We know we have to modify this loop
|
|
106
|
+
new_nodes = []
|
|
107
|
+
for i in range(self.N):
|
|
108
|
+
unroller = _ForLoopUnroller(
|
|
109
|
+
target=node.target.id,
|
|
110
|
+
inline_variables=self.inline_variables,
|
|
111
|
+
loop_iter=i,
|
|
112
|
+
)
|
|
113
|
+
for body in node.body:
|
|
114
|
+
body = copy.deepcopy(body)
|
|
115
|
+
new_node = ast.fix_missing_locations(unroller.visit(body))
|
|
116
|
+
new_node = self.visit(new_node)
|
|
117
|
+
new_nodes.append(new_node)
|
|
118
|
+
return new_nodes
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class _VisitorConditionalKernel(_VisitorVarargKernel):
|
|
122
|
+
def __init__(self, *args, **kwargs):
|
|
123
|
+
super().__init__(*args, **kwargs)
|
|
124
|
+
self.extra_nodes = None
|
|
125
|
+
|
|
126
|
+
def visit_Subscript(self, node):
|
|
127
|
+
if isinstance(node.value, ast.Subscript):
|
|
128
|
+
node.value = self.visit_Subscript(node.value)
|
|
129
|
+
return node
|
|
130
|
+
if not isinstance(node.value, ast.Name):
|
|
131
|
+
return node
|
|
132
|
+
if node.value.id in self.inline_variables and isinstance(node.slice, ast.Name):
|
|
133
|
+
# given `a[i]`, replace with `res`, where `res` is:
|
|
134
|
+
# a0 if i == 0 else a1 if i== 1 else a2 if i == 2 ...
|
|
135
|
+
if_statements = [None] * self.N
|
|
136
|
+
if_statements[-1] = ast.Name(f"{node.value.id}{self.N - 1}")
|
|
137
|
+
|
|
138
|
+
for i in reversed(range(self.N - 1)):
|
|
139
|
+
test = ast.Compare(node.slice, [ast.Eq()], [ast.Constant(i)])
|
|
140
|
+
body = ast.Name(f"{node.value.id}{i}")
|
|
141
|
+
if_statements[i] = ast.IfExp(
|
|
142
|
+
test=test,
|
|
143
|
+
body=body,
|
|
144
|
+
orelse=if_statements[i + 1],
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
return if_statements[0]
|
|
148
|
+
return node
|
|
149
|
+
|
|
150
|
+
def visit_Call(self, node):
|
|
151
|
+
if (
|
|
152
|
+
isinstance(node.func, ast.Name)
|
|
153
|
+
and node.func.id == "len"
|
|
154
|
+
and len(node.args) == 1
|
|
155
|
+
and isinstance(node.args[0], ast.Name)
|
|
156
|
+
and node.args[0].id in self.inline_variables
|
|
157
|
+
):
|
|
158
|
+
return ast.Constant(self.N)
|
|
159
|
+
self.generic_visit(node)
|
|
160
|
+
return node
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
# Hackfix to get access to get source-code for
|
|
164
|
+
# `exec`-created functions - see https://stackoverflow.com/a/69668999
|
|
165
|
+
_getlines_orig = None
|
|
166
|
+
_FILENAME_TO_SRC: Dict[str, List[str]] = {}
|
|
167
|
+
|
|
168
|
+
# Materializing the codegen to disk can be useful for external tools, e.g. ncu
|
|
169
|
+
# Disabled by default because writing to disk at module import time is unexpected and error-prone.
|
|
170
|
+
_should_materialize_codegen = os.environ.get("XFORMERS_MATERIALIZE_CODEGEN") == "1"
|
|
171
|
+
_should_keep_materialized_source = os.environ.get("XFORMERS_KEEP_CODEGEN") == "1"
|
|
172
|
+
_tmp_dir = None
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def _monkey_patched_getlines(filename, module_globals=None):
|
|
176
|
+
if filename in _FILENAME_TO_SRC:
|
|
177
|
+
return _FILENAME_TO_SRC[filename]
|
|
178
|
+
else:
|
|
179
|
+
return _getlines_orig(filename, module_globals) # type: ignore
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class VarargMode(Enum):
|
|
183
|
+
UNROLL = "unroll"
|
|
184
|
+
CONDITIONAL = "conditional"
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
@functools.lru_cache(None)
|
|
188
|
+
def unroll_varargs(kernel, N: int, mode: VarargMode = VarargMode.UNROLL):
|
|
189
|
+
"""
|
|
190
|
+
Specializes a triton kernel with variable number of inputs
|
|
191
|
+
to a specific number of inputs `N`.
|
|
192
|
+
|
|
193
|
+
`mode` can either be `UNROLL` or `CONDITIONAL`. Both options
|
|
194
|
+
implement the same functionality, but have different implementations
|
|
195
|
+
and can have different performance. In `UNROLL` mode, any loops that
|
|
196
|
+
loop over the varargs will be unrolled. In `CONDITIONAL` mode,
|
|
197
|
+
indexing into the list of varargs is replaced with conditional
|
|
198
|
+
statements like `a0 if i==0 else a1 if i==1 else a2...`.
|
|
199
|
+
`CONDITIONAL` mode is generally better if `N` is large, because it
|
|
200
|
+
generates a smaller triton kernel that should fit in the
|
|
201
|
+
instruction cache and will compile faster.
|
|
202
|
+
|
|
203
|
+
NOTE: Because it's quite costly to call `triton.jit`,
|
|
204
|
+
we cache the returned value with `lru_cache`
|
|
205
|
+
"""
|
|
206
|
+
global _getlines_orig, _tmp_dir
|
|
207
|
+
|
|
208
|
+
k = triton.JITFunction(kernel.fn)
|
|
209
|
+
parsed = ast.parse(k.src) # type: ignore
|
|
210
|
+
if mode == VarargMode.UNROLL:
|
|
211
|
+
nodeVisitor: _VisitorVarargKernel = _VisitorUnrollKernel(N=N)
|
|
212
|
+
elif mode == VarargMode.CONDITIONAL:
|
|
213
|
+
nodeVisitor = _VisitorConditionalKernel(N=N)
|
|
214
|
+
parsed = nodeVisitor.visit(parsed)
|
|
215
|
+
parsed = ast.fix_missing_locations(parsed)
|
|
216
|
+
|
|
217
|
+
# NOTE: `ast.unparse` requires python 3.9+
|
|
218
|
+
if (sys.version_info.major, sys.version_info.minor) <= (3, 8):
|
|
219
|
+
raise RuntimeError("Error: This functionality requires python 3.9 or above")
|
|
220
|
+
new_src = ast.unparse(parsed) # type: ignore
|
|
221
|
+
|
|
222
|
+
# Now we want to `eval` the function, but we need all this
|
|
223
|
+
# boilerplate code to make sure triton can run `inspect.getsource`
|
|
224
|
+
|
|
225
|
+
fn_basename = f"unroll_varargs-{kernel.fn.__name__}-{mode.value}-{N}"
|
|
226
|
+
if _should_materialize_codegen:
|
|
227
|
+
if not _tmp_dir:
|
|
228
|
+
_tmp_dir = tempfile.TemporaryDirectory()
|
|
229
|
+
fn_filename = os.path.join(_tmp_dir.name, f"{fn_basename}.py")
|
|
230
|
+
if _should_keep_materialized_source:
|
|
231
|
+
# destroy the TemporaryDirectory object
|
|
232
|
+
_tmp_dir = None
|
|
233
|
+
# create path if not exists
|
|
234
|
+
os.makedirs(os.path.dirname(fn_filename), exist_ok=True)
|
|
235
|
+
with open(fn_filename, "w") as f:
|
|
236
|
+
f.write(new_src)
|
|
237
|
+
else:
|
|
238
|
+
# Patch `getlines` only the first time
|
|
239
|
+
if not _FILENAME_TO_SRC:
|
|
240
|
+
_getlines_orig = linecache.getlines
|
|
241
|
+
linecache.getlines = _monkey_patched_getlines
|
|
242
|
+
fn_filename = f"<{fn_basename}>"
|
|
243
|
+
_FILENAME_TO_SRC[fn_filename] = new_src.splitlines(keepends=True)
|
|
244
|
+
|
|
245
|
+
# Create function given source
|
|
246
|
+
code = compile(new_src, fn_filename, "exec")
|
|
247
|
+
|
|
248
|
+
_locals: Dict[str, Any] = {}
|
|
249
|
+
exec(code, kernel.fn.__globals__, _locals)
|
|
250
|
+
assert len(_locals) == 1, len(_locals)
|
|
251
|
+
fn = next(iter(_locals.values()))
|
|
252
|
+
|
|
253
|
+
jitted_fn = triton.jit(fn)
|
|
254
|
+
if not hasattr(jitted_fn, "_unsafe_update_src"):
|
|
255
|
+
# Triton older than 3.2
|
|
256
|
+
jitted_fn.src = new_src
|
|
257
|
+
return jitted_fn
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
# Note: just import this to make mypy happy
|
|
261
|
+
# when annotating variables with `VAR_ARGS_ARRAY`
|
|
262
|
+
VAR_ARGS_ARRAY = List[Any]
|