mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,130 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # pyre-strict
7
+ from typing import List, Optional, Sequence, Tuple, Union
8
+
9
+ import torch
10
+
11
+ from .utils.op_common import _get_storage_base
12
+
13
+
14
+ def get_stack_strides(
15
+ tensors: Sequence[torch.Tensor], dim: int
16
+ ) -> Optional[Tuple[Union[int, torch.SymInt], ...]]:
17
+ """
18
+ If the tensors are already stacked on dimension :code:`dim`, \
19
+ returns the strides of the stacked tensors. \
20
+ Otherwise returns :code:`None`.
21
+ """
22
+ if len(tensors) <= 1 or dim > tensors[0].ndim:
23
+ return None
24
+
25
+ final_stride = []
26
+ for i in range(tensors[0].ndim + 1):
27
+ if i == dim:
28
+ # PyTorch 2.5 messed up the type annotations for SymInt, but 2.6 will fix it
29
+ # https://github.com/pytorch/pytorch/issues/138478
30
+ final_stride.append(
31
+ tensors[1].storage_offset() - tensors[0].storage_offset() # type: ignore[operator]
32
+ )
33
+ continue
34
+ if i > dim:
35
+ i -= 1
36
+ final_stride.append(tensors[0].stride(i))
37
+
38
+ storage_data_ptr: Optional[int] = None
39
+ for i, x in enumerate(tensors[1:]):
40
+ # Sanity checks
41
+ if x.shape != tensors[0].shape:
42
+ return None
43
+ if x.stride() != tensors[0].stride():
44
+ return None
45
+ # PyTorch 2.5 messed up the type annotations for SymInt, but 2.6 will fix it
46
+ # https://github.com/pytorch/pytorch/issues/138478
47
+ if (
48
+ x.storage_offset()
49
+ != tensors[0].storage_offset() + (i + 1) * final_stride[dim] # type: ignore[operator]
50
+ ):
51
+ return None
52
+ if storage_data_ptr is None:
53
+ storage_data_ptr = _get_storage_base(tensors[0])
54
+ # Actual storage check
55
+ if _get_storage_base(x) != storage_data_ptr:
56
+ return None
57
+ return tuple(final_stride)
58
+
59
+
60
+ def _stack_or_none_fw(
61
+ tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]],
62
+ dim: int,
63
+ ) -> Optional[torch.Tensor]:
64
+ strides = get_stack_strides(tensors, dim)
65
+ if strides is not None:
66
+ input_shape = list(tensors[0].shape)
67
+ input_shape.insert(dim, len(tensors))
68
+ return tensors[0].as_strided(input_shape, strides)
69
+ return None
70
+
71
+
72
+ def _stack_fw(
73
+ tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]],
74
+ dim: int,
75
+ ) -> torch.Tensor:
76
+ out = _stack_or_none_fw(tensors, dim)
77
+ if out is None:
78
+ out = torch.stack(tensors, dim=dim)
79
+ return out
80
+
81
+
82
+ class _Unbind(torch.autograd.Function):
83
+ """
84
+ See function `unbind`
85
+ """
86
+
87
+ @staticmethod
88
+ # type: ignore
89
+ def forward(ctx, x: torch.Tensor, dim: int):
90
+ ctx.dim = dim
91
+ return x.unbind(dim)
92
+
93
+ @classmethod
94
+ # type: ignore
95
+ def backward(cls, ctx, *tensors: torch.Tensor):
96
+ return _stack_fw(tensors, ctx.dim), None
97
+
98
+
99
+ class _StackOrNone(torch.autograd.Function):
100
+ """
101
+ See function `stack_or_none`
102
+ """
103
+
104
+ @staticmethod
105
+ # type: ignore
106
+ def forward(ctx, dim: int, *tensors: torch.Tensor):
107
+ ctx.dim = dim
108
+ return _stack_or_none_fw(tensors, dim=dim)
109
+
110
+ @classmethod
111
+ # type: ignore
112
+ def backward(cls, ctx, grad: torch.Tensor):
113
+ return (None, *grad.unbind(dim=ctx.dim))
114
+
115
+
116
+ def unbind(x: torch.Tensor, dim: int) -> Tuple[torch.Tensor, ...]:
117
+ """
118
+ Does exactly the same as :attr:`torch.unbind` for the forward.
119
+ In backward, avoids a :attr:`torch.cat` if the gradients
120
+ are already multiple views of the same storage
121
+ """
122
+ return _Unbind.apply(x, dim)
123
+
124
+
125
+ def stack_or_none(tensors: Sequence[torch.Tensor], dim: int) -> torch.Tensor:
126
+ """
127
+ Does exactly the same as :attr:`torch.stack` if the tensors can be concatenated
128
+ without any memory operation. Otherwise returns None.
129
+ """
130
+ return _StackOrNone.apply(dim, *tensors)
@@ -0,0 +1,6 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # pyre-unsafe
@@ -0,0 +1,74 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # pyre-unsafe
7
+
8
+ from typing import Callable, List, Optional
9
+
10
+ import torch
11
+
12
+
13
+ # from https://github.com/openai/triton/blob/95d9b7f4ae21710dc899e1de6a579b2136ea4f3d/python/triton/testing.py#L19
14
+ def do_bench_cudagraph(
15
+ fn: Callable, rep: int = 20, grad_to_none: Optional[List[torch.Tensor]] = None
16
+ ) -> float:
17
+ """
18
+ Benchmark the runtime of the provided function.
19
+ Args:
20
+ fn: Function to benchmark
21
+ rep: Repetition time (in ms)
22
+ grad_to_none: Reset the gradient of the provided tensor to None
23
+ Returns:
24
+ Benchmarked runtime in ms
25
+ """
26
+ if torch.cuda.current_stream() == torch.cuda.default_stream():
27
+ raise RuntimeError(
28
+ "Cannot capture graph in default stream. "
29
+ "Please use side stream in benchmark code."
30
+ )
31
+ # warmup
32
+ fn()
33
+ # step 1 - we estimate the amount of time the kernel call takes
34
+ # NOTE: this estimate isn't super accurate because the GPU isn't warmed up at this point
35
+ # but it is probably good enough
36
+ if grad_to_none is not None:
37
+ for x in grad_to_none:
38
+ x.detach_()
39
+ x.requires_grad_(True)
40
+ x.grad = None
41
+ g = torch.cuda.CUDAGraph()
42
+ with torch.cuda.graph(g):
43
+ fn()
44
+ torch.cuda.synchronize()
45
+ start_event = torch.cuda.Event(enable_timing=True)
46
+ end_event = torch.cuda.Event(enable_timing=True)
47
+ start_event.record()
48
+ g.replay()
49
+ end_event.record()
50
+ torch.cuda.synchronize()
51
+ estimate_ms = start_event.elapsed_time(end_event)
52
+ n_repeat = max(1, int(rep / estimate_ms))
53
+ # step 2 - construct a cuda graph with `n_repeat` unrolled function calls to minimize
54
+ # host overhead
55
+ g = torch.cuda.CUDAGraph()
56
+ with torch.cuda.graph(g):
57
+ for _i in range(n_repeat):
58
+ if grad_to_none is not None:
59
+ for x in grad_to_none:
60
+ x.grad = None
61
+ fn()
62
+ torch.cuda.synchronize()
63
+ # measure time and return
64
+ ret = []
65
+ n_retries = 10
66
+ for _ in range(n_retries):
67
+ start_event = torch.cuda.Event(enable_timing=True)
68
+ end_event = torch.cuda.Event(enable_timing=True)
69
+ start_event.record()
70
+ g.replay()
71
+ end_event.record()
72
+ torch.cuda.synchronize()
73
+ ret += [start_event.elapsed_time(end_event) / n_repeat]
74
+ return torch.mean(torch.tensor(ret)).item()
@@ -0,0 +1,148 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # pyre-unsafe
7
+
8
+ import dataclasses
9
+ import logging
10
+ import os
11
+ import platform
12
+ from pathlib import Path
13
+ from typing import Any, Dict, Optional
14
+
15
+ import torch
16
+
17
+ logger = logging.getLogger("mslk_fmha")
18
+
19
+ UNAVAILABLE_FEATURES_MSG = " Memory-efficient attention won't be available."
20
+
21
+
22
+ @dataclasses.dataclass
23
+ class _BuildInfo:
24
+ metadata: Dict[str, Any]
25
+
26
+ @property
27
+ def cuda_version(self) -> Optional[int]:
28
+ return self.metadata["version"]["cuda"]
29
+
30
+ @property
31
+ def hip_version(self) -> Optional[int]:
32
+ return self.metadata["version"]["hip"]
33
+
34
+ @property
35
+ def torch_version(self) -> str:
36
+ return self.metadata["version"]["torch"]
37
+
38
+ @property
39
+ def python_version(self) -> str:
40
+ return self.metadata["version"]["python"]
41
+
42
+ @property
43
+ def flash_version(self) -> str:
44
+ return self.metadata["version"].get("flash", "0.0.0")
45
+
46
+ @property
47
+ def use_torch_flash(self) -> bool:
48
+ return self.metadata["version"].get("use_torch_flash", False)
49
+
50
+ @property
51
+ def build_env(self) -> Dict[str, Any]:
52
+ return self.metadata["env"]
53
+
54
+
55
+ class xFormersWasNotBuiltException(Exception):
56
+ def __str__(self) -> str:
57
+ return (
58
+ "Need to compile C++ extensions to use all fmha features.\n"
59
+ " Please install xformers properly "
60
+ "(see https://github.com/facebookresearch/xformers#installing-xformers)\n"
61
+ + UNAVAILABLE_FEATURES_MSG
62
+ )
63
+
64
+
65
+ class xFormersInvalidLibException(Exception):
66
+ def __init__(self, build_info: Optional[_BuildInfo]) -> None:
67
+ self.build_info = build_info
68
+
69
+ def __str__(self) -> str:
70
+ if self.build_info is None:
71
+ msg = "fmha was built for a different version of PyTorch or Python."
72
+ else:
73
+ msg = f"""fmha was built for:
74
+ PyTorch {self.build_info.torch_version} with CUDA {self.build_info.cuda_version} (you have {torch.__version__})
75
+ Python {self.build_info.python_version} (you have {platform.python_version()})"""
76
+ return (
77
+ "fmha can't load C++/CUDA extensions. "
78
+ + msg
79
+ + "\n Please reinstall mslk "
80
+ + UNAVAILABLE_FEATURES_MSG
81
+ )
82
+
83
+
84
+ def _register_extensions():
85
+ import importlib
86
+ import os
87
+
88
+ import torch
89
+
90
+ # load the custom_op_library from the mslk directory
91
+ # and register the custom ops
92
+ lib_dir = str(Path(__file__).parent.parent.parent.parent)
93
+ if os.name == "nt":
94
+ # Register the main torchvision library location on the default DLL path
95
+ import ctypes
96
+ import sys
97
+
98
+ kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
99
+ with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
100
+ prev_error_mode = kernel32.SetErrorMode(0x0001)
101
+
102
+ if with_load_library_flags:
103
+ kernel32.AddDllDirectory.restype = ctypes.c_void_p
104
+
105
+ if sys.version_info >= (3, 8):
106
+ os.add_dll_directory(lib_dir)
107
+ elif with_load_library_flags:
108
+ res = kernel32.AddDllDirectory(lib_dir)
109
+ if res is None:
110
+ err = ctypes.WinError(ctypes.get_last_error())
111
+ err.strerror += f' Error adding "{lib_dir}" to the DLL directories.'
112
+ raise err
113
+
114
+ kernel32.SetErrorMode(prev_error_mode)
115
+
116
+ loader_details = (
117
+ importlib.machinery.ExtensionFileLoader,
118
+ importlib.machinery.EXTENSION_SUFFIXES,
119
+ )
120
+
121
+ extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)
122
+ if torch.version.hip and not hasattr(torch.version, "git_version"):
123
+ ext_specs = extfinder.find_spec("_C_hip")
124
+ else:
125
+ ext_specs = extfinder.find_spec("_C")
126
+ if ext_specs is None:
127
+ raise xFormersWasNotBuiltException()
128
+ try:
129
+ torch.ops.load_library(ext_specs.origin)
130
+ except OSError as exc:
131
+ raise xFormersInvalidLibException(None) from exc
132
+
133
+
134
+ _cpp_library_load_exception = None
135
+
136
+ try:
137
+ _register_extensions()
138
+ except (xFormersInvalidLibException, xFormersWasNotBuiltException) as e:
139
+ ENV_VAR_FOR_DETAILS = "XFORMERS_MORE_DETAILS"
140
+ if os.environ.get(ENV_VAR_FOR_DETAILS, False):
141
+ logger.warning(f"WARNING[XFORMERS]: {e}", exc_info=e)
142
+ else:
143
+ logger.warning(
144
+ f"WARNING[XFORMERS]: {e}\n Set {ENV_VAR_FOR_DETAILS}=1 for more details"
145
+ )
146
+ _cpp_library_load_exception = e
147
+
148
+ _built_with_cuda = True # XXXXX
@@ -0,0 +1,65 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # pyre-unsafe
7
+
8
+ from typing import Any, Dict, List, Type, TypeVar
9
+
10
+ import torch
11
+
12
+
13
+ def get_operator(library: str, name: str):
14
+ def no_such_operator(*args, **kwargs):
15
+ raise RuntimeError(
16
+ f"No such operator {library}::{name} - did you forget to build xformers with `python setup.py develop`?"
17
+ )
18
+
19
+ try:
20
+ return getattr(getattr(torch.ops, library), name)
21
+ except (RuntimeError, AttributeError):
22
+ return no_such_operator
23
+
24
+
25
+ def get_xformers_operator(name: str):
26
+ return get_operator("xformers", name)
27
+
28
+
29
+ class BaseOperator:
30
+ OPERATOR: Any # pyre-ignore[13]
31
+ NAME: str # pyre-ignore[13]
32
+ OPERATOR_CATEGORY: str # pyre-ignore[13]
33
+
34
+ @classmethod
35
+ def is_available(cls) -> bool:
36
+ # cls.OPERATOR can be either a kernel or a Triton Autotuner object, which doesn't have __name__
37
+ if (
38
+ cls.OPERATOR is None
39
+ or getattr(cls.OPERATOR, "__name__", "") == "no_such_operator"
40
+ ):
41
+ return False
42
+ return True
43
+
44
+
45
+ OPERATORS_REGISTRY: List[Type[BaseOperator]] = []
46
+ FUNC_TO_XFORMERS_OPERATOR: Dict[Any, Type[BaseOperator]] = {}
47
+
48
+ ClsT = TypeVar("ClsT")
49
+
50
+
51
+ def register_operator(cls: ClsT) -> ClsT:
52
+ OPERATORS_REGISTRY.append(cls) # type: ignore
53
+ FUNC_TO_XFORMERS_OPERATOR[cls.OPERATOR] = cls # type: ignore
54
+ return cls
55
+
56
+
57
+ # post-2.0, avoids a warning
58
+ # (`torch.Tensor.storage` will also be deleted in the future)
59
+ _GET_TENSOR_STORAGE = getattr(torch.Tensor, "untyped_storage", None)
60
+ if _GET_TENSOR_STORAGE is None: # pre-2.0, `untyped_storage` didn't exist
61
+ _GET_TENSOR_STORAGE = torch.Tensor.storage
62
+
63
+
64
+ def _get_storage_base(x: torch.Tensor) -> int:
65
+ return _GET_TENSOR_STORAGE(x).data_ptr() # type: ignore
@@ -0,0 +1,11 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+ from mslk.utils.torch.library import load_library_buck
10
+
11
+ load_library_buck("//mslk/csrc/attention/cuda/gqa_attn_splitk:gqa_attn_splitk_ops_gpu")
@@ -0,0 +1,7 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
@@ -0,0 +1,255 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import argparse
9
+ import os
10
+ import tempfile
11
+ import uuid
12
+ from functools import lru_cache
13
+ from pprint import pprint
14
+
15
+ import mslk.comm # noqa: F401
16
+ import pandas as pd
17
+ import torch
18
+ import torch.distributed as dist
19
+ import torch.distributed._symmetric_memory as symm_mem
20
+ from torch.distributed.launcher.api import elastic_launch, LaunchConfig
21
+
22
+
23
+ @lru_cache(None)
24
+ def get_symm_buffer(group):
25
+ inp = symm_mem.empty(
26
+ 16 * 1024 * 1024, device="cuda", dtype=torch.bfloat16
27
+ ) # .normal_()
28
+ symm_mem.rendezvous(inp, group=group)
29
+ return inp, group.group_name
30
+
31
+
32
+ def _setup(path: str) -> tuple[int, int]:
33
+ rank = int(os.environ["LOCAL_RANK"])
34
+ W = int(os.environ["WORLD_SIZE"])
35
+ device = torch.device(f"cuda:{rank}")
36
+ torch.cuda.set_device(device)
37
+ os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
38
+
39
+ torch.ops.mslk.nccl_init(rank, W, os.path.join(path, "rdvz"))
40
+ torch.distributed.init_process_group(
41
+ backend="cpu:gloo,cuda:nccl",
42
+ init_method=f"file://{os.path.join(path, 'gloo_rdvz')}",
43
+ world_size=W,
44
+ rank=rank,
45
+ )
46
+
47
+ buffer = torch.ops.mslk.car_tensor()
48
+ barrier = torch.ops.mslk.car_tensor()
49
+ barrier.zero_()
50
+
51
+ buffer_handle = torch.ops.mslk.car_ipc_handle(buffer)
52
+ all_buffer_handles = [torch.empty_like(buffer_handle) for _ in range(W)]
53
+ torch.distributed.all_gather(all_buffer_handles, buffer_handle)
54
+
55
+ barrier_handle = torch.ops.mslk.car_ipc_handle(barrier)
56
+ all_barrier_handles = [torch.empty_like(barrier_handle) for _ in range(W)]
57
+ torch.distributed.all_gather(all_barrier_handles, barrier_handle)
58
+ torch.ops.mslk.car_init(
59
+ rank, W, barrier, all_barrier_handles, buffer, all_buffer_handles
60
+ )
61
+ torch.cuda.synchronize()
62
+ torch.distributed.barrier()
63
+ group = dist.group.WORLD
64
+ _ = get_symm_buffer(group)
65
+ return rank, W
66
+
67
+
68
+ def symm_one_shot_allreduce(dst_tensor, src_tensor, bias=None, comm_idx=None):
69
+ # get_symm_buffer should be called for the first time during model init,
70
+ # and now return cached values. Make sure group is the same as during init
71
+ symm_buffer, group_name = get_symm_buffer(dist.group.WORLD)
72
+ symm_buffer = symm_buffer[: src_tensor.numel()].view_as(src_tensor)
73
+ torch.ops.symm_mem.one_shot_all_reduce_copy_out(
74
+ symm_buffer, src_tensor, "sum", group_name, dst_tensor
75
+ )
76
+ if bias is not None:
77
+ dst_tensor.add_(bias)
78
+
79
+
80
+ def symm_two_shot_allreduce(dst_tensor, src_tensor, bias=None, comm_idx=None):
81
+ # get_symm_buffer should be called for the first time during model init,
82
+ # and now return cached values. Make sure group is the same as during init
83
+ symm_buffer, group_name = get_symm_buffer(dist.group.WORLD)
84
+ # car is also doing explicit copy
85
+ symm_buffer = symm_buffer[: src_tensor.numel()].view_as(src_tensor)
86
+ symm_buffer.copy_(src_tensor)
87
+ torch.ops.symm_mem.two_shot_all_reduce_out(
88
+ symm_buffer, "sum", group_name, dst_tensor
89
+ )
90
+ if bias is not None:
91
+ dst_tensor.add_(bias)
92
+
93
+
94
+ def symm_reduce_scatter(dst_tensor, src_tensor, comm_idx=None):
95
+ symm_buffer, group_name = get_symm_buffer(dist.group.WORLD)
96
+ symm_buffer = symm_buffer[: src_tensor.numel()].view_as(src_tensor)
97
+ symm_buffer.copy_(src_tensor)
98
+ torch.ops.symm_mem.reduce_scatter_out(symm_buffer, group_name, False, dst_tensor)
99
+
100
+
101
+ def run_one_algo(fn, out, inp, num_iters, num_warmup_iters):
102
+ start_event = torch.cuda.Event(enable_timing=True)
103
+ end_event = torch.cuda.Event(enable_timing=True)
104
+ for _ in range(num_warmup_iters):
105
+ fn(out, inp)
106
+ start_event.record()
107
+ for _ in range(num_iters):
108
+ fn(out, inp)
109
+ end_event.record()
110
+ torch.cuda.synchronize()
111
+ time = start_event.elapsed_time(end_event) / num_iters
112
+ return time
113
+
114
+
115
+ def run_benchmark(args, path):
116
+ rank, W = _setup(path)
117
+ if rank == 0:
118
+ print(f"Running benchmark with {W} ranks")
119
+ # benchmark_results = defaultdict(defaultdict)
120
+ benchmark_results = []
121
+ # with torch.profiler.profile() as p:
122
+ for N in torch.logspace(
123
+ args.min_size, args.max_size, steps=args.size_steps, base=2
124
+ ).tolist():
125
+
126
+ def round_up(a: int, b: int) -> int:
127
+ return ((a + b - 1) // b) * b
128
+
129
+ N_even_divisor = 8 * 64 if torch.version.hip else 8 * 32
130
+ N = round_up(int(N), N_even_divisor)
131
+ inp = torch.rand(N, dtype=torch.bfloat16, device="cuda")
132
+ results = {"N": N}
133
+ if args.op == "allreduce":
134
+ out = torch.full_like(inp, -1)
135
+ fns = (
136
+ torch.ops.mslk.one_shot_car_allreduce,
137
+ symm_one_shot_allreduce,
138
+ torch.ops.mslk.two_shot_car_allreduce,
139
+ symm_two_shot_allreduce,
140
+ torch.ops.mslk.nccl_allreduce,
141
+ )
142
+ labels = (
143
+ "mslk_1shot",
144
+ "symm_1shot",
145
+ "mslk_2shot",
146
+ "symm_2shot",
147
+ "nccl",
148
+ )
149
+ for fn, label in zip(fns, labels):
150
+ time = run_one_algo(
151
+ fn,
152
+ out,
153
+ inp,
154
+ args.num_iters,
155
+ args.num_warmup_iters,
156
+ )
157
+ results[f"{label}_time"] = time
158
+ results[f"{label}_bwidth"] = (
159
+ N * inp.element_size() / (time * 1e-3) / 1e9
160
+ )
161
+ else:
162
+ out = torch.full(
163
+ (inp.shape[0] // W,), -1, dtype=inp.dtype, device=inp.device
164
+ )
165
+ fns = (
166
+ torch.ops.mslk.car_reducescatter,
167
+ symm_reduce_scatter,
168
+ torch.ops.mslk.nccl_reducescatter,
169
+ )
170
+ labels = ("mslk_rs", "symm_rs", "nccl_rs")
171
+ for fn, label in zip(fns, labels):
172
+ time = run_one_algo(
173
+ fn,
174
+ out,
175
+ inp,
176
+ args.num_iters,
177
+ args.num_warmup_iters,
178
+ )
179
+ results[f"{label}_time"] = time
180
+ results[f"{label}_bwidth"] = (
181
+ N * inp.element_size() / (time * 1e-3) / 1e9
182
+ )
183
+
184
+ benchmark_results.append(results)
185
+
186
+ if rank == 0:
187
+ pprint(benchmark_results)
188
+ if args.export_csv:
189
+ csv_file = os.path.join(args.output_dir, "comm_ops_benchmark.csv")
190
+ # Export results to a CSV file.
191
+ df = pd.DataFrame(benchmark_results)
192
+ df.to_csv(csv_file, index=False)
193
+
194
+
195
+ def main(args, path):
196
+ if args.export_csv:
197
+ os.makedirs(args.output_dir, exist_ok=True)
198
+ print("csv and images will be saved to " + args.output_dir)
199
+
200
+ lc = LaunchConfig(
201
+ min_nodes=1,
202
+ max_nodes=1,
203
+ nproc_per_node=args.num_ranks,
204
+ run_id=str(uuid.uuid4()),
205
+ rdzv_backend="c10d",
206
+ rdzv_endpoint="localhost:0",
207
+ max_restarts=0,
208
+ monitor_interval=1,
209
+ )
210
+ elastic_launch(lc, entrypoint=run_benchmark)(args, path)
211
+
212
+
213
+ def invoke_main():
214
+ parser = argparse.ArgumentParser()
215
+ parser.add_argument(
216
+ "--output_dir", default="/tmp", help="Directory to save plots and csvs to"
217
+ )
218
+ parser.add_argument(
219
+ "--export_csv",
220
+ action="store_true",
221
+ help="Export results to a CSV file.",
222
+ )
223
+ parser.add_argument("--num_ranks", type=int, default=8)
224
+ parser.add_argument("--num_iters", type=int, default=20)
225
+ parser.add_argument("--num_warmup_iters", type=int, default=10)
226
+ parser.add_argument(
227
+ "--min_size",
228
+ type=int,
229
+ default=10,
230
+ help="minimum size will be set to 2**min_size",
231
+ )
232
+ parser.add_argument(
233
+ "--max_size",
234
+ type=int,
235
+ default=24,
236
+ help="maximum size will be set to 2**max_size",
237
+ )
238
+ parser.add_argument(
239
+ "--size_steps", type=int, default=20, help="number of size steps to run"
240
+ )
241
+ parser.add_argument(
242
+ "--op",
243
+ type=str,
244
+ default="allreduce",
245
+ choices=["allreduce", "reduce_scatter"],
246
+ help="op to benchmark, allreduce or reduce_scatter",
247
+ )
248
+ args = parser.parse_args()
249
+
250
+ with tempfile.TemporaryDirectory() as path:
251
+ main(args, path)
252
+
253
+
254
+ if __name__ == "__main__":
255
+ invoke_main()