fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.
- fbgemm_gpu/__init__.py +186 -0
- fbgemm_gpu/asmjit.so +0 -0
- fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
- fbgemm_gpu/config/__init__.py +9 -0
- fbgemm_gpu/config/feature_list.py +88 -0
- fbgemm_gpu/docs/__init__.py +18 -0
- fbgemm_gpu/docs/common.py +9 -0
- fbgemm_gpu/docs/examples.py +73 -0
- fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
- fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
- fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
- fbgemm_gpu/docs/quantize_ops.py +41 -0
- fbgemm_gpu/docs/sparse_ops.py +616 -0
- fbgemm_gpu/docs/target.genai.json.py +6 -0
- fbgemm_gpu/enums.py +24 -0
- fbgemm_gpu/experimental/example/__init__.py +29 -0
- fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
- fbgemm_gpu/experimental/example/utils.py +20 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
- fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
- fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
- fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
- fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
- fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
- fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
- fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
- fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
- fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
- fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
- fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
- fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
- fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/metrics.py +160 -0
- fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
- fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
- fbgemm_gpu/quantize/__init__.py +43 -0
- fbgemm_gpu/quantize/quantize_ops.py +64 -0
- fbgemm_gpu/quantize_comm.py +315 -0
- fbgemm_gpu/quantize_utils.py +246 -0
- fbgemm_gpu/runtime_monitor.py +237 -0
- fbgemm_gpu/sll/__init__.py +189 -0
- fbgemm_gpu/sll/cpu/__init__.py +80 -0
- fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
- fbgemm_gpu/sll/meta/__init__.py +35 -0
- fbgemm_gpu/sll/meta/meta_sll.py +337 -0
- fbgemm_gpu/sll/triton/__init__.py +127 -0
- fbgemm_gpu/sll/triton/common.py +38 -0
- fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
- fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
- fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
- fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
- fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
- fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
- fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
- fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
- fbgemm_gpu/sparse_ops.py +1455 -0
- fbgemm_gpu/split_embedding_configs.py +452 -0
- fbgemm_gpu/split_embedding_inference_converter.py +175 -0
- fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
- fbgemm_gpu/split_embedding_utils.py +29 -0
- fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
- fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
- fbgemm_gpu/tbe/__init__.py +6 -0
- fbgemm_gpu/tbe/bench/__init__.py +55 -0
- fbgemm_gpu/tbe/bench/bench_config.py +156 -0
- fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
- fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
- fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
- fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
- fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
- fbgemm_gpu/tbe/bench/reporter.py +35 -0
- fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
- fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
- fbgemm_gpu/tbe/bench/utils.py +48 -0
- fbgemm_gpu/tbe/cache/__init__.py +11 -0
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
- fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
- fbgemm_gpu/tbe/ssd/__init__.py +15 -0
- fbgemm_gpu/tbe/ssd/common.py +46 -0
- fbgemm_gpu/tbe/ssd/inference.py +586 -0
- fbgemm_gpu/tbe/ssd/training.py +4908 -0
- fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
- fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
- fbgemm_gpu/tbe/stats/__init__.py +10 -0
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
- fbgemm_gpu/tbe/utils/__init__.py +13 -0
- fbgemm_gpu/tbe/utils/common.py +42 -0
- fbgemm_gpu/tbe/utils/offsets.py +65 -0
- fbgemm_gpu/tbe/utils/quantize.py +251 -0
- fbgemm_gpu/tbe/utils/requests.py +556 -0
- fbgemm_gpu/tbe_input_multiplexer.py +108 -0
- fbgemm_gpu/triton/__init__.py +22 -0
- fbgemm_gpu/triton/common.py +77 -0
- fbgemm_gpu/triton/jagged/__init__.py +8 -0
- fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
- fbgemm_gpu/triton/quantize.py +647 -0
- fbgemm_gpu/triton/quantize_ref.py +286 -0
- fbgemm_gpu/utils/__init__.py +11 -0
- fbgemm_gpu/utils/filestore.py +211 -0
- fbgemm_gpu/utils/loader.py +36 -0
- fbgemm_gpu/utils/torch_library.py +132 -0
- fbgemm_gpu/uvm.py +40 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
- fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
- list_versions/__init__.py +12 -0
- list_versions/cli_run.py +163 -0
fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py
ADDED
|
@@ -0,0 +1,552 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from typing import Any, Optional
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
# pyre-ignore[21]
|
|
13
|
+
# @manual=//deeplearning/fbgemm/fbgemm_gpu:test_utils
|
|
14
|
+
from fbgemm_gpu import open_source
|
|
15
|
+
except Exception:
|
|
16
|
+
open_source: bool = False
|
|
17
|
+
|
|
18
|
+
if open_source:
|
|
19
|
+
import os
|
|
20
|
+
|
|
21
|
+
torch.ops.load_library(
|
|
22
|
+
os.path.join(
|
|
23
|
+
os.path.dirname(os.path.dirname(__file__)),
|
|
24
|
+
"..",
|
|
25
|
+
"fbgemm_gpu_experimental_gen_ai.so",
|
|
26
|
+
)
|
|
27
|
+
)
|
|
28
|
+
else:
|
|
29
|
+
torch.ops.load_library(
|
|
30
|
+
"//deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai:blackwell_attention_ops_gpu"
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
from enum import IntEnum
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class GenKernelType(IntEnum):
|
|
38
|
+
UMMA_I = 0
|
|
39
|
+
UMMA_P = 1
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def get_splitk_heuristic(
|
|
43
|
+
batch: int,
|
|
44
|
+
seqlen_kv: int,
|
|
45
|
+
kv_heads: int = 1,
|
|
46
|
+
tile_n: int = 256,
|
|
47
|
+
sm_count: int | None = None,
|
|
48
|
+
) -> int:
|
|
49
|
+
"""
|
|
50
|
+
Compute optimal split-K size for Shape<64, 256, 128> tile configuration.
|
|
51
|
+
|
|
52
|
+
Targets full GPU utilization by distributing work across all SMs.
|
|
53
|
+
First calculates SMs per batch, then per kv_head, then divides seqlen_kv by that number.
|
|
54
|
+
Ensures split size evenly divides seqlen_kv so all CTAs process same number of tiles.
|
|
55
|
+
Returns 0 (no split) when split would equal seqlen_kv (only 1 split).
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
batch: Batch size
|
|
59
|
+
seqlen_kv: Maximum sequence length for K/V
|
|
60
|
+
kv_heads: Number of KV heads (default 1 for MQA)
|
|
61
|
+
tile_n: TileN dimension (default 256 for Shape<64, 256, 128>)
|
|
62
|
+
sm_count: Number of SMs on the GPU. If None, queries the current device.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
Optimal split size along the K/V sequence dimension, or 0 to disable split-K
|
|
66
|
+
"""
|
|
67
|
+
# Get SM count from current device if not provided
|
|
68
|
+
if sm_count is None:
|
|
69
|
+
sm_count = torch.cuda.get_device_properties(
|
|
70
|
+
torch.cuda.current_device()
|
|
71
|
+
).multi_processor_count
|
|
72
|
+
|
|
73
|
+
# Calculate number of SMs available per batch element
|
|
74
|
+
sms_per_batch = max(1, sm_count // batch)
|
|
75
|
+
# Further divide by kv_heads for multi-head KV
|
|
76
|
+
sms_per_head_batch = max(1, sms_per_batch // kv_heads)
|
|
77
|
+
|
|
78
|
+
# Each (batch, kv_head) element should have sms_per_head_batch splits
|
|
79
|
+
# So split size = seqlen_kv / sms_per_head_batch
|
|
80
|
+
ideal_split = seqlen_kv // sms_per_head_batch
|
|
81
|
+
|
|
82
|
+
# Round up to multiple of tile_n
|
|
83
|
+
split = ((ideal_split + tile_n - 1) // tile_n) * tile_n
|
|
84
|
+
|
|
85
|
+
# Clamp to valid range: [tile_n, seqlen_kv]
|
|
86
|
+
split = max(split, tile_n)
|
|
87
|
+
split = min(split, seqlen_kv)
|
|
88
|
+
|
|
89
|
+
# If split equals seqlen_kv, there's only 1 split - disable split-K
|
|
90
|
+
if split == seqlen_kv:
|
|
91
|
+
split = 0
|
|
92
|
+
|
|
93
|
+
return split
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def maybe_contiguous(x: torch.Tensor) -> torch.Tensor:
|
|
97
|
+
"""
|
|
98
|
+
We only require the head dim to be contiguous
|
|
99
|
+
"""
|
|
100
|
+
return (
|
|
101
|
+
x.contiguous()
|
|
102
|
+
if x is not None and (x.stride(-1) != 1 or x.stride(-2) % 8 != 0)
|
|
103
|
+
else x
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _cutlass_blackwell_fmha_forward(
|
|
108
|
+
q: torch.Tensor,
|
|
109
|
+
k: torch.Tensor,
|
|
110
|
+
v: torch.Tensor,
|
|
111
|
+
cu_seqlens_q: torch.Tensor | None = None,
|
|
112
|
+
cu_seqlens_k: torch.Tensor | None = None,
|
|
113
|
+
max_seq_len_q: int | None = None,
|
|
114
|
+
max_seq_len_k: int | None = None,
|
|
115
|
+
softmax_scale: float | None = None,
|
|
116
|
+
causal: bool = False,
|
|
117
|
+
seqlen_kv: torch.Tensor | None = None,
|
|
118
|
+
page_table: torch.Tensor | None = None,
|
|
119
|
+
seqlen_k: int | None = None,
|
|
120
|
+
window_left: int = -1,
|
|
121
|
+
window_right: int = -1,
|
|
122
|
+
bottom_right: bool = True,
|
|
123
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
124
|
+
q = maybe_contiguous(q)
|
|
125
|
+
k = maybe_contiguous(k)
|
|
126
|
+
v = maybe_contiguous(v)
|
|
127
|
+
return torch.ops.fbgemm.fmha_fwd(
|
|
128
|
+
q,
|
|
129
|
+
k,
|
|
130
|
+
v,
|
|
131
|
+
cu_seqlens_q=cu_seqlens_q,
|
|
132
|
+
cu_seqlens_k=cu_seqlens_k,
|
|
133
|
+
max_seq_len_q=max_seq_len_q,
|
|
134
|
+
max_seq_len_k=max_seq_len_k,
|
|
135
|
+
softmax_scale=softmax_scale,
|
|
136
|
+
causal=causal,
|
|
137
|
+
seqlen_kv=seqlen_kv,
|
|
138
|
+
page_table=page_table,
|
|
139
|
+
seqlen_k=seqlen_k,
|
|
140
|
+
window_size_left=window_left,
|
|
141
|
+
window_size_right=window_right,
|
|
142
|
+
bottom_right=bottom_right,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _cutlass_blackwell_fmha_backward(
|
|
147
|
+
dout: torch.Tensor,
|
|
148
|
+
q: torch.Tensor,
|
|
149
|
+
k: torch.Tensor,
|
|
150
|
+
v: torch.Tensor,
|
|
151
|
+
out: torch.Tensor,
|
|
152
|
+
softmax_lse: torch.Tensor,
|
|
153
|
+
cu_seqlens_q: torch.Tensor | None = None,
|
|
154
|
+
cu_seqlens_k: torch.Tensor | None = None,
|
|
155
|
+
max_seq_len_q: int | None = None,
|
|
156
|
+
max_seq_len_k: int | None = None,
|
|
157
|
+
softmax_scale: float | None = None,
|
|
158
|
+
causal: bool = False,
|
|
159
|
+
window_left: int = -1,
|
|
160
|
+
window_right: int = -1,
|
|
161
|
+
bottom_right: bool = True,
|
|
162
|
+
deterministic: bool = False,
|
|
163
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
164
|
+
deterministic = deterministic or torch.are_deterministic_algorithms_enabled()
|
|
165
|
+
dout = maybe_contiguous(dout)
|
|
166
|
+
q = maybe_contiguous(q)
|
|
167
|
+
k = maybe_contiguous(k)
|
|
168
|
+
v = maybe_contiguous(v)
|
|
169
|
+
out = maybe_contiguous(out)
|
|
170
|
+
return torch.ops.fbgemm.fmha_bwd(
|
|
171
|
+
dout,
|
|
172
|
+
q,
|
|
173
|
+
k,
|
|
174
|
+
v,
|
|
175
|
+
out,
|
|
176
|
+
softmax_lse,
|
|
177
|
+
cu_seqlens_q=cu_seqlens_q,
|
|
178
|
+
cu_seqlens_k=cu_seqlens_k,
|
|
179
|
+
max_seq_len_q=max_seq_len_q,
|
|
180
|
+
max_seq_len_k=max_seq_len_k,
|
|
181
|
+
softmax_scale=softmax_scale,
|
|
182
|
+
causal=causal,
|
|
183
|
+
window_size_left=window_left,
|
|
184
|
+
window_size_right=window_right,
|
|
185
|
+
bottom_right=bottom_right,
|
|
186
|
+
deterministic=deterministic,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def _validate_and_adjust_split_k_size(split_k_size: int) -> int:
|
|
191
|
+
"""
|
|
192
|
+
Validate and adjust split_k_size parameter for optimal performance.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
split_k_size: The requested split size along the K/V sequence dimension.
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
Adjusted split_k_size that is valid for the kernel.
|
|
199
|
+
|
|
200
|
+
Valid values:
|
|
201
|
+
- split_k_size <= 0: Disable split-K (no splitting)
|
|
202
|
+
- split_k_size > 0: Enable split-K with specified split size
|
|
203
|
+
"""
|
|
204
|
+
if not isinstance(split_k_size, int):
|
|
205
|
+
raise TypeError(
|
|
206
|
+
f"split_k_size must be an integer, got {type(split_k_size).__name__}"
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
# If split-K is disabled, return as-is
|
|
210
|
+
if split_k_size <= 0:
|
|
211
|
+
return split_k_size
|
|
212
|
+
|
|
213
|
+
# Constants
|
|
214
|
+
MIN_RECOMMENDED_SPLIT_SIZE = 256
|
|
215
|
+
TILE_SIZE = 128
|
|
216
|
+
|
|
217
|
+
# Adjust if split_k_size is too small
|
|
218
|
+
if split_k_size < MIN_RECOMMENDED_SPLIT_SIZE:
|
|
219
|
+
split_k_size = MIN_RECOMMENDED_SPLIT_SIZE
|
|
220
|
+
|
|
221
|
+
# Check if split_k_size is a power of 2
|
|
222
|
+
is_power_of_2 = (split_k_size & (split_k_size - 1)) == 0
|
|
223
|
+
|
|
224
|
+
# If not a power of 2, round to nearest multiple of tile size (128)
|
|
225
|
+
if not is_power_of_2:
|
|
226
|
+
split_k_size = ((split_k_size + TILE_SIZE - 1) // TILE_SIZE) * TILE_SIZE
|
|
227
|
+
|
|
228
|
+
return split_k_size
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def _validate_decode_inputs(
|
|
232
|
+
q: torch.Tensor,
|
|
233
|
+
k: torch.Tensor,
|
|
234
|
+
v: torch.Tensor,
|
|
235
|
+
seqlen_kv: torch.Tensor | None,
|
|
236
|
+
) -> None:
|
|
237
|
+
assert seqlen_kv is not None, "seqlen_kv must be provided for decode"
|
|
238
|
+
tensors = {"q": q, "k": k, "v": v, "seqlen_kv": seqlen_kv}
|
|
239
|
+
|
|
240
|
+
for name, tensor in tensors.items():
|
|
241
|
+
# assert tensor.is_contiguous(), f"{name} is not contiguous"
|
|
242
|
+
assert tensor.is_cuda, f"{name} must be on GPU"
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def _prepare_decode_inputs(
|
|
246
|
+
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
|
|
247
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, bool, tuple[int, ...]]:
|
|
248
|
+
"""
|
|
249
|
+
Prepare inputs for decode kernel by handling both varlen and batch formats.
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
- Reshaped q, k, v tensors in batch format [B, 1, H, D]
|
|
253
|
+
- batch_size
|
|
254
|
+
- needs_reshape_output flag
|
|
255
|
+
- original_shape of q
|
|
256
|
+
"""
|
|
257
|
+
original_shape = tuple(q.shape)
|
|
258
|
+
needs_reshape_output = False
|
|
259
|
+
batch_size = q.shape[0]
|
|
260
|
+
|
|
261
|
+
if q.dim() == 3:
|
|
262
|
+
# Varlen format: [total_queries, num_heads, head_dim]
|
|
263
|
+
q = q.view(batch_size, 1, q.shape[1], q.shape[2])
|
|
264
|
+
needs_reshape_output = True
|
|
265
|
+
|
|
266
|
+
if q.dim() != 4:
|
|
267
|
+
raise ValueError(
|
|
268
|
+
f"Invalid query shape: {q.shape}. Expected [B, 1, H, D] or [total_queries, H, D]"
|
|
269
|
+
)
|
|
270
|
+
assert q.shape[1] == 1, "Kernel have sq=1"
|
|
271
|
+
|
|
272
|
+
k = k.view(batch_size, -1, k.shape[1], k.shape[2]) if k.dim() == 3 else k
|
|
273
|
+
v = v.view(batch_size, -1, v.shape[1], v.shape[2]) if v.dim() == 3 else v
|
|
274
|
+
|
|
275
|
+
return q, k, v, batch_size, needs_reshape_output, original_shape
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def cutlass_blackwell_fmha_decode_forward(
|
|
279
|
+
q: torch.Tensor,
|
|
280
|
+
k: torch.Tensor,
|
|
281
|
+
v: torch.Tensor,
|
|
282
|
+
seqlen_kv: torch.Tensor | None = None,
|
|
283
|
+
cu_seqlens_q: torch.Tensor | None = None,
|
|
284
|
+
cu_seqlens_k: torch.Tensor | None = None,
|
|
285
|
+
max_seq_len_q: int | None = None,
|
|
286
|
+
max_seq_len_k: int | None = None,
|
|
287
|
+
softmax_scale: float | None = None,
|
|
288
|
+
causal: bool = False,
|
|
289
|
+
window_left: int = -1,
|
|
290
|
+
window_right: int = -1,
|
|
291
|
+
bottom_right: bool = True,
|
|
292
|
+
split_k_size: int = 0,
|
|
293
|
+
use_heuristic: bool = True,
|
|
294
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
295
|
+
"""
|
|
296
|
+
Decode-optimized forward pass using the gen kernel.
|
|
297
|
+
This is a wrapper to use the gen kernel which is optimized
|
|
298
|
+
for decode (query length = 1).
|
|
299
|
+
|
|
300
|
+
This function is called externally by xformers ops.
|
|
301
|
+
|
|
302
|
+
Accepts inputs in two formats:
|
|
303
|
+
- Varlen format: [total_queries, num_heads, head_dim] (3D)
|
|
304
|
+
- Batch format: [batch_size, 1, num_heads, head_dim] (4D)
|
|
305
|
+
|
|
306
|
+
Args:
|
|
307
|
+
q: Query tensor in varlen [B, H, D] or batch [B, 1, H, D] format
|
|
308
|
+
k: Key tensor [B, Sk, H_kv, D]
|
|
309
|
+
v: Value tensor [B, Sk, H_kv, D]
|
|
310
|
+
seqlen_kv: Per-batch sequence lengths [B] (required)
|
|
311
|
+
split_k_size: Size of each split along the K/V sequence dimension.
|
|
312
|
+
- split_k_size <= 0 with use_heuristic=True: auto-compute using heuristic
|
|
313
|
+
- split_k_size <= 0 with use_heuristic=False: disable split-K
|
|
314
|
+
- split_k_size > 0: use the provided split size directly
|
|
315
|
+
Values below 256 are adjusted to 256. Non-power-of-2 values
|
|
316
|
+
are rounded to the nearest multiple of 128.
|
|
317
|
+
use_heuristic: If True and split_k_size <= 0, automatically compute optimal
|
|
318
|
+
split size using the heuristic. Default is True.
|
|
319
|
+
|
|
320
|
+
Returns:
|
|
321
|
+
Kernel output with Q dimension added:
|
|
322
|
+
- out: [B, 1, H, num_splits, D] (num_splits=1 when split-K disabled)
|
|
323
|
+
- lse: [B, num_splits, H, 1]
|
|
324
|
+
"""
|
|
325
|
+
_validate_decode_inputs(q, k, v, seqlen_kv)
|
|
326
|
+
|
|
327
|
+
# Prepare inputs and handle format conversion
|
|
328
|
+
q, k, v, batch_size, _, original_shape = _prepare_decode_inputs(q, k, v)
|
|
329
|
+
|
|
330
|
+
# Determine effective split_k_size
|
|
331
|
+
if split_k_size <= 0 and use_heuristic:
|
|
332
|
+
# Auto-compute using heuristic
|
|
333
|
+
max_seqlen_kv = k.shape[1]
|
|
334
|
+
kv_heads = k.shape[2] # K shape is [B, Sk, H_kv, D]
|
|
335
|
+
split_k_size = get_splitk_heuristic(batch_size, max_seqlen_kv, kv_heads)
|
|
336
|
+
|
|
337
|
+
# Validate and adjust split_k_size
|
|
338
|
+
split_k_size = _validate_and_adjust_split_k_size(split_k_size)
|
|
339
|
+
|
|
340
|
+
# Validate window_right: decode kernel only supports causal attention (window_right <= 0)
|
|
341
|
+
if window_right > 0:
|
|
342
|
+
raise ValueError(
|
|
343
|
+
f"window_right={window_right} is not supported for decode attention. "
|
|
344
|
+
"The decode kernel only supports causal attention with window_right <= 0. "
|
|
345
|
+
"Use window_right=0 (causal, current position only)."
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
# Call the gen kernel (optimized for decode)
|
|
349
|
+
# Note: window_left specifies how many tokens to look back (exclusive)
|
|
350
|
+
# The kernel will attend to positions [seqlen_kv - window_left, seqlen_kv)
|
|
351
|
+
out, lse = torch.ops.fbgemm.fmha_gen_fwd(
|
|
352
|
+
q,
|
|
353
|
+
k,
|
|
354
|
+
v,
|
|
355
|
+
seqlen_kv,
|
|
356
|
+
None,
|
|
357
|
+
kernel_type=GenKernelType.UMMA_I,
|
|
358
|
+
window_left=window_left,
|
|
359
|
+
window_right=0,
|
|
360
|
+
split_k_size=split_k_size,
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
# Kernel returns: out [B, H, num_splits, D], lse [B, num_splits, H]
|
|
364
|
+
# Reshape to consistent format with Q dimension:
|
|
365
|
+
# out: [B, H, num_splits, D] -> [B, 1, H, num_splits, D]
|
|
366
|
+
# lse: [B, num_splits, H] -> [B, num_splits, H, 1]
|
|
367
|
+
out = out.unsqueeze(1) # [B, 1, H, num_splits, D]
|
|
368
|
+
lse = lse.unsqueeze(-1) # [B, num_splits, H, 1]
|
|
369
|
+
return out, lse
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
class CutlassBlackwellFmhaFunc(torch.autograd.Function):
|
|
373
|
+
@staticmethod
|
|
374
|
+
def forward( # type: ignore
|
|
375
|
+
ctx,
|
|
376
|
+
q: torch.Tensor,
|
|
377
|
+
k: torch.Tensor,
|
|
378
|
+
v: torch.Tensor,
|
|
379
|
+
softmax_scale: float | None = None,
|
|
380
|
+
causal: bool = False,
|
|
381
|
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
|
382
|
+
cu_seqlens_k: Optional[torch.Tensor] = None,
|
|
383
|
+
max_seq_len_q: Optional[int] = None,
|
|
384
|
+
max_seq_len_k: Optional[int] = None,
|
|
385
|
+
seqlen_kv: Optional[torch.Tensor] = None,
|
|
386
|
+
page_table: Optional[torch.Tensor] = None,
|
|
387
|
+
seqlen_k: Optional[int] = None,
|
|
388
|
+
window_size: tuple[int, int] = (-1, -1),
|
|
389
|
+
bottom_right: bool = True,
|
|
390
|
+
deterministic: bool = False,
|
|
391
|
+
) -> torch.Tensor:
|
|
392
|
+
window_left, window_right = window_size
|
|
393
|
+
# Check if this is generation phase (sq = 1)
|
|
394
|
+
sq = q.shape[1]
|
|
395
|
+
if q.dim() == 4 and sq == 1:
|
|
396
|
+
# For gen case, we don't need to save tensors for backward
|
|
397
|
+
ctx.is_gen = True
|
|
398
|
+
out, _ = cutlass_blackwell_fmha_decode_forward(
|
|
399
|
+
q,
|
|
400
|
+
k,
|
|
401
|
+
v,
|
|
402
|
+
seqlen_kv,
|
|
403
|
+
cu_seqlens_q,
|
|
404
|
+
cu_seqlens_k,
|
|
405
|
+
max_seq_len_q,
|
|
406
|
+
max_seq_len_k,
|
|
407
|
+
softmax_scale,
|
|
408
|
+
causal,
|
|
409
|
+
window_left,
|
|
410
|
+
window_right,
|
|
411
|
+
bottom_right,
|
|
412
|
+
)
|
|
413
|
+
return out
|
|
414
|
+
|
|
415
|
+
ctx.is_gen = False
|
|
416
|
+
# Only check dtype if cu_seqlens_q and cu_seqlens_k are provided
|
|
417
|
+
if cu_seqlens_q is not None and cu_seqlens_k is not None:
|
|
418
|
+
assert (
|
|
419
|
+
cu_seqlens_q.dtype == torch.int32
|
|
420
|
+
and cu_seqlens_q.dtype == cu_seqlens_k.dtype
|
|
421
|
+
), "cu_seqlens_q and cu_seqlens_k must be int32"
|
|
422
|
+
|
|
423
|
+
# handle window_size
|
|
424
|
+
if causal and window_left >= 0:
|
|
425
|
+
window_right = 0
|
|
426
|
+
# Use regular FMHA for non-generation case
|
|
427
|
+
out, softmax_lse = _cutlass_blackwell_fmha_forward(
|
|
428
|
+
q,
|
|
429
|
+
k,
|
|
430
|
+
v,
|
|
431
|
+
cu_seqlens_q,
|
|
432
|
+
cu_seqlens_k,
|
|
433
|
+
max_seq_len_q,
|
|
434
|
+
max_seq_len_k,
|
|
435
|
+
softmax_scale,
|
|
436
|
+
causal,
|
|
437
|
+
seqlen_kv,
|
|
438
|
+
page_table,
|
|
439
|
+
seqlen_k,
|
|
440
|
+
window_left,
|
|
441
|
+
window_right,
|
|
442
|
+
bottom_right,
|
|
443
|
+
)
|
|
444
|
+
ctx.save_for_backward(q, k, v, out, softmax_lse)
|
|
445
|
+
ctx.softmax_scale = softmax_scale
|
|
446
|
+
ctx.causal = causal
|
|
447
|
+
ctx.window_size = window_size
|
|
448
|
+
ctx.max_seq_len_q = max_seq_len_q
|
|
449
|
+
ctx.max_seq_len_k = max_seq_len_k
|
|
450
|
+
ctx.cu_seqlens_q = cu_seqlens_q
|
|
451
|
+
ctx.cu_seqlens_k = cu_seqlens_k
|
|
452
|
+
ctx.bottom_right = bottom_right
|
|
453
|
+
ctx.deterministic = deterministic
|
|
454
|
+
return out
|
|
455
|
+
|
|
456
|
+
@staticmethod
|
|
457
|
+
def backward(ctx, dout: torch.Tensor, *args: Any) -> tuple[ # type: ignore
|
|
458
|
+
torch.Tensor,
|
|
459
|
+
torch.Tensor,
|
|
460
|
+
torch.Tensor,
|
|
461
|
+
None,
|
|
462
|
+
None,
|
|
463
|
+
None,
|
|
464
|
+
None,
|
|
465
|
+
None,
|
|
466
|
+
None,
|
|
467
|
+
None,
|
|
468
|
+
None,
|
|
469
|
+
None,
|
|
470
|
+
None,
|
|
471
|
+
None,
|
|
472
|
+
None,
|
|
473
|
+
]:
|
|
474
|
+
if ctx.is_gen:
|
|
475
|
+
# For gen case, no backward pass is needed (generation is inference only)
|
|
476
|
+
raise RuntimeError(
|
|
477
|
+
"Backward pass is not supported for generation phase (sq=1)"
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
q, k, v, out, softmax_lse = ctx.saved_tensors
|
|
481
|
+
window_left, window_right = ctx.window_size
|
|
482
|
+
dq, dk, dv = _cutlass_blackwell_fmha_backward(
|
|
483
|
+
dout,
|
|
484
|
+
q,
|
|
485
|
+
k,
|
|
486
|
+
v,
|
|
487
|
+
out,
|
|
488
|
+
softmax_lse,
|
|
489
|
+
ctx.cu_seqlens_q,
|
|
490
|
+
ctx.cu_seqlens_k,
|
|
491
|
+
ctx.max_seq_len_q,
|
|
492
|
+
ctx.max_seq_len_k,
|
|
493
|
+
ctx.softmax_scale,
|
|
494
|
+
ctx.causal,
|
|
495
|
+
window_left,
|
|
496
|
+
window_right,
|
|
497
|
+
bottom_right=ctx.bottom_right,
|
|
498
|
+
deterministic=ctx.deterministic,
|
|
499
|
+
)
|
|
500
|
+
return (
|
|
501
|
+
dq,
|
|
502
|
+
dk,
|
|
503
|
+
dv,
|
|
504
|
+
None,
|
|
505
|
+
None,
|
|
506
|
+
None,
|
|
507
|
+
None,
|
|
508
|
+
None,
|
|
509
|
+
None,
|
|
510
|
+
None,
|
|
511
|
+
None,
|
|
512
|
+
None,
|
|
513
|
+
None,
|
|
514
|
+
None,
|
|
515
|
+
None,
|
|
516
|
+
)
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
def cutlass_blackwell_fmha_func(
|
|
520
|
+
q: torch.Tensor,
|
|
521
|
+
k: torch.Tensor,
|
|
522
|
+
v: torch.Tensor,
|
|
523
|
+
softmax_scale: float | None = None,
|
|
524
|
+
causal: bool = False,
|
|
525
|
+
cu_seqlens_q: torch.Tensor | None = None,
|
|
526
|
+
cu_seqlens_k: torch.Tensor | None = None,
|
|
527
|
+
max_seq_len_q: int | None = None,
|
|
528
|
+
max_seq_len_k: int | None = None,
|
|
529
|
+
seqlen_kv: torch.Tensor | None = None,
|
|
530
|
+
page_table: torch.Tensor | None = None,
|
|
531
|
+
seqlen_k: int | None = None,
|
|
532
|
+
window_size: tuple[int, int] | None = (-1, -1),
|
|
533
|
+
bottom_right: bool = True,
|
|
534
|
+
deterministic: bool = False,
|
|
535
|
+
):
|
|
536
|
+
return CutlassBlackwellFmhaFunc.apply(
|
|
537
|
+
q,
|
|
538
|
+
k,
|
|
539
|
+
v,
|
|
540
|
+
softmax_scale,
|
|
541
|
+
causal,
|
|
542
|
+
cu_seqlens_q,
|
|
543
|
+
cu_seqlens_k,
|
|
544
|
+
max_seq_len_q,
|
|
545
|
+
max_seq_len_k,
|
|
546
|
+
seqlen_kv,
|
|
547
|
+
page_table,
|
|
548
|
+
seqlen_k,
|
|
549
|
+
window_size,
|
|
550
|
+
bottom_right,
|
|
551
|
+
deterministic,
|
|
552
|
+
)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
# pyre-ignore[21]
|
|
10
|
+
# @manual=//deeplearning/fbgemm/fbgemm_gpu:test_utils
|
|
11
|
+
from fbgemm_gpu import open_source
|
|
12
|
+
except Exception:
|
|
13
|
+
open_source: bool = False
|