fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,552 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Any, Optional
8
+
9
+ import torch
10
+
11
+ try:
12
+ # pyre-ignore[21]
13
+ # @manual=//deeplearning/fbgemm/fbgemm_gpu:test_utils
14
+ from fbgemm_gpu import open_source
15
+ except Exception:
16
+ open_source: bool = False
17
+
18
+ if open_source:
19
+ import os
20
+
21
+ torch.ops.load_library(
22
+ os.path.join(
23
+ os.path.dirname(os.path.dirname(__file__)),
24
+ "..",
25
+ "fbgemm_gpu_experimental_gen_ai.so",
26
+ )
27
+ )
28
+ else:
29
+ torch.ops.load_library(
30
+ "//deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai:blackwell_attention_ops_gpu"
31
+ )
32
+
33
+
34
+ from enum import IntEnum
35
+
36
+
37
+ class GenKernelType(IntEnum):
38
+ UMMA_I = 0
39
+ UMMA_P = 1
40
+
41
+
42
+ def get_splitk_heuristic(
43
+ batch: int,
44
+ seqlen_kv: int,
45
+ kv_heads: int = 1,
46
+ tile_n: int = 256,
47
+ sm_count: int | None = None,
48
+ ) -> int:
49
+ """
50
+ Compute optimal split-K size for Shape<64, 256, 128> tile configuration.
51
+
52
+ Targets full GPU utilization by distributing work across all SMs.
53
+ First calculates SMs per batch, then per kv_head, then divides seqlen_kv by that number.
54
+ Ensures split size evenly divides seqlen_kv so all CTAs process same number of tiles.
55
+ Returns 0 (no split) when split would equal seqlen_kv (only 1 split).
56
+
57
+ Args:
58
+ batch: Batch size
59
+ seqlen_kv: Maximum sequence length for K/V
60
+ kv_heads: Number of KV heads (default 1 for MQA)
61
+ tile_n: TileN dimension (default 256 for Shape<64, 256, 128>)
62
+ sm_count: Number of SMs on the GPU. If None, queries the current device.
63
+
64
+ Returns:
65
+ Optimal split size along the K/V sequence dimension, or 0 to disable split-K
66
+ """
67
+ # Get SM count from current device if not provided
68
+ if sm_count is None:
69
+ sm_count = torch.cuda.get_device_properties(
70
+ torch.cuda.current_device()
71
+ ).multi_processor_count
72
+
73
+ # Calculate number of SMs available per batch element
74
+ sms_per_batch = max(1, sm_count // batch)
75
+ # Further divide by kv_heads for multi-head KV
76
+ sms_per_head_batch = max(1, sms_per_batch // kv_heads)
77
+
78
+ # Each (batch, kv_head) element should have sms_per_head_batch splits
79
+ # So split size = seqlen_kv / sms_per_head_batch
80
+ ideal_split = seqlen_kv // sms_per_head_batch
81
+
82
+ # Round up to multiple of tile_n
83
+ split = ((ideal_split + tile_n - 1) // tile_n) * tile_n
84
+
85
+ # Clamp to valid range: [tile_n, seqlen_kv]
86
+ split = max(split, tile_n)
87
+ split = min(split, seqlen_kv)
88
+
89
+ # If split equals seqlen_kv, there's only 1 split - disable split-K
90
+ if split == seqlen_kv:
91
+ split = 0
92
+
93
+ return split
94
+
95
+
96
+ def maybe_contiguous(x: torch.Tensor) -> torch.Tensor:
97
+ """
98
+ We only require the head dim to be contiguous
99
+ """
100
+ return (
101
+ x.contiguous()
102
+ if x is not None and (x.stride(-1) != 1 or x.stride(-2) % 8 != 0)
103
+ else x
104
+ )
105
+
106
+
107
+ def _cutlass_blackwell_fmha_forward(
108
+ q: torch.Tensor,
109
+ k: torch.Tensor,
110
+ v: torch.Tensor,
111
+ cu_seqlens_q: torch.Tensor | None = None,
112
+ cu_seqlens_k: torch.Tensor | None = None,
113
+ max_seq_len_q: int | None = None,
114
+ max_seq_len_k: int | None = None,
115
+ softmax_scale: float | None = None,
116
+ causal: bool = False,
117
+ seqlen_kv: torch.Tensor | None = None,
118
+ page_table: torch.Tensor | None = None,
119
+ seqlen_k: int | None = None,
120
+ window_left: int = -1,
121
+ window_right: int = -1,
122
+ bottom_right: bool = True,
123
+ ) -> tuple[torch.Tensor, torch.Tensor]:
124
+ q = maybe_contiguous(q)
125
+ k = maybe_contiguous(k)
126
+ v = maybe_contiguous(v)
127
+ return torch.ops.fbgemm.fmha_fwd(
128
+ q,
129
+ k,
130
+ v,
131
+ cu_seqlens_q=cu_seqlens_q,
132
+ cu_seqlens_k=cu_seqlens_k,
133
+ max_seq_len_q=max_seq_len_q,
134
+ max_seq_len_k=max_seq_len_k,
135
+ softmax_scale=softmax_scale,
136
+ causal=causal,
137
+ seqlen_kv=seqlen_kv,
138
+ page_table=page_table,
139
+ seqlen_k=seqlen_k,
140
+ window_size_left=window_left,
141
+ window_size_right=window_right,
142
+ bottom_right=bottom_right,
143
+ )
144
+
145
+
146
+ def _cutlass_blackwell_fmha_backward(
147
+ dout: torch.Tensor,
148
+ q: torch.Tensor,
149
+ k: torch.Tensor,
150
+ v: torch.Tensor,
151
+ out: torch.Tensor,
152
+ softmax_lse: torch.Tensor,
153
+ cu_seqlens_q: torch.Tensor | None = None,
154
+ cu_seqlens_k: torch.Tensor | None = None,
155
+ max_seq_len_q: int | None = None,
156
+ max_seq_len_k: int | None = None,
157
+ softmax_scale: float | None = None,
158
+ causal: bool = False,
159
+ window_left: int = -1,
160
+ window_right: int = -1,
161
+ bottom_right: bool = True,
162
+ deterministic: bool = False,
163
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
164
+ deterministic = deterministic or torch.are_deterministic_algorithms_enabled()
165
+ dout = maybe_contiguous(dout)
166
+ q = maybe_contiguous(q)
167
+ k = maybe_contiguous(k)
168
+ v = maybe_contiguous(v)
169
+ out = maybe_contiguous(out)
170
+ return torch.ops.fbgemm.fmha_bwd(
171
+ dout,
172
+ q,
173
+ k,
174
+ v,
175
+ out,
176
+ softmax_lse,
177
+ cu_seqlens_q=cu_seqlens_q,
178
+ cu_seqlens_k=cu_seqlens_k,
179
+ max_seq_len_q=max_seq_len_q,
180
+ max_seq_len_k=max_seq_len_k,
181
+ softmax_scale=softmax_scale,
182
+ causal=causal,
183
+ window_size_left=window_left,
184
+ window_size_right=window_right,
185
+ bottom_right=bottom_right,
186
+ deterministic=deterministic,
187
+ )
188
+
189
+
190
+ def _validate_and_adjust_split_k_size(split_k_size: int) -> int:
191
+ """
192
+ Validate and adjust split_k_size parameter for optimal performance.
193
+
194
+ Args:
195
+ split_k_size: The requested split size along the K/V sequence dimension.
196
+
197
+ Returns:
198
+ Adjusted split_k_size that is valid for the kernel.
199
+
200
+ Valid values:
201
+ - split_k_size <= 0: Disable split-K (no splitting)
202
+ - split_k_size > 0: Enable split-K with specified split size
203
+ """
204
+ if not isinstance(split_k_size, int):
205
+ raise TypeError(
206
+ f"split_k_size must be an integer, got {type(split_k_size).__name__}"
207
+ )
208
+
209
+ # If split-K is disabled, return as-is
210
+ if split_k_size <= 0:
211
+ return split_k_size
212
+
213
+ # Constants
214
+ MIN_RECOMMENDED_SPLIT_SIZE = 256
215
+ TILE_SIZE = 128
216
+
217
+ # Adjust if split_k_size is too small
218
+ if split_k_size < MIN_RECOMMENDED_SPLIT_SIZE:
219
+ split_k_size = MIN_RECOMMENDED_SPLIT_SIZE
220
+
221
+ # Check if split_k_size is a power of 2
222
+ is_power_of_2 = (split_k_size & (split_k_size - 1)) == 0
223
+
224
+ # If not a power of 2, round to nearest multiple of tile size (128)
225
+ if not is_power_of_2:
226
+ split_k_size = ((split_k_size + TILE_SIZE - 1) // TILE_SIZE) * TILE_SIZE
227
+
228
+ return split_k_size
229
+
230
+
231
+ def _validate_decode_inputs(
232
+ q: torch.Tensor,
233
+ k: torch.Tensor,
234
+ v: torch.Tensor,
235
+ seqlen_kv: torch.Tensor | None,
236
+ ) -> None:
237
+ assert seqlen_kv is not None, "seqlen_kv must be provided for decode"
238
+ tensors = {"q": q, "k": k, "v": v, "seqlen_kv": seqlen_kv}
239
+
240
+ for name, tensor in tensors.items():
241
+ # assert tensor.is_contiguous(), f"{name} is not contiguous"
242
+ assert tensor.is_cuda, f"{name} must be on GPU"
243
+
244
+
245
+ def _prepare_decode_inputs(
246
+ q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
247
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, bool, tuple[int, ...]]:
248
+ """
249
+ Prepare inputs for decode kernel by handling both varlen and batch formats.
250
+
251
+ Returns:
252
+ - Reshaped q, k, v tensors in batch format [B, 1, H, D]
253
+ - batch_size
254
+ - needs_reshape_output flag
255
+ - original_shape of q
256
+ """
257
+ original_shape = tuple(q.shape)
258
+ needs_reshape_output = False
259
+ batch_size = q.shape[0]
260
+
261
+ if q.dim() == 3:
262
+ # Varlen format: [total_queries, num_heads, head_dim]
263
+ q = q.view(batch_size, 1, q.shape[1], q.shape[2])
264
+ needs_reshape_output = True
265
+
266
+ if q.dim() != 4:
267
+ raise ValueError(
268
+ f"Invalid query shape: {q.shape}. Expected [B, 1, H, D] or [total_queries, H, D]"
269
+ )
270
+ assert q.shape[1] == 1, "Kernel have sq=1"
271
+
272
+ k = k.view(batch_size, -1, k.shape[1], k.shape[2]) if k.dim() == 3 else k
273
+ v = v.view(batch_size, -1, v.shape[1], v.shape[2]) if v.dim() == 3 else v
274
+
275
+ return q, k, v, batch_size, needs_reshape_output, original_shape
276
+
277
+
278
+ def cutlass_blackwell_fmha_decode_forward(
279
+ q: torch.Tensor,
280
+ k: torch.Tensor,
281
+ v: torch.Tensor,
282
+ seqlen_kv: torch.Tensor | None = None,
283
+ cu_seqlens_q: torch.Tensor | None = None,
284
+ cu_seqlens_k: torch.Tensor | None = None,
285
+ max_seq_len_q: int | None = None,
286
+ max_seq_len_k: int | None = None,
287
+ softmax_scale: float | None = None,
288
+ causal: bool = False,
289
+ window_left: int = -1,
290
+ window_right: int = -1,
291
+ bottom_right: bool = True,
292
+ split_k_size: int = 0,
293
+ use_heuristic: bool = True,
294
+ ) -> tuple[torch.Tensor, torch.Tensor]:
295
+ """
296
+ Decode-optimized forward pass using the gen kernel.
297
+ This is a wrapper to use the gen kernel which is optimized
298
+ for decode (query length = 1).
299
+
300
+ This function is called externally by xformers ops.
301
+
302
+ Accepts inputs in two formats:
303
+ - Varlen format: [total_queries, num_heads, head_dim] (3D)
304
+ - Batch format: [batch_size, 1, num_heads, head_dim] (4D)
305
+
306
+ Args:
307
+ q: Query tensor in varlen [B, H, D] or batch [B, 1, H, D] format
308
+ k: Key tensor [B, Sk, H_kv, D]
309
+ v: Value tensor [B, Sk, H_kv, D]
310
+ seqlen_kv: Per-batch sequence lengths [B] (required)
311
+ split_k_size: Size of each split along the K/V sequence dimension.
312
+ - split_k_size <= 0 with use_heuristic=True: auto-compute using heuristic
313
+ - split_k_size <= 0 with use_heuristic=False: disable split-K
314
+ - split_k_size > 0: use the provided split size directly
315
+ Values below 256 are adjusted to 256. Non-power-of-2 values
316
+ are rounded to the nearest multiple of 128.
317
+ use_heuristic: If True and split_k_size <= 0, automatically compute optimal
318
+ split size using the heuristic. Default is True.
319
+
320
+ Returns:
321
+ Kernel output with Q dimension added:
322
+ - out: [B, 1, H, num_splits, D] (num_splits=1 when split-K disabled)
323
+ - lse: [B, num_splits, H, 1]
324
+ """
325
+ _validate_decode_inputs(q, k, v, seqlen_kv)
326
+
327
+ # Prepare inputs and handle format conversion
328
+ q, k, v, batch_size, _, original_shape = _prepare_decode_inputs(q, k, v)
329
+
330
+ # Determine effective split_k_size
331
+ if split_k_size <= 0 and use_heuristic:
332
+ # Auto-compute using heuristic
333
+ max_seqlen_kv = k.shape[1]
334
+ kv_heads = k.shape[2] # K shape is [B, Sk, H_kv, D]
335
+ split_k_size = get_splitk_heuristic(batch_size, max_seqlen_kv, kv_heads)
336
+
337
+ # Validate and adjust split_k_size
338
+ split_k_size = _validate_and_adjust_split_k_size(split_k_size)
339
+
340
+ # Validate window_right: decode kernel only supports causal attention (window_right <= 0)
341
+ if window_right > 0:
342
+ raise ValueError(
343
+ f"window_right={window_right} is not supported for decode attention. "
344
+ "The decode kernel only supports causal attention with window_right <= 0. "
345
+ "Use window_right=0 (causal, current position only)."
346
+ )
347
+
348
+ # Call the gen kernel (optimized for decode)
349
+ # Note: window_left specifies how many tokens to look back (exclusive)
350
+ # The kernel will attend to positions [seqlen_kv - window_left, seqlen_kv)
351
+ out, lse = torch.ops.fbgemm.fmha_gen_fwd(
352
+ q,
353
+ k,
354
+ v,
355
+ seqlen_kv,
356
+ None,
357
+ kernel_type=GenKernelType.UMMA_I,
358
+ window_left=window_left,
359
+ window_right=0,
360
+ split_k_size=split_k_size,
361
+ )
362
+
363
+ # Kernel returns: out [B, H, num_splits, D], lse [B, num_splits, H]
364
+ # Reshape to consistent format with Q dimension:
365
+ # out: [B, H, num_splits, D] -> [B, 1, H, num_splits, D]
366
+ # lse: [B, num_splits, H] -> [B, num_splits, H, 1]
367
+ out = out.unsqueeze(1) # [B, 1, H, num_splits, D]
368
+ lse = lse.unsqueeze(-1) # [B, num_splits, H, 1]
369
+ return out, lse
370
+
371
+
372
+ class CutlassBlackwellFmhaFunc(torch.autograd.Function):
373
+ @staticmethod
374
+ def forward( # type: ignore
375
+ ctx,
376
+ q: torch.Tensor,
377
+ k: torch.Tensor,
378
+ v: torch.Tensor,
379
+ softmax_scale: float | None = None,
380
+ causal: bool = False,
381
+ cu_seqlens_q: Optional[torch.Tensor] = None,
382
+ cu_seqlens_k: Optional[torch.Tensor] = None,
383
+ max_seq_len_q: Optional[int] = None,
384
+ max_seq_len_k: Optional[int] = None,
385
+ seqlen_kv: Optional[torch.Tensor] = None,
386
+ page_table: Optional[torch.Tensor] = None,
387
+ seqlen_k: Optional[int] = None,
388
+ window_size: tuple[int, int] = (-1, -1),
389
+ bottom_right: bool = True,
390
+ deterministic: bool = False,
391
+ ) -> torch.Tensor:
392
+ window_left, window_right = window_size
393
+ # Check if this is generation phase (sq = 1)
394
+ sq = q.shape[1]
395
+ if q.dim() == 4 and sq == 1:
396
+ # For gen case, we don't need to save tensors for backward
397
+ ctx.is_gen = True
398
+ out, _ = cutlass_blackwell_fmha_decode_forward(
399
+ q,
400
+ k,
401
+ v,
402
+ seqlen_kv,
403
+ cu_seqlens_q,
404
+ cu_seqlens_k,
405
+ max_seq_len_q,
406
+ max_seq_len_k,
407
+ softmax_scale,
408
+ causal,
409
+ window_left,
410
+ window_right,
411
+ bottom_right,
412
+ )
413
+ return out
414
+
415
+ ctx.is_gen = False
416
+ # Only check dtype if cu_seqlens_q and cu_seqlens_k are provided
417
+ if cu_seqlens_q is not None and cu_seqlens_k is not None:
418
+ assert (
419
+ cu_seqlens_q.dtype == torch.int32
420
+ and cu_seqlens_q.dtype == cu_seqlens_k.dtype
421
+ ), "cu_seqlens_q and cu_seqlens_k must be int32"
422
+
423
+ # handle window_size
424
+ if causal and window_left >= 0:
425
+ window_right = 0
426
+ # Use regular FMHA for non-generation case
427
+ out, softmax_lse = _cutlass_blackwell_fmha_forward(
428
+ q,
429
+ k,
430
+ v,
431
+ cu_seqlens_q,
432
+ cu_seqlens_k,
433
+ max_seq_len_q,
434
+ max_seq_len_k,
435
+ softmax_scale,
436
+ causal,
437
+ seqlen_kv,
438
+ page_table,
439
+ seqlen_k,
440
+ window_left,
441
+ window_right,
442
+ bottom_right,
443
+ )
444
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
445
+ ctx.softmax_scale = softmax_scale
446
+ ctx.causal = causal
447
+ ctx.window_size = window_size
448
+ ctx.max_seq_len_q = max_seq_len_q
449
+ ctx.max_seq_len_k = max_seq_len_k
450
+ ctx.cu_seqlens_q = cu_seqlens_q
451
+ ctx.cu_seqlens_k = cu_seqlens_k
452
+ ctx.bottom_right = bottom_right
453
+ ctx.deterministic = deterministic
454
+ return out
455
+
456
+ @staticmethod
457
+ def backward(ctx, dout: torch.Tensor, *args: Any) -> tuple[ # type: ignore
458
+ torch.Tensor,
459
+ torch.Tensor,
460
+ torch.Tensor,
461
+ None,
462
+ None,
463
+ None,
464
+ None,
465
+ None,
466
+ None,
467
+ None,
468
+ None,
469
+ None,
470
+ None,
471
+ None,
472
+ None,
473
+ ]:
474
+ if ctx.is_gen:
475
+ # For gen case, no backward pass is needed (generation is inference only)
476
+ raise RuntimeError(
477
+ "Backward pass is not supported for generation phase (sq=1)"
478
+ )
479
+
480
+ q, k, v, out, softmax_lse = ctx.saved_tensors
481
+ window_left, window_right = ctx.window_size
482
+ dq, dk, dv = _cutlass_blackwell_fmha_backward(
483
+ dout,
484
+ q,
485
+ k,
486
+ v,
487
+ out,
488
+ softmax_lse,
489
+ ctx.cu_seqlens_q,
490
+ ctx.cu_seqlens_k,
491
+ ctx.max_seq_len_q,
492
+ ctx.max_seq_len_k,
493
+ ctx.softmax_scale,
494
+ ctx.causal,
495
+ window_left,
496
+ window_right,
497
+ bottom_right=ctx.bottom_right,
498
+ deterministic=ctx.deterministic,
499
+ )
500
+ return (
501
+ dq,
502
+ dk,
503
+ dv,
504
+ None,
505
+ None,
506
+ None,
507
+ None,
508
+ None,
509
+ None,
510
+ None,
511
+ None,
512
+ None,
513
+ None,
514
+ None,
515
+ None,
516
+ )
517
+
518
+
519
+ def cutlass_blackwell_fmha_func(
520
+ q: torch.Tensor,
521
+ k: torch.Tensor,
522
+ v: torch.Tensor,
523
+ softmax_scale: float | None = None,
524
+ causal: bool = False,
525
+ cu_seqlens_q: torch.Tensor | None = None,
526
+ cu_seqlens_k: torch.Tensor | None = None,
527
+ max_seq_len_q: int | None = None,
528
+ max_seq_len_k: int | None = None,
529
+ seqlen_kv: torch.Tensor | None = None,
530
+ page_table: torch.Tensor | None = None,
531
+ seqlen_k: int | None = None,
532
+ window_size: tuple[int, int] | None = (-1, -1),
533
+ bottom_right: bool = True,
534
+ deterministic: bool = False,
535
+ ):
536
+ return CutlassBlackwellFmhaFunc.apply(
537
+ q,
538
+ k,
539
+ v,
540
+ softmax_scale,
541
+ causal,
542
+ cu_seqlens_q,
543
+ cu_seqlens_k,
544
+ max_seq_len_q,
545
+ max_seq_len_k,
546
+ seqlen_kv,
547
+ page_table,
548
+ seqlen_k,
549
+ window_size,
550
+ bottom_right,
551
+ deterministic,
552
+ )
@@ -0,0 +1,13 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ try:
9
+ # pyre-ignore[21]
10
+ # @manual=//deeplearning/fbgemm/fbgemm_gpu:test_utils
11
+ from fbgemm_gpu import open_source
12
+ except Exception:
13
+ open_source: bool = False