mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
# pyre-strict
|
|
7
|
+
from typing import List, Optional, Sequence, Tuple, Union
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from .utils.op_common import _get_storage_base
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def get_stack_strides(
|
|
15
|
+
tensors: Sequence[torch.Tensor], dim: int
|
|
16
|
+
) -> Optional[Tuple[Union[int, torch.SymInt], ...]]:
|
|
17
|
+
"""
|
|
18
|
+
If the tensors are already stacked on dimension :code:`dim`, \
|
|
19
|
+
returns the strides of the stacked tensors. \
|
|
20
|
+
Otherwise returns :code:`None`.
|
|
21
|
+
"""
|
|
22
|
+
if len(tensors) <= 1 or dim > tensors[0].ndim:
|
|
23
|
+
return None
|
|
24
|
+
|
|
25
|
+
final_stride = []
|
|
26
|
+
for i in range(tensors[0].ndim + 1):
|
|
27
|
+
if i == dim:
|
|
28
|
+
# PyTorch 2.5 messed up the type annotations for SymInt, but 2.6 will fix it
|
|
29
|
+
# https://github.com/pytorch/pytorch/issues/138478
|
|
30
|
+
final_stride.append(
|
|
31
|
+
tensors[1].storage_offset() - tensors[0].storage_offset() # type: ignore[operator]
|
|
32
|
+
)
|
|
33
|
+
continue
|
|
34
|
+
if i > dim:
|
|
35
|
+
i -= 1
|
|
36
|
+
final_stride.append(tensors[0].stride(i))
|
|
37
|
+
|
|
38
|
+
storage_data_ptr: Optional[int] = None
|
|
39
|
+
for i, x in enumerate(tensors[1:]):
|
|
40
|
+
# Sanity checks
|
|
41
|
+
if x.shape != tensors[0].shape:
|
|
42
|
+
return None
|
|
43
|
+
if x.stride() != tensors[0].stride():
|
|
44
|
+
return None
|
|
45
|
+
# PyTorch 2.5 messed up the type annotations for SymInt, but 2.6 will fix it
|
|
46
|
+
# https://github.com/pytorch/pytorch/issues/138478
|
|
47
|
+
if (
|
|
48
|
+
x.storage_offset()
|
|
49
|
+
!= tensors[0].storage_offset() + (i + 1) * final_stride[dim] # type: ignore[operator]
|
|
50
|
+
):
|
|
51
|
+
return None
|
|
52
|
+
if storage_data_ptr is None:
|
|
53
|
+
storage_data_ptr = _get_storage_base(tensors[0])
|
|
54
|
+
# Actual storage check
|
|
55
|
+
if _get_storage_base(x) != storage_data_ptr:
|
|
56
|
+
return None
|
|
57
|
+
return tuple(final_stride)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _stack_or_none_fw(
|
|
61
|
+
tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]],
|
|
62
|
+
dim: int,
|
|
63
|
+
) -> Optional[torch.Tensor]:
|
|
64
|
+
strides = get_stack_strides(tensors, dim)
|
|
65
|
+
if strides is not None:
|
|
66
|
+
input_shape = list(tensors[0].shape)
|
|
67
|
+
input_shape.insert(dim, len(tensors))
|
|
68
|
+
return tensors[0].as_strided(input_shape, strides)
|
|
69
|
+
return None
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _stack_fw(
|
|
73
|
+
tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]],
|
|
74
|
+
dim: int,
|
|
75
|
+
) -> torch.Tensor:
|
|
76
|
+
out = _stack_or_none_fw(tensors, dim)
|
|
77
|
+
if out is None:
|
|
78
|
+
out = torch.stack(tensors, dim=dim)
|
|
79
|
+
return out
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class _Unbind(torch.autograd.Function):
|
|
83
|
+
"""
|
|
84
|
+
See function `unbind`
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
@staticmethod
|
|
88
|
+
# type: ignore
|
|
89
|
+
def forward(ctx, x: torch.Tensor, dim: int):
|
|
90
|
+
ctx.dim = dim
|
|
91
|
+
return x.unbind(dim)
|
|
92
|
+
|
|
93
|
+
@classmethod
|
|
94
|
+
# type: ignore
|
|
95
|
+
def backward(cls, ctx, *tensors: torch.Tensor):
|
|
96
|
+
return _stack_fw(tensors, ctx.dim), None
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class _StackOrNone(torch.autograd.Function):
|
|
100
|
+
"""
|
|
101
|
+
See function `stack_or_none`
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
@staticmethod
|
|
105
|
+
# type: ignore
|
|
106
|
+
def forward(ctx, dim: int, *tensors: torch.Tensor):
|
|
107
|
+
ctx.dim = dim
|
|
108
|
+
return _stack_or_none_fw(tensors, dim=dim)
|
|
109
|
+
|
|
110
|
+
@classmethod
|
|
111
|
+
# type: ignore
|
|
112
|
+
def backward(cls, ctx, grad: torch.Tensor):
|
|
113
|
+
return (None, *grad.unbind(dim=ctx.dim))
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def unbind(x: torch.Tensor, dim: int) -> Tuple[torch.Tensor, ...]:
|
|
117
|
+
"""
|
|
118
|
+
Does exactly the same as :attr:`torch.unbind` for the forward.
|
|
119
|
+
In backward, avoids a :attr:`torch.cat` if the gradients
|
|
120
|
+
are already multiple views of the same storage
|
|
121
|
+
"""
|
|
122
|
+
return _Unbind.apply(x, dim)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def stack_or_none(tensors: Sequence[torch.Tensor], dim: int) -> torch.Tensor:
|
|
126
|
+
"""
|
|
127
|
+
Does exactly the same as :attr:`torch.stack` if the tensors can be concatenated
|
|
128
|
+
without any memory operation. Otherwise returns None.
|
|
129
|
+
"""
|
|
130
|
+
return _StackOrNone.apply(dim, *tensors)
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
# pyre-unsafe
|
|
7
|
+
|
|
8
|
+
from typing import Callable, List, Optional
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# from https://github.com/openai/triton/blob/95d9b7f4ae21710dc899e1de6a579b2136ea4f3d/python/triton/testing.py#L19
|
|
14
|
+
def do_bench_cudagraph(
|
|
15
|
+
fn: Callable, rep: int = 20, grad_to_none: Optional[List[torch.Tensor]] = None
|
|
16
|
+
) -> float:
|
|
17
|
+
"""
|
|
18
|
+
Benchmark the runtime of the provided function.
|
|
19
|
+
Args:
|
|
20
|
+
fn: Function to benchmark
|
|
21
|
+
rep: Repetition time (in ms)
|
|
22
|
+
grad_to_none: Reset the gradient of the provided tensor to None
|
|
23
|
+
Returns:
|
|
24
|
+
Benchmarked runtime in ms
|
|
25
|
+
"""
|
|
26
|
+
if torch.cuda.current_stream() == torch.cuda.default_stream():
|
|
27
|
+
raise RuntimeError(
|
|
28
|
+
"Cannot capture graph in default stream. "
|
|
29
|
+
"Please use side stream in benchmark code."
|
|
30
|
+
)
|
|
31
|
+
# warmup
|
|
32
|
+
fn()
|
|
33
|
+
# step 1 - we estimate the amount of time the kernel call takes
|
|
34
|
+
# NOTE: this estimate isn't super accurate because the GPU isn't warmed up at this point
|
|
35
|
+
# but it is probably good enough
|
|
36
|
+
if grad_to_none is not None:
|
|
37
|
+
for x in grad_to_none:
|
|
38
|
+
x.detach_()
|
|
39
|
+
x.requires_grad_(True)
|
|
40
|
+
x.grad = None
|
|
41
|
+
g = torch.cuda.CUDAGraph()
|
|
42
|
+
with torch.cuda.graph(g):
|
|
43
|
+
fn()
|
|
44
|
+
torch.cuda.synchronize()
|
|
45
|
+
start_event = torch.cuda.Event(enable_timing=True)
|
|
46
|
+
end_event = torch.cuda.Event(enable_timing=True)
|
|
47
|
+
start_event.record()
|
|
48
|
+
g.replay()
|
|
49
|
+
end_event.record()
|
|
50
|
+
torch.cuda.synchronize()
|
|
51
|
+
estimate_ms = start_event.elapsed_time(end_event)
|
|
52
|
+
n_repeat = max(1, int(rep / estimate_ms))
|
|
53
|
+
# step 2 - construct a cuda graph with `n_repeat` unrolled function calls to minimize
|
|
54
|
+
# host overhead
|
|
55
|
+
g = torch.cuda.CUDAGraph()
|
|
56
|
+
with torch.cuda.graph(g):
|
|
57
|
+
for _i in range(n_repeat):
|
|
58
|
+
if grad_to_none is not None:
|
|
59
|
+
for x in grad_to_none:
|
|
60
|
+
x.grad = None
|
|
61
|
+
fn()
|
|
62
|
+
torch.cuda.synchronize()
|
|
63
|
+
# measure time and return
|
|
64
|
+
ret = []
|
|
65
|
+
n_retries = 10
|
|
66
|
+
for _ in range(n_retries):
|
|
67
|
+
start_event = torch.cuda.Event(enable_timing=True)
|
|
68
|
+
end_event = torch.cuda.Event(enable_timing=True)
|
|
69
|
+
start_event.record()
|
|
70
|
+
g.replay()
|
|
71
|
+
end_event.record()
|
|
72
|
+
torch.cuda.synchronize()
|
|
73
|
+
ret += [start_event.elapsed_time(end_event) / n_repeat]
|
|
74
|
+
return torch.mean(torch.tensor(ret)).item()
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
# pyre-unsafe
|
|
7
|
+
|
|
8
|
+
import dataclasses
|
|
9
|
+
import logging
|
|
10
|
+
import os
|
|
11
|
+
import platform
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Any, Dict, Optional
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger("mslk_fmha")
|
|
18
|
+
|
|
19
|
+
UNAVAILABLE_FEATURES_MSG = " Memory-efficient attention won't be available."
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclasses.dataclass
|
|
23
|
+
class _BuildInfo:
|
|
24
|
+
metadata: Dict[str, Any]
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
def cuda_version(self) -> Optional[int]:
|
|
28
|
+
return self.metadata["version"]["cuda"]
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
def hip_version(self) -> Optional[int]:
|
|
32
|
+
return self.metadata["version"]["hip"]
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def torch_version(self) -> str:
|
|
36
|
+
return self.metadata["version"]["torch"]
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def python_version(self) -> str:
|
|
40
|
+
return self.metadata["version"]["python"]
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def flash_version(self) -> str:
|
|
44
|
+
return self.metadata["version"].get("flash", "0.0.0")
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def use_torch_flash(self) -> bool:
|
|
48
|
+
return self.metadata["version"].get("use_torch_flash", False)
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def build_env(self) -> Dict[str, Any]:
|
|
52
|
+
return self.metadata["env"]
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class xFormersWasNotBuiltException(Exception):
|
|
56
|
+
def __str__(self) -> str:
|
|
57
|
+
return (
|
|
58
|
+
"Need to compile C++ extensions to use all fmha features.\n"
|
|
59
|
+
" Please install xformers properly "
|
|
60
|
+
"(see https://github.com/facebookresearch/xformers#installing-xformers)\n"
|
|
61
|
+
+ UNAVAILABLE_FEATURES_MSG
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class xFormersInvalidLibException(Exception):
|
|
66
|
+
def __init__(self, build_info: Optional[_BuildInfo]) -> None:
|
|
67
|
+
self.build_info = build_info
|
|
68
|
+
|
|
69
|
+
def __str__(self) -> str:
|
|
70
|
+
if self.build_info is None:
|
|
71
|
+
msg = "fmha was built for a different version of PyTorch or Python."
|
|
72
|
+
else:
|
|
73
|
+
msg = f"""fmha was built for:
|
|
74
|
+
PyTorch {self.build_info.torch_version} with CUDA {self.build_info.cuda_version} (you have {torch.__version__})
|
|
75
|
+
Python {self.build_info.python_version} (you have {platform.python_version()})"""
|
|
76
|
+
return (
|
|
77
|
+
"fmha can't load C++/CUDA extensions. "
|
|
78
|
+
+ msg
|
|
79
|
+
+ "\n Please reinstall mslk "
|
|
80
|
+
+ UNAVAILABLE_FEATURES_MSG
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _register_extensions():
|
|
85
|
+
import importlib
|
|
86
|
+
import os
|
|
87
|
+
|
|
88
|
+
import torch
|
|
89
|
+
|
|
90
|
+
# load the custom_op_library from the mslk directory
|
|
91
|
+
# and register the custom ops
|
|
92
|
+
lib_dir = str(Path(__file__).parent.parent.parent.parent)
|
|
93
|
+
if os.name == "nt":
|
|
94
|
+
# Register the main torchvision library location on the default DLL path
|
|
95
|
+
import ctypes
|
|
96
|
+
import sys
|
|
97
|
+
|
|
98
|
+
kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
|
|
99
|
+
with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
|
|
100
|
+
prev_error_mode = kernel32.SetErrorMode(0x0001)
|
|
101
|
+
|
|
102
|
+
if with_load_library_flags:
|
|
103
|
+
kernel32.AddDllDirectory.restype = ctypes.c_void_p
|
|
104
|
+
|
|
105
|
+
if sys.version_info >= (3, 8):
|
|
106
|
+
os.add_dll_directory(lib_dir)
|
|
107
|
+
elif with_load_library_flags:
|
|
108
|
+
res = kernel32.AddDllDirectory(lib_dir)
|
|
109
|
+
if res is None:
|
|
110
|
+
err = ctypes.WinError(ctypes.get_last_error())
|
|
111
|
+
err.strerror += f' Error adding "{lib_dir}" to the DLL directories.'
|
|
112
|
+
raise err
|
|
113
|
+
|
|
114
|
+
kernel32.SetErrorMode(prev_error_mode)
|
|
115
|
+
|
|
116
|
+
loader_details = (
|
|
117
|
+
importlib.machinery.ExtensionFileLoader,
|
|
118
|
+
importlib.machinery.EXTENSION_SUFFIXES,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)
|
|
122
|
+
if torch.version.hip and not hasattr(torch.version, "git_version"):
|
|
123
|
+
ext_specs = extfinder.find_spec("_C_hip")
|
|
124
|
+
else:
|
|
125
|
+
ext_specs = extfinder.find_spec("_C")
|
|
126
|
+
if ext_specs is None:
|
|
127
|
+
raise xFormersWasNotBuiltException()
|
|
128
|
+
try:
|
|
129
|
+
torch.ops.load_library(ext_specs.origin)
|
|
130
|
+
except OSError as exc:
|
|
131
|
+
raise xFormersInvalidLibException(None) from exc
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
_cpp_library_load_exception = None
|
|
135
|
+
|
|
136
|
+
try:
|
|
137
|
+
_register_extensions()
|
|
138
|
+
except (xFormersInvalidLibException, xFormersWasNotBuiltException) as e:
|
|
139
|
+
ENV_VAR_FOR_DETAILS = "XFORMERS_MORE_DETAILS"
|
|
140
|
+
if os.environ.get(ENV_VAR_FOR_DETAILS, False):
|
|
141
|
+
logger.warning(f"WARNING[XFORMERS]: {e}", exc_info=e)
|
|
142
|
+
else:
|
|
143
|
+
logger.warning(
|
|
144
|
+
f"WARNING[XFORMERS]: {e}\n Set {ENV_VAR_FOR_DETAILS}=1 for more details"
|
|
145
|
+
)
|
|
146
|
+
_cpp_library_load_exception = e
|
|
147
|
+
|
|
148
|
+
_built_with_cuda = True # XXXXX
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
# pyre-unsafe
|
|
7
|
+
|
|
8
|
+
from typing import Any, Dict, List, Type, TypeVar
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def get_operator(library: str, name: str):
|
|
14
|
+
def no_such_operator(*args, **kwargs):
|
|
15
|
+
raise RuntimeError(
|
|
16
|
+
f"No such operator {library}::{name} - did you forget to build xformers with `python setup.py develop`?"
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
return getattr(getattr(torch.ops, library), name)
|
|
21
|
+
except (RuntimeError, AttributeError):
|
|
22
|
+
return no_such_operator
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_xformers_operator(name: str):
|
|
26
|
+
return get_operator("xformers", name)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class BaseOperator:
|
|
30
|
+
OPERATOR: Any # pyre-ignore[13]
|
|
31
|
+
NAME: str # pyre-ignore[13]
|
|
32
|
+
OPERATOR_CATEGORY: str # pyre-ignore[13]
|
|
33
|
+
|
|
34
|
+
@classmethod
|
|
35
|
+
def is_available(cls) -> bool:
|
|
36
|
+
# cls.OPERATOR can be either a kernel or a Triton Autotuner object, which doesn't have __name__
|
|
37
|
+
if (
|
|
38
|
+
cls.OPERATOR is None
|
|
39
|
+
or getattr(cls.OPERATOR, "__name__", "") == "no_such_operator"
|
|
40
|
+
):
|
|
41
|
+
return False
|
|
42
|
+
return True
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
OPERATORS_REGISTRY: List[Type[BaseOperator]] = []
|
|
46
|
+
FUNC_TO_XFORMERS_OPERATOR: Dict[Any, Type[BaseOperator]] = {}
|
|
47
|
+
|
|
48
|
+
ClsT = TypeVar("ClsT")
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def register_operator(cls: ClsT) -> ClsT:
|
|
52
|
+
OPERATORS_REGISTRY.append(cls) # type: ignore
|
|
53
|
+
FUNC_TO_XFORMERS_OPERATOR[cls.OPERATOR] = cls # type: ignore
|
|
54
|
+
return cls
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
# post-2.0, avoids a warning
|
|
58
|
+
# (`torch.Tensor.storage` will also be deleted in the future)
|
|
59
|
+
_GET_TENSOR_STORAGE = getattr(torch.Tensor, "untyped_storage", None)
|
|
60
|
+
if _GET_TENSOR_STORAGE is None: # pre-2.0, `untyped_storage` didn't exist
|
|
61
|
+
_GET_TENSOR_STORAGE = torch.Tensor.storage
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _get_storage_base(x: torch.Tensor) -> int:
|
|
65
|
+
return _GET_TENSOR_STORAGE(x).data_ptr() # type: ignore
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-unsafe
|
|
8
|
+
|
|
9
|
+
from mslk.utils.torch.library import load_library_buck
|
|
10
|
+
|
|
11
|
+
load_library_buck("//mslk/csrc/attention/cuda/gqa_attn_splitk:gqa_attn_splitk_ops_gpu")
|
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
import argparse
|
|
9
|
+
import os
|
|
10
|
+
import tempfile
|
|
11
|
+
import uuid
|
|
12
|
+
from functools import lru_cache
|
|
13
|
+
from pprint import pprint
|
|
14
|
+
|
|
15
|
+
import mslk.comm # noqa: F401
|
|
16
|
+
import pandas as pd
|
|
17
|
+
import torch
|
|
18
|
+
import torch.distributed as dist
|
|
19
|
+
import torch.distributed._symmetric_memory as symm_mem
|
|
20
|
+
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@lru_cache(None)
|
|
24
|
+
def get_symm_buffer(group):
|
|
25
|
+
inp = symm_mem.empty(
|
|
26
|
+
16 * 1024 * 1024, device="cuda", dtype=torch.bfloat16
|
|
27
|
+
) # .normal_()
|
|
28
|
+
symm_mem.rendezvous(inp, group=group)
|
|
29
|
+
return inp, group.group_name
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _setup(path: str) -> tuple[int, int]:
|
|
33
|
+
rank = int(os.environ["LOCAL_RANK"])
|
|
34
|
+
W = int(os.environ["WORLD_SIZE"])
|
|
35
|
+
device = torch.device(f"cuda:{rank}")
|
|
36
|
+
torch.cuda.set_device(device)
|
|
37
|
+
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
|
|
38
|
+
|
|
39
|
+
torch.ops.mslk.nccl_init(rank, W, os.path.join(path, "rdvz"))
|
|
40
|
+
torch.distributed.init_process_group(
|
|
41
|
+
backend="cpu:gloo,cuda:nccl",
|
|
42
|
+
init_method=f"file://{os.path.join(path, 'gloo_rdvz')}",
|
|
43
|
+
world_size=W,
|
|
44
|
+
rank=rank,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
buffer = torch.ops.mslk.car_tensor()
|
|
48
|
+
barrier = torch.ops.mslk.car_tensor()
|
|
49
|
+
barrier.zero_()
|
|
50
|
+
|
|
51
|
+
buffer_handle = torch.ops.mslk.car_ipc_handle(buffer)
|
|
52
|
+
all_buffer_handles = [torch.empty_like(buffer_handle) for _ in range(W)]
|
|
53
|
+
torch.distributed.all_gather(all_buffer_handles, buffer_handle)
|
|
54
|
+
|
|
55
|
+
barrier_handle = torch.ops.mslk.car_ipc_handle(barrier)
|
|
56
|
+
all_barrier_handles = [torch.empty_like(barrier_handle) for _ in range(W)]
|
|
57
|
+
torch.distributed.all_gather(all_barrier_handles, barrier_handle)
|
|
58
|
+
torch.ops.mslk.car_init(
|
|
59
|
+
rank, W, barrier, all_barrier_handles, buffer, all_buffer_handles
|
|
60
|
+
)
|
|
61
|
+
torch.cuda.synchronize()
|
|
62
|
+
torch.distributed.barrier()
|
|
63
|
+
group = dist.group.WORLD
|
|
64
|
+
_ = get_symm_buffer(group)
|
|
65
|
+
return rank, W
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def symm_one_shot_allreduce(dst_tensor, src_tensor, bias=None, comm_idx=None):
|
|
69
|
+
# get_symm_buffer should be called for the first time during model init,
|
|
70
|
+
# and now return cached values. Make sure group is the same as during init
|
|
71
|
+
symm_buffer, group_name = get_symm_buffer(dist.group.WORLD)
|
|
72
|
+
symm_buffer = symm_buffer[: src_tensor.numel()].view_as(src_tensor)
|
|
73
|
+
torch.ops.symm_mem.one_shot_all_reduce_copy_out(
|
|
74
|
+
symm_buffer, src_tensor, "sum", group_name, dst_tensor
|
|
75
|
+
)
|
|
76
|
+
if bias is not None:
|
|
77
|
+
dst_tensor.add_(bias)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def symm_two_shot_allreduce(dst_tensor, src_tensor, bias=None, comm_idx=None):
|
|
81
|
+
# get_symm_buffer should be called for the first time during model init,
|
|
82
|
+
# and now return cached values. Make sure group is the same as during init
|
|
83
|
+
symm_buffer, group_name = get_symm_buffer(dist.group.WORLD)
|
|
84
|
+
# car is also doing explicit copy
|
|
85
|
+
symm_buffer = symm_buffer[: src_tensor.numel()].view_as(src_tensor)
|
|
86
|
+
symm_buffer.copy_(src_tensor)
|
|
87
|
+
torch.ops.symm_mem.two_shot_all_reduce_out(
|
|
88
|
+
symm_buffer, "sum", group_name, dst_tensor
|
|
89
|
+
)
|
|
90
|
+
if bias is not None:
|
|
91
|
+
dst_tensor.add_(bias)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def symm_reduce_scatter(dst_tensor, src_tensor, comm_idx=None):
|
|
95
|
+
symm_buffer, group_name = get_symm_buffer(dist.group.WORLD)
|
|
96
|
+
symm_buffer = symm_buffer[: src_tensor.numel()].view_as(src_tensor)
|
|
97
|
+
symm_buffer.copy_(src_tensor)
|
|
98
|
+
torch.ops.symm_mem.reduce_scatter_out(symm_buffer, group_name, False, dst_tensor)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def run_one_algo(fn, out, inp, num_iters, num_warmup_iters):
|
|
102
|
+
start_event = torch.cuda.Event(enable_timing=True)
|
|
103
|
+
end_event = torch.cuda.Event(enable_timing=True)
|
|
104
|
+
for _ in range(num_warmup_iters):
|
|
105
|
+
fn(out, inp)
|
|
106
|
+
start_event.record()
|
|
107
|
+
for _ in range(num_iters):
|
|
108
|
+
fn(out, inp)
|
|
109
|
+
end_event.record()
|
|
110
|
+
torch.cuda.synchronize()
|
|
111
|
+
time = start_event.elapsed_time(end_event) / num_iters
|
|
112
|
+
return time
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def run_benchmark(args, path):
|
|
116
|
+
rank, W = _setup(path)
|
|
117
|
+
if rank == 0:
|
|
118
|
+
print(f"Running benchmark with {W} ranks")
|
|
119
|
+
# benchmark_results = defaultdict(defaultdict)
|
|
120
|
+
benchmark_results = []
|
|
121
|
+
# with torch.profiler.profile() as p:
|
|
122
|
+
for N in torch.logspace(
|
|
123
|
+
args.min_size, args.max_size, steps=args.size_steps, base=2
|
|
124
|
+
).tolist():
|
|
125
|
+
|
|
126
|
+
def round_up(a: int, b: int) -> int:
|
|
127
|
+
return ((a + b - 1) // b) * b
|
|
128
|
+
|
|
129
|
+
N_even_divisor = 8 * 64 if torch.version.hip else 8 * 32
|
|
130
|
+
N = round_up(int(N), N_even_divisor)
|
|
131
|
+
inp = torch.rand(N, dtype=torch.bfloat16, device="cuda")
|
|
132
|
+
results = {"N": N}
|
|
133
|
+
if args.op == "allreduce":
|
|
134
|
+
out = torch.full_like(inp, -1)
|
|
135
|
+
fns = (
|
|
136
|
+
torch.ops.mslk.one_shot_car_allreduce,
|
|
137
|
+
symm_one_shot_allreduce,
|
|
138
|
+
torch.ops.mslk.two_shot_car_allreduce,
|
|
139
|
+
symm_two_shot_allreduce,
|
|
140
|
+
torch.ops.mslk.nccl_allreduce,
|
|
141
|
+
)
|
|
142
|
+
labels = (
|
|
143
|
+
"mslk_1shot",
|
|
144
|
+
"symm_1shot",
|
|
145
|
+
"mslk_2shot",
|
|
146
|
+
"symm_2shot",
|
|
147
|
+
"nccl",
|
|
148
|
+
)
|
|
149
|
+
for fn, label in zip(fns, labels):
|
|
150
|
+
time = run_one_algo(
|
|
151
|
+
fn,
|
|
152
|
+
out,
|
|
153
|
+
inp,
|
|
154
|
+
args.num_iters,
|
|
155
|
+
args.num_warmup_iters,
|
|
156
|
+
)
|
|
157
|
+
results[f"{label}_time"] = time
|
|
158
|
+
results[f"{label}_bwidth"] = (
|
|
159
|
+
N * inp.element_size() / (time * 1e-3) / 1e9
|
|
160
|
+
)
|
|
161
|
+
else:
|
|
162
|
+
out = torch.full(
|
|
163
|
+
(inp.shape[0] // W,), -1, dtype=inp.dtype, device=inp.device
|
|
164
|
+
)
|
|
165
|
+
fns = (
|
|
166
|
+
torch.ops.mslk.car_reducescatter,
|
|
167
|
+
symm_reduce_scatter,
|
|
168
|
+
torch.ops.mslk.nccl_reducescatter,
|
|
169
|
+
)
|
|
170
|
+
labels = ("mslk_rs", "symm_rs", "nccl_rs")
|
|
171
|
+
for fn, label in zip(fns, labels):
|
|
172
|
+
time = run_one_algo(
|
|
173
|
+
fn,
|
|
174
|
+
out,
|
|
175
|
+
inp,
|
|
176
|
+
args.num_iters,
|
|
177
|
+
args.num_warmup_iters,
|
|
178
|
+
)
|
|
179
|
+
results[f"{label}_time"] = time
|
|
180
|
+
results[f"{label}_bwidth"] = (
|
|
181
|
+
N * inp.element_size() / (time * 1e-3) / 1e9
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
benchmark_results.append(results)
|
|
185
|
+
|
|
186
|
+
if rank == 0:
|
|
187
|
+
pprint(benchmark_results)
|
|
188
|
+
if args.export_csv:
|
|
189
|
+
csv_file = os.path.join(args.output_dir, "comm_ops_benchmark.csv")
|
|
190
|
+
# Export results to a CSV file.
|
|
191
|
+
df = pd.DataFrame(benchmark_results)
|
|
192
|
+
df.to_csv(csv_file, index=False)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def main(args, path):
|
|
196
|
+
if args.export_csv:
|
|
197
|
+
os.makedirs(args.output_dir, exist_ok=True)
|
|
198
|
+
print("csv and images will be saved to " + args.output_dir)
|
|
199
|
+
|
|
200
|
+
lc = LaunchConfig(
|
|
201
|
+
min_nodes=1,
|
|
202
|
+
max_nodes=1,
|
|
203
|
+
nproc_per_node=args.num_ranks,
|
|
204
|
+
run_id=str(uuid.uuid4()),
|
|
205
|
+
rdzv_backend="c10d",
|
|
206
|
+
rdzv_endpoint="localhost:0",
|
|
207
|
+
max_restarts=0,
|
|
208
|
+
monitor_interval=1,
|
|
209
|
+
)
|
|
210
|
+
elastic_launch(lc, entrypoint=run_benchmark)(args, path)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def invoke_main():
|
|
214
|
+
parser = argparse.ArgumentParser()
|
|
215
|
+
parser.add_argument(
|
|
216
|
+
"--output_dir", default="/tmp", help="Directory to save plots and csvs to"
|
|
217
|
+
)
|
|
218
|
+
parser.add_argument(
|
|
219
|
+
"--export_csv",
|
|
220
|
+
action="store_true",
|
|
221
|
+
help="Export results to a CSV file.",
|
|
222
|
+
)
|
|
223
|
+
parser.add_argument("--num_ranks", type=int, default=8)
|
|
224
|
+
parser.add_argument("--num_iters", type=int, default=20)
|
|
225
|
+
parser.add_argument("--num_warmup_iters", type=int, default=10)
|
|
226
|
+
parser.add_argument(
|
|
227
|
+
"--min_size",
|
|
228
|
+
type=int,
|
|
229
|
+
default=10,
|
|
230
|
+
help="minimum size will be set to 2**min_size",
|
|
231
|
+
)
|
|
232
|
+
parser.add_argument(
|
|
233
|
+
"--max_size",
|
|
234
|
+
type=int,
|
|
235
|
+
default=24,
|
|
236
|
+
help="maximum size will be set to 2**max_size",
|
|
237
|
+
)
|
|
238
|
+
parser.add_argument(
|
|
239
|
+
"--size_steps", type=int, default=20, help="number of size steps to run"
|
|
240
|
+
)
|
|
241
|
+
parser.add_argument(
|
|
242
|
+
"--op",
|
|
243
|
+
type=str,
|
|
244
|
+
default="allreduce",
|
|
245
|
+
choices=["allreduce", "reduce_scatter"],
|
|
246
|
+
help="op to benchmark, allreduce or reduce_scatter",
|
|
247
|
+
)
|
|
248
|
+
args = parser.parse_args()
|
|
249
|
+
|
|
250
|
+
with tempfile.TemporaryDirectory() as path:
|
|
251
|
+
main(args, path)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
if __name__ == "__main__":
|
|
255
|
+
invoke_main()
|