fbgemm-gpu-genai-nightly 2025.11.4__cp313-cp313-manylinux_2_28_x86_64.whl → 2025.12.17__cp313-cp313-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. fbgemm_gpu/__init__.py +4 -1
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/docs/target.genai.json.py +1 -1
  4. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  5. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +4 -3
  6. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +7 -1
  7. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +135 -172
  8. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +15 -1
  9. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +75 -3
  10. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +278 -62
  11. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +9 -6
  12. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  13. fbgemm_gpu/fbgemm.so +0 -0
  14. fbgemm_gpu/quantize_comm.py +15 -2
  15. fbgemm_gpu/sparse_ops.py +53 -0
  16. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +22 -6
  17. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +173 -67
  18. fbgemm_gpu/tbe/bench/bench_runs.py +7 -0
  19. fbgemm_gpu/tbe/bench/tbe_data_config.py +15 -1
  20. fbgemm_gpu/tbe/ssd/training.py +174 -30
  21. fbgemm_gpu/tbe/stats/bench_params_reporter.py +5 -2
  22. {fbgemm_gpu_genai_nightly-2025.11.4.dist-info → fbgemm_gpu_genai_nightly-2025.12.17.dist-info}/METADATA +2 -2
  23. {fbgemm_gpu_genai_nightly-2025.11.4.dist-info → fbgemm_gpu_genai_nightly-2025.12.17.dist-info}/RECORD +25 -25
  24. {fbgemm_gpu_genai_nightly-2025.11.4.dist-info → fbgemm_gpu_genai_nightly-2025.12.17.dist-info}/WHEEL +0 -0
  25. {fbgemm_gpu_genai_nightly-2025.11.4.dist-info → fbgemm_gpu_genai_nightly-2025.12.17.dist-info}/top_level.txt +0 -0
fbgemm_gpu/__init__.py CHANGED
@@ -15,6 +15,8 @@ import torch
15
15
  # Based on the FBGEMM-PyTorch compatibility table at
16
16
  # https://docs.pytorch.org/FBGEMM/general/Releases.html#fbgemm-releases-compatibility
17
17
  _fbgemm_torch_compat_table = {
18
+ "1.5": "2.10",
19
+ "1.4": "2.9",
18
20
  "1.3": "2.8",
19
21
  "1.2": "2.7",
20
22
  "1.1": "2.6",
@@ -80,7 +82,7 @@ def _load_library(filename: str, version: str, no_throw: bool = False) -> None:
80
82
  """
81
83
  )
82
84
 
83
- elif str(torch.__version__) != _fbgemm_torch_compat_table[keys[0]]:
85
+ elif not str(torch.__version__).startswith(_fbgemm_torch_compat_table[keys[0]]):
84
86
  logging.warning(
85
87
  f"""
86
88
  \033[31m
@@ -131,6 +133,7 @@ fbgemm_gpu_libraries = [
131
133
  "fbgemm_gpu_config",
132
134
  "fbgemm_gpu_tbe_utils",
133
135
  "fbgemm_gpu_tbe_index_select",
136
+ "fbgemm_gpu_tbe_cache",
134
137
  "fbgemm_gpu_tbe_optimizers",
135
138
  "fbgemm_gpu_tbe_inference",
136
139
  "fbgemm_gpu_tbe_training_forward",
fbgemm_gpu/asmjit.so CHANGED
Binary file
@@ -1,6 +1,6 @@
1
1
 
2
2
  {
3
- "version": "2025.11.4",
3
+ "version": "2025.12.17",
4
4
  "target": "genai",
5
5
  "variant": "cuda"
6
6
  }
@@ -289,7 +289,7 @@ def triton_quantize_mx4_unpack(
289
289
  stochastic_casting (bool): Whether to use stochastic casting.
290
290
 
291
291
  Returns:
292
- torch.Tensor: [M / 2] mx4 scaled tensor packed into in8
292
+ torch.Tensor: [M / 2] mx4 scaled tensor packed into uint8
293
293
  torch.Tensor: [M / group_size] mx4 shared exponents into int8
294
294
 
295
295
  eg.
@@ -1410,8 +1410,9 @@ def _kernel_nvfp4_quantize(
1410
1410
 
1411
1411
  # Apply scale_ to input. We do this by broadcasting scale.
1412
1412
  # scaled_a = a * global_scale (fp32) / local_scale (fp8)
1413
- scaled_a = tl.reshape(a, [GROUP_LOAD, GROUP_SIZE]) * tl.reshape(
1414
- input_global_scale / scale_, [GROUP_LOAD, 1]
1413
+ scaled_a = tl.div_rn(
1414
+ tl.reshape(a, [GROUP_LOAD, GROUP_SIZE]).to(tl.float32),
1415
+ tl.reshape(scale_ / input_global_scale, [GROUP_LOAD, 1]).to(tl.float32),
1415
1416
  )
1416
1417
  # Reshape back to a flat array.
1417
1418
  scaled_a = tl.reshape(scaled_a, [GROUP_LOAD * GROUP_SIZE])
@@ -1212,6 +1212,8 @@ def matmul_fp8_row(
1212
1212
  imprecise_acc: bool = False,
1213
1213
  tma_persistent: bool = True,
1214
1214
  no_use_persistent: Optional[bool] = None,
1215
+ # add an option to explicitly require the use of persistent process
1216
+ use_persistent: Optional[bool] = None,
1215
1217
  use_warp_specialization: bool = False,
1216
1218
  ) -> torch.Tensor:
1217
1219
  """
@@ -1232,12 +1234,16 @@ def matmul_fp8_row(
1232
1234
  Returns:
1233
1235
  torch.Tensor: [M, N] Output tensor a @ b / (a_scale[:, None] * b_scale[None, :])
1234
1236
  """
1235
- if no_use_persistent is None:
1237
+ if use_persistent:
1238
+ no_use_persistent = False
1239
+ elif no_use_persistent is None:
1236
1240
  # Default True for AMD and False for Nvidia.
1237
1241
  if torch.version.hip is not None:
1238
1242
  no_use_persistent = True
1239
1243
  else:
1240
1244
  no_use_persistent = False
1245
+ # if use_persistent is explicitly requested, set o_use_persistent to False
1246
+
1241
1247
  # Get datatypes and constants to use.
1242
1248
  pt_fp8_dtype, _, _, _ = get_fp8_constants()
1243
1249
  # Handle 3D+ a shape
@@ -509,14 +509,13 @@ def _fbgemm_grouped_gemm_ws(
509
509
  num_tiles = num_m_tiles * NUM_N_TILES
510
510
 
511
511
  if USE_TMA_STORE:
512
- with tl.async_task([0]):
513
- c_desc_ptr = tl.make_tensor_descriptor(
514
- c_ptr + M_start_offset * N,
515
- shape=[m_size, N],
516
- # pyre-ignore
517
- strides=[N, 1],
518
- block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
519
- )
512
+ c_desc_ptr = tl.make_tensor_descriptor(
513
+ c_ptr + M_start_offset * N,
514
+ shape=[m_size, N],
515
+ # pyre-ignore
516
+ strides=[N, 1],
517
+ block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
518
+ )
520
519
 
521
520
  # Move across tiles
522
521
  next_iterated_tiles = iterated_tiles + num_tiles
@@ -534,72 +533,59 @@ def _fbgemm_grouped_gemm_ws(
534
533
  m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
535
534
  n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
536
535
  for k_offset in range(0, K, BLOCK_SIZE_K):
537
- with tl.async_task([0]):
538
- a = tl._experimental_descriptor_load(
539
- a_desc_ptr,
540
- [m_offset, k_offset],
541
- [BLOCK_SIZE_M, BLOCK_SIZE_K],
542
- dtype,
543
- )
544
- b = tl._experimental_descriptor_load(
545
- b_desc_ptr,
546
- [n_offset, k_offset],
547
- [BLOCK_SIZE_N, BLOCK_SIZE_K],
548
- dtype,
549
- )
550
- with tl.async_task([1, NUM_CONSUMER_GROUPS]):
551
- if USE_FAST_ACCUM:
552
- accumulator = tl.dot(a, b.T, accumulator)
553
- else:
554
- accumulator += tl.dot(a, b.T)
536
+ a = tl._experimental_descriptor_load(
537
+ a_desc_ptr,
538
+ [m_offset, k_offset],
539
+ [BLOCK_SIZE_M, BLOCK_SIZE_K],
540
+ dtype,
541
+ )
542
+ b = tl._experimental_descriptor_load(
543
+ b_desc_ptr,
544
+ [n_offset, k_offset],
545
+ [BLOCK_SIZE_N, BLOCK_SIZE_K],
546
+ dtype,
547
+ )
548
+ if USE_FAST_ACCUM:
549
+ accumulator = tl.dot(a, b.T, accumulator)
550
+ else:
551
+ accumulator += tl.dot(a, b.T)
555
552
 
556
553
  if USE_TMA_STORE:
557
- with tl.async_task([1, NUM_CONSUMER_GROUPS]):
558
- m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
559
- n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
560
- # pyre-ignore
561
- c_desc_ptr.store(
562
- [m_offset, n_offset],
563
- accumulator.to(c_ptr.dtype.element_ty),
564
- )
554
+ m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
555
+ n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
556
+ # pyre-ignore
557
+ c_desc_ptr.store(
558
+ [m_offset, n_offset],
559
+ accumulator.to(c_ptr.dtype.element_ty),
560
+ )
565
561
  elif FUSE_SCATTER_ADD:
566
- with tl.async_task([1, NUM_CONSUMER_GROUPS]):
567
- offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(
568
- 0, BLOCK_SIZE_M
569
- )
570
- mask = offs_am < m_size
571
- m_offsets = tl.load(
572
- scatter_add_indices + M_start_offset + offs_am,
573
- mask=mask,
574
- cache_modifier=".ca",
575
- )
576
- offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(
577
- 0, BLOCK_SIZE_N
578
- )
579
- c = accumulator.to(c_ptr.dtype.element_ty)
580
- tl.atomic_add(
581
- c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
582
- c,
583
- mask=mask[:, None],
584
- sem="relaxed",
585
- )
562
+ offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
563
+ mask = offs_am < m_size
564
+ m_offsets = tl.load(
565
+ scatter_add_indices + M_start_offset + offs_am,
566
+ mask=mask,
567
+ cache_modifier=".ca",
568
+ )
569
+ offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
570
+ c = accumulator.to(c_ptr.dtype.element_ty)
571
+ tl.atomic_add(
572
+ c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
573
+ c,
574
+ mask=mask[:, None],
575
+ sem="relaxed",
576
+ )
586
577
  else:
587
- with tl.async_task([1, NUM_CONSUMER_GROUPS]):
588
- offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(
589
- 0, BLOCK_SIZE_M
590
- )
591
- offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(
592
- 0, BLOCK_SIZE_N
593
- )
594
- c = accumulator.to(c_ptr.dtype.element_ty)
595
- tl.store(
596
- c_ptr
597
- + (M_start_offset + offs_am[:, None]) * N
598
- + offs_bn[None, :],
599
- c,
600
- mask=offs_am[:, None] < m_size,
601
- cache_modifier=".cs",
602
- )
578
+ offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
579
+ offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
580
+ c = accumulator.to(c_ptr.dtype.element_ty)
581
+ tl.store(
582
+ c_ptr
583
+ + (M_start_offset + offs_am[:, None]) * N
584
+ + offs_bn[None, :],
585
+ c,
586
+ mask=offs_am[:, None] < m_size,
587
+ cache_modifier=".cs",
588
+ )
603
589
  tidx += NUM_SMS
604
590
 
605
591
  iterated_tiles += num_tiles
@@ -841,14 +827,13 @@ def _fbgemm_grouped_gemm_fp8_rowwise_ws(
841
827
  num_tiles = num_m_tiles * NUM_N_TILES
842
828
 
843
829
  if USE_TMA_STORE:
844
- with tl.async_task([0]):
845
- c_desc_ptr = tl.make_tensor_descriptor(
846
- c_ptr + M_start_offset * N,
847
- shape=[m_size, N],
848
- # pyre-ignore
849
- strides=[N, 1],
850
- block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
851
- )
830
+ c_desc_ptr = tl.make_tensor_descriptor(
831
+ c_ptr + M_start_offset * N,
832
+ shape=[m_size, N],
833
+ # pyre-ignore
834
+ strides=[N, 1],
835
+ block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
836
+ )
852
837
 
853
838
  # Move across tiles
854
839
  next_iterated_tiles = iterated_tiles + num_tiles
@@ -867,107 +852,85 @@ def _fbgemm_grouped_gemm_fp8_rowwise_ws(
867
852
  m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
868
853
  n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
869
854
  for k_offset in range(0, K, BLOCK_SIZE_K):
870
- with tl.async_task([0]):
871
- a = tl._experimental_descriptor_load(
872
- a_desc_ptr,
873
- [m_offset, k_offset],
874
- [BLOCK_SIZE_M, BLOCK_SIZE_K],
875
- dtype,
876
- )
877
- b = tl._experimental_descriptor_load(
878
- b_desc_ptr,
879
- [n_offset, k_offset],
880
- [BLOCK_SIZE_N, BLOCK_SIZE_K],
881
- dtype,
882
- )
883
- with tl.async_task([1, NUM_CONSUMER_GROUPS]):
884
- if USE_FAST_ACCUM:
885
- accumulator = tl.dot(a, b.T, accumulator)
886
- else:
887
- accumulator += tl.dot(a, b.T)
855
+ a = tl._experimental_descriptor_load(
856
+ a_desc_ptr,
857
+ [m_offset, k_offset],
858
+ [BLOCK_SIZE_M, BLOCK_SIZE_K],
859
+ dtype,
860
+ )
861
+ b = tl._experimental_descriptor_load(
862
+ b_desc_ptr,
863
+ [n_offset, k_offset],
864
+ [BLOCK_SIZE_N, BLOCK_SIZE_K],
865
+ dtype,
866
+ )
867
+ if USE_FAST_ACCUM:
868
+ accumulator = tl.dot(a, b.T, accumulator)
869
+ else:
870
+ accumulator += tl.dot(a, b.T)
888
871
 
889
872
  if USE_TMA_LOAD_ON_SCALES:
890
- with tl.async_task([0]):
891
- b_scale = tl._experimental_descriptor_load(
892
- b_scale_desc_ptr,
893
- [n_offset],
894
- [BLOCK_SIZE_N],
895
- tl.float32,
896
- )
897
-
898
- with tl.async_task([1, NUM_CONSUMER_GROUPS]):
899
- offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(
900
- 0, BLOCK_SIZE_M
901
- )
902
- a_scale = tl.load(
903
- a_scale_ptr + M_start_offset + offs_am[:, None],
904
- mask=offs_am[:, None] < m_size,
905
- cache_modifier=".ca",
906
- )
907
- c = accumulator.to(tl.float32) * a_scale * b_scale[None, :]
873
+ b_scale = tl._experimental_descriptor_load(
874
+ b_scale_desc_ptr,
875
+ [n_offset],
876
+ [BLOCK_SIZE_N],
877
+ tl.float32,
878
+ )
879
+
880
+ offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
881
+ a_scale = tl.load(
882
+ a_scale_ptr + M_start_offset + offs_am[:, None],
883
+ mask=offs_am[:, None] < m_size,
884
+ cache_modifier=".ca",
885
+ )
886
+ c = accumulator.to(tl.float32) * a_scale * b_scale[None, :]
908
887
  else:
909
- with tl.async_task([1, NUM_CONSUMER_GROUPS]):
910
- offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(
911
- 0, BLOCK_SIZE_M
912
- )
913
- offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(
914
- 0, BLOCK_SIZE_N
915
- )
916
- a_scale = tl.load(
917
- a_scale_ptr + M_start_offset + offs_am[:, None],
918
- mask=offs_am[:, None] < m_size,
919
- cache_modifier=".ca",
920
- )
921
- b_scale = tl.load(
922
- b_scale_ptr + N_start_offset + offs_bn[None, :],
923
- cache_modifier=".ca",
924
- )
925
- c = accumulator.to(tl.float32) * a_scale * b_scale
888
+ offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
889
+ offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
890
+ a_scale = tl.load(
891
+ a_scale_ptr + M_start_offset + offs_am[:, None],
892
+ mask=offs_am[:, None] < m_size,
893
+ cache_modifier=".ca",
894
+ )
895
+ b_scale = tl.load(
896
+ b_scale_ptr + N_start_offset + offs_bn[None, :],
897
+ cache_modifier=".ca",
898
+ )
899
+ c = accumulator.to(tl.float32) * a_scale * b_scale
926
900
 
927
901
  if USE_TMA_STORE:
928
- with tl.async_task([1, NUM_CONSUMER_GROUPS]):
929
- m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
930
- n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
931
- # pyre-ignore
932
- c_desc_ptr.store(
933
- [m_offset, n_offset], c.to(c_ptr.dtype.element_ty)
934
- )
902
+ m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
903
+ n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
904
+ # pyre-ignore
905
+ c_desc_ptr.store(
906
+ [m_offset, n_offset], c.to(c_ptr.dtype.element_ty)
907
+ )
935
908
  elif FUSE_SCATTER_ADD:
936
- with tl.async_task([1, NUM_CONSUMER_GROUPS]):
937
- offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(
938
- 0, BLOCK_SIZE_M
939
- )
940
- mask = offs_am < m_size
941
- m_offsets = tl.load(
942
- scatter_add_indices + M_start_offset + offs_am,
943
- mask=mask,
944
- cache_modifier=".ca",
945
- )
946
- offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(
947
- 0, BLOCK_SIZE_N
948
- )
949
- tl.atomic_add(
950
- c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
951
- c,
952
- mask=mask[:, None],
953
- sem="relaxed",
954
- )
909
+ offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
910
+ mask = offs_am < m_size
911
+ m_offsets = tl.load(
912
+ scatter_add_indices + M_start_offset + offs_am,
913
+ mask=mask,
914
+ cache_modifier=".ca",
915
+ )
916
+ offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
917
+ tl.atomic_add(
918
+ c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
919
+ c,
920
+ mask=mask[:, None],
921
+ sem="relaxed",
922
+ )
955
923
  else:
956
- with tl.async_task([1, NUM_CONSUMER_GROUPS]):
957
- offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(
958
- 0, BLOCK_SIZE_M
959
- )
960
- offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(
961
- 0, BLOCK_SIZE_N
962
- )
963
- tl.store(
964
- c_ptr
965
- + (M_start_offset + offs_am[:, None]) * N
966
- + offs_bn[None, :],
967
- c,
968
- mask=offs_am[:, None] < m_size,
969
- cache_modifier=".cs",
970
- )
924
+ offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
925
+ offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
926
+ tl.store(
927
+ c_ptr
928
+ + (M_start_offset + offs_am[:, None]) * N
929
+ + offs_bn[None, :],
930
+ c,
931
+ mask=offs_am[:, None] < m_size,
932
+ cache_modifier=".cs",
933
+ )
971
934
  tidx += NUM_SMS
972
935
 
973
936
  iterated_tiles += num_tiles
@@ -29,4 +29,18 @@ else:
29
29
  )
30
30
 
31
31
  from . import cutlass_blackwell_fmha_custom_op # noqa: F401
32
- from .cutlass_blackwell_fmha_interface import cutlass_blackwell_fmha_func # noqa: F401
32
+ from .cutlass_blackwell_fmha_interface import ( # noqa: F401
33
+ _cutlass_blackwell_fmha_forward,
34
+ cutlass_blackwell_fmha_decode_forward,
35
+ cutlass_blackwell_fmha_func,
36
+ )
37
+
38
+ # Note: _cutlass_blackwell_fmha_forward is an internal function (indicated by leading underscore)
39
+ # that is exported here specifically for testing purposes. It allows tests to access the LSE
40
+ # (log-sum-exp) values returned by the forward pass without modifying the public API.
41
+ # Production code should use cutlass_blackwell_fmha_func instead.
42
+ __all__ = [
43
+ "_cutlass_blackwell_fmha_forward",
44
+ "cutlass_blackwell_fmha_decode_forward",
45
+ "cutlass_blackwell_fmha_func",
46
+ ]
@@ -12,13 +12,13 @@ from torch.library import register_fake
12
12
 
13
13
  torch.library.define(
14
14
  "blackwell_fmha::fmha_fwd",
15
- "(Tensor q, Tensor k, Tensor v, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seq_len_q, int? max_seq_len_k, float? softmax_scale, bool? causal, Tensor? seqlen_kv) -> (Tensor, Tensor)",
15
+ "(Tensor q, Tensor k, Tensor v, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seq_len_q, int? max_seq_len_k, float? softmax_scale, bool? causal, Tensor? seqlen_kv, Tensor? page_table, int seqlen_k=-1, int window_size_left=-1, int window_size_right=-1, bool bottom_right=True) -> (Tensor, Tensor)",
16
16
  tags=torch.Tag.pt2_compliant_tag,
17
17
  )
18
18
 
19
19
  torch.library.define(
20
20
  "blackwell_fmha::fmha_bwd",
21
- "(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seq_len_q, int? max_seq_len_k, bool? causal) -> (Tensor, Tensor, Tensor)",
21
+ "(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seq_len_q, int? max_seq_len_k, float? softmax_scale, bool? causal, int window_size_left=-1, int window_size_right=-1, bool bottom_right=True, bool deterministic=False) -> (Tensor, Tensor, Tensor)",
22
22
  tags=torch.Tag.pt2_compliant_tag,
23
23
  )
24
24
 
@@ -35,6 +35,11 @@ def custom_op_fmha(
35
35
  softmax_scale: Optional[float] = None,
36
36
  causal: bool = False,
37
37
  seqlen_kv: Optional[torch.Tensor] = None,
38
+ page_table: Optional[torch.Tensor] = None,
39
+ seqlen_k: Optional[int] = None,
40
+ window_size_left: int = -1,
41
+ window_size_right: int = -1,
42
+ bottom_right: bool = True,
38
43
  ) -> tuple[torch.Tensor, torch.Tensor]:
39
44
  assert q.is_contiguous(), "q is not contiguous"
40
45
  assert k.is_contiguous(), "k is not contiguous"
@@ -42,6 +47,7 @@ def custom_op_fmha(
42
47
  assert q.is_cuda, "q must be on GPU"
43
48
  assert k.is_cuda, "k must be on GPU"
44
49
  assert v.is_cuda, "v must be on GPU"
50
+
45
51
  return torch.ops.fbgemm.fmha_fwd(
46
52
  q,
47
53
  k,
@@ -53,6 +59,11 @@ def custom_op_fmha(
53
59
  softmax_scale=softmax_scale,
54
60
  causal=causal,
55
61
  seqlen_kv=seqlen_kv,
62
+ page_table=page_table,
63
+ seqlen_k=seqlen_k,
64
+ window_size_left=window_size_left,
65
+ window_size_right=window_size_right,
66
+ bottom_right=bottom_right,
56
67
  )
57
68
 
58
69
 
@@ -68,6 +79,11 @@ def fmha_fwd_meta(
68
79
  softmax_scale: Optional[float] = None,
69
80
  causal: bool = False,
70
81
  seqlen_kv: Optional[torch.Tensor] = None,
82
+ page_table: Optional[torch.Tensor] = None,
83
+ seqlen_k: Optional[int] = None,
84
+ window_size_left: int = -1,
85
+ window_size_right: int = -1,
86
+ bottom_right: bool = True,
71
87
  ):
72
88
  if q.dtype == torch.float16:
73
89
  out_dtype = torch.float16
@@ -122,8 +138,14 @@ def custom_op_fmha_bwd(
122
138
  cu_seqlens_k: Optional[torch.Tensor] = None,
123
139
  max_seq_len_q: Optional[int] = None,
124
140
  max_seq_len_k: Optional[int] = None,
141
+ softmax_scale: Optional[float] = None,
125
142
  causal: bool = False,
143
+ window_size_left: int = -1,
144
+ window_size_right: int = -1,
145
+ bottom_right: bool = True,
146
+ deterministic: bool = False,
126
147
  ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
148
+
127
149
  return torch.ops.fbgemm.fmha_bwd(
128
150
  dOutput,
129
151
  query,
@@ -135,7 +157,12 @@ def custom_op_fmha_bwd(
135
157
  cu_seqlens_k=cu_seqlens_k,
136
158
  max_seq_len_q=max_seq_len_q,
137
159
  max_seq_len_k=max_seq_len_k,
160
+ softmax_scale=softmax_scale,
138
161
  causal=causal,
162
+ window_size_left=window_size_left,
163
+ window_size_right=window_size_right,
164
+ bottom_right=bottom_right,
165
+ deterministic=deterministic,
139
166
  )
140
167
 
141
168
 
@@ -151,7 +178,12 @@ def fmha_bwd_meta(
151
178
  cu_seqlens_k: Optional[torch.Tensor] = None,
152
179
  max_seq_len_q: Optional[int] = None,
153
180
  max_seq_len_k: Optional[int] = None,
181
+ softmax_scale: Optional[float] = None,
154
182
  causal: bool = False,
183
+ window_size_left: int = -1,
184
+ window_size_right: int = -1,
185
+ bottom_right: bool = True,
186
+ deterministic: bool = False,
155
187
  ):
156
188
  return (
157
189
  torch.empty_like(query),
@@ -198,9 +230,30 @@ def _backward(ctx, *grad):
198
230
  ctx.cu_seqlens_k,
199
231
  ctx.max_seq_len_q,
200
232
  ctx.max_seq_len_k,
233
+ ctx.softmax_scale,
201
234
  ctx.causal,
235
+ ctx.window_size_left,
236
+ ctx.window_size_right,
237
+ ctx.bottom_right,
238
+ ctx.deterministic,
239
+ )
240
+ return (
241
+ dq,
242
+ dk,
243
+ dv,
244
+ None,
245
+ None,
246
+ None,
247
+ None,
248
+ None,
249
+ None,
250
+ None,
251
+ None,
252
+ None,
253
+ None,
254
+ None,
255
+ None,
202
256
  )
203
- return dq, dk, dv, None, None, None, None, None, None, None
204
257
 
205
258
 
206
259
  def _setup_context(ctx, inputs, output):
@@ -215,6 +268,11 @@ def _setup_context(ctx, inputs, output):
215
268
  softmax_scale,
216
269
  causal,
217
270
  seqlen_kv,
271
+ page_table,
272
+ seqlen_k,
273
+ window_size_left,
274
+ window_size_right,
275
+ bottom_right,
218
276
  ) = inputs
219
277
  (out, softmax_lse) = output
220
278
  ctx.save_for_backward(q, k, v, out, softmax_lse)
@@ -224,6 +282,10 @@ def _setup_context(ctx, inputs, output):
224
282
  ctx.max_seq_len_k = max_seq_len_k
225
283
  ctx.cu_seqlens_q = cu_seqlens_q
226
284
  ctx.cu_seqlens_k = cu_seqlens_k
285
+ ctx.window_size_left = window_size_left
286
+ ctx.window_size_right = window_size_right
287
+ ctx.bottom_right = bottom_right
288
+ ctx.deterministic = False # Set default value
227
289
  ctx.is_gen = False
228
290
 
229
291
 
@@ -246,6 +308,11 @@ def cutlass_blackwell_fmha_custom_op(
246
308
  max_seq_len_q: int | None = None,
247
309
  max_seq_len_k: int | None = None,
248
310
  seqlen_kv: torch.Tensor | None = None,
311
+ page_table: torch.Tensor | None = None,
312
+ seqlen_k: int | None = -1,
313
+ window_size_left: int | None = -1,
314
+ window_size_right: int | None = -1,
315
+ bottom_right: bool | None = True,
249
316
  ):
250
317
  return torch.ops.blackwell_fmha.fmha_fwd(
251
318
  q=q,
@@ -258,4 +325,9 @@ def cutlass_blackwell_fmha_custom_op(
258
325
  softmax_scale=softmax_scale,
259
326
  causal=causal,
260
327
  seqlen_kv=seqlen_kv,
328
+ page_table=page_table,
329
+ seqlen_k=seqlen_k,
330
+ window_size_left=window_size_left,
331
+ window_size_right=window_size_right,
332
+ bottom_right=bottom_right,
261
333
  )[0]