fbgemm-gpu-genai-nightly 2025.10.20__cp312-cp312-manylinux_2_28_x86_64.whl → 2026.1.8__cp312-cp312-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fbgemm_gpu/__init__.py +3 -1
- fbgemm_gpu/config/feature_list.py +3 -0
- fbgemm_gpu/docs/target.genai.json.py +1 -1
- fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
- fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +4 -3
- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +11 -1
- fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +135 -172
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +15 -1
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +75 -3
- fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +299 -61
- fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +11 -8
- fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
- fbgemm_gpu/fbgemm.so +0 -0
- fbgemm_gpu/quantize_comm.py +15 -4
- fbgemm_gpu/quantize_utils.py +54 -6
- fbgemm_gpu/sparse_ops.py +53 -0
- fbgemm_gpu/split_embedding_configs.py +34 -0
- fbgemm_gpu/split_table_batched_embeddings_ops_common.py +53 -11
- fbgemm_gpu/split_table_batched_embeddings_ops_training.py +470 -94
- fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +44 -1
- fbgemm_gpu/tbe/bench/bench_runs.py +7 -0
- fbgemm_gpu/tbe/bench/tbe_data_config.py +15 -1
- fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +6 -1
- fbgemm_gpu/tbe/ssd/training.py +335 -50
- fbgemm_gpu/tbe/stats/bench_params_reporter.py +5 -2
- fbgemm_gpu/triton/quantize.py +13 -7
- fbgemm_gpu/utils/writeback_util.py +124 -0
- {fbgemm_gpu_genai_nightly-2025.10.20.dist-info → fbgemm_gpu_genai_nightly-2026.1.8.dist-info}/METADATA +2 -2
- {fbgemm_gpu_genai_nightly-2025.10.20.dist-info → fbgemm_gpu_genai_nightly-2026.1.8.dist-info}/RECORD +31 -30
- {fbgemm_gpu_genai_nightly-2025.10.20.dist-info → fbgemm_gpu_genai_nightly-2026.1.8.dist-info}/WHEEL +0 -0
- {fbgemm_gpu_genai_nightly-2025.10.20.dist-info → fbgemm_gpu_genai_nightly-2026.1.8.dist-info}/top_level.txt +0 -0
fbgemm_gpu/__init__.py
CHANGED
|
@@ -15,6 +15,7 @@ import torch
|
|
|
15
15
|
# Based on the FBGEMM-PyTorch compatibility table at
|
|
16
16
|
# https://docs.pytorch.org/FBGEMM/general/Releases.html#fbgemm-releases-compatibility
|
|
17
17
|
_fbgemm_torch_compat_table = {
|
|
18
|
+
"1.5": "2.10",
|
|
18
19
|
"1.4": "2.9",
|
|
19
20
|
"1.3": "2.8",
|
|
20
21
|
"1.2": "2.7",
|
|
@@ -81,7 +82,7 @@ def _load_library(filename: str, version: str, no_throw: bool = False) -> None:
|
|
|
81
82
|
"""
|
|
82
83
|
)
|
|
83
84
|
|
|
84
|
-
elif str(torch.__version__)
|
|
85
|
+
elif not str(torch.__version__).startswith(_fbgemm_torch_compat_table[keys[0]]):
|
|
85
86
|
logging.warning(
|
|
86
87
|
f"""
|
|
87
88
|
\033[31m
|
|
@@ -132,6 +133,7 @@ fbgemm_gpu_libraries = [
|
|
|
132
133
|
"fbgemm_gpu_config",
|
|
133
134
|
"fbgemm_gpu_tbe_utils",
|
|
134
135
|
"fbgemm_gpu_tbe_index_select",
|
|
136
|
+
"fbgemm_gpu_tbe_cache",
|
|
135
137
|
"fbgemm_gpu_tbe_optimizers",
|
|
136
138
|
"fbgemm_gpu_tbe_inference",
|
|
137
139
|
"fbgemm_gpu_tbe_training_forward",
|
|
@@ -63,6 +63,9 @@ class FeatureGateName(Enum):
|
|
|
63
63
|
# Enable TBE input parameters extraction
|
|
64
64
|
TBE_REPORT_INPUT_PARAMS = auto()
|
|
65
65
|
|
|
66
|
+
# Enable tuned max segment length per CTA for B200
|
|
67
|
+
TBE_USE_TUNED_SEGMENT_LENGTHS_CTA_B200 = auto()
|
|
68
|
+
|
|
66
69
|
def is_enabled(self) -> bool:
|
|
67
70
|
return FeatureGate.is_enabled(self)
|
|
68
71
|
|
|
Binary file
|
|
@@ -289,7 +289,7 @@ def triton_quantize_mx4_unpack(
|
|
|
289
289
|
stochastic_casting (bool): Whether to use stochastic casting.
|
|
290
290
|
|
|
291
291
|
Returns:
|
|
292
|
-
torch.Tensor: [M / 2] mx4 scaled tensor packed into
|
|
292
|
+
torch.Tensor: [M / 2] mx4 scaled tensor packed into uint8
|
|
293
293
|
torch.Tensor: [M / group_size] mx4 shared exponents into int8
|
|
294
294
|
|
|
295
295
|
eg.
|
|
@@ -1410,8 +1410,9 @@ def _kernel_nvfp4_quantize(
|
|
|
1410
1410
|
|
|
1411
1411
|
# Apply scale_ to input. We do this by broadcasting scale.
|
|
1412
1412
|
# scaled_a = a * global_scale (fp32) / local_scale (fp8)
|
|
1413
|
-
scaled_a = tl.
|
|
1414
|
-
|
|
1413
|
+
scaled_a = tl.div_rn(
|
|
1414
|
+
tl.reshape(a, [GROUP_LOAD, GROUP_SIZE]).to(tl.float32),
|
|
1415
|
+
tl.reshape(scale_ / input_global_scale, [GROUP_LOAD, 1]).to(tl.float32),
|
|
1415
1416
|
)
|
|
1416
1417
|
# Reshape back to a flat array.
|
|
1417
1418
|
scaled_a = tl.reshape(scaled_a, [GROUP_LOAD * GROUP_SIZE])
|
|
@@ -1212,6 +1212,8 @@ def matmul_fp8_row(
|
|
|
1212
1212
|
imprecise_acc: bool = False,
|
|
1213
1213
|
tma_persistent: bool = True,
|
|
1214
1214
|
no_use_persistent: Optional[bool] = None,
|
|
1215
|
+
# add an option to explicitly require the use of persistent process
|
|
1216
|
+
use_persistent: Optional[bool] = None,
|
|
1215
1217
|
use_warp_specialization: bool = False,
|
|
1216
1218
|
) -> torch.Tensor:
|
|
1217
1219
|
"""
|
|
@@ -1232,12 +1234,16 @@ def matmul_fp8_row(
|
|
|
1232
1234
|
Returns:
|
|
1233
1235
|
torch.Tensor: [M, N] Output tensor a @ b / (a_scale[:, None] * b_scale[None, :])
|
|
1234
1236
|
"""
|
|
1235
|
-
if
|
|
1237
|
+
if use_persistent:
|
|
1238
|
+
no_use_persistent = False
|
|
1239
|
+
elif no_use_persistent is None:
|
|
1236
1240
|
# Default True for AMD and False for Nvidia.
|
|
1237
1241
|
if torch.version.hip is not None:
|
|
1238
1242
|
no_use_persistent = True
|
|
1239
1243
|
else:
|
|
1240
1244
|
no_use_persistent = False
|
|
1245
|
+
# if use_persistent is explicitly requested, set o_use_persistent to False
|
|
1246
|
+
|
|
1241
1247
|
# Get datatypes and constants to use.
|
|
1242
1248
|
pt_fp8_dtype, _, _, _ = get_fp8_constants()
|
|
1243
1249
|
# Handle 3D+ a shape
|
|
@@ -3840,6 +3846,10 @@ _MATMUL_CONFIG_TUPLES_PINGPONG_4K_8K_16K = [
|
|
|
3840
3846
|
(256, 128, 128, 1, 1, 2, 16, 1, 8, 2),
|
|
3841
3847
|
(128, 256, 128, 2, 1, 2, 16, 2, 4, 1),
|
|
3842
3848
|
(256, 128, 64, 2, 1, 2, 16, 1, 4, 2),
|
|
3849
|
+
(128, 128, 256, 2, 1, 0, 16, 2, 8, 2),
|
|
3850
|
+
(128, 64, 128, 2, 1, 2, 16, 2, 4, 2),
|
|
3851
|
+
(128, 128, 64, 2, 1, 0, 16, 1, 4, 2),
|
|
3852
|
+
(128, 128, 128, 1, 1, 2, 16, 1, 4, 2),
|
|
3843
3853
|
]
|
|
3844
3854
|
|
|
3845
3855
|
|
|
@@ -509,14 +509,13 @@ def _fbgemm_grouped_gemm_ws(
|
|
|
509
509
|
num_tiles = num_m_tiles * NUM_N_TILES
|
|
510
510
|
|
|
511
511
|
if USE_TMA_STORE:
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
)
|
|
512
|
+
c_desc_ptr = tl.make_tensor_descriptor(
|
|
513
|
+
c_ptr + M_start_offset * N,
|
|
514
|
+
shape=[m_size, N],
|
|
515
|
+
# pyre-ignore
|
|
516
|
+
strides=[N, 1],
|
|
517
|
+
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
|
518
|
+
)
|
|
520
519
|
|
|
521
520
|
# Move across tiles
|
|
522
521
|
next_iterated_tiles = iterated_tiles + num_tiles
|
|
@@ -534,72 +533,59 @@ def _fbgemm_grouped_gemm_ws(
|
|
|
534
533
|
m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
|
|
535
534
|
n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
|
|
536
535
|
for k_offset in range(0, K, BLOCK_SIZE_K):
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
else:
|
|
554
|
-
accumulator += tl.dot(a, b.T)
|
|
536
|
+
a = tl._experimental_descriptor_load(
|
|
537
|
+
a_desc_ptr,
|
|
538
|
+
[m_offset, k_offset],
|
|
539
|
+
[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
|
540
|
+
dtype,
|
|
541
|
+
)
|
|
542
|
+
b = tl._experimental_descriptor_load(
|
|
543
|
+
b_desc_ptr,
|
|
544
|
+
[n_offset, k_offset],
|
|
545
|
+
[BLOCK_SIZE_N, BLOCK_SIZE_K],
|
|
546
|
+
dtype,
|
|
547
|
+
)
|
|
548
|
+
if USE_FAST_ACCUM:
|
|
549
|
+
accumulator = tl.dot(a, b.T, accumulator)
|
|
550
|
+
else:
|
|
551
|
+
accumulator += tl.dot(a, b.T)
|
|
555
552
|
|
|
556
553
|
if USE_TMA_STORE:
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
)
|
|
554
|
+
m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
|
|
555
|
+
n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
|
|
556
|
+
# pyre-ignore
|
|
557
|
+
c_desc_ptr.store(
|
|
558
|
+
[m_offset, n_offset],
|
|
559
|
+
accumulator.to(c_ptr.dtype.element_ty),
|
|
560
|
+
)
|
|
565
561
|
elif FUSE_SCATTER_ADD:
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
mask
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
|
|
582
|
-
c,
|
|
583
|
-
mask=mask[:, None],
|
|
584
|
-
sem="relaxed",
|
|
585
|
-
)
|
|
562
|
+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
563
|
+
mask = offs_am < m_size
|
|
564
|
+
m_offsets = tl.load(
|
|
565
|
+
scatter_add_indices + M_start_offset + offs_am,
|
|
566
|
+
mask=mask,
|
|
567
|
+
cache_modifier=".ca",
|
|
568
|
+
)
|
|
569
|
+
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
570
|
+
c = accumulator.to(c_ptr.dtype.element_ty)
|
|
571
|
+
tl.atomic_add(
|
|
572
|
+
c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
|
|
573
|
+
c,
|
|
574
|
+
mask=mask[:, None],
|
|
575
|
+
sem="relaxed",
|
|
576
|
+
)
|
|
586
577
|
else:
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
c
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
+ offs_bn[None, :],
|
|
599
|
-
c,
|
|
600
|
-
mask=offs_am[:, None] < m_size,
|
|
601
|
-
cache_modifier=".cs",
|
|
602
|
-
)
|
|
578
|
+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
579
|
+
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
580
|
+
c = accumulator.to(c_ptr.dtype.element_ty)
|
|
581
|
+
tl.store(
|
|
582
|
+
c_ptr
|
|
583
|
+
+ (M_start_offset + offs_am[:, None]) * N
|
|
584
|
+
+ offs_bn[None, :],
|
|
585
|
+
c,
|
|
586
|
+
mask=offs_am[:, None] < m_size,
|
|
587
|
+
cache_modifier=".cs",
|
|
588
|
+
)
|
|
603
589
|
tidx += NUM_SMS
|
|
604
590
|
|
|
605
591
|
iterated_tiles += num_tiles
|
|
@@ -841,14 +827,13 @@ def _fbgemm_grouped_gemm_fp8_rowwise_ws(
|
|
|
841
827
|
num_tiles = num_m_tiles * NUM_N_TILES
|
|
842
828
|
|
|
843
829
|
if USE_TMA_STORE:
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
)
|
|
830
|
+
c_desc_ptr = tl.make_tensor_descriptor(
|
|
831
|
+
c_ptr + M_start_offset * N,
|
|
832
|
+
shape=[m_size, N],
|
|
833
|
+
# pyre-ignore
|
|
834
|
+
strides=[N, 1],
|
|
835
|
+
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
|
836
|
+
)
|
|
852
837
|
|
|
853
838
|
# Move across tiles
|
|
854
839
|
next_iterated_tiles = iterated_tiles + num_tiles
|
|
@@ -867,107 +852,85 @@ def _fbgemm_grouped_gemm_fp8_rowwise_ws(
|
|
|
867
852
|
m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
|
|
868
853
|
n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
|
|
869
854
|
for k_offset in range(0, K, BLOCK_SIZE_K):
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
else:
|
|
887
|
-
accumulator += tl.dot(a, b.T)
|
|
855
|
+
a = tl._experimental_descriptor_load(
|
|
856
|
+
a_desc_ptr,
|
|
857
|
+
[m_offset, k_offset],
|
|
858
|
+
[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
|
859
|
+
dtype,
|
|
860
|
+
)
|
|
861
|
+
b = tl._experimental_descriptor_load(
|
|
862
|
+
b_desc_ptr,
|
|
863
|
+
[n_offset, k_offset],
|
|
864
|
+
[BLOCK_SIZE_N, BLOCK_SIZE_K],
|
|
865
|
+
dtype,
|
|
866
|
+
)
|
|
867
|
+
if USE_FAST_ACCUM:
|
|
868
|
+
accumulator = tl.dot(a, b.T, accumulator)
|
|
869
|
+
else:
|
|
870
|
+
accumulator += tl.dot(a, b.T)
|
|
888
871
|
|
|
889
872
|
if USE_TMA_LOAD_ON_SCALES:
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
mask=offs_am[:, None] < m_size,
|
|
905
|
-
cache_modifier=".ca",
|
|
906
|
-
)
|
|
907
|
-
c = accumulator.to(tl.float32) * a_scale * b_scale[None, :]
|
|
873
|
+
b_scale = tl._experimental_descriptor_load(
|
|
874
|
+
b_scale_desc_ptr,
|
|
875
|
+
[n_offset],
|
|
876
|
+
[BLOCK_SIZE_N],
|
|
877
|
+
tl.float32,
|
|
878
|
+
)
|
|
879
|
+
|
|
880
|
+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
881
|
+
a_scale = tl.load(
|
|
882
|
+
a_scale_ptr + M_start_offset + offs_am[:, None],
|
|
883
|
+
mask=offs_am[:, None] < m_size,
|
|
884
|
+
cache_modifier=".ca",
|
|
885
|
+
)
|
|
886
|
+
c = accumulator.to(tl.float32) * a_scale * b_scale[None, :]
|
|
908
887
|
else:
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
b_scale = tl.load(
|
|
922
|
-
b_scale_ptr + N_start_offset + offs_bn[None, :],
|
|
923
|
-
cache_modifier=".ca",
|
|
924
|
-
)
|
|
925
|
-
c = accumulator.to(tl.float32) * a_scale * b_scale
|
|
888
|
+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
889
|
+
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
890
|
+
a_scale = tl.load(
|
|
891
|
+
a_scale_ptr + M_start_offset + offs_am[:, None],
|
|
892
|
+
mask=offs_am[:, None] < m_size,
|
|
893
|
+
cache_modifier=".ca",
|
|
894
|
+
)
|
|
895
|
+
b_scale = tl.load(
|
|
896
|
+
b_scale_ptr + N_start_offset + offs_bn[None, :],
|
|
897
|
+
cache_modifier=".ca",
|
|
898
|
+
)
|
|
899
|
+
c = accumulator.to(tl.float32) * a_scale * b_scale
|
|
926
900
|
|
|
927
901
|
if USE_TMA_STORE:
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
)
|
|
902
|
+
m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
|
|
903
|
+
n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
|
|
904
|
+
# pyre-ignore
|
|
905
|
+
c_desc_ptr.store(
|
|
906
|
+
[m_offset, n_offset], c.to(c_ptr.dtype.element_ty)
|
|
907
|
+
)
|
|
935
908
|
elif FUSE_SCATTER_ADD:
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
mask
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
|
|
951
|
-
c,
|
|
952
|
-
mask=mask[:, None],
|
|
953
|
-
sem="relaxed",
|
|
954
|
-
)
|
|
909
|
+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
910
|
+
mask = offs_am < m_size
|
|
911
|
+
m_offsets = tl.load(
|
|
912
|
+
scatter_add_indices + M_start_offset + offs_am,
|
|
913
|
+
mask=mask,
|
|
914
|
+
cache_modifier=".ca",
|
|
915
|
+
)
|
|
916
|
+
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
917
|
+
tl.atomic_add(
|
|
918
|
+
c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
|
|
919
|
+
c,
|
|
920
|
+
mask=mask[:, None],
|
|
921
|
+
sem="relaxed",
|
|
922
|
+
)
|
|
955
923
|
else:
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
|
|
963
|
-
|
|
964
|
-
|
|
965
|
-
|
|
966
|
-
+ offs_bn[None, :],
|
|
967
|
-
c,
|
|
968
|
-
mask=offs_am[:, None] < m_size,
|
|
969
|
-
cache_modifier=".cs",
|
|
970
|
-
)
|
|
924
|
+
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
925
|
+
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
926
|
+
tl.store(
|
|
927
|
+
c_ptr
|
|
928
|
+
+ (M_start_offset + offs_am[:, None]) * N
|
|
929
|
+
+ offs_bn[None, :],
|
|
930
|
+
c,
|
|
931
|
+
mask=offs_am[:, None] < m_size,
|
|
932
|
+
cache_modifier=".cs",
|
|
933
|
+
)
|
|
971
934
|
tidx += NUM_SMS
|
|
972
935
|
|
|
973
936
|
iterated_tiles += num_tiles
|
|
@@ -29,4 +29,18 @@ else:
|
|
|
29
29
|
)
|
|
30
30
|
|
|
31
31
|
from . import cutlass_blackwell_fmha_custom_op # noqa: F401
|
|
32
|
-
from .cutlass_blackwell_fmha_interface import
|
|
32
|
+
from .cutlass_blackwell_fmha_interface import ( # noqa: F401
|
|
33
|
+
_cutlass_blackwell_fmha_forward,
|
|
34
|
+
cutlass_blackwell_fmha_decode_forward,
|
|
35
|
+
cutlass_blackwell_fmha_func,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
# Note: _cutlass_blackwell_fmha_forward is an internal function (indicated by leading underscore)
|
|
39
|
+
# that is exported here specifically for testing purposes. It allows tests to access the LSE
|
|
40
|
+
# (log-sum-exp) values returned by the forward pass without modifying the public API.
|
|
41
|
+
# Production code should use cutlass_blackwell_fmha_func instead.
|
|
42
|
+
__all__ = [
|
|
43
|
+
"_cutlass_blackwell_fmha_forward",
|
|
44
|
+
"cutlass_blackwell_fmha_decode_forward",
|
|
45
|
+
"cutlass_blackwell_fmha_func",
|
|
46
|
+
]
|
fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py
CHANGED
|
@@ -12,13 +12,13 @@ from torch.library import register_fake
|
|
|
12
12
|
|
|
13
13
|
torch.library.define(
|
|
14
14
|
"blackwell_fmha::fmha_fwd",
|
|
15
|
-
"(Tensor q, Tensor k, Tensor v, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seq_len_q, int? max_seq_len_k, float? softmax_scale, bool? causal, Tensor? seqlen_kv) -> (Tensor, Tensor)",
|
|
15
|
+
"(Tensor q, Tensor k, Tensor v, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seq_len_q, int? max_seq_len_k, float? softmax_scale, bool? causal, Tensor? seqlen_kv, Tensor? page_table, int seqlen_k=-1, int window_size_left=-1, int window_size_right=-1, bool bottom_right=True) -> (Tensor, Tensor)",
|
|
16
16
|
tags=torch.Tag.pt2_compliant_tag,
|
|
17
17
|
)
|
|
18
18
|
|
|
19
19
|
torch.library.define(
|
|
20
20
|
"blackwell_fmha::fmha_bwd",
|
|
21
|
-
"(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seq_len_q, int? max_seq_len_k, bool? causal) -> (Tensor, Tensor, Tensor)",
|
|
21
|
+
"(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seq_len_q, int? max_seq_len_k, float? softmax_scale, bool? causal, int window_size_left=-1, int window_size_right=-1, bool bottom_right=True, bool deterministic=False) -> (Tensor, Tensor, Tensor)",
|
|
22
22
|
tags=torch.Tag.pt2_compliant_tag,
|
|
23
23
|
)
|
|
24
24
|
|
|
@@ -35,6 +35,11 @@ def custom_op_fmha(
|
|
|
35
35
|
softmax_scale: Optional[float] = None,
|
|
36
36
|
causal: bool = False,
|
|
37
37
|
seqlen_kv: Optional[torch.Tensor] = None,
|
|
38
|
+
page_table: Optional[torch.Tensor] = None,
|
|
39
|
+
seqlen_k: Optional[int] = None,
|
|
40
|
+
window_size_left: int = -1,
|
|
41
|
+
window_size_right: int = -1,
|
|
42
|
+
bottom_right: bool = True,
|
|
38
43
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
39
44
|
assert q.is_contiguous(), "q is not contiguous"
|
|
40
45
|
assert k.is_contiguous(), "k is not contiguous"
|
|
@@ -42,6 +47,7 @@ def custom_op_fmha(
|
|
|
42
47
|
assert q.is_cuda, "q must be on GPU"
|
|
43
48
|
assert k.is_cuda, "k must be on GPU"
|
|
44
49
|
assert v.is_cuda, "v must be on GPU"
|
|
50
|
+
|
|
45
51
|
return torch.ops.fbgemm.fmha_fwd(
|
|
46
52
|
q,
|
|
47
53
|
k,
|
|
@@ -53,6 +59,11 @@ def custom_op_fmha(
|
|
|
53
59
|
softmax_scale=softmax_scale,
|
|
54
60
|
causal=causal,
|
|
55
61
|
seqlen_kv=seqlen_kv,
|
|
62
|
+
page_table=page_table,
|
|
63
|
+
seqlen_k=seqlen_k,
|
|
64
|
+
window_size_left=window_size_left,
|
|
65
|
+
window_size_right=window_size_right,
|
|
66
|
+
bottom_right=bottom_right,
|
|
56
67
|
)
|
|
57
68
|
|
|
58
69
|
|
|
@@ -68,6 +79,11 @@ def fmha_fwd_meta(
|
|
|
68
79
|
softmax_scale: Optional[float] = None,
|
|
69
80
|
causal: bool = False,
|
|
70
81
|
seqlen_kv: Optional[torch.Tensor] = None,
|
|
82
|
+
page_table: Optional[torch.Tensor] = None,
|
|
83
|
+
seqlen_k: Optional[int] = None,
|
|
84
|
+
window_size_left: int = -1,
|
|
85
|
+
window_size_right: int = -1,
|
|
86
|
+
bottom_right: bool = True,
|
|
71
87
|
):
|
|
72
88
|
if q.dtype == torch.float16:
|
|
73
89
|
out_dtype = torch.float16
|
|
@@ -122,8 +138,14 @@ def custom_op_fmha_bwd(
|
|
|
122
138
|
cu_seqlens_k: Optional[torch.Tensor] = None,
|
|
123
139
|
max_seq_len_q: Optional[int] = None,
|
|
124
140
|
max_seq_len_k: Optional[int] = None,
|
|
141
|
+
softmax_scale: Optional[float] = None,
|
|
125
142
|
causal: bool = False,
|
|
143
|
+
window_size_left: int = -1,
|
|
144
|
+
window_size_right: int = -1,
|
|
145
|
+
bottom_right: bool = True,
|
|
146
|
+
deterministic: bool = False,
|
|
126
147
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
148
|
+
|
|
127
149
|
return torch.ops.fbgemm.fmha_bwd(
|
|
128
150
|
dOutput,
|
|
129
151
|
query,
|
|
@@ -135,7 +157,12 @@ def custom_op_fmha_bwd(
|
|
|
135
157
|
cu_seqlens_k=cu_seqlens_k,
|
|
136
158
|
max_seq_len_q=max_seq_len_q,
|
|
137
159
|
max_seq_len_k=max_seq_len_k,
|
|
160
|
+
softmax_scale=softmax_scale,
|
|
138
161
|
causal=causal,
|
|
162
|
+
window_size_left=window_size_left,
|
|
163
|
+
window_size_right=window_size_right,
|
|
164
|
+
bottom_right=bottom_right,
|
|
165
|
+
deterministic=deterministic,
|
|
139
166
|
)
|
|
140
167
|
|
|
141
168
|
|
|
@@ -151,7 +178,12 @@ def fmha_bwd_meta(
|
|
|
151
178
|
cu_seqlens_k: Optional[torch.Tensor] = None,
|
|
152
179
|
max_seq_len_q: Optional[int] = None,
|
|
153
180
|
max_seq_len_k: Optional[int] = None,
|
|
181
|
+
softmax_scale: Optional[float] = None,
|
|
154
182
|
causal: bool = False,
|
|
183
|
+
window_size_left: int = -1,
|
|
184
|
+
window_size_right: int = -1,
|
|
185
|
+
bottom_right: bool = True,
|
|
186
|
+
deterministic: bool = False,
|
|
155
187
|
):
|
|
156
188
|
return (
|
|
157
189
|
torch.empty_like(query),
|
|
@@ -198,9 +230,30 @@ def _backward(ctx, *grad):
|
|
|
198
230
|
ctx.cu_seqlens_k,
|
|
199
231
|
ctx.max_seq_len_q,
|
|
200
232
|
ctx.max_seq_len_k,
|
|
233
|
+
ctx.softmax_scale,
|
|
201
234
|
ctx.causal,
|
|
235
|
+
ctx.window_size_left,
|
|
236
|
+
ctx.window_size_right,
|
|
237
|
+
ctx.bottom_right,
|
|
238
|
+
ctx.deterministic,
|
|
239
|
+
)
|
|
240
|
+
return (
|
|
241
|
+
dq,
|
|
242
|
+
dk,
|
|
243
|
+
dv,
|
|
244
|
+
None,
|
|
245
|
+
None,
|
|
246
|
+
None,
|
|
247
|
+
None,
|
|
248
|
+
None,
|
|
249
|
+
None,
|
|
250
|
+
None,
|
|
251
|
+
None,
|
|
252
|
+
None,
|
|
253
|
+
None,
|
|
254
|
+
None,
|
|
255
|
+
None,
|
|
202
256
|
)
|
|
203
|
-
return dq, dk, dv, None, None, None, None, None, None, None
|
|
204
257
|
|
|
205
258
|
|
|
206
259
|
def _setup_context(ctx, inputs, output):
|
|
@@ -215,6 +268,11 @@ def _setup_context(ctx, inputs, output):
|
|
|
215
268
|
softmax_scale,
|
|
216
269
|
causal,
|
|
217
270
|
seqlen_kv,
|
|
271
|
+
page_table,
|
|
272
|
+
seqlen_k,
|
|
273
|
+
window_size_left,
|
|
274
|
+
window_size_right,
|
|
275
|
+
bottom_right,
|
|
218
276
|
) = inputs
|
|
219
277
|
(out, softmax_lse) = output
|
|
220
278
|
ctx.save_for_backward(q, k, v, out, softmax_lse)
|
|
@@ -224,6 +282,10 @@ def _setup_context(ctx, inputs, output):
|
|
|
224
282
|
ctx.max_seq_len_k = max_seq_len_k
|
|
225
283
|
ctx.cu_seqlens_q = cu_seqlens_q
|
|
226
284
|
ctx.cu_seqlens_k = cu_seqlens_k
|
|
285
|
+
ctx.window_size_left = window_size_left
|
|
286
|
+
ctx.window_size_right = window_size_right
|
|
287
|
+
ctx.bottom_right = bottom_right
|
|
288
|
+
ctx.deterministic = False # Set default value
|
|
227
289
|
ctx.is_gen = False
|
|
228
290
|
|
|
229
291
|
|
|
@@ -246,6 +308,11 @@ def cutlass_blackwell_fmha_custom_op(
|
|
|
246
308
|
max_seq_len_q: int | None = None,
|
|
247
309
|
max_seq_len_k: int | None = None,
|
|
248
310
|
seqlen_kv: torch.Tensor | None = None,
|
|
311
|
+
page_table: torch.Tensor | None = None,
|
|
312
|
+
seqlen_k: int | None = -1,
|
|
313
|
+
window_size_left: int | None = -1,
|
|
314
|
+
window_size_right: int | None = -1,
|
|
315
|
+
bottom_right: bool | None = True,
|
|
249
316
|
):
|
|
250
317
|
return torch.ops.blackwell_fmha.fmha_fwd(
|
|
251
318
|
q=q,
|
|
@@ -258,4 +325,9 @@ def cutlass_blackwell_fmha_custom_op(
|
|
|
258
325
|
softmax_scale=softmax_scale,
|
|
259
326
|
causal=causal,
|
|
260
327
|
seqlen_kv=seqlen_kv,
|
|
328
|
+
page_table=page_table,
|
|
329
|
+
seqlen_k=seqlen_k,
|
|
330
|
+
window_size_left=window_size_left,
|
|
331
|
+
window_size_right=window_size_right,
|
|
332
|
+
bottom_right=bottom_right,
|
|
261
333
|
)[0]
|