fastvideo-kernel 0.2.6__tar.gz → 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastvideo_kernel-0.3.0/CMakeLists.txt +352 -0
- {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/MANIFEST.in +1 -0
- {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/PKG-INFO +34 -2
- {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/README.md +32 -0
- fastvideo_kernel-0.3.0/attn_qat_infer/__init__.py +16 -0
- fastvideo_kernel-0.3.0/attn_qat_infer/api.py +189 -0
- fastvideo_kernel-0.3.0/attn_qat_infer/blackwell/__init__.py +1 -0
- fastvideo_kernel-0.3.0/attn_qat_infer/blackwell/api.cu +347 -0
- fastvideo_kernel-0.3.0/attn_qat_infer/blackwell/block_config.h +28 -0
- fastvideo_kernel-0.3.0/attn_qat_infer/blackwell/block_info.h +60 -0
- fastvideo_kernel-0.3.0/attn_qat_infer/blackwell/blockscaled_layout.h +149 -0
- fastvideo_kernel-0.3.0/attn_qat_infer/blackwell/cute_extension.h +327 -0
- fastvideo_kernel-0.3.0/attn_qat_infer/blackwell/epilogue_tma_ws.h +222 -0
- fastvideo_kernel-0.3.0/attn_qat_infer/blackwell/kernel_traits.h +202 -0
- fastvideo_kernel-0.3.0/attn_qat_infer/blackwell/kernel_ws.h +204 -0
- fastvideo_kernel-0.3.0/attn_qat_infer/blackwell/launch.h +114 -0
- fastvideo_kernel-0.3.0/attn_qat_infer/blackwell/mainloop_tma_ws.h +926 -0
- fastvideo_kernel-0.3.0/attn_qat_infer/blackwell/named_barrier.h +119 -0
- fastvideo_kernel-0.3.0/attn_qat_infer/blackwell/params.h +180 -0
- fastvideo_kernel-0.3.0/attn_qat_infer/blackwell/softmax_fused.h +190 -0
- fastvideo_kernel-0.3.0/attn_qat_infer/blackwell/static_switch.h +83 -0
- fastvideo_kernel-0.3.0/attn_qat_infer/blackwell/tile_scheduler.h +304 -0
- fastvideo_kernel-0.3.0/attn_qat_infer/blackwell/utils.h +408 -0
- fastvideo_kernel-0.3.0/attn_qat_infer/quantization/__init__.py +1 -0
- fastvideo_kernel-0.3.0/attn_qat_infer/quantization/bench/bench_quant_k.py +90 -0
- fastvideo_kernel-0.3.0/attn_qat_infer/quantization/bench/bench_quant_q.py +86 -0
- fastvideo_kernel-0.3.0/attn_qat_infer/quantization/bench/bench_quant_v.py +86 -0
- fastvideo_kernel-0.3.0/attn_qat_infer/quantization/bench/bench_utils.py +169 -0
- fastvideo_kernel-0.3.0/attn_qat_infer/quantization/cuda_utils.h +52 -0
- fastvideo_kernel-0.3.0/attn_qat_infer/quantization/fp4_quantization_4d.cu +639 -0
- fastvideo_kernel-0.3.0/benchmarks/bench_fused_compress_topk.py +315 -0
- fastvideo_kernel-0.3.0/build.sh +146 -0
- {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/csrc/attention/block_sparse_h100.cu +22 -3
- {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/csrc/attention/st_attn_h100.cu +15 -1
- fastvideo_kernel-0.2.6/dist/fastvideo_kernel-0.2.6-cp312-cp312-manylinux_2_34_x86_64.manylinux_2_35_x86_64.whl → fastvideo_kernel-0.3.0/dist/fastvideo_kernel-0.3.0-cp312-cp312-manylinux_2_34_x86_64.manylinux_2_35_x86_64.whl +0 -0
- {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/pyproject.toml +3 -3
- {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/python/fastvideo_kernel/__init__.py +14 -0
- fastvideo_kernel-0.3.0/python/fastvideo_kernel/block_sparse_attn.py +424 -0
- fastvideo_kernel-0.3.0/python/fastvideo_kernel/block_sparse_attn_256.py +170 -0
- fastvideo_kernel-0.3.0/python/fastvideo_kernel/block_sparse_attn_cute_fwd.py +267 -0
- fastvideo_kernel-0.3.0/python/fastvideo_kernel/block_sparse_attn_varlen.py +207 -0
- fastvideo_kernel-0.3.0/python/fastvideo_kernel/ops.py +238 -0
- fastvideo_kernel-0.3.0/python/fastvideo_kernel/triton_kernels/attn_qat_train.py +1119 -0
- fastvideo_kernel-0.3.0/python/fastvideo_kernel/triton_kernels/fused_attention.py +55 -0
- fastvideo_kernel-0.3.0/python/fastvideo_kernel/triton_kernels/fused_compress_topk.py +334 -0
- {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/python/fastvideo_kernel/triton_kernels/index.py +113 -1
- fastvideo_kernel-0.3.0/python/fastvideo_kernel/triton_kernels/nvfp4_utils.py +237 -0
- fastvideo_kernel-0.3.0/python/fastvideo_kernel/triton_kernels/quant_utils.py +80 -0
- fastvideo_kernel-0.3.0/python/fastvideo_kernel/version.py +1 -0
- {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/python/fastvideo_kernel/vmoba.py +1 -1
- fastvideo_kernel-0.3.0/tests/test_attn_qat_infer.py +148 -0
- fastvideo_kernel-0.3.0/tests/test_attn_qat_train.py +1503 -0
- fastvideo_kernel-0.3.0/tests/test_fused_compress_topk.py +401 -0
- fastvideo_kernel-0.3.0/tests/test_vsa256_forward.py +131 -0
- fastvideo_kernel-0.3.0/tests/test_vsa256_forward_cross.py +139 -0
- fastvideo_kernel-0.3.0/tests/test_vsa256_forward_vbs.py +129 -0
- fastvideo_kernel-0.3.0/tests/test_vsa256_triton.py +150 -0
- fastvideo_kernel-0.3.0/tests/test_vsa_varlen.py +434 -0
- fastvideo_kernel-0.2.6/CMakeLists.txt +0 -185
- fastvideo_kernel-0.2.6/build.sh +0 -38
- fastvideo_kernel-0.2.6/dist/fastvideo_kernel-0.2.6-cp310-cp310-manylinux_2_34_x86_64.manylinux_2_35_x86_64.whl +0 -0
- fastvideo_kernel-0.2.6/dist/fastvideo_kernel-0.2.6-cp311-cp311-manylinux_2_34_x86_64.manylinux_2_35_x86_64.whl +0 -0
- fastvideo_kernel-0.2.6/python/fastvideo_kernel/block_sparse_attn.py +0 -294
- fastvideo_kernel-0.2.6/python/fastvideo_kernel/ops.py +0 -149
- fastvideo_kernel-0.2.6/python/fastvideo_kernel/version.py +0 -1
- {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/LICENSE +0 -0
- {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/benchmarks/bench_vsa.py +0 -0
- {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/csrc/common_extension.cpp +0 -0
- {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/csrc/turbodiffusion/common/common.hpp +0 -0
- {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/csrc/turbodiffusion/common/launch.hpp +0 -0
- {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/csrc/turbodiffusion/common/load.hpp +0 -0
- {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/csrc/turbodiffusion/common/store.hpp +0 -0
- {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/csrc/turbodiffusion/gemm/gemm.cu +0 -0
- {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/csrc/turbodiffusion/gemm/kernel.hpp +0 -0
- {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/csrc/turbodiffusion/gemm/launch.hpp +0 -0
- {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/csrc/turbodiffusion/gemm/utils.hpp +0 -0
- {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/csrc/turbodiffusion/norm/layernorm.cu +0 -0
- {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/csrc/turbodiffusion/norm/layernorm.hpp +0 -0
- {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/csrc/turbodiffusion/norm/rmsnorm.cu +0 -0
- {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/csrc/turbodiffusion/norm/rmsnorm.hpp +0 -0
- {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/csrc/turbodiffusion/quant/quant.cu +0 -0
- {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/csrc/turbodiffusion/quant/quant.hpp +0 -0
- {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/python/fastvideo_kernel/triton_kernels/block_sparse_attn_triton.py +0 -0
- {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/python/fastvideo_kernel/triton_kernels/sla_triton.py +0 -0
- {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/python/fastvideo_kernel/triton_kernels/st_attn_triton.py +0 -0
- {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/python/fastvideo_kernel/turbodiffusion_ops.py +0 -0
- {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/tests/__init__.py +0 -0
- {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/tests/support_flex_sta.py +0 -0
- {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/tests/test_sta.py +0 -0
- {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/tests/test_turbodiffusion.py +0 -0
- {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/tests/test_vmoba_correctness.py +0 -0
- {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/tests/test_vsa.py +0 -0
- {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/tests/test_vsa_forward.py +0 -0
- {fastvideo_kernel-0.2.6 → fastvideo_kernel-0.3.0}/tests/utils.py +0 -0
|
@@ -0,0 +1,352 @@
|
|
|
1
|
+
cmake_minimum_required(VERSION 3.26 FATAL_ERROR)
|
|
2
|
+
project(fastvideo-kernel LANGUAGES CXX)
|
|
3
|
+
|
|
4
|
+
# Prefer environment variable (used by CI or uv pip install git+repo_addr) if CMake var is not explicitly set.
|
|
5
|
+
if(NOT DEFINED GPU_BACKEND AND DEFINED ENV{GPU_BACKEND})
|
|
6
|
+
set(GPU_BACKEND "$ENV{GPU_BACKEND}")
|
|
7
|
+
endif()
|
|
8
|
+
|
|
9
|
+
if(GPU_BACKEND STREQUAL "ROCM")
|
|
10
|
+
enable_language(HIP)
|
|
11
|
+
else()
|
|
12
|
+
enable_language(CUDA)
|
|
13
|
+
# Ensure CUDA toolkit targets (CUDA::cudart, CUDA::cuda_driver, etc.) are available.
|
|
14
|
+
find_package(CUDAToolkit REQUIRED)
|
|
15
|
+
endif()
|
|
16
|
+
|
|
17
|
+
# Import common utils if needed, but we keep it simple for now
|
|
18
|
+
|
|
19
|
+
# Find Python and Torch
|
|
20
|
+
find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)
|
|
21
|
+
|
|
22
|
+
# Robustly find Torch include paths using Python
|
|
23
|
+
execute_process(
|
|
24
|
+
COMMAND "${Python_EXECUTABLE}" -c "import torch; from torch.utils.cpp_extension import include_paths; print(';'.join(include_paths()))"
|
|
25
|
+
OUTPUT_VARIABLE TORCH_INCLUDE_PATHS
|
|
26
|
+
OUTPUT_STRIP_TRAILING_WHITESPACE
|
|
27
|
+
)
|
|
28
|
+
list(APPEND TORCH_INCLUDE_DIRS ${TORCH_INCLUDE_PATHS})
|
|
29
|
+
|
|
30
|
+
# Find Torch package (still useful for libraries)
|
|
31
|
+
find_package(Torch REQUIRED)
|
|
32
|
+
|
|
33
|
+
# Include directories
|
|
34
|
+
include_directories(
|
|
35
|
+
${CMAKE_SOURCE_DIR}/include
|
|
36
|
+
${CMAKE_SOURCE_DIR}/include/cutlass/include
|
|
37
|
+
${CMAKE_SOURCE_DIR}/include/tk/include
|
|
38
|
+
${CMAKE_SOURCE_DIR}/include/tk/prototype
|
|
39
|
+
${CMAKE_SOURCE_DIR}/csrc
|
|
40
|
+
${CMAKE_SOURCE_DIR}/csrc/turbodiffusion
|
|
41
|
+
${TORCH_INCLUDE_DIRS}
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
# ---------------------------
|
|
45
|
+
# ThunderKittens (TK) toggles
|
|
46
|
+
# ---------------------------
|
|
47
|
+
# AUTO: enable TK only when we can confidently target Hopper (sm_90a).
|
|
48
|
+
# ON: force-enable TK kernels (intended for release wheels/images; does NOT require a GPU).
|
|
49
|
+
# OFF: never build TK kernels.
|
|
50
|
+
set(FASTVIDEO_KERNEL_BUILD_TK "AUTO" CACHE STRING "Build ThunderKittens kernels: AUTO/ON/OFF")
|
|
51
|
+
set_property(CACHE FASTVIDEO_KERNEL_BUILD_TK PROPERTY STRINGS AUTO ON OFF)
|
|
52
|
+
|
|
53
|
+
set(_FASTVIDEO_KERNEL_BUILD_ATTN_QAT_INFER_DEFAULT "AUTO")
|
|
54
|
+
if(DEFINED FASTVIDEO_KERNEL_BUILD_MODIFIED_SAGE3 AND NOT DEFINED CACHE{FASTVIDEO_KERNEL_BUILD_ATTN_QAT_INFER})
|
|
55
|
+
set(_FASTVIDEO_KERNEL_BUILD_ATTN_QAT_INFER_DEFAULT "${FASTVIDEO_KERNEL_BUILD_MODIFIED_SAGE3}")
|
|
56
|
+
endif()
|
|
57
|
+
|
|
58
|
+
set(FASTVIDEO_KERNEL_BUILD_ATTN_QAT_INFER "${_FASTVIDEO_KERNEL_BUILD_ATTN_QAT_INFER_DEFAULT}" CACHE STRING
|
|
59
|
+
"Build attn_qat_infer Blackwell inference kernels: AUTO/ON/OFF")
|
|
60
|
+
set_property(CACHE FASTVIDEO_KERNEL_BUILD_ATTN_QAT_INFER PROPERTY STRINGS AUTO ON OFF)
|
|
61
|
+
|
|
62
|
+
if(DEFINED FASTVIDEO_KERNEL_BUILD_MODIFIED_SAGE3)
|
|
63
|
+
message(DEPRECATION
|
|
64
|
+
"FASTVIDEO_KERNEL_BUILD_MODIFIED_SAGE3 is deprecated. "
|
|
65
|
+
"Use FASTVIDEO_KERNEL_BUILD_ATTN_QAT_INFER instead.")
|
|
66
|
+
endif()
|
|
67
|
+
|
|
68
|
+
# Prefer environment variable (used by CI) if CMake var is not explicitly set.
|
|
69
|
+
if(NOT DEFINED TORCH_CUDA_ARCH_LIST AND DEFINED ENV{TORCH_CUDA_ARCH_LIST})
|
|
70
|
+
set(TORCH_CUDA_ARCH_LIST "$ENV{TORCH_CUDA_ARCH_LIST}")
|
|
71
|
+
endif()
|
|
72
|
+
|
|
73
|
+
message(STATUS "TORCH_CUDA_ARCH_LIST (cmake/env): ${TORCH_CUDA_ARCH_LIST}")
|
|
74
|
+
message(STATUS "FASTVIDEO_KERNEL_BUILD_TK: ${FASTVIDEO_KERNEL_BUILD_TK}")
|
|
75
|
+
message(STATUS "FASTVIDEO_KERNEL_BUILD_ATTN_QAT_INFER: ${FASTVIDEO_KERNEL_BUILD_ATTN_QAT_INFER}")
|
|
76
|
+
|
|
77
|
+
set(ENABLE_TK_KERNELS OFF)
|
|
78
|
+
if(FASTVIDEO_KERNEL_BUILD_TK STREQUAL "ON")
|
|
79
|
+
set(ENABLE_TK_KERNELS ON)
|
|
80
|
+
elseif(FASTVIDEO_KERNEL_BUILD_TK STREQUAL "OFF")
|
|
81
|
+
set(ENABLE_TK_KERNELS OFF)
|
|
82
|
+
else()
|
|
83
|
+
# AUTO: detect Hopper if possible.
|
|
84
|
+
if(TORCH_CUDA_ARCH_LIST)
|
|
85
|
+
# Accept common spellings: 9.0a, 90a, sm_90a.
|
|
86
|
+
string(REGEX MATCH "(^|[; ,])((9\\.0a)|(90a)|(sm_90a))([; ,]|$)" _HAS_90A "${TORCH_CUDA_ARCH_LIST}")
|
|
87
|
+
if(_HAS_90A)
|
|
88
|
+
set(ENABLE_TK_KERNELS ON)
|
|
89
|
+
endif()
|
|
90
|
+
else()
|
|
91
|
+
# Best-effort local detection (works when a CUDA device is visible).
|
|
92
|
+
execute_process(
|
|
93
|
+
COMMAND "${Python_EXECUTABLE}" -c "import torch; import sys; \nprint('1' if (torch.cuda.is_available() and torch.version.cuda and torch.cuda.get_device_capability()[0] >= 9) else '0')"
|
|
94
|
+
OUTPUT_VARIABLE _LOCAL_HAS_HOPPER
|
|
95
|
+
OUTPUT_STRIP_TRAILING_WHITESPACE
|
|
96
|
+
ERROR_QUIET
|
|
97
|
+
)
|
|
98
|
+
if(_LOCAL_HAS_HOPPER STREQUAL "1")
|
|
99
|
+
set(ENABLE_TK_KERNELS ON)
|
|
100
|
+
endif()
|
|
101
|
+
endif()
|
|
102
|
+
endif()
|
|
103
|
+
|
|
104
|
+
if(ENABLE_TK_KERNELS)
|
|
105
|
+
message(STATUS "ThunderKittens kernels: ENABLED")
|
|
106
|
+
else()
|
|
107
|
+
message(STATUS "ThunderKittens kernels: DISABLED (will use Triton fallbacks at runtime)")
|
|
108
|
+
endif()
|
|
109
|
+
|
|
110
|
+
set(ENABLE_ATTN_QAT_INFER OFF)
|
|
111
|
+
if(GPU_BACKEND STREQUAL "ROCM")
|
|
112
|
+
message(STATUS "attn_qat_infer kernels: DISABLED (ROCm build)")
|
|
113
|
+
else()
|
|
114
|
+
set(_WANTS_ATTN_QAT_INFER OFF)
|
|
115
|
+
if(FASTVIDEO_KERNEL_BUILD_ATTN_QAT_INFER STREQUAL "ON")
|
|
116
|
+
set(_WANTS_ATTN_QAT_INFER ON)
|
|
117
|
+
elseif(FASTVIDEO_KERNEL_BUILD_ATTN_QAT_INFER STREQUAL "AUTO")
|
|
118
|
+
if(TORCH_CUDA_ARCH_LIST)
|
|
119
|
+
string(REGEX MATCH
|
|
120
|
+
"(^|[; ,])((12\\.0a)|(120a)|(sm_120a))([; ,]|$)"
|
|
121
|
+
_HAS_120A "${TORCH_CUDA_ARCH_LIST}")
|
|
122
|
+
if(_HAS_120A)
|
|
123
|
+
set(_WANTS_ATTN_QAT_INFER ON)
|
|
124
|
+
endif()
|
|
125
|
+
else()
|
|
126
|
+
execute_process(
|
|
127
|
+
COMMAND "${Python_EXECUTABLE}" -c
|
|
128
|
+
"import torch; print('1' if (torch.cuda.is_available() and torch.version.cuda and torch.cuda.get_device_capability()[0] >= 12) else '0')"
|
|
129
|
+
OUTPUT_VARIABLE _LOCAL_HAS_BLACKWELL
|
|
130
|
+
OUTPUT_STRIP_TRAILING_WHITESPACE
|
|
131
|
+
ERROR_QUIET
|
|
132
|
+
)
|
|
133
|
+
if(_LOCAL_HAS_BLACKWELL STREQUAL "1")
|
|
134
|
+
set(_WANTS_ATTN_QAT_INFER ON)
|
|
135
|
+
endif()
|
|
136
|
+
endif()
|
|
137
|
+
endif()
|
|
138
|
+
|
|
139
|
+
if(_WANTS_ATTN_QAT_INFER)
|
|
140
|
+
if(CUDAToolkit_VERSION VERSION_LESS 12.8)
|
|
141
|
+
message(WARNING
|
|
142
|
+
"attn_qat_infer kernels require CUDA Toolkit 12.8+. "
|
|
143
|
+
"Skipping because CUDAToolkit_VERSION=${CUDAToolkit_VERSION}.")
|
|
144
|
+
else()
|
|
145
|
+
set(ENABLE_ATTN_QAT_INFER ON)
|
|
146
|
+
endif()
|
|
147
|
+
endif()
|
|
148
|
+
|
|
149
|
+
if(ENABLE_ATTN_QAT_INFER)
|
|
150
|
+
message(STATUS "attn_qat_infer kernels: ENABLED")
|
|
151
|
+
else()
|
|
152
|
+
message(STATUS
|
|
153
|
+
"attn_qat_infer kernels: DISABLED "
|
|
154
|
+
"(requires CUDA 12.8+ and Blackwell sm_120a)")
|
|
155
|
+
endif()
|
|
156
|
+
endif()
|
|
157
|
+
|
|
158
|
+
# Always try to build the extension if CUDA is available, but conditionally add sources/flags
|
|
159
|
+
set(BUILD_CXX_KERNELS ON)
|
|
160
|
+
|
|
161
|
+
# ---------------------------------------------------------------------------
|
|
162
|
+
# Per-arch split for the Blackwell FP4 (attn_qat_infer) build
|
|
163
|
+
# ---------------------------------------------------------------------------
|
|
164
|
+
# The FP4 kernels are sm_120a-only (they emit `cvt.e2m1x2` etc.), while the main
|
|
165
|
+
# extension (Hopper-only TK + generic turbodiffusion) targets the full arch list.
|
|
166
|
+
# find_package(Torch) injects ONE global -gencode list into CMAKE_CUDA_FLAGS that
|
|
167
|
+
# forces every target onto every arch, so the FP4 sources also get the sm_90a pass
|
|
168
|
+
# and ptxas rejects their Blackwell instructions. Strip that global list and drive
|
|
169
|
+
# arch per target via CUDA_ARCHITECTURES instead (the fp4* targets pin 120a below;
|
|
170
|
+
# the main extension gets the full list). Only do this for the FP4 build with an
|
|
171
|
+
# explicit arch list, so the cu126 / local autodetect paths stay untouched.
|
|
172
|
+
if(ENABLE_ATTN_QAT_INFER AND TORCH_CUDA_ARCH_LIST)
|
|
173
|
+
message(STATUS "[per-arch] CMAKE_CUDA_FLAGS before strip: ${CMAKE_CUDA_FLAGS}")
|
|
174
|
+
string(REGEX REPLACE "-gencode[ =]+arch=[^ ]+" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
|
|
175
|
+
string(REGEX REPLACE " +" " " CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
|
|
176
|
+
message(STATUS "[per-arch] CMAKE_CUDA_FLAGS after strip: ${CMAKE_CUDA_FLAGS}")
|
|
177
|
+
# Convert TORCH_CUDA_ARCH_LIST ("9.0a;12.0a") to CMake form ("90a;120a").
|
|
178
|
+
set(FASTVIDEO_MAIN_CUDA_ARCHS "${TORCH_CUDA_ARCH_LIST}")
|
|
179
|
+
string(REPLACE "sm_" "" FASTVIDEO_MAIN_CUDA_ARCHS "${FASTVIDEO_MAIN_CUDA_ARCHS}")
|
|
180
|
+
string(REPLACE "." "" FASTVIDEO_MAIN_CUDA_ARCHS "${FASTVIDEO_MAIN_CUDA_ARCHS}")
|
|
181
|
+
message(STATUS "[per-arch] main extension archs=${FASTVIDEO_MAIN_CUDA_ARCHS}, fp4* archs=120a")
|
|
182
|
+
endif()
|
|
183
|
+
|
|
184
|
+
# Compiler flags
|
|
185
|
+
set(CUDA_FLAGS
|
|
186
|
+
"-DNDEBUG"
|
|
187
|
+
"-O3"
|
|
188
|
+
"-std=c++20"
|
|
189
|
+
"--use_fast_math"
|
|
190
|
+
"--expt-extended-lambda"
|
|
191
|
+
"--expt-relaxed-constexpr"
|
|
192
|
+
"-Xcompiler=-fno-strict-aliasing"
|
|
193
|
+
"-Xcompiler=-fPIC"
|
|
194
|
+
"-DTORCH_COMPILE"
|
|
195
|
+
"-Xnvlink=--verbose"
|
|
196
|
+
"-Xptxas=--verbose"
|
|
197
|
+
"-Xptxas=--warn-on-spills"
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
# If TK is enabled, ensure we target Hopper. This is required even on GPU-less builders (CI).
|
|
201
|
+
if(ENABLE_TK_KERNELS)
|
|
202
|
+
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES OR CMAKE_CUDA_ARCHITECTURES STREQUAL "")
|
|
203
|
+
set(CMAKE_CUDA_ARCHITECTURES "90a" CACHE STRING "CUDA architectures" FORCE)
|
|
204
|
+
endif()
|
|
205
|
+
list(APPEND CUDA_FLAGS "-DKITTENS_HOPPER")
|
|
206
|
+
message(STATUS "CMAKE_CUDA_ARCHITECTURES: ${CMAKE_CUDA_ARCHITECTURES}")
|
|
207
|
+
endif()
|
|
208
|
+
|
|
209
|
+
if(BUILD_CXX_KERNELS)
|
|
210
|
+
# Source files
|
|
211
|
+
set(EXTENSION_SOURCES
|
|
212
|
+
csrc/common_extension.cpp
|
|
213
|
+
csrc/turbodiffusion/gemm/gemm.cu
|
|
214
|
+
csrc/turbodiffusion/norm/rmsnorm.cu
|
|
215
|
+
csrc/turbodiffusion/norm/layernorm.cu
|
|
216
|
+
csrc/turbodiffusion/quant/quant.cu
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
# Conditionally add TK kernels
|
|
220
|
+
if(ENABLE_TK_KERNELS)
|
|
221
|
+
list(APPEND EXTENSION_SOURCES
|
|
222
|
+
csrc/attention/st_attn_h100.cu
|
|
223
|
+
csrc/attention/block_sparse_h100.cu
|
|
224
|
+
)
|
|
225
|
+
endif()
|
|
226
|
+
|
|
227
|
+
# Combined FastVideo Extension
|
|
228
|
+
# Using name 'fastvideo_kernel_ops' to distinguish from the python package namespace
|
|
229
|
+
Python_add_library(fastvideo_kernel_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI
|
|
230
|
+
${EXTENSION_SOURCES}
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
# When the per-arch split is active (FP4 build), torch's global gencode was
|
|
234
|
+
# stripped above, so set this target's arch explicitly. TK is guarded down to
|
|
235
|
+
# sm_90a in source; the turbodiffusion kernels are generic, so the main
|
|
236
|
+
# extension targets the full list. (No-FP4 builds keep torch's global gencode.)
|
|
237
|
+
if(ENABLE_ATTN_QAT_INFER AND FASTVIDEO_MAIN_CUDA_ARCHS)
|
|
238
|
+
set_target_properties(fastvideo_kernel_ops PROPERTIES
|
|
239
|
+
CUDA_ARCHITECTURES "${FASTVIDEO_MAIN_CUDA_ARCHS}")
|
|
240
|
+
endif()
|
|
241
|
+
|
|
242
|
+
# Build compile definitions list
|
|
243
|
+
set(COMPILE_DEFS TORCH_EXTENSION_NAME=fastvideo_kernel_ops)
|
|
244
|
+
if(ENABLE_TK_KERNELS)
|
|
245
|
+
list(APPEND COMPILE_DEFS TK_COMPILE_ST_ATTN TK_COMPILE_BLOCK_SPARSE)
|
|
246
|
+
endif()
|
|
247
|
+
|
|
248
|
+
target_compile_definitions(fastvideo_kernel_ops PRIVATE ${COMPILE_DEFS})
|
|
249
|
+
|
|
250
|
+
target_compile_options(fastvideo_kernel_ops PRIVATE
|
|
251
|
+
$<$<COMPILE_LANGUAGE:CUDA>:${CUDA_FLAGS}>
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
# Link against Torch libraries to avoid undefined symbols at import time
|
|
255
|
+
# (e.g., torch::autograd vtables) when loading the extension module.
|
|
256
|
+
target_link_libraries(fastvideo_kernel_ops PRIVATE ${TORCH_LIBRARIES})
|
|
257
|
+
|
|
258
|
+
# Also link against libtorch_python to satisfy Python-binding symbols
|
|
259
|
+
# (e.g., torch::PyWarningHandler) required by torch/extension.h.
|
|
260
|
+
execute_process(
|
|
261
|
+
COMMAND "${Python_EXECUTABLE}" -c "import torch; from pathlib import Path; p=Path(torch.__file__).parent/'lib'; m=sorted(p.glob('libtorch_python*')); print(str(m[0]) if m else '')"
|
|
262
|
+
OUTPUT_VARIABLE TORCH_PYTHON_LIBRARY_PATH
|
|
263
|
+
OUTPUT_STRIP_TRAILING_WHITESPACE
|
|
264
|
+
ERROR_QUIET
|
|
265
|
+
)
|
|
266
|
+
if(TORCH_PYTHON_LIBRARY_PATH)
|
|
267
|
+
message(STATUS "TORCH_PYTHON_LIBRARY_PATH: ${TORCH_PYTHON_LIBRARY_PATH}")
|
|
268
|
+
target_link_libraries(fastvideo_kernel_ops PRIVATE "${TORCH_PYTHON_LIBRARY_PATH}")
|
|
269
|
+
else()
|
|
270
|
+
message(WARNING "Could not locate libtorch_python; fastvideo_kernel_ops may fail to import.")
|
|
271
|
+
endif()
|
|
272
|
+
|
|
273
|
+
# Link CUDA runtime + driver explicitly (fixes missing symbols like cuGetErrorString at import time)
|
|
274
|
+
if(NOT GPU_BACKEND STREQUAL "ROCM")
|
|
275
|
+
target_link_libraries(fastvideo_kernel_ops PRIVATE CUDA::cudart CUDA::cuda_driver)
|
|
276
|
+
endif()
|
|
277
|
+
|
|
278
|
+
# We install it to fastvideo_kernel/_C so we can load it to register the ops
|
|
279
|
+
install(TARGETS fastvideo_kernel_ops LIBRARY DESTINATION fastvideo_kernel/_C)
|
|
280
|
+
endif()
|
|
281
|
+
|
|
282
|
+
if(ENABLE_ATTN_QAT_INFER)
|
|
283
|
+
set(ATTN_QAT_INFER_DIR ${CMAKE_SOURCE_DIR}/attn_qat_infer)
|
|
284
|
+
set(ATTN_QAT_INFER_INCLUDE_DIRS
|
|
285
|
+
${ATTN_QAT_INFER_DIR}
|
|
286
|
+
${CMAKE_SOURCE_DIR}/include/cutlass/include
|
|
287
|
+
${CMAKE_SOURCE_DIR}/include/cutlass/tools/util/include
|
|
288
|
+
${TORCH_INCLUDE_DIRS}
|
|
289
|
+
)
|
|
290
|
+
set(ATTN_QAT_INFER_CUDA_FLAGS
|
|
291
|
+
"-O3"
|
|
292
|
+
"-std=c++17"
|
|
293
|
+
"-U__CUDA_NO_HALF_OPERATORS__"
|
|
294
|
+
"-U__CUDA_NO_HALF_CONVERSIONS__"
|
|
295
|
+
"-U__CUDA_NO_BFLOAT16_OPERATORS__"
|
|
296
|
+
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__"
|
|
297
|
+
"-U__CUDA_NO_BFLOAT162_OPERATORS__"
|
|
298
|
+
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__"
|
|
299
|
+
"--expt-relaxed-constexpr"
|
|
300
|
+
"--expt-extended-lambda"
|
|
301
|
+
"--use_fast_math"
|
|
302
|
+
"--ptxas-options=--verbose,--warn-on-local-memory-usage"
|
|
303
|
+
"-lineinfo"
|
|
304
|
+
"-DCUTLASS_DEBUG_TRACE_LEVEL=0"
|
|
305
|
+
"-DNDEBUG"
|
|
306
|
+
"-DQBLKSIZE=128"
|
|
307
|
+
"-DKBLKSIZE=128"
|
|
308
|
+
"-DCTA256"
|
|
309
|
+
"-DDQINRMEM"
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
Python_add_library(fp4attn_cuda MODULE WITH_SOABI
|
|
313
|
+
attn_qat_infer/blackwell/api.cu
|
|
314
|
+
)
|
|
315
|
+
target_include_directories(fp4attn_cuda PRIVATE ${ATTN_QAT_INFER_INCLUDE_DIRS})
|
|
316
|
+
target_compile_definitions(fp4attn_cuda PRIVATE TORCH_EXTENSION_NAME=fp4attn_cuda)
|
|
317
|
+
target_compile_options(fp4attn_cuda PRIVATE
|
|
318
|
+
$<$<COMPILE_LANGUAGE:CXX>:-O3 -std=c++17>
|
|
319
|
+
$<$<COMPILE_LANGUAGE:CUDA>:${ATTN_QAT_INFER_CUDA_FLAGS}>
|
|
320
|
+
)
|
|
321
|
+
set_target_properties(fp4attn_cuda PROPERTIES
|
|
322
|
+
CUDA_ARCHITECTURES "120a"
|
|
323
|
+
CXX_STANDARD 17
|
|
324
|
+
CUDA_STANDARD 17
|
|
325
|
+
)
|
|
326
|
+
target_link_libraries(fp4attn_cuda PRIVATE ${TORCH_LIBRARIES} CUDA::cudart CUDA::cuda_driver)
|
|
327
|
+
|
|
328
|
+
Python_add_library(fp4quant_cuda MODULE WITH_SOABI
|
|
329
|
+
attn_qat_infer/quantization/fp4_quantization_4d.cu
|
|
330
|
+
)
|
|
331
|
+
target_include_directories(fp4quant_cuda PRIVATE ${ATTN_QAT_INFER_INCLUDE_DIRS})
|
|
332
|
+
target_compile_definitions(fp4quant_cuda PRIVATE TORCH_EXTENSION_NAME=fp4quant_cuda)
|
|
333
|
+
target_compile_options(fp4quant_cuda PRIVATE
|
|
334
|
+
$<$<COMPILE_LANGUAGE:CXX>:-O3 -std=c++17>
|
|
335
|
+
$<$<COMPILE_LANGUAGE:CUDA>:${ATTN_QAT_INFER_CUDA_FLAGS}>
|
|
336
|
+
)
|
|
337
|
+
set_target_properties(fp4quant_cuda PROPERTIES
|
|
338
|
+
CUDA_ARCHITECTURES "120a"
|
|
339
|
+
CXX_STANDARD 17
|
|
340
|
+
CUDA_STANDARD 17
|
|
341
|
+
)
|
|
342
|
+
target_link_libraries(fp4quant_cuda PRIVATE ${TORCH_LIBRARIES} CUDA::cudart CUDA::cuda_driver)
|
|
343
|
+
|
|
344
|
+
if(TORCH_PYTHON_LIBRARY_PATH)
|
|
345
|
+
target_link_libraries(fp4attn_cuda PRIVATE "${TORCH_PYTHON_LIBRARY_PATH}")
|
|
346
|
+
target_link_libraries(fp4quant_cuda PRIVATE "${TORCH_PYTHON_LIBRARY_PATH}")
|
|
347
|
+
endif()
|
|
348
|
+
|
|
349
|
+
install(TARGETS fp4attn_cuda LIBRARY DESTINATION .)
|
|
350
|
+
install(TARGETS fp4quant_cuda LIBRARY DESTINATION .)
|
|
351
|
+
endif()
|
|
352
|
+
|
|
@@ -2,5 +2,6 @@ include LICENSE
|
|
|
2
2
|
include README.md
|
|
3
3
|
include pyproject.toml
|
|
4
4
|
recursive-include python/fastvideo_kernel *.py
|
|
5
|
+
recursive-include attn_qat_infer *.py *.cu *.cuh *.cpp *.h
|
|
5
6
|
recursive-include csrc *.cu *.cuh *.cpp *.h
|
|
6
7
|
recursive-include include/tk *.cu *.cuh *.cpp *.h *.src
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: fastvideo-kernel
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0
|
|
4
4
|
Summary: Unified CUDA kernels for FastVideo
|
|
5
5
|
Author-Email: Hao AI Lab <contact@haoailab.com>
|
|
6
6
|
License: Apache License
|
|
@@ -195,7 +195,7 @@ Classifier: Environment :: GPU :: NVIDIA CUDA
|
|
|
195
195
|
Project-URL: Homepage, https://github.com/hao-ai-lab/FastVideo
|
|
196
196
|
Requires-Python: >=3.10
|
|
197
197
|
Requires-Dist: torch>=2.5.0
|
|
198
|
-
Requires-Dist: triton>=2.0.0
|
|
198
|
+
Requires-Dist: triton>=2.0.0; sys_platform == "linux"
|
|
199
199
|
Description-Content-Type: text/markdown
|
|
200
200
|
|
|
201
201
|
# FastVideo Kernel
|
|
@@ -207,6 +207,13 @@ CUDA kernels for FastVideo video generation.
|
|
|
207
207
|
### Standard Installation (Local Development)
|
|
208
208
|
This will automatically detect your GPU architecture. If an NVIDIA Hopper (H100/sm_90a) GPU is detected, ThunderKittens kernels will be enabled. Otherwise, they will be skipped, and the package will use Triton fallbacks at runtime.
|
|
209
209
|
|
|
210
|
+
Before installation, set CUDA toolchain paths:
|
|
211
|
+
|
|
212
|
+
```bash
|
|
213
|
+
export CUDA_HOME=/usr/local/cuda
|
|
214
|
+
export CUDACXX=$CUDA_HOME/bin/nvcc
|
|
215
|
+
```
|
|
216
|
+
|
|
210
217
|
```bash
|
|
211
218
|
git submodule update --init --recursive
|
|
212
219
|
cd fastvideo-kernel
|
|
@@ -221,6 +228,29 @@ cd fastvideo-kernel
|
|
|
221
228
|
./build.sh --rocm
|
|
222
229
|
```
|
|
223
230
|
|
|
231
|
+
### Optional: FA4 CuTe block-sparse backend (VSA-256 fastpath)
|
|
232
|
+
|
|
233
|
+
The VSA-256 fastpath (tile volume 256, on NVIDIA Blackwell / sm_100) routes to the
|
|
234
|
+
FlashAttention-4 CuTe-DSL block-sparse kernel exposed as `flash_attn.cute`. This is
|
|
235
|
+
an **optional** dependency: it is imported lazily, and `video_sparse_attn`
|
|
236
|
+
transparently falls back to the Triton backend when it is absent (so the package is
|
|
237
|
+
fully usable without it).
|
|
238
|
+
|
|
239
|
+
The symbols the fastpath needs (`flash_attn.cute.block_sparsity.BlockSparseTensorsTorch`,
|
|
240
|
+
`flash_attn.cute.interface._flash_attn_fwd`) are provided upstream by
|
|
241
|
+
[Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention). Pin to
|
|
242
|
+
commit `c19cd20e`: the wrapper targets that revision's `_flash_attn_fwd` signature
|
|
243
|
+
(`m_block_size` / `n_block_size`); later upstream revisions reshaped it into a
|
|
244
|
+
`tile_mn` tuple and are not drop-in compatible.
|
|
245
|
+
|
|
246
|
+
```bash
|
|
247
|
+
pip install "nvidia-cutlass-dsl>=4.5.0" torchvision
|
|
248
|
+
pip install "git+https://github.com/Dao-AILab/flash-attention.git@c19cd20e#subdirectory=flash_attn/cute"
|
|
249
|
+
```
|
|
250
|
+
|
|
251
|
+
The CuTe kernel JIT-compiles on first use. Verified on Blackwell (sm_100) against
|
|
252
|
+
`tests/test_vsa256_forward*.py`.
|
|
253
|
+
|
|
224
254
|
## Usage
|
|
225
255
|
|
|
226
256
|
### Sliding Tile Attention (STA) & Video Sparse Attention (VSA)
|
|
@@ -262,6 +292,8 @@ This package also includes kernels from [TurboDiffusion](https://github.com/thu-
|
|
|
262
292
|
- Any CUDA GPU for Triton-based fallbacks.
|
|
263
293
|
- **Build**:
|
|
264
294
|
- CUDA Toolkit 12.3+
|
|
295
|
+
- `CUDA_HOME` must be set (for example, `/usr/local/cuda`)
|
|
296
|
+
- `CUDACXX` must be set (for example, `$CUDA_HOME/bin/nvcc`)
|
|
265
297
|
- C++20 compatible compiler (GCC 10+, Clang 11+)
|
|
266
298
|
|
|
267
299
|
## Acknowledgement
|
|
@@ -7,6 +7,13 @@ CUDA kernels for FastVideo video generation.
|
|
|
7
7
|
### Standard Installation (Local Development)
|
|
8
8
|
This will automatically detect your GPU architecture. If an NVIDIA Hopper (H100/sm_90a) GPU is detected, ThunderKittens kernels will be enabled. Otherwise, they will be skipped, and the package will use Triton fallbacks at runtime.
|
|
9
9
|
|
|
10
|
+
Before installation, set CUDA toolchain paths:
|
|
11
|
+
|
|
12
|
+
```bash
|
|
13
|
+
export CUDA_HOME=/usr/local/cuda
|
|
14
|
+
export CUDACXX=$CUDA_HOME/bin/nvcc
|
|
15
|
+
```
|
|
16
|
+
|
|
10
17
|
```bash
|
|
11
18
|
git submodule update --init --recursive
|
|
12
19
|
cd fastvideo-kernel
|
|
@@ -21,6 +28,29 @@ cd fastvideo-kernel
|
|
|
21
28
|
./build.sh --rocm
|
|
22
29
|
```
|
|
23
30
|
|
|
31
|
+
### Optional: FA4 CuTe block-sparse backend (VSA-256 fastpath)
|
|
32
|
+
|
|
33
|
+
The VSA-256 fastpath (tile volume 256, on NVIDIA Blackwell / sm_100) routes to the
|
|
34
|
+
FlashAttention-4 CuTe-DSL block-sparse kernel exposed as `flash_attn.cute`. This is
|
|
35
|
+
an **optional** dependency: it is imported lazily, and `video_sparse_attn`
|
|
36
|
+
transparently falls back to the Triton backend when it is absent (so the package is
|
|
37
|
+
fully usable without it).
|
|
38
|
+
|
|
39
|
+
The symbols the fastpath needs (`flash_attn.cute.block_sparsity.BlockSparseTensorsTorch`,
|
|
40
|
+
`flash_attn.cute.interface._flash_attn_fwd`) are provided upstream by
|
|
41
|
+
[Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention). Pin to
|
|
42
|
+
commit `c19cd20e`: the wrapper targets that revision's `_flash_attn_fwd` signature
|
|
43
|
+
(`m_block_size` / `n_block_size`); later upstream revisions reshaped it into a
|
|
44
|
+
`tile_mn` tuple and are not drop-in compatible.
|
|
45
|
+
|
|
46
|
+
```bash
|
|
47
|
+
pip install "nvidia-cutlass-dsl>=4.5.0" torchvision
|
|
48
|
+
pip install "git+https://github.com/Dao-AILab/flash-attention.git@c19cd20e#subdirectory=flash_attn/cute"
|
|
49
|
+
```
|
|
50
|
+
|
|
51
|
+
The CuTe kernel JIT-compiles on first use. Verified on Blackwell (sm_100) against
|
|
52
|
+
`tests/test_vsa256_forward*.py`.
|
|
53
|
+
|
|
24
54
|
## Usage
|
|
25
55
|
|
|
26
56
|
### Sliding Tile Attention (STA) & Video Sparse Attention (VSA)
|
|
@@ -62,6 +92,8 @@ This package also includes kernels from [TurboDiffusion](https://github.com/thu-
|
|
|
62
92
|
- Any CUDA GPU for Triton-based fallbacks.
|
|
63
93
|
- **Build**:
|
|
64
94
|
- CUDA Toolkit 12.3+
|
|
95
|
+
- `CUDA_HOME` must be set (for example, `/usr/local/cuda`)
|
|
96
|
+
- `CUDACXX` must be set (for example, `$CUDA_HOME/bin/nvcc`)
|
|
65
97
|
- C++20 compatible compiler (GCC 10+, Clang 11+)
|
|
66
98
|
|
|
67
99
|
## Acknowledgement
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2025 by SageAttention team.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
from .api import sageattn_blackwell
|
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
# Modified from the original SageATtention3 code
|
|
2
|
+
"""
|
|
3
|
+
Copyright (c) 2025 by SageAttention team.
|
|
4
|
+
|
|
5
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
you may not use this file except in compliance with the License.
|
|
7
|
+
You may obtain a copy of the License at
|
|
8
|
+
|
|
9
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
|
|
11
|
+
Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
See the License for the specific language governing permissions and
|
|
15
|
+
limitations under the License.
|
|
16
|
+
"""
|
|
17
|
+
import torch
|
|
18
|
+
import triton
|
|
19
|
+
import triton.language as tl
|
|
20
|
+
import torch.nn.functional as F
|
|
21
|
+
from typing import Tuple
|
|
22
|
+
from torch.nn.functional import scaled_dot_product_attention as sdpa
|
|
23
|
+
import fp4attn_cuda
|
|
24
|
+
import fp4quant_cuda
|
|
25
|
+
|
|
26
|
+
# Centralized block size configuration for sageattn_blackwell kernels
|
|
27
|
+
# These should match the values in fastvideo/attention/backends/sageattn/blackwell/block_config.h
|
|
28
|
+
BLOCK_M = 128 # Block size for M dimension (query sequence length)
|
|
29
|
+
BLOCK_N = 128 # Block size for N dimension (key/value sequence length)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@triton.jit
|
|
33
|
+
def group_mean_kernel(
|
|
34
|
+
q_ptr,
|
|
35
|
+
q_out_ptr,
|
|
36
|
+
qm_out_ptr,
|
|
37
|
+
B, H, L, D: tl.constexpr,
|
|
38
|
+
stride_qb, stride_qh, stride_ql, stride_qd,
|
|
39
|
+
stride_qmb, stride_qmh, stride_qml, stride_qmd,
|
|
40
|
+
GROUP_SIZE: tl.constexpr
|
|
41
|
+
):
|
|
42
|
+
pid_b = tl.program_id(0)
|
|
43
|
+
pid_h = tl.program_id(1)
|
|
44
|
+
pid_group = tl.program_id(2)
|
|
45
|
+
|
|
46
|
+
group_start = pid_group * GROUP_SIZE
|
|
47
|
+
offsets = group_start + tl.arange(0, GROUP_SIZE)
|
|
48
|
+
|
|
49
|
+
q_offsets = pid_b * stride_qb + pid_h * stride_qh + offsets[:, None] * stride_ql + tl.arange(0, D)[None, :] * stride_qd
|
|
50
|
+
q_group = tl.load(q_ptr + q_offsets)
|
|
51
|
+
|
|
52
|
+
qm_group = tl.sum(q_group, axis=0) / GROUP_SIZE
|
|
53
|
+
|
|
54
|
+
q_group = q_group - qm_group
|
|
55
|
+
tl.store(q_out_ptr + q_offsets, q_group)
|
|
56
|
+
|
|
57
|
+
qm_offset = pid_b * stride_qmb + pid_h * stride_qmh + pid_group * stride_qml + tl.arange(0, D) * stride_qmd
|
|
58
|
+
tl.store(qm_out_ptr + qm_offset, qm_group)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def triton_group_mean(q: torch.Tensor):
|
|
62
|
+
B, H, L, D = q.shape
|
|
63
|
+
GROUP_SIZE = BLOCK_M
|
|
64
|
+
num_groups = L // GROUP_SIZE
|
|
65
|
+
|
|
66
|
+
q_out = torch.empty_like(q) # [B, H, L, D]
|
|
67
|
+
qm = torch.empty(B, H, num_groups, D, device=q.device, dtype=q.dtype)
|
|
68
|
+
|
|
69
|
+
grid = (B, H, num_groups)
|
|
70
|
+
|
|
71
|
+
group_mean_kernel[grid](
|
|
72
|
+
q, q_out, qm,
|
|
73
|
+
B, H, L, D,
|
|
74
|
+
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
|
75
|
+
qm.stride(0), qm.stride(1), qm.stride(2), qm.stride(3),
|
|
76
|
+
GROUP_SIZE=GROUP_SIZE
|
|
77
|
+
)
|
|
78
|
+
return q_out, qm
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def preprocess_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, per_block_mean: bool = True, enable_smoothing_q: bool = False, enable_smoothing_k: bool = False):
|
|
82
|
+
|
|
83
|
+
def pad_to_block_size(x):
|
|
84
|
+
L = x.size(2)
|
|
85
|
+
pad_len = (BLOCK_M - L % BLOCK_M) % BLOCK_M
|
|
86
|
+
if pad_len == 0:
|
|
87
|
+
return x.contiguous()
|
|
88
|
+
return F.pad(x, (0, 0, 0, pad_len), value=0).contiguous()
|
|
89
|
+
|
|
90
|
+
if enable_smoothing_k:
|
|
91
|
+
k -= k.mean(dim=-2, keepdim=True)
|
|
92
|
+
q, k, v = map(lambda x: pad_to_block_size(x), [q, k, v])
|
|
93
|
+
if per_block_mean and enable_smoothing_q:
|
|
94
|
+
q, qm = triton_group_mean(q)
|
|
95
|
+
elif enable_smoothing_q:
|
|
96
|
+
qm = q.mean(dim=-2, keepdim=True)
|
|
97
|
+
q = q - qm
|
|
98
|
+
if enable_smoothing_q:
|
|
99
|
+
delta_s = torch.matmul(qm, k.transpose(-2, -1)).to(torch.float32).contiguous()
|
|
100
|
+
else: # used to disable q smoothing
|
|
101
|
+
B, H, L, D = q.shape
|
|
102
|
+
delta_s = torch.zeros((B, H, L // BLOCK_M, k.shape[2]), device=q.device, dtype=torch.float32)
|
|
103
|
+
|
|
104
|
+
return q, k, v, delta_s
|
|
105
|
+
|
|
106
|
+
def scale_and_quant_fp4(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
107
|
+
assert x.ndim == 4
|
|
108
|
+
B, H, N, D = x.shape
|
|
109
|
+
packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8)
|
|
110
|
+
fp8_scale = torch.empty((B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn)
|
|
111
|
+
fp4quant_cuda.scaled_fp4_quant(x, packed_fp4, fp8_scale, 1)
|
|
112
|
+
return packed_fp4, fp8_scale
|
|
113
|
+
|
|
114
|
+
def scale_and_quant_fp4_permute(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
115
|
+
assert x.ndim == 4
|
|
116
|
+
B, H, N, D = x.shape
|
|
117
|
+
packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8)
|
|
118
|
+
fp8_scale = torch.empty((B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn)
|
|
119
|
+
fp4quant_cuda.scaled_fp4_quant_permute(x, packed_fp4, fp8_scale, 1)
|
|
120
|
+
return packed_fp4, fp8_scale
|
|
121
|
+
|
|
122
|
+
def scale_and_quant_fp4_transpose(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
123
|
+
assert x.ndim == 4
|
|
124
|
+
B, H, N, D = x.shape
|
|
125
|
+
packed_fp4 = torch.empty((B, H, D, N // 2), device=x.device, dtype=torch.uint8)
|
|
126
|
+
fp8_scale = torch.empty((B, H, D, N // 16), device=x.device, dtype=torch.float8_e4m3fn)
|
|
127
|
+
fp4quant_cuda.scaled_fp4_quant_trans(x, packed_fp4, fp8_scale, 1)
|
|
128
|
+
return packed_fp4, fp8_scale
|
|
129
|
+
|
|
130
|
+
def blockscaled_fp4_attn(qlist: Tuple,
|
|
131
|
+
klist: Tuple,
|
|
132
|
+
vlist: Tuple,
|
|
133
|
+
delta_s: torch.Tensor,
|
|
134
|
+
KL: int,
|
|
135
|
+
is_causal: bool = False,
|
|
136
|
+
per_block_mean: bool = True,
|
|
137
|
+
is_bf16: bool = True,
|
|
138
|
+
single_level_p_quant: bool = False,
|
|
139
|
+
sm_scale: float | None = None
|
|
140
|
+
):
|
|
141
|
+
softmax_scale = sm_scale if sm_scale is not None else (qlist[0].shape[-1] * 2) ** (-0.5)
|
|
142
|
+
return fp4attn_cuda.fwd(qlist[0], klist[0], vlist[0], qlist[1], klist[1], vlist[1], delta_s, KL, None, softmax_scale, is_causal, per_block_mean, is_bf16, single_level_p_quant)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def sageattn_blackwell(q, k, v, attn_mask = None, is_causal = False, per_block_mean = True, single_level_p_quant = True, sm_scale: float | None = None, **kwargs):
|
|
146
|
+
"""
|
|
147
|
+
SageAttention3 Blackwell kernel for FP4 attention.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
q: Query tensor [B, H, L, D]
|
|
151
|
+
k: Key tensor [B, H, L, D]
|
|
152
|
+
v: Value tensor [B, H, L, D]
|
|
153
|
+
attn_mask: Attention mask (not used)
|
|
154
|
+
is_causal: Whether to use causal masking
|
|
155
|
+
per_block_mean: Whether to use per-block mean for Q smoothing
|
|
156
|
+
single_level_p_quant: If True, use single-level quantization: s_P2, P̂_2 = φ(P̃) directly
|
|
157
|
+
(standard per-block FP4 quantization like V, no s_P1).
|
|
158
|
+
If False (default), use two-level quantization:
|
|
159
|
+
s_P1 = rowmax(P̃)/(448×6), then s_P2, P̂_2 = φ(P̃/s_P1).
|
|
160
|
+
sm_scale: Softmax scale to pass through to the CUDA kernel. If None,
|
|
161
|
+
defaults to the kernel's 1/sqrt(D) scale.
|
|
162
|
+
**kwargs: Additional arguments (ignored)
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
Output tensor [B, H, L, D]
|
|
166
|
+
"""
|
|
167
|
+
if q.size(-1) >= 256:
|
|
168
|
+
print(f"Unsupported Headdim {q.size(-1)}")
|
|
169
|
+
return sdpa(q, k, v, is_causal = is_causal)
|
|
170
|
+
QL = q.size(2)
|
|
171
|
+
KL = k.size(2)
|
|
172
|
+
is_bf16 = q.dtype == torch.bfloat16
|
|
173
|
+
q, k, v, delta_s = preprocess_qkv(q, k, v, per_block_mean)
|
|
174
|
+
qlist_from_cuda = scale_and_quant_fp4(q)
|
|
175
|
+
klist_from_cuda = scale_and_quant_fp4_permute(k)
|
|
176
|
+
vlist_from_cuda = scale_and_quant_fp4_transpose(v)
|
|
177
|
+
o_fp4 = blockscaled_fp4_attn(
|
|
178
|
+
qlist_from_cuda,
|
|
179
|
+
klist_from_cuda,
|
|
180
|
+
vlist_from_cuda,
|
|
181
|
+
delta_s,
|
|
182
|
+
KL,
|
|
183
|
+
is_causal,
|
|
184
|
+
per_block_mean,
|
|
185
|
+
is_bf16,
|
|
186
|
+
single_level_p_quant,
|
|
187
|
+
sm_scale
|
|
188
|
+
)[0][:, :, :QL, :].contiguous()
|
|
189
|
+
return o_fp4
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "3.0.0.b1"
|