fastvideo-kernel 0.2.6__tar.gz → 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. fastvideo_kernel-0.3.0/CMakeLists.txt +352 -0
  2. {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/MANIFEST.in +1 -0
  3. {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/PKG-INFO +34 -2
  4. {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/README.md +32 -0
  5. fastvideo_kernel-0.3.0/attn_qat_infer/__init__.py +16 -0
  6. fastvideo_kernel-0.3.0/attn_qat_infer/api.py +189 -0
  7. fastvideo_kernel-0.3.0/attn_qat_infer/blackwell/__init__.py +1 -0
  8. fastvideo_kernel-0.3.0/attn_qat_infer/blackwell/api.cu +347 -0
  9. fastvideo_kernel-0.3.0/attn_qat_infer/blackwell/block_config.h +28 -0
  10. fastvideo_kernel-0.3.0/attn_qat_infer/blackwell/block_info.h +60 -0
  11. fastvideo_kernel-0.3.0/attn_qat_infer/blackwell/blockscaled_layout.h +149 -0
  12. fastvideo_kernel-0.3.0/attn_qat_infer/blackwell/cute_extension.h +327 -0
  13. fastvideo_kernel-0.3.0/attn_qat_infer/blackwell/epilogue_tma_ws.h +222 -0
  14. fastvideo_kernel-0.3.0/attn_qat_infer/blackwell/kernel_traits.h +202 -0
  15. fastvideo_kernel-0.3.0/attn_qat_infer/blackwell/kernel_ws.h +204 -0
  16. fastvideo_kernel-0.3.0/attn_qat_infer/blackwell/launch.h +114 -0
  17. fastvideo_kernel-0.3.0/attn_qat_infer/blackwell/mainloop_tma_ws.h +926 -0
  18. fastvideo_kernel-0.3.0/attn_qat_infer/blackwell/named_barrier.h +119 -0
  19. fastvideo_kernel-0.3.0/attn_qat_infer/blackwell/params.h +180 -0
  20. fastvideo_kernel-0.3.0/attn_qat_infer/blackwell/softmax_fused.h +190 -0
  21. fastvideo_kernel-0.3.0/attn_qat_infer/blackwell/static_switch.h +83 -0
  22. fastvideo_kernel-0.3.0/attn_qat_infer/blackwell/tile_scheduler.h +304 -0
  23. fastvideo_kernel-0.3.0/attn_qat_infer/blackwell/utils.h +408 -0
  24. fastvideo_kernel-0.3.0/attn_qat_infer/quantization/__init__.py +1 -0
  25. fastvideo_kernel-0.3.0/attn_qat_infer/quantization/bench/bench_quant_k.py +90 -0
  26. fastvideo_kernel-0.3.0/attn_qat_infer/quantization/bench/bench_quant_q.py +86 -0
  27. fastvideo_kernel-0.3.0/attn_qat_infer/quantization/bench/bench_quant_v.py +86 -0
  28. fastvideo_kernel-0.3.0/attn_qat_infer/quantization/bench/bench_utils.py +169 -0
  29. fastvideo_kernel-0.3.0/attn_qat_infer/quantization/cuda_utils.h +52 -0
  30. fastvideo_kernel-0.3.0/attn_qat_infer/quantization/fp4_quantization_4d.cu +639 -0
  31. fastvideo_kernel-0.3.0/benchmarks/bench_fused_compress_topk.py +315 -0
  32. fastvideo_kernel-0.3.0/build.sh +146 -0
  33. {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/csrc/attention/block_sparse_h100.cu +22 -3
  34. {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/csrc/attention/st_attn_h100.cu +15 -1
  35. fastvideo_kernel-0.2.6/dist/fastvideo_kernel-0.2.6-cp312-cp312-manylinux_2_34_x86_64.manylinux_2_35_x86_64.whl → fastvideo_kernel-0.3.0/dist/fastvideo_kernel-0.3.0-cp312-cp312-manylinux_2_34_x86_64.manylinux_2_35_x86_64.whl +0 -0
  36. {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/pyproject.toml +3 -3
  37. {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/python/fastvideo_kernel/__init__.py +14 -0
  38. fastvideo_kernel-0.3.0/python/fastvideo_kernel/block_sparse_attn.py +424 -0
  39. fastvideo_kernel-0.3.0/python/fastvideo_kernel/block_sparse_attn_256.py +170 -0
  40. fastvideo_kernel-0.3.0/python/fastvideo_kernel/block_sparse_attn_cute_fwd.py +267 -0
  41. fastvideo_kernel-0.3.0/python/fastvideo_kernel/block_sparse_attn_varlen.py +207 -0
  42. fastvideo_kernel-0.3.0/python/fastvideo_kernel/ops.py +238 -0
  43. fastvideo_kernel-0.3.0/python/fastvideo_kernel/triton_kernels/attn_qat_train.py +1119 -0
  44. fastvideo_kernel-0.3.0/python/fastvideo_kernel/triton_kernels/fused_attention.py +55 -0
  45. fastvideo_kernel-0.3.0/python/fastvideo_kernel/triton_kernels/fused_compress_topk.py +334 -0
  46. {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/python/fastvideo_kernel/triton_kernels/index.py +113 -1
  47. fastvideo_kernel-0.3.0/python/fastvideo_kernel/triton_kernels/nvfp4_utils.py +237 -0
  48. fastvideo_kernel-0.3.0/python/fastvideo_kernel/triton_kernels/quant_utils.py +80 -0
  49. fastvideo_kernel-0.3.0/python/fastvideo_kernel/version.py +1 -0
  50. {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/python/fastvideo_kernel/vmoba.py +1 -1
  51. fastvideo_kernel-0.3.0/tests/test_attn_qat_infer.py +148 -0
  52. fastvideo_kernel-0.3.0/tests/test_attn_qat_train.py +1503 -0
  53. fastvideo_kernel-0.3.0/tests/test_fused_compress_topk.py +401 -0
  54. fastvideo_kernel-0.3.0/tests/test_vsa256_forward.py +131 -0
  55. fastvideo_kernel-0.3.0/tests/test_vsa256_forward_cross.py +139 -0
  56. fastvideo_kernel-0.3.0/tests/test_vsa256_forward_vbs.py +129 -0
  57. fastvideo_kernel-0.3.0/tests/test_vsa256_triton.py +150 -0
  58. fastvideo_kernel-0.3.0/tests/test_vsa_varlen.py +434 -0
  59. fastvideo_kernel-0.2.6/CMakeLists.txt +0 -185
  60. fastvideo_kernel-0.2.6/build.sh +0 -38
  61. fastvideo_kernel-0.2.6/dist/fastvideo_kernel-0.2.6-cp310-cp310-manylinux_2_34_x86_64.manylinux_2_35_x86_64.whl +0 -0
  62. fastvideo_kernel-0.2.6/dist/fastvideo_kernel-0.2.6-cp311-cp311-manylinux_2_34_x86_64.manylinux_2_35_x86_64.whl +0 -0
  63. fastvideo_kernel-0.2.6/python/fastvideo_kernel/block_sparse_attn.py +0 -294
  64. fastvideo_kernel-0.2.6/python/fastvideo_kernel/ops.py +0 -149
  65. fastvideo_kernel-0.2.6/python/fastvideo_kernel/version.py +0 -1
  66. {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/LICENSE +0 -0
  67. {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/benchmarks/bench_vsa.py +0 -0
  68. {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/csrc/common_extension.cpp +0 -0
  69. {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/csrc/turbodiffusion/common/common.hpp +0 -0
  70. {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/csrc/turbodiffusion/common/launch.hpp +0 -0
  71. {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/csrc/turbodiffusion/common/load.hpp +0 -0
  72. {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/csrc/turbodiffusion/common/store.hpp +0 -0
  73. {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/csrc/turbodiffusion/gemm/gemm.cu +0 -0
  74. {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/csrc/turbodiffusion/gemm/kernel.hpp +0 -0
  75. {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/csrc/turbodiffusion/gemm/launch.hpp +0 -0
  76. {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/csrc/turbodiffusion/gemm/utils.hpp +0 -0
  77. {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/csrc/turbodiffusion/norm/layernorm.cu +0 -0
  78. {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/csrc/turbodiffusion/norm/layernorm.hpp +0 -0
  79. {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/csrc/turbodiffusion/norm/rmsnorm.cu +0 -0
  80. {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/csrc/turbodiffusion/norm/rmsnorm.hpp +0 -0
  81. {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/csrc/turbodiffusion/quant/quant.cu +0 -0
  82. {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/csrc/turbodiffusion/quant/quant.hpp +0 -0
  83. {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/python/fastvideo_kernel/triton_kernels/block_sparse_attn_triton.py +0 -0
  84. {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/python/fastvideo_kernel/triton_kernels/sla_triton.py +0 -0
  85. {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/python/fastvideo_kernel/triton_kernels/st_attn_triton.py +0 -0
  86. {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/python/fastvideo_kernel/turbodiffusion_ops.py +0 -0
  87. {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/tests/__init__.py +0 -0
  88. {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/tests/support_flex_sta.py +0 -0
  89. {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/tests/test_sta.py +0 -0
  90. {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/tests/test_turbodiffusion.py +0 -0
  91. {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/tests/test_vmoba_correctness.py +0 -0
  92. {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/tests/test_vsa.py +0 -0
  93. {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/tests/test_vsa_forward.py +0 -0
  94. {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/tests/utils.py +0 -0
@@ -0,0 +1,352 @@
1
+ cmake_minimum_required(VERSION 3.26 FATAL_ERROR)
2
+ project(fastvideo-kernel LANGUAGES CXX)
3
+
4
+ # Prefer environment variable (used by CI or uv pip install git+repo_addr) if CMake var is not explicitly set.
5
+ if(NOT DEFINED GPU_BACKEND AND DEFINED ENV{GPU_BACKEND})
6
+ set(GPU_BACKEND "$ENV{GPU_BACKEND}")
7
+ endif()
8
+
9
+ if(GPU_BACKEND STREQUAL "ROCM")
10
+ enable_language(HIP)
11
+ else()
12
+ enable_language(CUDA)
13
+ # Ensure CUDA toolkit targets (CUDA::cudart, CUDA::cuda_driver, etc.) are available.
14
+ find_package(CUDAToolkit REQUIRED)
15
+ endif()
16
+
17
+ # Import common utils if needed, but we keep it simple for now
18
+
19
+ # Find Python and Torch
20
+ find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)
21
+
22
+ # Robustly find Torch include paths using Python
23
+ execute_process(
24
+ COMMAND "${Python_EXECUTABLE}" -c "import torch; from torch.utils.cpp_extension import include_paths; print(';'.join(include_paths()))"
25
+ OUTPUT_VARIABLE TORCH_INCLUDE_PATHS
26
+ OUTPUT_STRIP_TRAILING_WHITESPACE
27
+ )
28
+ list(APPEND TORCH_INCLUDE_DIRS ${TORCH_INCLUDE_PATHS})
29
+
30
+ # Find Torch package (still useful for libraries)
31
+ find_package(Torch REQUIRED)
32
+
33
+ # Include directories
34
+ include_directories(
35
+ ${CMAKE_SOURCE_DIR}/include
36
+ ${CMAKE_SOURCE_DIR}/include/cutlass/include
37
+ ${CMAKE_SOURCE_DIR}/include/tk/include
38
+ ${CMAKE_SOURCE_DIR}/include/tk/prototype
39
+ ${CMAKE_SOURCE_DIR}/csrc
40
+ ${CMAKE_SOURCE_DIR}/csrc/turbodiffusion
41
+ ${TORCH_INCLUDE_DIRS}
42
+ )
43
+
44
+ # ---------------------------
45
+ # ThunderKittens (TK) toggles
46
+ # ---------------------------
47
+ # AUTO: enable TK only when we can confidently target Hopper (sm_90a).
48
+ # ON: force-enable TK kernels (intended for release wheels/images; does NOT require a GPU).
49
+ # OFF: never build TK kernels.
50
+ set(FASTVIDEO_KERNEL_BUILD_TK "AUTO" CACHE STRING "Build ThunderKittens kernels: AUTO/ON/OFF")
51
+ set_property(CACHE FASTVIDEO_KERNEL_BUILD_TK PROPERTY STRINGS AUTO ON OFF)
52
+
53
+ set(_FASTVIDEO_KERNEL_BUILD_ATTN_QAT_INFER_DEFAULT "AUTO")
54
+ if(DEFINED FASTVIDEO_KERNEL_BUILD_MODIFIED_SAGE3 AND NOT DEFINED CACHE{FASTVIDEO_KERNEL_BUILD_ATTN_QAT_INFER})
55
+ set(_FASTVIDEO_KERNEL_BUILD_ATTN_QAT_INFER_DEFAULT "${FASTVIDEO_KERNEL_BUILD_MODIFIED_SAGE3}")
56
+ endif()
57
+
58
+ set(FASTVIDEO_KERNEL_BUILD_ATTN_QAT_INFER "${_FASTVIDEO_KERNEL_BUILD_ATTN_QAT_INFER_DEFAULT}" CACHE STRING
59
+ "Build attn_qat_infer Blackwell inference kernels: AUTO/ON/OFF")
60
+ set_property(CACHE FASTVIDEO_KERNEL_BUILD_ATTN_QAT_INFER PROPERTY STRINGS AUTO ON OFF)
61
+
62
+ if(DEFINED FASTVIDEO_KERNEL_BUILD_MODIFIED_SAGE3)
63
+ message(DEPRECATION
64
+ "FASTVIDEO_KERNEL_BUILD_MODIFIED_SAGE3 is deprecated. "
65
+ "Use FASTVIDEO_KERNEL_BUILD_ATTN_QAT_INFER instead.")
66
+ endif()
67
+
68
+ # Prefer environment variable (used by CI) if CMake var is not explicitly set.
69
+ if(NOT DEFINED TORCH_CUDA_ARCH_LIST AND DEFINED ENV{TORCH_CUDA_ARCH_LIST})
70
+ set(TORCH_CUDA_ARCH_LIST "$ENV{TORCH_CUDA_ARCH_LIST}")
71
+ endif()
72
+
73
+ message(STATUS "TORCH_CUDA_ARCH_LIST (cmake/env): ${TORCH_CUDA_ARCH_LIST}")
74
+ message(STATUS "FASTVIDEO_KERNEL_BUILD_TK: ${FASTVIDEO_KERNEL_BUILD_TK}")
75
+ message(STATUS "FASTVIDEO_KERNEL_BUILD_ATTN_QAT_INFER: ${FASTVIDEO_KERNEL_BUILD_ATTN_QAT_INFER}")
76
+
77
+ set(ENABLE_TK_KERNELS OFF)
78
+ if(FASTVIDEO_KERNEL_BUILD_TK STREQUAL "ON")
79
+ set(ENABLE_TK_KERNELS ON)
80
+ elseif(FASTVIDEO_KERNEL_BUILD_TK STREQUAL "OFF")
81
+ set(ENABLE_TK_KERNELS OFF)
82
+ else()
83
+ # AUTO: detect Hopper if possible.
84
+ if(TORCH_CUDA_ARCH_LIST)
85
+ # Accept common spellings: 9.0a, 90a, sm_90a.
86
+ string(REGEX MATCH "(^|[; ,])((9\\.0a)|(90a)|(sm_90a))([; ,]|$)" _HAS_90A "${TORCH_CUDA_ARCH_LIST}")
87
+ if(_HAS_90A)
88
+ set(ENABLE_TK_KERNELS ON)
89
+ endif()
90
+ else()
91
+ # Best-effort local detection (works when a CUDA device is visible).
92
+ execute_process(
93
+ COMMAND "${Python_EXECUTABLE}" -c "import torch; import sys; \nprint('1' if (torch.cuda.is_available() and torch.version.cuda and torch.cuda.get_device_capability()[0] >= 9) else '0')"
94
+ OUTPUT_VARIABLE _LOCAL_HAS_HOPPER
95
+ OUTPUT_STRIP_TRAILING_WHITESPACE
96
+ ERROR_QUIET
97
+ )
98
+ if(_LOCAL_HAS_HOPPER STREQUAL "1")
99
+ set(ENABLE_TK_KERNELS ON)
100
+ endif()
101
+ endif()
102
+ endif()
103
+
104
+ if(ENABLE_TK_KERNELS)
105
+ message(STATUS "ThunderKittens kernels: ENABLED")
106
+ else()
107
+ message(STATUS "ThunderKittens kernels: DISABLED (will use Triton fallbacks at runtime)")
108
+ endif()
109
+
110
+ set(ENABLE_ATTN_QAT_INFER OFF)
111
+ if(GPU_BACKEND STREQUAL "ROCM")
112
+ message(STATUS "attn_qat_infer kernels: DISABLED (ROCm build)")
113
+ else()
114
+ set(_WANTS_ATTN_QAT_INFER OFF)
115
+ if(FASTVIDEO_KERNEL_BUILD_ATTN_QAT_INFER STREQUAL "ON")
116
+ set(_WANTS_ATTN_QAT_INFER ON)
117
+ elseif(FASTVIDEO_KERNEL_BUILD_ATTN_QAT_INFER STREQUAL "AUTO")
118
+ if(TORCH_CUDA_ARCH_LIST)
119
+ string(REGEX MATCH
120
+ "(^|[; ,])((12\\.0a)|(120a)|(sm_120a))([; ,]|$)"
121
+ _HAS_120A "${TORCH_CUDA_ARCH_LIST}")
122
+ if(_HAS_120A)
123
+ set(_WANTS_ATTN_QAT_INFER ON)
124
+ endif()
125
+ else()
126
+ execute_process(
127
+ COMMAND "${Python_EXECUTABLE}" -c
128
+ "import torch; print('1' if (torch.cuda.is_available() and torch.version.cuda and torch.cuda.get_device_capability()[0] >= 12) else '0')"
129
+ OUTPUT_VARIABLE _LOCAL_HAS_BLACKWELL
130
+ OUTPUT_STRIP_TRAILING_WHITESPACE
131
+ ERROR_QUIET
132
+ )
133
+ if(_LOCAL_HAS_BLACKWELL STREQUAL "1")
134
+ set(_WANTS_ATTN_QAT_INFER ON)
135
+ endif()
136
+ endif()
137
+ endif()
138
+
139
+ if(_WANTS_ATTN_QAT_INFER)
140
+ if(CUDAToolkit_VERSION VERSION_LESS 12.8)
141
+ message(WARNING
142
+ "attn_qat_infer kernels require CUDA Toolkit 12.8+. "
143
+ "Skipping because CUDAToolkit_VERSION=${CUDAToolkit_VERSION}.")
144
+ else()
145
+ set(ENABLE_ATTN_QAT_INFER ON)
146
+ endif()
147
+ endif()
148
+
149
+ if(ENABLE_ATTN_QAT_INFER)
150
+ message(STATUS "attn_qat_infer kernels: ENABLED")
151
+ else()
152
+ message(STATUS
153
+ "attn_qat_infer kernels: DISABLED "
154
+ "(requires CUDA 12.8+ and Blackwell sm_120a)")
155
+ endif()
156
+ endif()
157
+
158
+ # Always try to build the extension if CUDA is available, but conditionally add sources/flags
159
+ set(BUILD_CXX_KERNELS ON)
160
+
161
+ # ---------------------------------------------------------------------------
162
+ # Per-arch split for the Blackwell FP4 (attn_qat_infer) build
163
+ # ---------------------------------------------------------------------------
164
+ # The FP4 kernels are sm_120a-only (they emit `cvt.e2m1x2` etc.), while the main
165
+ # extension (Hopper-only TK + generic turbodiffusion) targets the full arch list.
166
+ # find_package(Torch) injects ONE global -gencode list into CMAKE_CUDA_FLAGS that
167
+ # forces every target onto every arch, so the FP4 sources also get the sm_90a pass
168
+ # and ptxas rejects their Blackwell instructions. Strip that global list and drive
169
+ # arch per target via CUDA_ARCHITECTURES instead (the fp4* targets pin 120a below;
170
+ # the main extension gets the full list). Only do this for the FP4 build with an
171
+ # explicit arch list, so the cu126 / local autodetect paths stay untouched.
172
+ if(ENABLE_ATTN_QAT_INFER AND TORCH_CUDA_ARCH_LIST)
173
+ message(STATUS "[per-arch] CMAKE_CUDA_FLAGS before strip: ${CMAKE_CUDA_FLAGS}")
174
+ string(REGEX REPLACE "-gencode[ =]+arch=[^ ]+" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
175
+ string(REGEX REPLACE " +" " " CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
176
+ message(STATUS "[per-arch] CMAKE_CUDA_FLAGS after strip: ${CMAKE_CUDA_FLAGS}")
177
+ # Convert TORCH_CUDA_ARCH_LIST ("9.0a;12.0a") to CMake form ("90a;120a").
178
+ set(FASTVIDEO_MAIN_CUDA_ARCHS "${TORCH_CUDA_ARCH_LIST}")
179
+ string(REPLACE "sm_" "" FASTVIDEO_MAIN_CUDA_ARCHS "${FASTVIDEO_MAIN_CUDA_ARCHS}")
180
+ string(REPLACE "." "" FASTVIDEO_MAIN_CUDA_ARCHS "${FASTVIDEO_MAIN_CUDA_ARCHS}")
181
+ message(STATUS "[per-arch] main extension archs=${FASTVIDEO_MAIN_CUDA_ARCHS}, fp4* archs=120a")
182
+ endif()
183
+
184
+ # Compiler flags
185
+ set(CUDA_FLAGS
186
+ "-DNDEBUG"
187
+ "-O3"
188
+ "-std=c++20"
189
+ "--use_fast_math"
190
+ "--expt-extended-lambda"
191
+ "--expt-relaxed-constexpr"
192
+ "-Xcompiler=-fno-strict-aliasing"
193
+ "-Xcompiler=-fPIC"
194
+ "-DTORCH_COMPILE"
195
+ "-Xnvlink=--verbose"
196
+ "-Xptxas=--verbose"
197
+ "-Xptxas=--warn-on-spills"
198
+ )
199
+
200
+ # If TK is enabled, ensure we target Hopper. This is required even on GPU-less builders (CI).
201
+ if(ENABLE_TK_KERNELS)
202
+ if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES OR CMAKE_CUDA_ARCHITECTURES STREQUAL "")
203
+ set(CMAKE_CUDA_ARCHITECTURES "90a" CACHE STRING "CUDA architectures" FORCE)
204
+ endif()
205
+ list(APPEND CUDA_FLAGS "-DKITTENS_HOPPER")
206
+ message(STATUS "CMAKE_CUDA_ARCHITECTURES: ${CMAKE_CUDA_ARCHITECTURES}")
207
+ endif()
208
+
209
+ if(BUILD_CXX_KERNELS)
210
+ # Source files
211
+ set(EXTENSION_SOURCES
212
+ csrc/common_extension.cpp
213
+ csrc/turbodiffusion/gemm/gemm.cu
214
+ csrc/turbodiffusion/norm/rmsnorm.cu
215
+ csrc/turbodiffusion/norm/layernorm.cu
216
+ csrc/turbodiffusion/quant/quant.cu
217
+ )
218
+
219
+ # Conditionally add TK kernels
220
+ if(ENABLE_TK_KERNELS)
221
+ list(APPEND EXTENSION_SOURCES
222
+ csrc/attention/st_attn_h100.cu
223
+ csrc/attention/block_sparse_h100.cu
224
+ )
225
+ endif()
226
+
227
+ # Combined FastVideo Extension
228
+ # Using name 'fastvideo_kernel_ops' to distinguish from the python package namespace
229
+ Python_add_library(fastvideo_kernel_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI
230
+ ${EXTENSION_SOURCES}
231
+ )
232
+
233
+ # When the per-arch split is active (FP4 build), torch's global gencode was
234
+ # stripped above, so set this target's arch explicitly. TK is guarded down to
235
+ # sm_90a in source; the turbodiffusion kernels are generic, so the main
236
+ # extension targets the full list. (No-FP4 builds keep torch's global gencode.)
237
+ if(ENABLE_ATTN_QAT_INFER AND FASTVIDEO_MAIN_CUDA_ARCHS)
238
+ set_target_properties(fastvideo_kernel_ops PROPERTIES
239
+ CUDA_ARCHITECTURES "${FASTVIDEO_MAIN_CUDA_ARCHS}")
240
+ endif()
241
+
242
+ # Build compile definitions list
243
+ set(COMPILE_DEFS TORCH_EXTENSION_NAME=fastvideo_kernel_ops)
244
+ if(ENABLE_TK_KERNELS)
245
+ list(APPEND COMPILE_DEFS TK_COMPILE_ST_ATTN TK_COMPILE_BLOCK_SPARSE)
246
+ endif()
247
+
248
+ target_compile_definitions(fastvideo_kernel_ops PRIVATE ${COMPILE_DEFS})
249
+
250
+ target_compile_options(fastvideo_kernel_ops PRIVATE
251
+ $<$<COMPILE_LANGUAGE:CUDA>:${CUDA_FLAGS}>
252
+ )
253
+
254
+ # Link against Torch libraries to avoid undefined symbols at import time
255
+ # (e.g., torch::autograd vtables) when loading the extension module.
256
+ target_link_libraries(fastvideo_kernel_ops PRIVATE ${TORCH_LIBRARIES})
257
+
258
+ # Also link against libtorch_python to satisfy Python-binding symbols
259
+ # (e.g., torch::PyWarningHandler) required by torch/extension.h.
260
+ execute_process(
261
+ COMMAND "${Python_EXECUTABLE}" -c "import torch; from pathlib import Path; p=Path(torch.__file__).parent/'lib'; m=sorted(p.glob('libtorch_python*')); print(str(m[0]) if m else '')"
262
+ OUTPUT_VARIABLE TORCH_PYTHON_LIBRARY_PATH
263
+ OUTPUT_STRIP_TRAILING_WHITESPACE
264
+ ERROR_QUIET
265
+ )
266
+ if(TORCH_PYTHON_LIBRARY_PATH)
267
+ message(STATUS "TORCH_PYTHON_LIBRARY_PATH: ${TORCH_PYTHON_LIBRARY_PATH}")
268
+ target_link_libraries(fastvideo_kernel_ops PRIVATE "${TORCH_PYTHON_LIBRARY_PATH}")
269
+ else()
270
+ message(WARNING "Could not locate libtorch_python; fastvideo_kernel_ops may fail to import.")
271
+ endif()
272
+
273
+ # Link CUDA runtime + driver explicitly (fixes missing symbols like cuGetErrorString at import time)
274
+ if(NOT GPU_BACKEND STREQUAL "ROCM")
275
+ target_link_libraries(fastvideo_kernel_ops PRIVATE CUDA::cudart CUDA::cuda_driver)
276
+ endif()
277
+
278
+ # We install it to fastvideo_kernel/_C so we can load it to register the ops
279
+ install(TARGETS fastvideo_kernel_ops LIBRARY DESTINATION fastvideo_kernel/_C)
280
+ endif()
281
+
282
+ if(ENABLE_ATTN_QAT_INFER)
283
+ set(ATTN_QAT_INFER_DIR ${CMAKE_SOURCE_DIR}/attn_qat_infer)
284
+ set(ATTN_QAT_INFER_INCLUDE_DIRS
285
+ ${ATTN_QAT_INFER_DIR}
286
+ ${CMAKE_SOURCE_DIR}/include/cutlass/include
287
+ ${CMAKE_SOURCE_DIR}/include/cutlass/tools/util/include
288
+ ${TORCH_INCLUDE_DIRS}
289
+ )
290
+ set(ATTN_QAT_INFER_CUDA_FLAGS
291
+ "-O3"
292
+ "-std=c++17"
293
+ "-U__CUDA_NO_HALF_OPERATORS__"
294
+ "-U__CUDA_NO_HALF_CONVERSIONS__"
295
+ "-U__CUDA_NO_BFLOAT16_OPERATORS__"
296
+ "-U__CUDA_NO_BFLOAT16_CONVERSIONS__"
297
+ "-U__CUDA_NO_BFLOAT162_OPERATORS__"
298
+ "-U__CUDA_NO_BFLOAT162_CONVERSIONS__"
299
+ "--expt-relaxed-constexpr"
300
+ "--expt-extended-lambda"
301
+ "--use_fast_math"
302
+ "--ptxas-options=--verbose,--warn-on-local-memory-usage"
303
+ "-lineinfo"
304
+ "-DCUTLASS_DEBUG_TRACE_LEVEL=0"
305
+ "-DNDEBUG"
306
+ "-DQBLKSIZE=128"
307
+ "-DKBLKSIZE=128"
308
+ "-DCTA256"
309
+ "-DDQINRMEM"
310
+ )
311
+
312
+ Python_add_library(fp4attn_cuda MODULE WITH_SOABI
313
+ attn_qat_infer/blackwell/api.cu
314
+ )
315
+ target_include_directories(fp4attn_cuda PRIVATE ${ATTN_QAT_INFER_INCLUDE_DIRS})
316
+ target_compile_definitions(fp4attn_cuda PRIVATE TORCH_EXTENSION_NAME=fp4attn_cuda)
317
+ target_compile_options(fp4attn_cuda PRIVATE
318
+ $<$<COMPILE_LANGUAGE:CXX>:-O3 -std=c++17>
319
+ $<$<COMPILE_LANGUAGE:CUDA>:${ATTN_QAT_INFER_CUDA_FLAGS}>
320
+ )
321
+ set_target_properties(fp4attn_cuda PROPERTIES
322
+ CUDA_ARCHITECTURES "120a"
323
+ CXX_STANDARD 17
324
+ CUDA_STANDARD 17
325
+ )
326
+ target_link_libraries(fp4attn_cuda PRIVATE ${TORCH_LIBRARIES} CUDA::cudart CUDA::cuda_driver)
327
+
328
+ Python_add_library(fp4quant_cuda MODULE WITH_SOABI
329
+ attn_qat_infer/quantization/fp4_quantization_4d.cu
330
+ )
331
+ target_include_directories(fp4quant_cuda PRIVATE ${ATTN_QAT_INFER_INCLUDE_DIRS})
332
+ target_compile_definitions(fp4quant_cuda PRIVATE TORCH_EXTENSION_NAME=fp4quant_cuda)
333
+ target_compile_options(fp4quant_cuda PRIVATE
334
+ $<$<COMPILE_LANGUAGE:CXX>:-O3 -std=c++17>
335
+ $<$<COMPILE_LANGUAGE:CUDA>:${ATTN_QAT_INFER_CUDA_FLAGS}>
336
+ )
337
+ set_target_properties(fp4quant_cuda PROPERTIES
338
+ CUDA_ARCHITECTURES "120a"
339
+ CXX_STANDARD 17
340
+ CUDA_STANDARD 17
341
+ )
342
+ target_link_libraries(fp4quant_cuda PRIVATE ${TORCH_LIBRARIES} CUDA::cudart CUDA::cuda_driver)
343
+
344
+ if(TORCH_PYTHON_LIBRARY_PATH)
345
+ target_link_libraries(fp4attn_cuda PRIVATE "${TORCH_PYTHON_LIBRARY_PATH}")
346
+ target_link_libraries(fp4quant_cuda PRIVATE "${TORCH_PYTHON_LIBRARY_PATH}")
347
+ endif()
348
+
349
+ install(TARGETS fp4attn_cuda LIBRARY DESTINATION .)
350
+ install(TARGETS fp4quant_cuda LIBRARY DESTINATION .)
351
+ endif()
352
+
@@ -2,5 +2,6 @@ include LICENSE
2
2
  include README.md
3
3
  include pyproject.toml
4
4
  recursive-include python/fastvideo_kernel *.py
5
+ recursive-include attn_qat_infer *.py *.cu *.cuh *.cpp *.h
5
6
  recursive-include csrc *.cu *.cuh *.cpp *.h
6
7
  recursive-include include/tk *.cu *.cuh *.cpp *.h *.src
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: fastvideo-kernel
3
- Version: 0.2.6
3
+ Version: 0.3.0
4
4
  Summary: Unified CUDA kernels for FastVideo
5
5
  Author-Email: Hao AI Lab <contact@haoailab.com>
6
6
  License: Apache License
@@ -195,7 +195,7 @@ Classifier: Environment :: GPU :: NVIDIA CUDA
195
195
  Project-URL: Homepage, https://github.com/hao-ai-lab/FastVideo
196
196
  Requires-Python: >=3.10
197
197
  Requires-Dist: torch>=2.5.0
198
- Requires-Dist: triton>=2.0.0
198
+ Requires-Dist: triton>=2.0.0; sys_platform == "linux"
199
199
  Description-Content-Type: text/markdown
200
200
 
201
201
  # FastVideo Kernel
@@ -207,6 +207,13 @@ CUDA kernels for FastVideo video generation.
207
207
  ### Standard Installation (Local Development)
208
208
  This will automatically detect your GPU architecture. If an NVIDIA Hopper (H100/sm_90a) GPU is detected, ThunderKittens kernels will be enabled. Otherwise, they will be skipped, and the package will use Triton fallbacks at runtime.
209
209
 
210
+ Before installation, set CUDA toolchain paths:
211
+
212
+ ```bash
213
+ export CUDA_HOME=/usr/local/cuda
214
+ export CUDACXX=$CUDA_HOME/bin/nvcc
215
+ ```
216
+
210
217
  ```bash
211
218
  git submodule update --init --recursive
212
219
  cd fastvideo-kernel
@@ -221,6 +228,29 @@ cd fastvideo-kernel
221
228
  ./build.sh --rocm
222
229
  ```
223
230
 
231
+ ### Optional: FA4 CuTe block-sparse backend (VSA-256 fastpath)
232
+
233
+ The VSA-256 fastpath (tile volume 256, on NVIDIA Blackwell / sm_100) routes to the
234
+ FlashAttention-4 CuTe-DSL block-sparse kernel exposed as `flash_attn.cute`. This is
235
+ an **optional** dependency: it is imported lazily, and `video_sparse_attn`
236
+ transparently falls back to the Triton backend when it is absent (so the package is
237
+ fully usable without it).
238
+
239
+ The symbols the fastpath needs (`flash_attn.cute.block_sparsity.BlockSparseTensorsTorch`,
240
+ `flash_attn.cute.interface._flash_attn_fwd`) are provided upstream by
241
+ [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention). Pin to
242
+ commit `c19cd20e`: the wrapper targets that revision's `_flash_attn_fwd` signature
243
+ (`m_block_size` / `n_block_size`); later upstream revisions reshaped it into a
244
+ `tile_mn` tuple and are not drop-in compatible.
245
+
246
+ ```bash
247
+ pip install "nvidia-cutlass-dsl>=4.5.0" torchvision
248
+ pip install "git+https://github.com/Dao-AILab/flash-attention.git@c19cd20e#subdirectory=flash_attn/cute"
249
+ ```
250
+
251
+ The CuTe kernel JIT-compiles on first use. Verified on Blackwell (sm_100) against
252
+ `tests/test_vsa256_forward*.py`.
253
+
224
254
  ## Usage
225
255
 
226
256
  ### Sliding Tile Attention (STA) & Video Sparse Attention (VSA)
@@ -262,6 +292,8 @@ This package also includes kernels from [TurboDiffusion](https://github.com/thu-
262
292
  - Any CUDA GPU for Triton-based fallbacks.
263
293
  - **Build**:
264
294
  - CUDA Toolkit 12.3+
295
+ - `CUDA_HOME` must be set (for example, `/usr/local/cuda`)
296
+ - `CUDACXX` must be set (for example, `$CUDA_HOME/bin/nvcc`)
265
297
  - C++20 compatible compiler (GCC 10+, Clang 11+)
266
298
 
267
299
  ## Acknowledgement
@@ -7,6 +7,13 @@ CUDA kernels for FastVideo video generation.
7
7
  ### Standard Installation (Local Development)
8
8
  This will automatically detect your GPU architecture. If an NVIDIA Hopper (H100/sm_90a) GPU is detected, ThunderKittens kernels will be enabled. Otherwise, they will be skipped, and the package will use Triton fallbacks at runtime.
9
9
 
10
+ Before installation, set CUDA toolchain paths:
11
+
12
+ ```bash
13
+ export CUDA_HOME=/usr/local/cuda
14
+ export CUDACXX=$CUDA_HOME/bin/nvcc
15
+ ```
16
+
10
17
  ```bash
11
18
  git submodule update --init --recursive
12
19
  cd fastvideo-kernel
@@ -21,6 +28,29 @@ cd fastvideo-kernel
21
28
  ./build.sh --rocm
22
29
  ```
23
30
 
31
+ ### Optional: FA4 CuTe block-sparse backend (VSA-256 fastpath)
32
+
33
+ The VSA-256 fastpath (tile volume 256, on NVIDIA Blackwell / sm_100) routes to the
34
+ FlashAttention-4 CuTe-DSL block-sparse kernel exposed as `flash_attn.cute`. This is
35
+ an **optional** dependency: it is imported lazily, and `video_sparse_attn`
36
+ transparently falls back to the Triton backend when it is absent (so the package is
37
+ fully usable without it).
38
+
39
+ The symbols the fastpath needs (`flash_attn.cute.block_sparsity.BlockSparseTensorsTorch`,
40
+ `flash_attn.cute.interface._flash_attn_fwd`) are provided upstream by
41
+ [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention). Pin to
42
+ commit `c19cd20e`: the wrapper targets that revision's `_flash_attn_fwd` signature
43
+ (`m_block_size` / `n_block_size`); later upstream revisions reshaped it into a
44
+ `tile_mn` tuple and are not drop-in compatible.
45
+
46
+ ```bash
47
+ pip install "nvidia-cutlass-dsl>=4.5.0" torchvision
48
+ pip install "git+https://github.com/Dao-AILab/flash-attention.git@c19cd20e#subdirectory=flash_attn/cute"
49
+ ```
50
+
51
+ The CuTe kernel JIT-compiles on first use. Verified on Blackwell (sm_100) against
52
+ `tests/test_vsa256_forward*.py`.
53
+
24
54
  ## Usage
25
55
 
26
56
  ### Sliding Tile Attention (STA) & Video Sparse Attention (VSA)
@@ -62,6 +92,8 @@ This package also includes kernels from [TurboDiffusion](https://github.com/thu-
62
92
  - Any CUDA GPU for Triton-based fallbacks.
63
93
  - **Build**:
64
94
  - CUDA Toolkit 12.3+
95
+ - `CUDA_HOME` must be set (for example, `/usr/local/cuda`)
96
+ - `CUDACXX` must be set (for example, `$CUDA_HOME/bin/nvcc`)
65
97
  - C++20 compatible compiler (GCC 10+, Clang 11+)
66
98
 
67
99
  ## Acknowledgement
@@ -0,0 +1,16 @@
1
+ """
2
+ Copyright (c) 2025 by SageAttention team.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+ from .api import sageattn_blackwell
@@ -0,0 +1,189 @@
1
+ # Modified from the original SageATtention3 code
2
+ """
3
+ Copyright (c) 2025 by SageAttention team.
4
+
5
+ Licensed under the Apache License, Version 2.0 (the "License");
6
+ you may not use this file except in compliance with the License.
7
+ You may obtain a copy of the License at
8
+
9
+ http://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ Unless required by applicable law or agreed to in writing, software
12
+ distributed under the License is distributed on an "AS IS" BASIS,
13
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ See the License for the specific language governing permissions and
15
+ limitations under the License.
16
+ """
17
+ import torch
18
+ import triton
19
+ import triton.language as tl
20
+ import torch.nn.functional as F
21
+ from typing import Tuple
22
+ from torch.nn.functional import scaled_dot_product_attention as sdpa
23
+ import fp4attn_cuda
24
+ import fp4quant_cuda
25
+
26
+ # Centralized block size configuration for sageattn_blackwell kernels
27
+ # These should match the values in fastvideo/attention/backends/sageattn/blackwell/block_config.h
28
+ BLOCK_M = 128 # Block size for M dimension (query sequence length)
29
+ BLOCK_N = 128 # Block size for N dimension (key/value sequence length)
30
+
31
+
32
+ @triton.jit
33
+ def group_mean_kernel(
34
+ q_ptr,
35
+ q_out_ptr,
36
+ qm_out_ptr,
37
+ B, H, L, D: tl.constexpr,
38
+ stride_qb, stride_qh, stride_ql, stride_qd,
39
+ stride_qmb, stride_qmh, stride_qml, stride_qmd,
40
+ GROUP_SIZE: tl.constexpr
41
+ ):
42
+ pid_b = tl.program_id(0)
43
+ pid_h = tl.program_id(1)
44
+ pid_group = tl.program_id(2)
45
+
46
+ group_start = pid_group * GROUP_SIZE
47
+ offsets = group_start + tl.arange(0, GROUP_SIZE)
48
+
49
+ q_offsets = pid_b * stride_qb + pid_h * stride_qh + offsets[:, None] * stride_ql + tl.arange(0, D)[None, :] * stride_qd
50
+ q_group = tl.load(q_ptr + q_offsets)
51
+
52
+ qm_group = tl.sum(q_group, axis=0) / GROUP_SIZE
53
+
54
+ q_group = q_group - qm_group
55
+ tl.store(q_out_ptr + q_offsets, q_group)
56
+
57
+ qm_offset = pid_b * stride_qmb + pid_h * stride_qmh + pid_group * stride_qml + tl.arange(0, D) * stride_qmd
58
+ tl.store(qm_out_ptr + qm_offset, qm_group)
59
+
60
+
61
+ def triton_group_mean(q: torch.Tensor):
62
+ B, H, L, D = q.shape
63
+ GROUP_SIZE = BLOCK_M
64
+ num_groups = L // GROUP_SIZE
65
+
66
+ q_out = torch.empty_like(q) # [B, H, L, D]
67
+ qm = torch.empty(B, H, num_groups, D, device=q.device, dtype=q.dtype)
68
+
69
+ grid = (B, H, num_groups)
70
+
71
+ group_mean_kernel[grid](
72
+ q, q_out, qm,
73
+ B, H, L, D,
74
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3),
75
+ qm.stride(0), qm.stride(1), qm.stride(2), qm.stride(3),
76
+ GROUP_SIZE=GROUP_SIZE
77
+ )
78
+ return q_out, qm
79
+
80
+
81
+ def preprocess_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, per_block_mean: bool = True, enable_smoothing_q: bool = False, enable_smoothing_k: bool = False):
82
+
83
+ def pad_to_block_size(x):
84
+ L = x.size(2)
85
+ pad_len = (BLOCK_M - L % BLOCK_M) % BLOCK_M
86
+ if pad_len == 0:
87
+ return x.contiguous()
88
+ return F.pad(x, (0, 0, 0, pad_len), value=0).contiguous()
89
+
90
+ if enable_smoothing_k:
91
+ k -= k.mean(dim=-2, keepdim=True)
92
+ q, k, v = map(lambda x: pad_to_block_size(x), [q, k, v])
93
+ if per_block_mean and enable_smoothing_q:
94
+ q, qm = triton_group_mean(q)
95
+ elif enable_smoothing_q:
96
+ qm = q.mean(dim=-2, keepdim=True)
97
+ q = q - qm
98
+ if enable_smoothing_q:
99
+ delta_s = torch.matmul(qm, k.transpose(-2, -1)).to(torch.float32).contiguous()
100
+ else: # used to disable q smoothing
101
+ B, H, L, D = q.shape
102
+ delta_s = torch.zeros((B, H, L // BLOCK_M, k.shape[2]), device=q.device, dtype=torch.float32)
103
+
104
+ return q, k, v, delta_s
105
+
106
+ def scale_and_quant_fp4(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
107
+ assert x.ndim == 4
108
+ B, H, N, D = x.shape
109
+ packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8)
110
+ fp8_scale = torch.empty((B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn)
111
+ fp4quant_cuda.scaled_fp4_quant(x, packed_fp4, fp8_scale, 1)
112
+ return packed_fp4, fp8_scale
113
+
114
+ def scale_and_quant_fp4_permute(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
115
+ assert x.ndim == 4
116
+ B, H, N, D = x.shape
117
+ packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8)
118
+ fp8_scale = torch.empty((B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn)
119
+ fp4quant_cuda.scaled_fp4_quant_permute(x, packed_fp4, fp8_scale, 1)
120
+ return packed_fp4, fp8_scale
121
+
122
+ def scale_and_quant_fp4_transpose(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
123
+ assert x.ndim == 4
124
+ B, H, N, D = x.shape
125
+ packed_fp4 = torch.empty((B, H, D, N // 2), device=x.device, dtype=torch.uint8)
126
+ fp8_scale = torch.empty((B, H, D, N // 16), device=x.device, dtype=torch.float8_e4m3fn)
127
+ fp4quant_cuda.scaled_fp4_quant_trans(x, packed_fp4, fp8_scale, 1)
128
+ return packed_fp4, fp8_scale
129
+
130
+ def blockscaled_fp4_attn(qlist: Tuple,
131
+ klist: Tuple,
132
+ vlist: Tuple,
133
+ delta_s: torch.Tensor,
134
+ KL: int,
135
+ is_causal: bool = False,
136
+ per_block_mean: bool = True,
137
+ is_bf16: bool = True,
138
+ single_level_p_quant: bool = False,
139
+ sm_scale: float | None = None
140
+ ):
141
+ softmax_scale = sm_scale if sm_scale is not None else (qlist[0].shape[-1] * 2) ** (-0.5)
142
+ return fp4attn_cuda.fwd(qlist[0], klist[0], vlist[0], qlist[1], klist[1], vlist[1], delta_s, KL, None, softmax_scale, is_causal, per_block_mean, is_bf16, single_level_p_quant)
143
+
144
+
145
+ def sageattn_blackwell(q, k, v, attn_mask = None, is_causal = False, per_block_mean = True, single_level_p_quant = True, sm_scale: float | None = None, **kwargs):
146
+ """
147
+ SageAttention3 Blackwell kernel for FP4 attention.
148
+
149
+ Args:
150
+ q: Query tensor [B, H, L, D]
151
+ k: Key tensor [B, H, L, D]
152
+ v: Value tensor [B, H, L, D]
153
+ attn_mask: Attention mask (not used)
154
+ is_causal: Whether to use causal masking
155
+ per_block_mean: Whether to use per-block mean for Q smoothing
156
+ single_level_p_quant: If True, use single-level quantization: s_P2, P̂_2 = φ(P̃) directly
157
+ (standard per-block FP4 quantization like V, no s_P1).
158
+ If False (default), use two-level quantization:
159
+ s_P1 = rowmax(P̃)/(448×6), then s_P2, P̂_2 = φ(P̃/s_P1).
160
+ sm_scale: Softmax scale to pass through to the CUDA kernel. If None,
161
+ defaults to the kernel's 1/sqrt(D) scale.
162
+ **kwargs: Additional arguments (ignored)
163
+
164
+ Returns:
165
+ Output tensor [B, H, L, D]
166
+ """
167
+ if q.size(-1) >= 256:
168
+ print(f"Unsupported Headdim {q.size(-1)}")
169
+ return sdpa(q, k, v, is_causal = is_causal)
170
+ QL = q.size(2)
171
+ KL = k.size(2)
172
+ is_bf16 = q.dtype == torch.bfloat16
173
+ q, k, v, delta_s = preprocess_qkv(q, k, v, per_block_mean)
174
+ qlist_from_cuda = scale_and_quant_fp4(q)
175
+ klist_from_cuda = scale_and_quant_fp4_permute(k)
176
+ vlist_from_cuda = scale_and_quant_fp4_transpose(v)
177
+ o_fp4 = blockscaled_fp4_attn(
178
+ qlist_from_cuda,
179
+ klist_from_cuda,
180
+ vlist_from_cuda,
181
+ delta_s,
182
+ KL,
183
+ is_causal,
184
+ per_block_mean,
185
+ is_bf16,
186
+ single_level_p_quant,
187
+ sm_scale
188
+ )[0][:, :, :QL, :].contiguous()
189
+ return o_fp4
@@ -0,0 +1 @@
1
+ __version__ = "3.0.0.b1"