mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
mslk/testing/__init__.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
import subprocess
|
|
11
|
+
from typing import Tuple
|
|
12
|
+
|
|
13
|
+
import mslk
|
|
14
|
+
import torch
|
|
15
|
+
|
|
16
|
+
################################################################################
|
|
17
|
+
# Unit test skip attributes for environments
|
|
18
|
+
################################################################################
|
|
19
|
+
|
|
20
|
+
gpu_unavailable: Tuple[bool, str] = (
|
|
21
|
+
not torch.cuda.is_available() or torch.cuda.device_count() == 0,
|
|
22
|
+
"GPU is not available or no GPUs detected",
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
running_in_github: Tuple[bool, str] = (
|
|
26
|
+
os.getenv("GITHUB_ENV") is not None,
|
|
27
|
+
"Test fails or hangs when run in the GitHub runners",
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
running_in_oss: Tuple[bool, str] = (
|
|
31
|
+
# pyre-ignore [16]
|
|
32
|
+
getattr(mslk, "open_source", False),
|
|
33
|
+
"Test is currently known to fail in OSS mode",
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
################################################################################
|
|
37
|
+
# Unit test skip attributes for platforms
|
|
38
|
+
################################################################################
|
|
39
|
+
|
|
40
|
+
running_on_arm: Tuple[bool, str] = (
|
|
41
|
+
subprocess.run(["uname", "-m"], stdout=subprocess.PIPE)
|
|
42
|
+
.stdout.decode("utf-8")
|
|
43
|
+
.strip()
|
|
44
|
+
== "aarch64",
|
|
45
|
+
"Test is currently known to fail when running on ARM platform",
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
running_on_cuda: Tuple[bool, str] = (
|
|
49
|
+
torch.cuda.is_available()
|
|
50
|
+
and torch.cuda.device_count() > 0
|
|
51
|
+
and torch.version.hip is not None,
|
|
52
|
+
"Test currently doesn't work on the ROCm stack",
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
running_on_rocm: Tuple[bool, str] = (
|
|
56
|
+
torch.cuda.is_available()
|
|
57
|
+
and torch.cuda.device_count() > 0
|
|
58
|
+
and torch.version.hip is not None,
|
|
59
|
+
"Test currently doesn't work on the ROCm stack",
|
|
60
|
+
)
|
mslk/testing/rocm.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
import unittest
|
|
11
|
+
from functools import wraps
|
|
12
|
+
from typing import Any, Callable
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
|
|
16
|
+
running_on_rocm: bool = (
|
|
17
|
+
torch.cuda.is_available()
|
|
18
|
+
and torch.cuda.device_count() > 0
|
|
19
|
+
and torch.version.hip is not None
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def skipIfRocm(
|
|
24
|
+
reason: str = "Test does not work on ROCm",
|
|
25
|
+
) -> Any:
|
|
26
|
+
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
|
|
27
|
+
def decorator(fn: Callable) -> Any:
|
|
28
|
+
@wraps(fn)
|
|
29
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
30
|
+
if running_on_rocm:
|
|
31
|
+
raise unittest.SkipTest(reason)
|
|
32
|
+
else:
|
|
33
|
+
fn(*args, **kwargs)
|
|
34
|
+
|
|
35
|
+
return wrapper
|
|
36
|
+
|
|
37
|
+
return decorator
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def skipIfNotRocm(
|
|
41
|
+
reason: str = "Test only works on ROCm",
|
|
42
|
+
) -> Any:
|
|
43
|
+
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
|
|
44
|
+
def decorator(fn: Callable) -> Any:
|
|
45
|
+
@wraps(fn)
|
|
46
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
47
|
+
if running_on_rocm:
|
|
48
|
+
fn(*args, **kwargs)
|
|
49
|
+
else:
|
|
50
|
+
raise unittest.SkipTest(reason)
|
|
51
|
+
|
|
52
|
+
return wrapper
|
|
53
|
+
|
|
54
|
+
return decorator
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def skipIfRocmLessThan(min_version: int) -> Any:
|
|
58
|
+
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
|
|
59
|
+
def decorator(testfn: Callable) -> Any:
|
|
60
|
+
@wraps(testfn)
|
|
61
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
62
|
+
ROCM_VERSION_FILEPATH = "/opt/rocm/.info/version"
|
|
63
|
+
if running_on_rocm:
|
|
64
|
+
# Fail if ROCm version file is missing.
|
|
65
|
+
if not os.path.isfile(ROCM_VERSION_FILEPATH):
|
|
66
|
+
raise AssertionError(
|
|
67
|
+
f"ROCm version file {ROCM_VERSION_FILEPATH} is missing!"
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# Parse the version number from the file.
|
|
71
|
+
with open(ROCM_VERSION_FILEPATH, "r") as file:
|
|
72
|
+
version = file.read().strip()
|
|
73
|
+
version = version.replace("-", "").split(".")
|
|
74
|
+
version = (
|
|
75
|
+
int(version[0]) * 10000 + int(version[1]) * 100 + int(version[2])
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# Fail if ROCm version is less than the minimum version.
|
|
79
|
+
if version < min_version:
|
|
80
|
+
raise unittest.SkipTest(
|
|
81
|
+
f"Skip the test since the ROCm version is less than {min_version}"
|
|
82
|
+
)
|
|
83
|
+
else:
|
|
84
|
+
testfn(*args, **kwargs)
|
|
85
|
+
|
|
86
|
+
else:
|
|
87
|
+
testfn(*args, **kwargs)
|
|
88
|
+
|
|
89
|
+
return wrapper
|
|
90
|
+
|
|
91
|
+
return decorator
|
mslk/utils/__init__.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
import re
|
|
12
|
+
from typing import Callable
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def load_library_buck(buck_target: str) -> None:
|
|
18
|
+
import mslk # noqa: F401
|
|
19
|
+
|
|
20
|
+
# pyre-ignore [16]
|
|
21
|
+
open_source: bool = getattr(mslk, "open_source", False)
|
|
22
|
+
|
|
23
|
+
try:
|
|
24
|
+
torch.ops.load_library(buck_target)
|
|
25
|
+
except OSError as e:
|
|
26
|
+
if open_source:
|
|
27
|
+
pass
|
|
28
|
+
else:
|
|
29
|
+
logging.error(
|
|
30
|
+
f"Failed to load buck target {buck_target}, ops will not be available via torch.ops! Error: {e}"
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class TorchLibraryFragment:
|
|
35
|
+
"""
|
|
36
|
+
A wrapper class around PyTorch library fragments, which are used to define
|
|
37
|
+
and register PyTorch operators. Handles duplicate operator definitions and
|
|
38
|
+
registrations under the hood.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(self, namespace: str) -> None:
|
|
42
|
+
"""
|
|
43
|
+
Constructs the TorchLibraryFragment class.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
namespace: The namespace for the operators.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
None
|
|
50
|
+
|
|
51
|
+
Example:
|
|
52
|
+
lib = TorchLibrary("mslk")
|
|
53
|
+
"""
|
|
54
|
+
self.namespace = namespace
|
|
55
|
+
self.lib = torch.library.Library(namespace, "FRAGMENT")
|
|
56
|
+
|
|
57
|
+
def define(self, schema: str) -> None:
|
|
58
|
+
"""
|
|
59
|
+
Defines an operator schema. This function handles the case where the
|
|
60
|
+
opeator name has already been defined.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
schema: The schema of the operator to be defined. The operator name
|
|
64
|
+
should NOT be prefixed with the operator namespace.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
None
|
|
68
|
+
|
|
69
|
+
Example:
|
|
70
|
+
lib = TorchLibrary("mslk")
|
|
71
|
+
lib.define("sll_jagged_jagged_bmm(Tensor x, Tensor y, bool flag=True) -> Tensor")
|
|
72
|
+
"""
|
|
73
|
+
pattern = re.compile(
|
|
74
|
+
r"""
|
|
75
|
+
(\w+) # Match the function name (capturing group)
|
|
76
|
+
\s*\( # Match the opening parenthesis with optional whitespace
|
|
77
|
+
([^)]*) # Match params list (capturing group)
|
|
78
|
+
\s*\) # Match the closing parenthesis with optional whitespace
|
|
79
|
+
\s*->\s*.+ # Match '-> <Return Type>'
|
|
80
|
+
""",
|
|
81
|
+
re.VERBOSE,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
match = pattern.search(schema.strip())
|
|
85
|
+
if match:
|
|
86
|
+
name = match.group(1)
|
|
87
|
+
if f"{self.namespace}::{name}" not in torch.library._defs:
|
|
88
|
+
self.lib.define(schema)
|
|
89
|
+
else:
|
|
90
|
+
raise ValueError(
|
|
91
|
+
f"PyTorch operator schema appears to be ill-defined: '''{schema}'''"
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
# pyre-ignore[24]
|
|
95
|
+
def register_dispatch(self, op_name: str, dispatch_key: str, fn: Callable) -> None:
|
|
96
|
+
"""
|
|
97
|
+
Registers a single dispatch for an operator with the given name and dispatch key.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
op_name: operator name
|
|
101
|
+
dispatch_key: dispatch key that the function should be registered for (e.g., "CUDA")
|
|
102
|
+
fn: a function that is the operator implementation for the input dispatch key
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
None
|
|
106
|
+
|
|
107
|
+
Example:
|
|
108
|
+
lib = TorchLibrary("mslk")
|
|
109
|
+
lib.define(...)
|
|
110
|
+
lib.register_dispatch(lib, "jagged_dense_bmm", jagged_dense_bmm, "CUDA")
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
valid_backends = [
|
|
114
|
+
"CUDA",
|
|
115
|
+
"AutogradCUDA",
|
|
116
|
+
"CPU",
|
|
117
|
+
"AutogradCPU",
|
|
118
|
+
"AutogradMeta",
|
|
119
|
+
"Meta",
|
|
120
|
+
"CompositeImplicitAutograd",
|
|
121
|
+
]
|
|
122
|
+
assert dispatch_key in valid_backends
|
|
123
|
+
|
|
124
|
+
if not torch._C._dispatch_has_kernel_for_dispatch_key(
|
|
125
|
+
f"{self.namespace}::{op_name}", dispatch_key
|
|
126
|
+
):
|
|
127
|
+
if dispatch_key == "Meta":
|
|
128
|
+
self.lib._register_fake(op_name, fn)
|
|
129
|
+
else:
|
|
130
|
+
self.lib.impl(op_name, fn, dispatch_key)
|
|
131
|
+
|
|
132
|
+
# pyre-ignore[24]
|
|
133
|
+
def register(self, op_name: str, functors: dict[str, Callable]) -> None:
|
|
134
|
+
"""
|
|
135
|
+
Registers a set of dispatches for a defined operator.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
op_name: operator name
|
|
139
|
+
functors: A dictionary of dispatch keys to dispatch implementations
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
None
|
|
143
|
+
|
|
144
|
+
Example:
|
|
145
|
+
lib = TorchLibrary("mslk")
|
|
146
|
+
lib.define(...)
|
|
147
|
+
lib.register(lib, "jagged_dense_bmm", {"CUDA": jagged_dense_bmm, "Meta": jagged_dense_bmm_meta })
|
|
148
|
+
"""
|
|
149
|
+
for dispatch, func in functors.items():
|
|
150
|
+
self.register_dispatch(op_name, dispatch, func)
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-unsafe
|
|
8
|
+
import functools
|
|
9
|
+
import logging
|
|
10
|
+
import os
|
|
11
|
+
from typing import Tuple
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
import triton.language as tl # @manual
|
|
15
|
+
from triton.runtime.jit import reinterpret as tl_reinterpret, TensorWrapper # @manual
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
running_on_github: bool = os.getenv("GITHUB_ENV") is not None
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@functools.lru_cache
|
|
22
|
+
def supports_float8_fnuz(throw_on_hip_incompatibility: bool = True) -> bool:
|
|
23
|
+
if torch.version.hip:
|
|
24
|
+
device_capability = torch.cuda.get_device_capability()
|
|
25
|
+
|
|
26
|
+
if device_capability < (9, 4):
|
|
27
|
+
gpu_arch = torch.cuda.get_device_properties("cuda").gcnArchName
|
|
28
|
+
msg = f"Unsupported GPU arch: {gpu_arch} for FP8"
|
|
29
|
+
if throw_on_hip_incompatibility:
|
|
30
|
+
raise RuntimeError(msg)
|
|
31
|
+
else:
|
|
32
|
+
logging.error(msg)
|
|
33
|
+
return False
|
|
34
|
+
|
|
35
|
+
elif device_capability == (9, 4):
|
|
36
|
+
return True
|
|
37
|
+
|
|
38
|
+
return False
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def get_fp8_constants() -> Tuple[torch.dtype, tl.dtype, float, float]:
|
|
42
|
+
"""
|
|
43
|
+
Helper function to get constant values for the current platform.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
pt_dtype (torch.dtype): The correct torch fp8 datatype.
|
|
47
|
+
tl_dtype (tl.dtype): The correct triton fp8 datatype.
|
|
48
|
+
max_fp8 (float): The maximum reprsentable value for the fp8 datatype.
|
|
49
|
+
eps (float): Minimum clip value to prevent divide by zero.
|
|
50
|
+
"""
|
|
51
|
+
if supports_float8_fnuz(throw_on_hip_incompatibility=(not running_on_github)):
|
|
52
|
+
pt_fp8_dtype = torch.float8_e4m3fnuz
|
|
53
|
+
tl_fp8_dtype = tl.float8e4b8
|
|
54
|
+
else:
|
|
55
|
+
pt_fp8_dtype = torch.float8_e4m3fn
|
|
56
|
+
tl_fp8_dtype = tl.float8e4nv
|
|
57
|
+
|
|
58
|
+
return pt_fp8_dtype, tl_fp8_dtype, torch.finfo(pt_fp8_dtype).max, 1e-12
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def reinterpret_fp8_type(tensor: torch.Tensor, dtype: tl.dtype) -> TensorWrapper:
|
|
62
|
+
"""
|
|
63
|
+
Converts tensor to triton fp8 type.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
tensor (torch.Tensor): input tensor.
|
|
67
|
+
dtype (tl.dtype): target triton dtype.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
triton.TensorWrapper: fp8 tensor.
|
|
71
|
+
"""
|
|
72
|
+
return tl_reinterpret(tensor, dtype=dtype)
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-unsafe
|
|
8
|
+
import sys
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
import triton # @manual
|
|
12
|
+
import triton.language as tl # @manual
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def map_dtype_to_triton(dtype: torch.dtype) -> tl.dtype:
|
|
16
|
+
"""
|
|
17
|
+
Maps torch dtype to triton dtype.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
dtype (torch.dtype): input dtype.
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
tl.dtype: triton dtype.
|
|
24
|
+
"""
|
|
25
|
+
if dtype == torch.float16:
|
|
26
|
+
return tl.float16
|
|
27
|
+
elif dtype == torch.bfloat16:
|
|
28
|
+
return tl.bfloat16
|
|
29
|
+
elif dtype == torch.float32:
|
|
30
|
+
return tl.float32
|
|
31
|
+
elif dtype == torch.int32:
|
|
32
|
+
return tl.int32
|
|
33
|
+
elif dtype == torch.float8_e4m3fn and torch.version.hip is None:
|
|
34
|
+
return tl.float8e4nv
|
|
35
|
+
else:
|
|
36
|
+
raise ValueError(f"Unsupported dtype {dtype}")
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
# check if we have the TMA version in Triton PR #4498 (https://github.com/triton-lang/triton/pull/4498).
|
|
40
|
+
HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl)
|
|
41
|
+
|
|
42
|
+
if HAS_TMA_DESC:
|
|
43
|
+
print(
|
|
44
|
+
"TMA benchmarks will be running with experimental grid constant TMA descriptor.",
|
|
45
|
+
file=sys.stderr,
|
|
46
|
+
)
|
|
47
|
+
else:
|
|
48
|
+
print(
|
|
49
|
+
"TMA benchmarks will be running without grid constant TMA descriptor.",
|
|
50
|
+
file=sys.stderr,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class TmaAutoTuneHelper:
|
|
55
|
+
# duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498
|
|
56
|
+
class KernelParamWrapper:
|
|
57
|
+
def __init__(self, desc):
|
|
58
|
+
self.desc = desc
|
|
59
|
+
|
|
60
|
+
def tma_desc_cpu_ptr(self):
|
|
61
|
+
return self.desc.data_ptr()
|
|
62
|
+
|
|
63
|
+
TMA_SIZE = 128
|
|
64
|
+
|
|
65
|
+
def __init__(self):
|
|
66
|
+
self.fill_1d_tma_descriptor_inner = (
|
|
67
|
+
triton.runtime.driver.active.utils.fill_1d_tma_descriptor
|
|
68
|
+
)
|
|
69
|
+
self.fill_2d_tma_descriptor_inner = (
|
|
70
|
+
triton.runtime.driver.active.utils.fill_2d_tma_descriptor
|
|
71
|
+
)
|
|
72
|
+
if HAS_TMA_DESC:
|
|
73
|
+
self.descriptors = {}
|
|
74
|
+
else:
|
|
75
|
+
self.cuda_descriptors = {}
|
|
76
|
+
|
|
77
|
+
# Call this method outside of the lambda function for grid size
|
|
78
|
+
def init_tma_descriptor(self, name):
|
|
79
|
+
if HAS_TMA_DESC:
|
|
80
|
+
self.descriptors[name] = torch.empty(
|
|
81
|
+
TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8
|
|
82
|
+
)
|
|
83
|
+
else:
|
|
84
|
+
self.cuda_descriptors[name] = torch.empty(
|
|
85
|
+
TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
# Call this method inside the lambda function for grid size
|
|
89
|
+
def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size):
|
|
90
|
+
if HAS_TMA_DESC:
|
|
91
|
+
desc_x = self.descriptors[name]
|
|
92
|
+
assert desc_x.data_ptr() % 64 == 0
|
|
93
|
+
self.fill_1d_tma_descriptor_inner(
|
|
94
|
+
ptr, dim, block_dim, element_size, desc_x.data_ptr()
|
|
95
|
+
)
|
|
96
|
+
else:
|
|
97
|
+
desc_x = self.cuda_descriptors[name]
|
|
98
|
+
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
|
|
99
|
+
self.fill_1d_tma_descriptor_inner(
|
|
100
|
+
ptr, dim, block_dim, element_size, buf_x.data_ptr()
|
|
101
|
+
)
|
|
102
|
+
desc_x.copy_(buf_x, non_blocking=True)
|
|
103
|
+
|
|
104
|
+
# Call this method inside the lambda function for grid size
|
|
105
|
+
def fill_2d_tma_descriptor(
|
|
106
|
+
self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size
|
|
107
|
+
):
|
|
108
|
+
if HAS_TMA_DESC:
|
|
109
|
+
desc_x = self.descriptors[name]
|
|
110
|
+
assert desc_x.data_ptr() % 64 == 0
|
|
111
|
+
self.fill_2d_tma_descriptor_inner(
|
|
112
|
+
ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
|
|
113
|
+
)
|
|
114
|
+
else:
|
|
115
|
+
desc_x = self.cuda_descriptors[name]
|
|
116
|
+
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
|
|
117
|
+
self.fill_2d_tma_descriptor_inner(
|
|
118
|
+
ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr()
|
|
119
|
+
)
|
|
120
|
+
desc_x.copy_(buf_x, non_blocking=True)
|
|
121
|
+
|
|
122
|
+
def get_tma_descriptor_kernel_param(self, name):
|
|
123
|
+
if HAS_TMA_DESC:
|
|
124
|
+
assert self.descriptors[name] is not None
|
|
125
|
+
return self.KernelParamWrapper(self.descriptors[name])
|
|
126
|
+
else:
|
|
127
|
+
assert self.cuda_descriptors[name] is not None
|
|
128
|
+
return self.cuda_descriptors[name]
|
mslk/version.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
|
|
2
|
+
#!/usr/bin/env python3
|
|
3
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# This source code is licensed under the BSD-style license found in the
|
|
7
|
+
# LICENSE file in the root directory of this source tree.
|
|
8
|
+
|
|
9
|
+
__version__: str = "2026.1.19"
|
|
10
|
+
__target__: str = "default"
|
|
11
|
+
__variant__: str = "cuda"
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: mslk-cuda-nightly
|
|
3
|
+
Version: 2026.1.19
|
|
4
|
+
Home-page: https://github.com/pytorch/MSLK
|
|
5
|
+
Author: MSLK Team
|
|
6
|
+
Author-email: packages@pytorch.org
|
|
7
|
+
License: BSD-3
|
|
8
|
+
Keywords: PyTorch,Generative AI,High Performance Computing,GPU,CUDA,ROCm
|
|
9
|
+
Classifier: Development Status :: 4 - Beta
|
|
10
|
+
Classifier: Intended Audience :: Developers
|
|
11
|
+
Classifier: Intended Audience :: Science/Research
|
|
12
|
+
Classifier: License :: OSI Approved :: BSD License
|
|
13
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
14
|
+
Classifier: Programming Language :: Python :: 3
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
20
|
+
Description-Content-Type: text/markdown
|
|
21
|
+
License-File: LICENSE
|
|
22
|
+
Requires-Dist: numpy
|
|
23
|
+
Dynamic: author
|
|
24
|
+
Dynamic: author-email
|
|
25
|
+
Dynamic: classifier
|
|
26
|
+
Dynamic: description
|
|
27
|
+
Dynamic: description-content-type
|
|
28
|
+
Dynamic: home-page
|
|
29
|
+
Dynamic: keywords
|
|
30
|
+
Dynamic: license
|
|
31
|
+
Dynamic: license-file
|
|
32
|
+
Dynamic: requires-dist
|
|
33
|
+
|
|
34
|
+
# MSLK Library
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
MSLK (Meta Superintelligence Labs Kernels, formerly known as **[FBGEMM GenAI](https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu/experimental/gen_ai)**)
|
|
38
|
+
is a collection of high-performance kernels and optimizations built on top of PyTorch
|
|
39
|
+
primitives for GenAI training and inference.
|
|
40
|
+
|
|
41
|
+
## **Installation**
|
|
42
|
+
|
|
43
|
+
```bash
|
|
44
|
+
# Install MSLK for CUDA
|
|
45
|
+
pip install mslk-cuda==1.0.0
|
|
46
|
+
# Install MSLK for ROCm
|
|
47
|
+
pip install mslk-rocm==1.0.0
|
|
48
|
+
# Install a nightly version
|
|
49
|
+
pip3 install --pre mslk --index-url https://download.pytorch.org/whl/nightly/cu128
|
|
50
|
+
```
|
|
51
|
+
|
|
52
|
+
## Release Compatibility Table
|
|
53
|
+
|
|
54
|
+
MSLK is released in accordance to the PyTorch release schedule, and each
|
|
55
|
+
release has no guarantee to work in conjunction with PyTorch releases that are
|
|
56
|
+
older than the one that the MSLK release corresponds to.
|
|
57
|
+
|
|
58
|
+
| MSLK Release | Corresponding PyTorch Release | Supported Python Versions | Supported CUDA Versions | Supported CUDA Architectures | Supported ROCm Versions | Supported ROCm Architectures |
|
|
59
|
+
|---------|---------|---------|---------|----------|-------------|-------------|
|
|
60
|
+
| 1.0.0 | 2.10.x | 3.10, 3.11, 3.12, 3.13, 3.14 | 12.6, 12.8, 12.9, 13.0 | 8.0, 9.0a, 10.0a, 12.0a | 7.0, 7.1 | gfx908, gfx90a, gfx942, gfx950 |
|
|
61
|
+
|
|
62
|
+
## **Running Benchmarks**
|
|
63
|
+
```bash
|
|
64
|
+
python bench/gemm/gemm_bench.py
|
|
65
|
+
python bench/quantize/quantize_bench.py
|
|
66
|
+
```
|
|
67
|
+
|
|
68
|
+
## **Running Tests**
|
|
69
|
+
```bash
|
|
70
|
+
python test/gemm/gemm_test.py
|
|
71
|
+
python test/gemm/quantize_test.py
|
|
72
|
+
```
|
|
73
|
+
|
|
74
|
+
## **Build From Source**
|
|
75
|
+
We only support building on Linux. See the release compatibility table above for supported versions of Python, CUDA, ROCm.
|
|
76
|
+
```bash
|
|
77
|
+
# Clone repo
|
|
78
|
+
git clone https://github.com/meta-pytorch/MSLK
|
|
79
|
+
cd MSLK
|
|
80
|
+
git submodule sync
|
|
81
|
+
git submodule update --init --recursive
|
|
82
|
+
# Build and install
|
|
83
|
+
# The script will create a conda environment and install the required dependencies.
|
|
84
|
+
# The conda environment will look something like: build-py3.14-torchnightly-cuda12.9.1
|
|
85
|
+
./ci/integration/mslk_oss_build.bash
|
|
86
|
+
# After the initial environment setup, you can activate the environment and iterate faster:
|
|
87
|
+
conda activate build-py3.14-torchnightly-cuda12.9.1
|
|
88
|
+
python setup.py install
|
|
89
|
+
```
|
|
90
|
+
|
|
91
|
+
## Join the MSLK community
|
|
92
|
+
|
|
93
|
+
For questions, support, news updates, or feature requests, please feel free to:
|
|
94
|
+
|
|
95
|
+
* File a ticket in [GitHub Issues](https://github.com/meta-pytorch/MSLK/issues)
|
|
96
|
+
|
|
97
|
+
For contributions, please see the [`CONTRIBUTING`](./CONTRIBUTING.md) file for
|
|
98
|
+
ways to help out.
|
|
99
|
+
|
|
100
|
+
## License
|
|
101
|
+
|
|
102
|
+
MSLK is BSD licensed, as found in the [`LICENSE`](./LICENSE) file.
|