mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,7 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
@@ -0,0 +1,60 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ import os
10
+ import subprocess
11
+ from typing import Tuple
12
+
13
+ import mslk
14
+ import torch
15
+
16
+ ################################################################################
17
+ # Unit test skip attributes for environments
18
+ ################################################################################
19
+
20
+ gpu_unavailable: Tuple[bool, str] = (
21
+ not torch.cuda.is_available() or torch.cuda.device_count() == 0,
22
+ "GPU is not available or no GPUs detected",
23
+ )
24
+
25
+ running_in_github: Tuple[bool, str] = (
26
+ os.getenv("GITHUB_ENV") is not None,
27
+ "Test fails or hangs when run in the GitHub runners",
28
+ )
29
+
30
+ running_in_oss: Tuple[bool, str] = (
31
+ # pyre-ignore [16]
32
+ getattr(mslk, "open_source", False),
33
+ "Test is currently known to fail in OSS mode",
34
+ )
35
+
36
+ ################################################################################
37
+ # Unit test skip attributes for platforms
38
+ ################################################################################
39
+
40
+ running_on_arm: Tuple[bool, str] = (
41
+ subprocess.run(["uname", "-m"], stdout=subprocess.PIPE)
42
+ .stdout.decode("utf-8")
43
+ .strip()
44
+ == "aarch64",
45
+ "Test is currently known to fail when running on ARM platform",
46
+ )
47
+
48
+ running_on_cuda: Tuple[bool, str] = (
49
+ torch.cuda.is_available()
50
+ and torch.cuda.device_count() > 0
51
+ and torch.version.hip is not None,
52
+ "Test currently doesn't work on the ROCm stack",
53
+ )
54
+
55
+ running_on_rocm: Tuple[bool, str] = (
56
+ torch.cuda.is_available()
57
+ and torch.cuda.device_count() > 0
58
+ and torch.version.hip is not None,
59
+ "Test currently doesn't work on the ROCm stack",
60
+ )
mslk/testing/rocm.py ADDED
@@ -0,0 +1,91 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ import os
10
+ import unittest
11
+ from functools import wraps
12
+ from typing import Any, Callable
13
+
14
+ import torch
15
+
16
+ running_on_rocm: bool = (
17
+ torch.cuda.is_available()
18
+ and torch.cuda.device_count() > 0
19
+ and torch.version.hip is not None
20
+ )
21
+
22
+
23
+ def skipIfRocm(
24
+ reason: str = "Test does not work on ROCm",
25
+ ) -> Any:
26
+ # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
27
+ def decorator(fn: Callable) -> Any:
28
+ @wraps(fn)
29
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
30
+ if running_on_rocm:
31
+ raise unittest.SkipTest(reason)
32
+ else:
33
+ fn(*args, **kwargs)
34
+
35
+ return wrapper
36
+
37
+ return decorator
38
+
39
+
40
+ def skipIfNotRocm(
41
+ reason: str = "Test only works on ROCm",
42
+ ) -> Any:
43
+ # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
44
+ def decorator(fn: Callable) -> Any:
45
+ @wraps(fn)
46
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
47
+ if running_on_rocm:
48
+ fn(*args, **kwargs)
49
+ else:
50
+ raise unittest.SkipTest(reason)
51
+
52
+ return wrapper
53
+
54
+ return decorator
55
+
56
+
57
+ def skipIfRocmLessThan(min_version: int) -> Any:
58
+ # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
59
+ def decorator(testfn: Callable) -> Any:
60
+ @wraps(testfn)
61
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
62
+ ROCM_VERSION_FILEPATH = "/opt/rocm/.info/version"
63
+ if running_on_rocm:
64
+ # Fail if ROCm version file is missing.
65
+ if not os.path.isfile(ROCM_VERSION_FILEPATH):
66
+ raise AssertionError(
67
+ f"ROCm version file {ROCM_VERSION_FILEPATH} is missing!"
68
+ )
69
+
70
+ # Parse the version number from the file.
71
+ with open(ROCM_VERSION_FILEPATH, "r") as file:
72
+ version = file.read().strip()
73
+ version = version.replace("-", "").split(".")
74
+ version = (
75
+ int(version[0]) * 10000 + int(version[1]) * 100 + int(version[2])
76
+ )
77
+
78
+ # Fail if ROCm version is less than the minimum version.
79
+ if version < min_version:
80
+ raise unittest.SkipTest(
81
+ f"Skip the test since the ROCm version is less than {min_version}"
82
+ )
83
+ else:
84
+ testfn(*args, **kwargs)
85
+
86
+ else:
87
+ testfn(*args, **kwargs)
88
+
89
+ return wrapper
90
+
91
+ return decorator
mslk/utils/__init__.py ADDED
@@ -0,0 +1,7 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
@@ -0,0 +1,7 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
@@ -0,0 +1,150 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+
10
+ import logging
11
+ import re
12
+ from typing import Callable
13
+
14
+ import torch
15
+
16
+
17
+ def load_library_buck(buck_target: str) -> None:
18
+ import mslk # noqa: F401
19
+
20
+ # pyre-ignore [16]
21
+ open_source: bool = getattr(mslk, "open_source", False)
22
+
23
+ try:
24
+ torch.ops.load_library(buck_target)
25
+ except OSError as e:
26
+ if open_source:
27
+ pass
28
+ else:
29
+ logging.error(
30
+ f"Failed to load buck target {buck_target}, ops will not be available via torch.ops! Error: {e}"
31
+ )
32
+
33
+
34
+ class TorchLibraryFragment:
35
+ """
36
+ A wrapper class around PyTorch library fragments, which are used to define
37
+ and register PyTorch operators. Handles duplicate operator definitions and
38
+ registrations under the hood.
39
+ """
40
+
41
+ def __init__(self, namespace: str) -> None:
42
+ """
43
+ Constructs the TorchLibraryFragment class.
44
+
45
+ Args:
46
+ namespace: The namespace for the operators.
47
+
48
+ Returns:
49
+ None
50
+
51
+ Example:
52
+ lib = TorchLibrary("mslk")
53
+ """
54
+ self.namespace = namespace
55
+ self.lib = torch.library.Library(namespace, "FRAGMENT")
56
+
57
+ def define(self, schema: str) -> None:
58
+ """
59
+ Defines an operator schema. This function handles the case where the
60
+ opeator name has already been defined.
61
+
62
+ Args:
63
+ schema: The schema of the operator to be defined. The operator name
64
+ should NOT be prefixed with the operator namespace.
65
+
66
+ Returns:
67
+ None
68
+
69
+ Example:
70
+ lib = TorchLibrary("mslk")
71
+ lib.define("sll_jagged_jagged_bmm(Tensor x, Tensor y, bool flag=True) -> Tensor")
72
+ """
73
+ pattern = re.compile(
74
+ r"""
75
+ (\w+) # Match the function name (capturing group)
76
+ \s*\( # Match the opening parenthesis with optional whitespace
77
+ ([^)]*) # Match params list (capturing group)
78
+ \s*\) # Match the closing parenthesis with optional whitespace
79
+ \s*->\s*.+ # Match '-> <Return Type>'
80
+ """,
81
+ re.VERBOSE,
82
+ )
83
+
84
+ match = pattern.search(schema.strip())
85
+ if match:
86
+ name = match.group(1)
87
+ if f"{self.namespace}::{name}" not in torch.library._defs:
88
+ self.lib.define(schema)
89
+ else:
90
+ raise ValueError(
91
+ f"PyTorch operator schema appears to be ill-defined: '''{schema}'''"
92
+ )
93
+
94
+ # pyre-ignore[24]
95
+ def register_dispatch(self, op_name: str, dispatch_key: str, fn: Callable) -> None:
96
+ """
97
+ Registers a single dispatch for an operator with the given name and dispatch key.
98
+
99
+ Args:
100
+ op_name: operator name
101
+ dispatch_key: dispatch key that the function should be registered for (e.g., "CUDA")
102
+ fn: a function that is the operator implementation for the input dispatch key
103
+
104
+ Returns:
105
+ None
106
+
107
+ Example:
108
+ lib = TorchLibrary("mslk")
109
+ lib.define(...)
110
+ lib.register_dispatch(lib, "jagged_dense_bmm", jagged_dense_bmm, "CUDA")
111
+ """
112
+
113
+ valid_backends = [
114
+ "CUDA",
115
+ "AutogradCUDA",
116
+ "CPU",
117
+ "AutogradCPU",
118
+ "AutogradMeta",
119
+ "Meta",
120
+ "CompositeImplicitAutograd",
121
+ ]
122
+ assert dispatch_key in valid_backends
123
+
124
+ if not torch._C._dispatch_has_kernel_for_dispatch_key(
125
+ f"{self.namespace}::{op_name}", dispatch_key
126
+ ):
127
+ if dispatch_key == "Meta":
128
+ self.lib._register_fake(op_name, fn)
129
+ else:
130
+ self.lib.impl(op_name, fn, dispatch_key)
131
+
132
+ # pyre-ignore[24]
133
+ def register(self, op_name: str, functors: dict[str, Callable]) -> None:
134
+ """
135
+ Registers a set of dispatches for a defined operator.
136
+
137
+ Args:
138
+ op_name: operator name
139
+ functors: A dictionary of dispatch keys to dispatch implementations
140
+
141
+ Returns:
142
+ None
143
+
144
+ Example:
145
+ lib = TorchLibrary("mslk")
146
+ lib.define(...)
147
+ lib.register(lib, "jagged_dense_bmm", {"CUDA": jagged_dense_bmm, "Meta": jagged_dense_bmm_meta })
148
+ """
149
+ for dispatch, func in functors.items():
150
+ self.register_dispatch(op_name, dispatch, func)
@@ -0,0 +1,7 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
@@ -0,0 +1,72 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import functools
9
+ import logging
10
+ import os
11
+ from typing import Tuple
12
+
13
+ import torch
14
+ import triton.language as tl # @manual
15
+ from triton.runtime.jit import reinterpret as tl_reinterpret, TensorWrapper # @manual
16
+
17
+
18
+ running_on_github: bool = os.getenv("GITHUB_ENV") is not None
19
+
20
+
21
+ @functools.lru_cache
22
+ def supports_float8_fnuz(throw_on_hip_incompatibility: bool = True) -> bool:
23
+ if torch.version.hip:
24
+ device_capability = torch.cuda.get_device_capability()
25
+
26
+ if device_capability < (9, 4):
27
+ gpu_arch = torch.cuda.get_device_properties("cuda").gcnArchName
28
+ msg = f"Unsupported GPU arch: {gpu_arch} for FP8"
29
+ if throw_on_hip_incompatibility:
30
+ raise RuntimeError(msg)
31
+ else:
32
+ logging.error(msg)
33
+ return False
34
+
35
+ elif device_capability == (9, 4):
36
+ return True
37
+
38
+ return False
39
+
40
+
41
+ def get_fp8_constants() -> Tuple[torch.dtype, tl.dtype, float, float]:
42
+ """
43
+ Helper function to get constant values for the current platform.
44
+
45
+ Returns:
46
+ pt_dtype (torch.dtype): The correct torch fp8 datatype.
47
+ tl_dtype (tl.dtype): The correct triton fp8 datatype.
48
+ max_fp8 (float): The maximum reprsentable value for the fp8 datatype.
49
+ eps (float): Minimum clip value to prevent divide by zero.
50
+ """
51
+ if supports_float8_fnuz(throw_on_hip_incompatibility=(not running_on_github)):
52
+ pt_fp8_dtype = torch.float8_e4m3fnuz
53
+ tl_fp8_dtype = tl.float8e4b8
54
+ else:
55
+ pt_fp8_dtype = torch.float8_e4m3fn
56
+ tl_fp8_dtype = tl.float8e4nv
57
+
58
+ return pt_fp8_dtype, tl_fp8_dtype, torch.finfo(pt_fp8_dtype).max, 1e-12
59
+
60
+
61
+ def reinterpret_fp8_type(tensor: torch.Tensor, dtype: tl.dtype) -> TensorWrapper:
62
+ """
63
+ Converts tensor to triton fp8 type.
64
+
65
+ Args:
66
+ tensor (torch.Tensor): input tensor.
67
+ dtype (tl.dtype): target triton dtype.
68
+
69
+ Returns:
70
+ triton.TensorWrapper: fp8 tensor.
71
+ """
72
+ return tl_reinterpret(tensor, dtype=dtype)
@@ -0,0 +1,128 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import sys
9
+
10
+ import torch
11
+ import triton # @manual
12
+ import triton.language as tl # @manual
13
+
14
+
15
+ def map_dtype_to_triton(dtype: torch.dtype) -> tl.dtype:
16
+ """
17
+ Maps torch dtype to triton dtype.
18
+
19
+ Args:
20
+ dtype (torch.dtype): input dtype.
21
+
22
+ Returns:
23
+ tl.dtype: triton dtype.
24
+ """
25
+ if dtype == torch.float16:
26
+ return tl.float16
27
+ elif dtype == torch.bfloat16:
28
+ return tl.bfloat16
29
+ elif dtype == torch.float32:
30
+ return tl.float32
31
+ elif dtype == torch.int32:
32
+ return tl.int32
33
+ elif dtype == torch.float8_e4m3fn and torch.version.hip is None:
34
+ return tl.float8e4nv
35
+ else:
36
+ raise ValueError(f"Unsupported dtype {dtype}")
37
+
38
+
39
+ # check if we have the TMA version in Triton PR #4498 (https://github.com/triton-lang/triton/pull/4498).
40
+ HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl)
41
+
42
+ if HAS_TMA_DESC:
43
+ print(
44
+ "TMA benchmarks will be running with experimental grid constant TMA descriptor.",
45
+ file=sys.stderr,
46
+ )
47
+ else:
48
+ print(
49
+ "TMA benchmarks will be running without grid constant TMA descriptor.",
50
+ file=sys.stderr,
51
+ )
52
+
53
+
54
+ class TmaAutoTuneHelper:
55
+ # duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498
56
+ class KernelParamWrapper:
57
+ def __init__(self, desc):
58
+ self.desc = desc
59
+
60
+ def tma_desc_cpu_ptr(self):
61
+ return self.desc.data_ptr()
62
+
63
+ TMA_SIZE = 128
64
+
65
+ def __init__(self):
66
+ self.fill_1d_tma_descriptor_inner = (
67
+ triton.runtime.driver.active.utils.fill_1d_tma_descriptor
68
+ )
69
+ self.fill_2d_tma_descriptor_inner = (
70
+ triton.runtime.driver.active.utils.fill_2d_tma_descriptor
71
+ )
72
+ if HAS_TMA_DESC:
73
+ self.descriptors = {}
74
+ else:
75
+ self.cuda_descriptors = {}
76
+
77
+ # Call this method outside of the lambda function for grid size
78
+ def init_tma_descriptor(self, name):
79
+ if HAS_TMA_DESC:
80
+ self.descriptors[name] = torch.empty(
81
+ TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8
82
+ )
83
+ else:
84
+ self.cuda_descriptors[name] = torch.empty(
85
+ TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8
86
+ )
87
+
88
+ # Call this method inside the lambda function for grid size
89
+ def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size):
90
+ if HAS_TMA_DESC:
91
+ desc_x = self.descriptors[name]
92
+ assert desc_x.data_ptr() % 64 == 0
93
+ self.fill_1d_tma_descriptor_inner(
94
+ ptr, dim, block_dim, element_size, desc_x.data_ptr()
95
+ )
96
+ else:
97
+ desc_x = self.cuda_descriptors[name]
98
+ buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
99
+ self.fill_1d_tma_descriptor_inner(
100
+ ptr, dim, block_dim, element_size, buf_x.data_ptr()
101
+ )
102
+ desc_x.copy_(buf_x, non_blocking=True)
103
+
104
+ # Call this method inside the lambda function for grid size
105
+ def fill_2d_tma_descriptor(
106
+ self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size
107
+ ):
108
+ if HAS_TMA_DESC:
109
+ desc_x = self.descriptors[name]
110
+ assert desc_x.data_ptr() % 64 == 0
111
+ self.fill_2d_tma_descriptor_inner(
112
+ ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
113
+ )
114
+ else:
115
+ desc_x = self.cuda_descriptors[name]
116
+ buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
117
+ self.fill_2d_tma_descriptor_inner(
118
+ ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr()
119
+ )
120
+ desc_x.copy_(buf_x, non_blocking=True)
121
+
122
+ def get_tma_descriptor_kernel_param(self, name):
123
+ if HAS_TMA_DESC:
124
+ assert self.descriptors[name] is not None
125
+ return self.KernelParamWrapper(self.descriptors[name])
126
+ else:
127
+ assert self.cuda_descriptors[name] is not None
128
+ return self.cuda_descriptors[name]
mslk/version.py ADDED
@@ -0,0 +1,11 @@
1
+
2
+ #!/usr/bin/env python3
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
4
+ # All rights reserved.
5
+ #
6
+ # This source code is licensed under the BSD-style license found in the
7
+ # LICENSE file in the root directory of this source tree.
8
+
9
+ __version__: str = "2026.1.19"
10
+ __target__: str = "default"
11
+ __variant__: str = "cuda"
@@ -0,0 +1,102 @@
1
+ Metadata-Version: 2.4
2
+ Name: mslk-cuda-nightly
3
+ Version: 2026.1.19
4
+ Home-page: https://github.com/pytorch/MSLK
5
+ Author: MSLK Team
6
+ Author-email: packages@pytorch.org
7
+ License: BSD-3
8
+ Keywords: PyTorch,Generative AI,High Performance Computing,GPU,CUDA,ROCm
9
+ Classifier: Development Status :: 4 - Beta
10
+ Classifier: Intended Audience :: Developers
11
+ Classifier: Intended Audience :: Science/Research
12
+ Classifier: License :: OSI Approved :: BSD License
13
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
14
+ Classifier: Programming Language :: Python :: 3
15
+ Classifier: Programming Language :: Python :: 3.9
16
+ Classifier: Programming Language :: Python :: 3.10
17
+ Classifier: Programming Language :: Python :: 3.11
18
+ Classifier: Programming Language :: Python :: 3.12
19
+ Classifier: Programming Language :: Python :: 3.13
20
+ Description-Content-Type: text/markdown
21
+ License-File: LICENSE
22
+ Requires-Dist: numpy
23
+ Dynamic: author
24
+ Dynamic: author-email
25
+ Dynamic: classifier
26
+ Dynamic: description
27
+ Dynamic: description-content-type
28
+ Dynamic: home-page
29
+ Dynamic: keywords
30
+ Dynamic: license
31
+ Dynamic: license-file
32
+ Dynamic: requires-dist
33
+
34
+ # MSLK Library
35
+
36
+
37
+ MSLK (Meta Superintelligence Labs Kernels, formerly known as **[FBGEMM GenAI](https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu/experimental/gen_ai)**)
38
+ is a collection of high-performance kernels and optimizations built on top of PyTorch
39
+ primitives for GenAI training and inference.
40
+
41
+ ## **Installation**
42
+
43
+ ```bash
44
+ # Install MSLK for CUDA
45
+ pip install mslk-cuda==1.0.0
46
+ # Install MSLK for ROCm
47
+ pip install mslk-rocm==1.0.0
48
+ # Install a nightly version
49
+ pip3 install --pre mslk --index-url https://download.pytorch.org/whl/nightly/cu128
50
+ ```
51
+
52
+ ## Release Compatibility Table
53
+
54
+ MSLK is released in accordance to the PyTorch release schedule, and each
55
+ release has no guarantee to work in conjunction with PyTorch releases that are
56
+ older than the one that the MSLK release corresponds to.
57
+
58
+ | MSLK Release | Corresponding PyTorch Release | Supported Python Versions | Supported CUDA Versions | Supported CUDA Architectures | Supported ROCm Versions | Supported ROCm Architectures |
59
+ |---------|---------|---------|---------|----------|-------------|-------------|
60
+ | 1.0.0 | 2.10.x | 3.10, 3.11, 3.12, 3.13, 3.14 | 12.6, 12.8, 12.9, 13.0 | 8.0, 9.0a, 10.0a, 12.0a | 7.0, 7.1 | gfx908, gfx90a, gfx942, gfx950 |
61
+
62
+ ## **Running Benchmarks**
63
+ ```bash
64
+ python bench/gemm/gemm_bench.py
65
+ python bench/quantize/quantize_bench.py
66
+ ```
67
+
68
+ ## **Running Tests**
69
+ ```bash
70
+ python test/gemm/gemm_test.py
71
+ python test/gemm/quantize_test.py
72
+ ```
73
+
74
+ ## **Build From Source**
75
+ We only support building on Linux. See the release compatibility table above for supported versions of Python, CUDA, ROCm.
76
+ ```bash
77
+ # Clone repo
78
+ git clone https://github.com/meta-pytorch/MSLK
79
+ cd MSLK
80
+ git submodule sync
81
+ git submodule update --init --recursive
82
+ # Build and install
83
+ # The script will create a conda environment and install the required dependencies.
84
+ # The conda environment will look something like: build-py3.14-torchnightly-cuda12.9.1
85
+ ./ci/integration/mslk_oss_build.bash
86
+ # After the initial environment setup, you can activate the environment and iterate faster:
87
+ conda activate build-py3.14-torchnightly-cuda12.9.1
88
+ python setup.py install
89
+ ```
90
+
91
+ ## Join the MSLK community
92
+
93
+ For questions, support, news updates, or feature requests, please feel free to:
94
+
95
+ * File a ticket in [GitHub Issues](https://github.com/meta-pytorch/MSLK/issues)
96
+
97
+ For contributions, please see the [`CONTRIBUTING`](./CONTRIBUTING.md) file for
98
+ ways to help out.
99
+
100
+ ## License
101
+
102
+ MSLK is BSD licensed, as found in the [`LICENSE`](./LICENSE) file.