mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
mslk/__init__.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import os
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
open_source: bool = True
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _load_library(filename: str, version: str, no_throw: bool = False) -> None:
|
|
16
|
+
"""Load a shared library from the given filename."""
|
|
17
|
+
try:
|
|
18
|
+
library_path = os.path.join(os.path.dirname(__file__), filename)
|
|
19
|
+
torch.ops.load_library(library_path)
|
|
20
|
+
torch.classes.load_library(library_path)
|
|
21
|
+
logging.info(f"Successfully loaded: '{filename}'")
|
|
22
|
+
|
|
23
|
+
except Exception as error:
|
|
24
|
+
logging.error(f"Could not load the library '{filename}'!\n\n\n{error}\n\n\n")
|
|
25
|
+
if not no_throw:
|
|
26
|
+
raise error
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
# Export the version string from the version file auto-generated by setup.py
|
|
31
|
+
from .version import __target__, __variant__, __version__ # noqa: F401, E402
|
|
32
|
+
except Exception:
|
|
33
|
+
__variant__: str = "INTERNAL"
|
|
34
|
+
__version__: str = "INTERNAL"
|
|
35
|
+
__target__: str = "default"
|
|
36
|
+
|
|
37
|
+
_default_libraries = [
|
|
38
|
+
"mslk",
|
|
39
|
+
]
|
|
40
|
+
|
|
41
|
+
libraries_to_load = {
|
|
42
|
+
"default": _default_libraries,
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
for library in libraries_to_load.get(__target__, []):
|
|
46
|
+
# NOTE: In all cases, we want to throw an error if we cannot load the
|
|
47
|
+
# library. However, this appears to break the OSS documentation build,
|
|
48
|
+
# where the Python documentation doesn't show up in the generated docs.
|
|
49
|
+
#
|
|
50
|
+
# To work around this problem, we introduce a fake build variant called
|
|
51
|
+
# `docs` and we only throw a library load error when the variant is not
|
|
52
|
+
# `docs`. For more information, see:
|
|
53
|
+
#
|
|
54
|
+
# https://github.com/pytorch/FBGEMM/pull/3477
|
|
55
|
+
# https://github.com/pytorch/FBGEMM/pull/3717
|
|
56
|
+
_load_library(f"{library}.so", __version__, __variant__ == "docs")
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-unsafe
|
|
8
|
+
|
|
9
|
+
from mslk.utils.torch.library import load_library_buck
|
|
10
|
+
|
|
11
|
+
from . import cutlass_blackwell_fmha_custom_op # noqa: F401
|
|
12
|
+
from .cutlass_blackwell_fmha_interface import ( # noqa: F401
|
|
13
|
+
_cutlass_blackwell_fmha_forward,
|
|
14
|
+
cutlass_blackwell_fmha_decode_forward,
|
|
15
|
+
cutlass_blackwell_fmha_func,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
load_library_buck(
|
|
19
|
+
"//mslk/csrc/attention/cuda/cutlass_blackwell_fmha:blackwell_attention_ops_gpu"
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
# Note: _cutlass_blackwell_fmha_forward is an internal function (indicated by leading underscore)
|
|
23
|
+
# that is exported here specifically for testing purposes. It allows tests to access the LSE
|
|
24
|
+
# (log-sum-exp) values returned by the forward pass without modifying the public API.
|
|
25
|
+
# Production code should use cutlass_blackwell_fmha_func instead.
|
|
26
|
+
__all__ = [
|
|
27
|
+
"_cutlass_blackwell_fmha_forward",
|
|
28
|
+
"cutlass_blackwell_fmha_decode_forward",
|
|
29
|
+
"cutlass_blackwell_fmha_func",
|
|
30
|
+
]
|
|
@@ -0,0 +1,332 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from typing import Optional
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from torch.library import register_fake
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
torch.library.define(
|
|
14
|
+
"mslk::cutlass_blackwell_fmha_fwd",
|
|
15
|
+
"(Tensor q, Tensor k, Tensor v, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seq_len_q, int? max_seq_len_k, float? softmax_scale, bool? causal, Tensor? seqlen_kv, Tensor? page_table, int seqlen_k=-1, int window_size_left=-1, int window_size_right=-1, bool bottom_right=True) -> (Tensor, Tensor)",
|
|
16
|
+
tags=torch.Tag.pt2_compliant_tag,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
torch.library.define(
|
|
20
|
+
"mslk::cutlass_blackwell_fmha_bwd",
|
|
21
|
+
"(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seq_len_q, int? max_seq_len_k, float? softmax_scale, bool? causal, int window_size_left=-1, int window_size_right=-1, bool bottom_right=True, bool deterministic=False) -> (Tensor, Tensor, Tensor)",
|
|
22
|
+
tags=torch.Tag.pt2_compliant_tag,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@torch.library.impl("mslk::cutlass_blackwell_fmha_fwd", "cuda")
|
|
27
|
+
def custom_op_fmha(
|
|
28
|
+
q: torch.Tensor,
|
|
29
|
+
k: torch.Tensor,
|
|
30
|
+
v: torch.Tensor,
|
|
31
|
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
|
32
|
+
cu_seqlens_k: Optional[torch.Tensor] = None,
|
|
33
|
+
max_seq_len_q: Optional[int] = None,
|
|
34
|
+
max_seq_len_k: Optional[int] = None,
|
|
35
|
+
softmax_scale: Optional[float] = None,
|
|
36
|
+
causal: bool = False,
|
|
37
|
+
seqlen_kv: Optional[torch.Tensor] = None,
|
|
38
|
+
page_table: Optional[torch.Tensor] = None,
|
|
39
|
+
seqlen_k: Optional[int] = None,
|
|
40
|
+
window_size_left: int = -1,
|
|
41
|
+
window_size_right: int = -1,
|
|
42
|
+
bottom_right: bool = True,
|
|
43
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
44
|
+
assert q.is_contiguous(), "q is not contiguous"
|
|
45
|
+
assert k.is_contiguous(), "k is not contiguous"
|
|
46
|
+
assert v.is_contiguous(), "v is not contiguous"
|
|
47
|
+
assert q.is_cuda, "q must be on GPU"
|
|
48
|
+
assert k.is_cuda, "k must be on GPU"
|
|
49
|
+
assert v.is_cuda, "v must be on GPU"
|
|
50
|
+
|
|
51
|
+
return torch.ops.mslk.fmha_fwd(
|
|
52
|
+
q,
|
|
53
|
+
k,
|
|
54
|
+
v,
|
|
55
|
+
cu_seqlens_q=cu_seqlens_q,
|
|
56
|
+
cu_seqlens_k=cu_seqlens_k,
|
|
57
|
+
max_seq_len_q=max_seq_len_q,
|
|
58
|
+
max_seq_len_k=max_seq_len_k,
|
|
59
|
+
softmax_scale=softmax_scale,
|
|
60
|
+
causal=causal,
|
|
61
|
+
seqlen_kv=seqlen_kv,
|
|
62
|
+
page_table=page_table,
|
|
63
|
+
seqlen_k=seqlen_k,
|
|
64
|
+
window_size_left=window_size_left,
|
|
65
|
+
window_size_right=window_size_right,
|
|
66
|
+
bottom_right=bottom_right,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@register_fake("mslk::cutlass_blackwell_fmha_fwd")
|
|
71
|
+
def fmha_fwd_meta(
|
|
72
|
+
q: torch.Tensor,
|
|
73
|
+
k: torch.Tensor,
|
|
74
|
+
v: torch.Tensor,
|
|
75
|
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
|
76
|
+
cu_seqlens_k: Optional[torch.Tensor] = None,
|
|
77
|
+
max_seq_len_q: Optional[int] = None,
|
|
78
|
+
max_seq_len_k: Optional[int] = None,
|
|
79
|
+
softmax_scale: Optional[float] = None,
|
|
80
|
+
causal: bool = False,
|
|
81
|
+
seqlen_kv: Optional[torch.Tensor] = None,
|
|
82
|
+
page_table: Optional[torch.Tensor] = None,
|
|
83
|
+
seqlen_k: Optional[int] = None,
|
|
84
|
+
window_size_left: int = -1,
|
|
85
|
+
window_size_right: int = -1,
|
|
86
|
+
bottom_right: bool = True,
|
|
87
|
+
):
|
|
88
|
+
if q.dtype == torch.float16:
|
|
89
|
+
out_dtype = torch.float16
|
|
90
|
+
elif q.dtype == torch.bfloat16:
|
|
91
|
+
out_dtype = torch.bfloat16
|
|
92
|
+
elif q.dtype == torch.float8_e4m3fn:
|
|
93
|
+
# Output is BF16 when input is FP8
|
|
94
|
+
out_dtype = torch.bfloat16
|
|
95
|
+
else:
|
|
96
|
+
raise RuntimeError(f"Unsupported dtype for q: {q.dtype}")
|
|
97
|
+
|
|
98
|
+
kIsVarlen = max_seq_len_q is not None
|
|
99
|
+
if kIsVarlen:
|
|
100
|
+
assert cu_seqlens_q is not None
|
|
101
|
+
SQ = q.shape[0]
|
|
102
|
+
H_Q = q.shape[1]
|
|
103
|
+
B = cu_seqlens_q.shape[0] - 1
|
|
104
|
+
else:
|
|
105
|
+
SQ = q.shape[1]
|
|
106
|
+
H_Q = q.shape[2]
|
|
107
|
+
B = q.shape[0]
|
|
108
|
+
device = q.device
|
|
109
|
+
options2 = {"dtype": torch.float32, "device": device}
|
|
110
|
+
if kIsVarlen:
|
|
111
|
+
assert max_seq_len_q is not None
|
|
112
|
+
out = torch.empty_like(q, dtype=out_dtype)
|
|
113
|
+
size = out.size()
|
|
114
|
+
stride = out.stride()
|
|
115
|
+
storage_offset = q.shape[-1] * max_seq_len_q * H_Q # example scalar offset
|
|
116
|
+
out1 = torch.as_strided(
|
|
117
|
+
out, size=size, stride=stride, storage_offset=storage_offset
|
|
118
|
+
)
|
|
119
|
+
else:
|
|
120
|
+
out1 = torch.empty_like(q, dtype=out_dtype)
|
|
121
|
+
|
|
122
|
+
if kIsVarlen:
|
|
123
|
+
out2 = torch.empty((1, H_Q, SQ), **options2) # type: ignore
|
|
124
|
+
else:
|
|
125
|
+
out2 = torch.empty((B, H_Q, SQ), **options2) # type: ignore
|
|
126
|
+
return out1, out2
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
@torch.library.impl("mslk::cutlass_blackwell_fmha_bwd", "cuda")
|
|
130
|
+
def custom_op_fmha_bwd(
|
|
131
|
+
dOutput: torch.Tensor,
|
|
132
|
+
query: torch.Tensor,
|
|
133
|
+
key: torch.Tensor,
|
|
134
|
+
value: torch.Tensor,
|
|
135
|
+
output: torch.Tensor,
|
|
136
|
+
softmax_lse: torch.Tensor,
|
|
137
|
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
|
138
|
+
cu_seqlens_k: Optional[torch.Tensor] = None,
|
|
139
|
+
max_seq_len_q: Optional[int] = None,
|
|
140
|
+
max_seq_len_k: Optional[int] = None,
|
|
141
|
+
softmax_scale: Optional[float] = None,
|
|
142
|
+
causal: bool = False,
|
|
143
|
+
window_size_left: int = -1,
|
|
144
|
+
window_size_right: int = -1,
|
|
145
|
+
bottom_right: bool = True,
|
|
146
|
+
deterministic: bool = False,
|
|
147
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
148
|
+
return torch.ops.mslk.fmha_bwd(
|
|
149
|
+
dOutput,
|
|
150
|
+
query,
|
|
151
|
+
key,
|
|
152
|
+
value,
|
|
153
|
+
output,
|
|
154
|
+
softmax_lse,
|
|
155
|
+
cu_seqlens_q=cu_seqlens_q,
|
|
156
|
+
cu_seqlens_k=cu_seqlens_k,
|
|
157
|
+
max_seq_len_q=max_seq_len_q,
|
|
158
|
+
max_seq_len_k=max_seq_len_k,
|
|
159
|
+
softmax_scale=softmax_scale,
|
|
160
|
+
causal=causal,
|
|
161
|
+
window_size_left=window_size_left,
|
|
162
|
+
window_size_right=window_size_right,
|
|
163
|
+
bottom_right=bottom_right,
|
|
164
|
+
deterministic=deterministic,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
@register_fake("mslk::cutlass_blackwell_fmha_bwd")
|
|
169
|
+
def fmha_bwd_meta(
|
|
170
|
+
dOutput: torch.Tensor,
|
|
171
|
+
query: torch.Tensor,
|
|
172
|
+
key: torch.Tensor,
|
|
173
|
+
value: torch.Tensor,
|
|
174
|
+
output: torch.Tensor,
|
|
175
|
+
softmax_lse: torch.Tensor,
|
|
176
|
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
|
177
|
+
cu_seqlens_k: Optional[torch.Tensor] = None,
|
|
178
|
+
max_seq_len_q: Optional[int] = None,
|
|
179
|
+
max_seq_len_k: Optional[int] = None,
|
|
180
|
+
softmax_scale: Optional[float] = None,
|
|
181
|
+
causal: bool = False,
|
|
182
|
+
window_size_left: int = -1,
|
|
183
|
+
window_size_right: int = -1,
|
|
184
|
+
bottom_right: bool = True,
|
|
185
|
+
deterministic: bool = False,
|
|
186
|
+
):
|
|
187
|
+
return (
|
|
188
|
+
torch.empty_like(query),
|
|
189
|
+
torch.empty_like(key),
|
|
190
|
+
torch.empty_like(value),
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def _backward(ctx, *grad):
|
|
195
|
+
if ctx.is_gen:
|
|
196
|
+
# For gen case, no backward pass is needed (generation is inference only)
|
|
197
|
+
raise RuntimeError("Backward pass is not supported for generation phase (sq=1)")
|
|
198
|
+
q, k, v, out, softmax_lse = ctx.saved_tensors
|
|
199
|
+
if not grad[0].is_contiguous():
|
|
200
|
+
grad0 = grad[0].contiguous()
|
|
201
|
+
else:
|
|
202
|
+
grad0 = grad[0]
|
|
203
|
+
if not softmax_lse.is_contiguous:
|
|
204
|
+
softmax_lse = softmax_lse.contiguous()
|
|
205
|
+
if not out.is_contiguous:
|
|
206
|
+
out = out.contiguous()
|
|
207
|
+
if not q.is_contiguous:
|
|
208
|
+
q = q.contiguous()
|
|
209
|
+
if not k.is_contiguous:
|
|
210
|
+
k = k.contiguous()
|
|
211
|
+
|
|
212
|
+
if not softmax_lse.is_contiguous:
|
|
213
|
+
softmax_lse = softmax_lse.contiguous()
|
|
214
|
+
if not out.is_contiguous:
|
|
215
|
+
out = out.contiguous()
|
|
216
|
+
if not q.is_contiguous:
|
|
217
|
+
q = q.contiguous()
|
|
218
|
+
if not k.is_contiguous:
|
|
219
|
+
k = k.contiguous()
|
|
220
|
+
|
|
221
|
+
dq, dk, dv = torch.ops.mslk.cutlass_blackwell_fmha_bwd(
|
|
222
|
+
grad0,
|
|
223
|
+
q,
|
|
224
|
+
k,
|
|
225
|
+
v,
|
|
226
|
+
out,
|
|
227
|
+
softmax_lse,
|
|
228
|
+
ctx.cu_seqlens_q,
|
|
229
|
+
ctx.cu_seqlens_k,
|
|
230
|
+
ctx.max_seq_len_q,
|
|
231
|
+
ctx.max_seq_len_k,
|
|
232
|
+
ctx.softmax_scale,
|
|
233
|
+
ctx.causal,
|
|
234
|
+
ctx.window_size_left,
|
|
235
|
+
ctx.window_size_right,
|
|
236
|
+
ctx.bottom_right,
|
|
237
|
+
ctx.deterministic,
|
|
238
|
+
)
|
|
239
|
+
return (
|
|
240
|
+
dq,
|
|
241
|
+
dk,
|
|
242
|
+
dv,
|
|
243
|
+
None,
|
|
244
|
+
None,
|
|
245
|
+
None,
|
|
246
|
+
None,
|
|
247
|
+
None,
|
|
248
|
+
None,
|
|
249
|
+
None,
|
|
250
|
+
None,
|
|
251
|
+
None,
|
|
252
|
+
None,
|
|
253
|
+
None,
|
|
254
|
+
None,
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def _setup_context(ctx, inputs, output):
|
|
259
|
+
(
|
|
260
|
+
q,
|
|
261
|
+
k,
|
|
262
|
+
v,
|
|
263
|
+
cu_seqlens_q,
|
|
264
|
+
cu_seqlens_k,
|
|
265
|
+
max_seq_len_q,
|
|
266
|
+
max_seq_len_k,
|
|
267
|
+
softmax_scale,
|
|
268
|
+
causal,
|
|
269
|
+
seqlen_kv,
|
|
270
|
+
page_table,
|
|
271
|
+
seqlen_k,
|
|
272
|
+
window_size_left,
|
|
273
|
+
window_size_right,
|
|
274
|
+
bottom_right,
|
|
275
|
+
) = inputs
|
|
276
|
+
(out, softmax_lse) = output
|
|
277
|
+
ctx.save_for_backward(q, k, v, out, softmax_lse)
|
|
278
|
+
ctx.softmax_scale = softmax_scale
|
|
279
|
+
ctx.causal = causal
|
|
280
|
+
ctx.max_seq_len_q = max_seq_len_q
|
|
281
|
+
ctx.max_seq_len_k = max_seq_len_k
|
|
282
|
+
ctx.cu_seqlens_q = cu_seqlens_q
|
|
283
|
+
ctx.cu_seqlens_k = cu_seqlens_k
|
|
284
|
+
ctx.window_size_left = window_size_left
|
|
285
|
+
ctx.window_size_right = window_size_right
|
|
286
|
+
ctx.bottom_right = bottom_right
|
|
287
|
+
ctx.deterministic = False # Set default value
|
|
288
|
+
ctx.is_gen = False
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
# This code adds training support for the operator. You must provide us
|
|
292
|
+
# the backward formula for the operator and a `setup_context` function
|
|
293
|
+
# to save values to be used in the backward.
|
|
294
|
+
torch.library.register_autograd(
|
|
295
|
+
"mslk::cutlass_blackwell_fmha_fwd", _backward, setup_context=_setup_context
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def cutlass_blackwell_fmha_custom_op(
|
|
300
|
+
q: torch.Tensor,
|
|
301
|
+
k: torch.Tensor,
|
|
302
|
+
v: torch.Tensor,
|
|
303
|
+
softmax_scale: float | None = None,
|
|
304
|
+
causal: bool = False,
|
|
305
|
+
cu_seqlens_q: torch.Tensor | None = None,
|
|
306
|
+
cu_seqlens_k: torch.Tensor | None = None,
|
|
307
|
+
max_seq_len_q: int | None = None,
|
|
308
|
+
max_seq_len_k: int | None = None,
|
|
309
|
+
seqlen_kv: torch.Tensor | None = None,
|
|
310
|
+
page_table: torch.Tensor | None = None,
|
|
311
|
+
seqlen_k: int | None = -1,
|
|
312
|
+
window_size_left: int | None = -1,
|
|
313
|
+
window_size_right: int | None = -1,
|
|
314
|
+
bottom_right: bool | None = True,
|
|
315
|
+
):
|
|
316
|
+
return torch.ops.mslk.cutlass_blackwell_fmha_fwd(
|
|
317
|
+
q=q,
|
|
318
|
+
k=k,
|
|
319
|
+
v=v,
|
|
320
|
+
cu_seqlens_q=cu_seqlens_q,
|
|
321
|
+
cu_seqlens_k=cu_seqlens_k,
|
|
322
|
+
max_seq_len_q=max_seq_len_q,
|
|
323
|
+
max_seq_len_k=max_seq_len_k,
|
|
324
|
+
softmax_scale=softmax_scale,
|
|
325
|
+
causal=causal,
|
|
326
|
+
seqlen_kv=seqlen_kv,
|
|
327
|
+
page_table=page_table,
|
|
328
|
+
seqlen_k=seqlen_k,
|
|
329
|
+
window_size_left=window_size_left,
|
|
330
|
+
window_size_right=window_size_right,
|
|
331
|
+
bottom_right=bottom_right,
|
|
332
|
+
)[0]
|