fastvideo-kernel 0.3.0__tar.gz → 0.3.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/CMakeLists.txt +77 -1
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/PKG-INFO +5 -5
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/README.md +4 -4
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/build.sh +7 -2
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/gemm/gemm.cu +7 -3
- fastvideo_kernel-0.3.2/dist/fastvideo_kernel-0.3.2-cp312-cp312-manylinux_2_34_aarch64.manylinux_2_35_aarch64.whl +0 -0
- fastvideo_kernel-0.3.0/dist/fastvideo_kernel-0.3.0-cp312-cp312-manylinux_2_34_x86_64.manylinux_2_35_x86_64.whl → fastvideo_kernel-0.3.2/dist/fastvideo_kernel-0.3.2-cp312-cp312-manylinux_2_34_x86_64.manylinux_2_35_x86_64.whl +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/pyproject.toml +1 -1
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/__init__.py +15 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/block_sparse_attn_cute_fwd.py +4 -5
- fastvideo_kernel-0.3.2/python/fastvideo_kernel/version.py +1 -0
- fastvideo_kernel-0.3.2/python/fastvideo_kernel/vsa_utils.py +160 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/tests/test_attn_qat_infer.py +5 -0
- fastvideo_kernel-0.3.2/tests/test_vsa_utils.py +276 -0
- fastvideo_kernel-0.3.0/python/fastvideo_kernel/version.py +0 -1
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/LICENSE +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/MANIFEST.in +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/__init__.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/api.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/__init__.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/api.cu +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/block_config.h +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/block_info.h +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/blockscaled_layout.h +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/cute_extension.h +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/epilogue_tma_ws.h +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/kernel_traits.h +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/kernel_ws.h +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/launch.h +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/mainloop_tma_ws.h +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/named_barrier.h +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/params.h +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/softmax_fused.h +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/static_switch.h +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/tile_scheduler.h +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/utils.h +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/quantization/__init__.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/quantization/bench/bench_quant_k.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/quantization/bench/bench_quant_q.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/quantization/bench/bench_quant_v.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/quantization/bench/bench_utils.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/quantization/cuda_utils.h +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/quantization/fp4_quantization_4d.cu +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/benchmarks/bench_fused_compress_topk.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/benchmarks/bench_vsa.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/csrc/attention/block_sparse_h100.cu +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/csrc/attention/st_attn_h100.cu +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/csrc/common_extension.cpp +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/common/common.hpp +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/common/launch.hpp +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/common/load.hpp +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/common/store.hpp +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/gemm/kernel.hpp +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/gemm/launch.hpp +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/gemm/utils.hpp +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/norm/layernorm.cu +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/norm/layernorm.hpp +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/norm/rmsnorm.cu +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/norm/rmsnorm.hpp +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/quant/quant.cu +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/quant/quant.hpp +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/block_sparse_attn.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/block_sparse_attn_256.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/block_sparse_attn_varlen.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/ops.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/triton_kernels/attn_qat_train.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/triton_kernels/block_sparse_attn_triton.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/triton_kernels/fused_attention.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/triton_kernels/fused_compress_topk.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/triton_kernels/index.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/triton_kernels/nvfp4_utils.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/triton_kernels/quant_utils.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/triton_kernels/sla_triton.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/triton_kernels/st_attn_triton.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/turbodiffusion_ops.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/vmoba.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/tests/__init__.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/tests/support_flex_sta.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/tests/test_attn_qat_train.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/tests/test_fused_compress_topk.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/tests/test_sta.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/tests/test_turbodiffusion.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/tests/test_vmoba_correctness.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/tests/test_vsa.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/tests/test_vsa256_forward.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/tests/test_vsa256_forward_cross.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/tests/test_vsa256_forward_vbs.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/tests/test_vsa256_triton.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/tests/test_vsa_forward.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/tests/test_vsa_varlen.py +0 -0
- {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/tests/utils.py +0 -0
|
@@ -1,6 +1,13 @@
|
|
|
1
1
|
cmake_minimum_required(VERSION 3.26 FATAL_ERROR)
|
|
2
2
|
project(fastvideo-kernel LANGUAGES CXX)
|
|
3
3
|
|
|
4
|
+
# Capture any caller-provided -DCMAKE_CUDA_ARCHITECTURES *before* enable_language(CUDA)
|
|
5
|
+
# auto-populates it with CMake's built-in default (an old arch, e.g. sm_75 on CUDA 13).
|
|
6
|
+
# torch's cmake actually ignores CMAKE_CUDA_ARCHITECTURES (it drives arch selection via
|
|
7
|
+
# TORCH_CUDA_ARCH_LIST), so we only use this captured value to honor an explicit pin by
|
|
8
|
+
# translating it into TORCH_CUDA_ARCH_LIST below.
|
|
9
|
+
set(_FASTVIDEO_USER_CUDA_ARCH "${CMAKE_CUDA_ARCHITECTURES}")
|
|
10
|
+
|
|
4
11
|
# Prefer environment variable (used by CI or uv pip install git+repo_addr) if CMake var is not explicitly set.
|
|
5
12
|
if(NOT DEFINED GPU_BACKEND AND DEFINED ENV{GPU_BACKEND})
|
|
6
13
|
set(GPU_BACKEND "$ENV{GPU_BACKEND}")
|
|
@@ -19,6 +26,71 @@ endif()
|
|
|
19
26
|
# Find Python and Torch
|
|
20
27
|
find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)
|
|
21
28
|
|
|
29
|
+
# ---------------------------------------------------------------------------
|
|
30
|
+
# Resolve the target CUDA architecture, BEFORE find_package(Torch) below.
|
|
31
|
+
#
|
|
32
|
+
# torch's cmake (Caffe2 public/cuda.cmake) takes over arch selection: it emits
|
|
33
|
+
# the real -gencode flags from TORCH_CUDA_ARCH_LIST and forces
|
|
34
|
+
# CMAKE_CUDA_ARCHITECTURES to OFF. So the *effective* arch is whatever
|
|
35
|
+
# TORCH_CUDA_ARCH_LIST is when find_package(Torch) runs. build.sh exports it;
|
|
36
|
+
# standards-based builds (pip / uv pip install, sdist) don't, and torch then
|
|
37
|
+
# auto-detects an arch that does not match the GPU -- the kernels build but fail
|
|
38
|
+
# at runtime ("no kernel image is available for execution on the device").
|
|
39
|
+
# Resolve it here when absent (mirrors build.sh): honor a pinned
|
|
40
|
+
# CMAKE_CUDA_ARCHITECTURES if given, else probe the visible GPU with torch.
|
|
41
|
+
# ---------------------------------------------------------------------------
|
|
42
|
+
if(NOT GPU_BACKEND STREQUAL "ROCM")
|
|
43
|
+
if(DEFINED ENV{TORCH_CUDA_ARCH_LIST})
|
|
44
|
+
message(STATUS "CUDA arch: TORCH_CUDA_ARCH_LIST=$ENV{TORCH_CUDA_ARCH_LIST} (from environment)")
|
|
45
|
+
elseif(TORCH_CUDA_ARCH_LIST)
|
|
46
|
+
set(ENV{TORCH_CUDA_ARCH_LIST} "${TORCH_CUDA_ARCH_LIST}")
|
|
47
|
+
message(STATUS "CUDA arch: TORCH_CUDA_ARCH_LIST=${TORCH_CUDA_ARCH_LIST} (from cmake)")
|
|
48
|
+
else()
|
|
49
|
+
set(_FV_ARCH_LIST "")
|
|
50
|
+
if(_FASTVIDEO_USER_CUDA_ARCH)
|
|
51
|
+
# Caller pinned -DCMAKE_CUDA_ARCHITECTURES (which torch ignores); translate it
|
|
52
|
+
# to the TORCH_CUDA_ARCH_LIST spelling: "121" -> "12.1", "90a" -> "9.0a".
|
|
53
|
+
foreach(_fv_arch IN LISTS _FASTVIDEO_USER_CUDA_ARCH)
|
|
54
|
+
string(REGEX MATCH "[af]$" _fv_suffix "${_fv_arch}")
|
|
55
|
+
string(REGEX REPLACE "[af]$" "" _fv_num "${_fv_arch}")
|
|
56
|
+
string(REGEX REPLACE "(.)$" ".\\1" _fv_num "${_fv_num}") # dot before the last digit
|
|
57
|
+
list(APPEND _FV_ARCH_LIST "${_fv_num}${_fv_suffix}")
|
|
58
|
+
endforeach()
|
|
59
|
+
message(STATUS "CUDA arch: TORCH_CUDA_ARCH_LIST=${_FV_ARCH_LIST} (from -DCMAKE_CUDA_ARCHITECTURES=${_FASTVIDEO_USER_CUDA_ARCH})")
|
|
60
|
+
else()
|
|
61
|
+
# Best-effort probe of the visible GPU (mirrors build.sh detect_with_torch).
|
|
62
|
+
execute_process(
|
|
63
|
+
COMMAND "${Python_EXECUTABLE}" -c "import torch; assert torch.cuda.is_available(); mj, mn = torch.cuda.get_device_capability(0); print(f'{mj}.{mn}a' if (mj, mn) in ((9, 0), (12, 0)) else f'{mj}.{mn}')"
|
|
64
|
+
OUTPUT_VARIABLE _FV_ARCH_LIST
|
|
65
|
+
OUTPUT_STRIP_TRAILING_WHITESPACE
|
|
66
|
+
RESULT_VARIABLE _fv_detect_rc
|
|
67
|
+
ERROR_QUIET
|
|
68
|
+
)
|
|
69
|
+
if(_fv_detect_rc EQUAL 0 AND _FV_ARCH_LIST)
|
|
70
|
+
message(STATUS "CUDA arch: TORCH_CUDA_ARCH_LIST=${_FV_ARCH_LIST} (detected via torch, live GPU)")
|
|
71
|
+
else()
|
|
72
|
+
set(_FV_ARCH_LIST "")
|
|
73
|
+
endif()
|
|
74
|
+
endif()
|
|
75
|
+
|
|
76
|
+
if(_FV_ARCH_LIST)
|
|
77
|
+
set(TORCH_CUDA_ARCH_LIST "${_FV_ARCH_LIST}")
|
|
78
|
+
set(ENV{TORCH_CUDA_ARCH_LIST} "${_FV_ARCH_LIST}")
|
|
79
|
+
else()
|
|
80
|
+
message(FATAL_ERROR
|
|
81
|
+
"fastvideo-kernel: could not determine the target CUDA architecture.\n"
|
|
82
|
+
"Refusing to let torch auto-detect an arch that may not run on this GPU. "
|
|
83
|
+
"Fix with one of:\n"
|
|
84
|
+
" - set TORCH_CUDA_ARCH_LIST (e.g. 12.1, or 9.0a for Hopper), or\n"
|
|
85
|
+
" - pass -DCMAKE_CUDA_ARCHITECTURES=<arch> (e.g. 121), or\n"
|
|
86
|
+
" - build where the target GPU is visible to torch.\n"
|
|
87
|
+
"Note: 'pip/uv pip install' builds under build isolation, which hides the "
|
|
88
|
+
"GPU; set TORCH_CUDA_ARCH_LIST or add --no-build-isolation. "
|
|
89
|
+
"fastvideo-kernel/build.sh sets all of this for you.")
|
|
90
|
+
endif()
|
|
91
|
+
endif()
|
|
92
|
+
endif()
|
|
93
|
+
|
|
22
94
|
# Robustly find Torch include paths using Python
|
|
23
95
|
execute_process(
|
|
24
96
|
COMMAND "${Python_EXECUTABLE}" -c "import torch; from torch.utils.cpp_extension import include_paths; print(';'.join(include_paths()))"
|
|
@@ -191,6 +263,11 @@ set(CUDA_FLAGS
|
|
|
191
263
|
"--expt-relaxed-constexpr"
|
|
192
264
|
"-Xcompiler=-fno-strict-aliasing"
|
|
193
265
|
"-Xcompiler=-fPIC"
|
|
266
|
+
# ARM/aarch64 defaults `char` to unsigned, but ThunderKittens headers assume the
|
|
267
|
+
# x86 signed-char behavior (else base_types.cuh hits "narrowing conversion from
|
|
268
|
+
# char to signed char"). Force signed char so TK compiles on Grace Hopper; this
|
|
269
|
+
# is a no-op on x86_64, where char is already signed.
|
|
270
|
+
"-Xcompiler=-fsigned-char"
|
|
194
271
|
"-DTORCH_COMPILE"
|
|
195
272
|
"-Xnvlink=--verbose"
|
|
196
273
|
"-Xptxas=--verbose"
|
|
@@ -349,4 +426,3 @@ if(ENABLE_ATTN_QAT_INFER)
|
|
|
349
426
|
install(TARGETS fp4attn_cuda LIBRARY DESTINATION .)
|
|
350
427
|
install(TARGETS fp4quant_cuda LIBRARY DESTINATION .)
|
|
351
428
|
endif()
|
|
352
|
-
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: fastvideo-kernel
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.2
|
|
4
4
|
Summary: Unified CUDA kernels for FastVideo
|
|
5
5
|
Author-Email: Hao AI Lab <contact@haoailab.com>
|
|
6
6
|
License: Apache License
|
|
@@ -239,13 +239,13 @@ fully usable without it).
|
|
|
239
239
|
The symbols the fastpath needs (`flash_attn.cute.block_sparsity.BlockSparseTensorsTorch`,
|
|
240
240
|
`flash_attn.cute.interface._flash_attn_fwd`) are provided upstream by
|
|
241
241
|
[Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention). Pin to
|
|
242
|
-
commit `
|
|
243
|
-
|
|
244
|
-
|
|
242
|
+
commit `940cd9680f3315f2f06b43ab5bea2c2cf2d96806`, the revision FastVideo pins as
|
|
243
|
+
the `flash-attn-4` source in the repo-root `pyproject.toml`; other revisions may
|
|
244
|
+
have an incompatible `_flash_attn_fwd` signature.
|
|
245
245
|
|
|
246
246
|
```bash
|
|
247
247
|
pip install "nvidia-cutlass-dsl>=4.5.0" torchvision
|
|
248
|
-
pip install "git+https://github.com/Dao-AILab/flash-attention.git@
|
|
248
|
+
pip install "git+https://github.com/Dao-AILab/flash-attention.git@940cd9680f3315f2f06b43ab5bea2c2cf2d96806#subdirectory=flash_attn/cute"
|
|
249
249
|
```
|
|
250
250
|
|
|
251
251
|
The CuTe kernel JIT-compiles on first use. Verified on Blackwell (sm_100) against
|
|
@@ -39,13 +39,13 @@ fully usable without it).
|
|
|
39
39
|
The symbols the fastpath needs (`flash_attn.cute.block_sparsity.BlockSparseTensorsTorch`,
|
|
40
40
|
`flash_attn.cute.interface._flash_attn_fwd`) are provided upstream by
|
|
41
41
|
[Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention). Pin to
|
|
42
|
-
commit `
|
|
43
|
-
|
|
44
|
-
|
|
42
|
+
commit `940cd9680f3315f2f06b43ab5bea2c2cf2d96806`, the revision FastVideo pins as
|
|
43
|
+
the `flash-attn-4` source in the repo-root `pyproject.toml`; other revisions may
|
|
44
|
+
have an incompatible `_flash_attn_fwd` signature.
|
|
45
45
|
|
|
46
46
|
```bash
|
|
47
47
|
pip install "nvidia-cutlass-dsl>=4.5.0" torchvision
|
|
48
|
-
pip install "git+https://github.com/Dao-AILab/flash-attention.git@
|
|
48
|
+
pip install "git+https://github.com/Dao-AILab/flash-attention.git@940cd9680f3315f2f06b43ab5bea2c2cf2d96806#subdirectory=flash_attn/cute"
|
|
49
49
|
```
|
|
50
50
|
|
|
51
51
|
The CuTe kernel JIT-compiles on first use. Verified on Blackwell (sm_100) against
|
|
@@ -39,8 +39,13 @@ if [[ -n "${CONDA_PREFIX:-}" ]]; then
|
|
|
39
39
|
unset _need_clean _host_arch
|
|
40
40
|
fi
|
|
41
41
|
|
|
42
|
-
# Ensure
|
|
43
|
-
|
|
42
|
+
# Ensure only the kernel's required headers are initialized. A repository-wide
|
|
43
|
+
# update also clones the unrelated VBench evaluation submodule. Skip outside a
|
|
44
|
+
# git checkout (e.g. Docker contexts that exclude .git), where the submodule
|
|
45
|
+
# contents must already be present.
|
|
46
|
+
if git rev-parse --git-dir >/dev/null 2>&1; then
|
|
47
|
+
git submodule update --init --recursive include/cutlass include/tk
|
|
48
|
+
fi
|
|
44
49
|
|
|
45
50
|
# Install build dependencies
|
|
46
51
|
uv pip install scikit-build-core cmake ninja
|
|
@@ -25,12 +25,16 @@
|
|
|
25
25
|
#include "gemm/launch.hpp"
|
|
26
26
|
|
|
27
27
|
void int8_gemm(
|
|
28
|
-
at::Tensor const& A, at::Tensor const& A_S,
|
|
29
|
-
at::Tensor const& B, at::Tensor const& B_S,
|
|
28
|
+
at::Tensor const& A, at::Tensor const& A_S,
|
|
29
|
+
at::Tensor const& B, at::Tensor const& B_S,
|
|
30
30
|
torch::Tensor& C
|
|
31
31
|
) {
|
|
32
32
|
|
|
33
|
-
|
|
33
|
+
// The kernel dereferences raw pointers; a CPU tensor here (e.g. an Int8Linear
|
|
34
|
+
// never moved to CUDA) would otherwise fail as an illegal memory access.
|
|
35
|
+
TORCH_CHECK(A.is_cuda() && A_S.is_cuda() && B.is_cuda() && B_S.is_cuda() && C.is_cuda(),
|
|
36
|
+
"int8_gemm: all tensors must be on CUDA (move Int8Linear to CUDA before forward)");
|
|
37
|
+
|
|
34
38
|
static constexpr int swizzle_dir = 1;
|
|
35
39
|
static constexpr int swizzle_size_log = 5;
|
|
36
40
|
|
|
Binary file
|
|
@@ -29,6 +29,15 @@ from fastvideo_kernel.block_sparse_attn_varlen import (
|
|
|
29
29
|
block_sparse_attn_varlen,
|
|
30
30
|
)
|
|
31
31
|
|
|
32
|
+
from fastvideo_kernel.vsa_utils import (
|
|
33
|
+
VSA_TILE_SIZE,
|
|
34
|
+
get_tile_partition_indices,
|
|
35
|
+
get_reverse_tile_partition_indices,
|
|
36
|
+
construct_variable_block_sizes,
|
|
37
|
+
get_non_pad_index,
|
|
38
|
+
build_vsa_metadata,
|
|
39
|
+
)
|
|
40
|
+
|
|
32
41
|
__all__ = [
|
|
33
42
|
"sliding_tile_attention",
|
|
34
43
|
"video_sparse_attn",
|
|
@@ -44,5 +53,11 @@ __all__ = [
|
|
|
44
53
|
"FastLayerNorm",
|
|
45
54
|
"int8_linear",
|
|
46
55
|
"int8_quant",
|
|
56
|
+
"VSA_TILE_SIZE",
|
|
57
|
+
"get_tile_partition_indices",
|
|
58
|
+
"get_reverse_tile_partition_indices",
|
|
59
|
+
"construct_variable_block_sizes",
|
|
60
|
+
"get_non_pad_index",
|
|
61
|
+
"build_vsa_metadata",
|
|
47
62
|
"__version__",
|
|
48
63
|
]
|
|
@@ -46,8 +46,7 @@ def _load_fa4_cute():
|
|
|
46
46
|
return BlockSparseTensorsTorch, _flash_attn_fwd
|
|
47
47
|
|
|
48
48
|
|
|
49
|
-
#
|
|
50
|
-
# Q-side tile, kv_block_size comes from the caller's VSA logical KV block.
|
|
49
|
+
# Q-side tile size; kv_block_size comes from the caller's VSA logical KV block.
|
|
51
50
|
_M_BLOCK_SIZE_DEFAULT = 128
|
|
52
51
|
|
|
53
52
|
|
|
@@ -182,18 +181,18 @@ def _cute_forward(
|
|
|
182
181
|
block_size=(q_sparse_block_size, kv_block_size),
|
|
183
182
|
)
|
|
184
183
|
|
|
184
|
+
# _flash_attn_fwd returns (out, lse, p, row_max); keep the first two.
|
|
185
185
|
out, lse = _flash_attn_fwd(
|
|
186
186
|
q_bshd,
|
|
187
187
|
k_bshd,
|
|
188
188
|
v_bshd,
|
|
189
|
-
|
|
190
|
-
n_block_size=kv_block_size,
|
|
189
|
+
tile_mn=(_M_BLOCK_SIZE_DEFAULT, kv_block_size),
|
|
191
190
|
mask_mod=_build_vbs_mask_mod(kv_block_size),
|
|
192
191
|
block_sparse_tensors=sparse_tensors,
|
|
193
192
|
aux_tensors=[variable_block_sizes],
|
|
194
193
|
causal=False,
|
|
195
194
|
return_lse=True,
|
|
196
|
-
)
|
|
195
|
+
)[:2]
|
|
197
196
|
return out, lse
|
|
198
197
|
|
|
199
198
|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.3.2"
|
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
"""VSA metadata utilities — standalone, no fastvideo framework dependency.
|
|
2
|
+
|
|
3
|
+
Provides the tile-partition index helpers and variable-block-size
|
|
4
|
+
computations that are needed to call `video_sparse_attn` or
|
|
5
|
+
`block_sparse_attn_from_indices` without depending on the full
|
|
6
|
+
fastvideo package.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import functools
|
|
12
|
+
import math
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
|
|
16
|
+
VSA_TILE_SIZE = (4, 4, 4)
|
|
17
|
+
_SUPPORTED_VSA_BLOCK_VOLUMES = (64, 256)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _canonicalize_device(device: torch.device | str) -> torch.device:
|
|
21
|
+
"""Resolve an indexless CUDA device before it is used as a cache key."""
|
|
22
|
+
device = torch.device(device)
|
|
23
|
+
if device.type == "cuda" and device.index is None:
|
|
24
|
+
return torch.device("cuda", torch.cuda.current_device())
|
|
25
|
+
return device
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@functools.lru_cache(maxsize=10)
|
|
29
|
+
def get_tile_partition_indices(
|
|
30
|
+
dit_seq_shape: tuple[int, int, int],
|
|
31
|
+
tile_size: tuple[int, int, int],
|
|
32
|
+
device: torch.device,
|
|
33
|
+
) -> torch.LongTensor:
|
|
34
|
+
"""Map raster-order token indices to tile-contiguous order.
|
|
35
|
+
|
|
36
|
+
Groups spatially adjacent tokens into (ts_t x ts_h x ts_w) tiles
|
|
37
|
+
so that each tile's tokens are contiguous in the output.
|
|
38
|
+
"""
|
|
39
|
+
T, H, W = dit_seq_shape
|
|
40
|
+
ts, hs, ws = tile_size
|
|
41
|
+
indices = torch.arange(T * H * W, device=device, dtype=torch.long).reshape(T, H, W)
|
|
42
|
+
ls = []
|
|
43
|
+
for t in range(math.ceil(T / ts)):
|
|
44
|
+
for h in range(math.ceil(H / hs)):
|
|
45
|
+
for w in range(math.ceil(W / ws)):
|
|
46
|
+
ls.append(indices[
|
|
47
|
+
t * ts:min(t * ts + ts, T),
|
|
48
|
+
h * hs:min(h * hs + hs, H),
|
|
49
|
+
w * ws:min(w * ws + ws, W),
|
|
50
|
+
].flatten())
|
|
51
|
+
return torch.cat(ls, dim=0)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@functools.lru_cache(maxsize=10)
|
|
55
|
+
def get_reverse_tile_partition_indices(
|
|
56
|
+
dit_seq_shape: tuple[int, int, int],
|
|
57
|
+
tile_size: tuple[int, int, int],
|
|
58
|
+
device: torch.device,
|
|
59
|
+
) -> torch.LongTensor:
|
|
60
|
+
"""Inverse of get_tile_partition_indices: tile order back to raster."""
|
|
61
|
+
return torch.argsort(get_tile_partition_indices(dit_seq_shape, tile_size, device))
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@functools.lru_cache(maxsize=10)
|
|
65
|
+
def construct_variable_block_sizes(
|
|
66
|
+
dit_seq_shape: tuple[int, int, int],
|
|
67
|
+
num_tiles: tuple[int, int, int],
|
|
68
|
+
device: torch.device,
|
|
69
|
+
tile_size: tuple[int, int, int] = VSA_TILE_SIZE,
|
|
70
|
+
) -> torch.LongTensor:
|
|
71
|
+
"""Compute the number of valid tokens in each tile.
|
|
72
|
+
|
|
73
|
+
Tiles at the boundary of each dimension may contain fewer tokens
|
|
74
|
+
when the video shape is not evenly divisible by tile_size.
|
|
75
|
+
"""
|
|
76
|
+
t, h, w = dit_seq_shape
|
|
77
|
+
ts_t, ts_h, ts_w = tile_size
|
|
78
|
+
n_t, n_h, n_w = num_tiles
|
|
79
|
+
|
|
80
|
+
def _sizes(dim_len: int, tile: int, n: int) -> torch.LongTensor:
|
|
81
|
+
sizes = torch.full((n,), tile, dtype=torch.int, device=device)
|
|
82
|
+
remainder = dim_len - (n - 1) * tile
|
|
83
|
+
sizes[-1] = remainder if remainder > 0 else tile
|
|
84
|
+
return sizes
|
|
85
|
+
|
|
86
|
+
t_sizes = _sizes(t, ts_t, n_t)
|
|
87
|
+
h_sizes = _sizes(h, ts_h, n_h)
|
|
88
|
+
w_sizes = _sizes(w, ts_w, n_w)
|
|
89
|
+
|
|
90
|
+
return (
|
|
91
|
+
t_sizes[:, None, None]
|
|
92
|
+
* h_sizes[None, :, None]
|
|
93
|
+
* w_sizes[None, None, :]
|
|
94
|
+
).reshape(-1)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def get_non_pad_index(
|
|
98
|
+
variable_block_sizes: torch.LongTensor,
|
|
99
|
+
max_block_size: int,
|
|
100
|
+
) -> torch.LongTensor:
|
|
101
|
+
"""Find positions of real tokens within a block-padded layout.
|
|
102
|
+
|
|
103
|
+
Each block occupies max_block_size slots. This returns the flat
|
|
104
|
+
indices of the valid (non-padding) positions.
|
|
105
|
+
"""
|
|
106
|
+
n_win = variable_block_sizes.shape[0]
|
|
107
|
+
device = variable_block_sizes.device
|
|
108
|
+
starts_pad = torch.arange(n_win, device=device) * max_block_size
|
|
109
|
+
index_pad = starts_pad[:, None] + torch.arange(max_block_size, device=device)[None, :]
|
|
110
|
+
index_mask = torch.arange(max_block_size, device=device)[None, :] < variable_block_sizes[:, None]
|
|
111
|
+
return index_pad[index_mask]
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def build_vsa_metadata(
|
|
115
|
+
dit_seq_shape: tuple[int, int, int],
|
|
116
|
+
tile_size: tuple[int, int, int] = VSA_TILE_SIZE,
|
|
117
|
+
device: torch.device | str = "cuda",
|
|
118
|
+
) -> dict:
|
|
119
|
+
"""Build all VSA metadata from a video latent shape in one call.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
dit_seq_shape: (T, H, W) — temporal frames, spatial height, width.
|
|
123
|
+
tile_size: (ts_t, ts_h, ts_w) — tokens per tile in each dimension.
|
|
124
|
+
The resulting tile volume must be supported by the VSA kernels.
|
|
125
|
+
device: Target device for index tensors.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
Dict with keys: tile_partition_indices, reverse_tile_partition_indices,
|
|
129
|
+
variable_block_sizes, non_pad_index, num_tiles, max_block_size.
|
|
130
|
+
"""
|
|
131
|
+
device = _canonicalize_device(device)
|
|
132
|
+
|
|
133
|
+
T, H, W = dit_seq_shape
|
|
134
|
+
ts_t, ts_h, ts_w = tile_size
|
|
135
|
+
max_block_size = math.prod(tile_size)
|
|
136
|
+
if max_block_size not in _SUPPORTED_VSA_BLOCK_VOLUMES:
|
|
137
|
+
raise ValueError(
|
|
138
|
+
f"Unsupported VSA tile volume {max_block_size} for tile_size={tile_size}; "
|
|
139
|
+
f"supported volumes are {_SUPPORTED_VSA_BLOCK_VOLUMES}."
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
num_tiles = (
|
|
143
|
+
math.ceil(T / ts_t),
|
|
144
|
+
math.ceil(H / ts_h),
|
|
145
|
+
math.ceil(W / ts_w),
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
tile_indices = get_tile_partition_indices(dit_seq_shape, tile_size, device)
|
|
149
|
+
reverse_tile_indices = get_reverse_tile_partition_indices(dit_seq_shape, tile_size, device)
|
|
150
|
+
vbs = construct_variable_block_sizes(dit_seq_shape, num_tiles, device, tile_size)
|
|
151
|
+
npi = get_non_pad_index(vbs, max_block_size)
|
|
152
|
+
|
|
153
|
+
return {
|
|
154
|
+
"tile_partition_indices": tile_indices,
|
|
155
|
+
"reverse_tile_partition_indices": reverse_tile_indices,
|
|
156
|
+
"variable_block_sizes": vbs,
|
|
157
|
+
"non_pad_index": npi,
|
|
158
|
+
"num_tiles": num_tiles,
|
|
159
|
+
"max_block_size": max_block_size,
|
|
160
|
+
}
|
|
@@ -18,10 +18,15 @@ import os
|
|
|
18
18
|
import sys
|
|
19
19
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
|
20
20
|
|
|
21
|
+
import pytest
|
|
21
22
|
import torch
|
|
22
23
|
import torch.nn.functional as F
|
|
23
24
|
from torch.nn.attention import SDPBackend, sdpa_kernel
|
|
24
25
|
|
|
26
|
+
# The FP4 extensions are only compiled under the sm_120a (Blackwell) arch
|
|
27
|
+
# gate; on other GPUs the api import below would die at collection time.
|
|
28
|
+
pytest.importorskip("fp4attn_cuda", reason="ATTN_QAT_INFER FP4 kernels require a sm_120a build")
|
|
29
|
+
|
|
25
30
|
from attn_qat_infer.api import sageattn_blackwell
|
|
26
31
|
|
|
27
32
|
DEVICE = torch.device("cuda")
|
|
@@ -0,0 +1,276 @@
|
|
|
1
|
+
"""Tests for vsa_utils — standalone VSA metadata utilities.
|
|
2
|
+
|
|
3
|
+
These tests use CPU index computation except for an optional multi-GPU
|
|
4
|
+
regression covering cache isolation between CUDA devices.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import math
|
|
8
|
+
import pytest
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from fastvideo_kernel.vsa_utils import (
|
|
12
|
+
VSA_TILE_SIZE,
|
|
13
|
+
_canonicalize_device,
|
|
14
|
+
get_tile_partition_indices,
|
|
15
|
+
get_reverse_tile_partition_indices,
|
|
16
|
+
construct_variable_block_sizes,
|
|
17
|
+
get_non_pad_index,
|
|
18
|
+
build_vsa_metadata,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class TestDeviceCanonicalization:
|
|
23
|
+
|
|
24
|
+
def test_unindexed_cuda_resolves_current_device(self, monkeypatch):
|
|
25
|
+
monkeypatch.setattr(torch.cuda, "current_device", lambda: 3)
|
|
26
|
+
assert _canonicalize_device("cuda") == torch.device("cuda:3")
|
|
27
|
+
assert _canonicalize_device(torch.device("cuda")) == torch.device("cuda:3")
|
|
28
|
+
|
|
29
|
+
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="requires two CUDA devices")
|
|
30
|
+
def test_metadata_cache_isolated_between_cuda_devices(self):
|
|
31
|
+
shape = (7, 8, 8)
|
|
32
|
+
tensor_keys = (
|
|
33
|
+
"tile_partition_indices",
|
|
34
|
+
"reverse_tile_partition_indices",
|
|
35
|
+
"variable_block_sizes",
|
|
36
|
+
"non_pad_index",
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
with torch.cuda.device(0):
|
|
40
|
+
metadata_0 = build_vsa_metadata(shape, device="cuda")
|
|
41
|
+
with torch.cuda.device(1):
|
|
42
|
+
metadata_1 = build_vsa_metadata(shape, device="cuda")
|
|
43
|
+
|
|
44
|
+
for key in tensor_keys:
|
|
45
|
+
assert metadata_0[key].device == torch.device("cuda:0")
|
|
46
|
+
assert metadata_1[key].device == torch.device("cuda:1")
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class TestGetTilePartitionIndices:
|
|
50
|
+
|
|
51
|
+
@pytest.mark.parametrize("dit_seq_shape,tile_size", [
|
|
52
|
+
((8, 16, 16), (4, 4, 4)),
|
|
53
|
+
((4, 8, 8), (4, 4, 4)),
|
|
54
|
+
((9, 10, 7), (4, 4, 4)),
|
|
55
|
+
])
|
|
56
|
+
def test_is_valid_permutation(self, dit_seq_shape, tile_size):
|
|
57
|
+
"""Output must be a permutation of 0..N-1."""
|
|
58
|
+
device = torch.device("cpu")
|
|
59
|
+
idx = get_tile_partition_indices(dit_seq_shape, tile_size, device)
|
|
60
|
+
n = math.prod(dit_seq_shape)
|
|
61
|
+
assert idx.shape == (n,)
|
|
62
|
+
assert idx.dtype == torch.long
|
|
63
|
+
assert set(idx.tolist()) == set(range(n))
|
|
64
|
+
|
|
65
|
+
def test_exact_values_small(self):
|
|
66
|
+
"""Manually verify a small case: (2,2,2) with tile (2,2,2) = 1 tile."""
|
|
67
|
+
device = torch.device("cpu")
|
|
68
|
+
idx = get_tile_partition_indices((2, 2, 2), (2, 2, 2), device)
|
|
69
|
+
assert idx.tolist() == list(range(8))
|
|
70
|
+
|
|
71
|
+
def test_non_divisible_shape(self):
|
|
72
|
+
"""When shape doesn't divide evenly by tile_size, all tokens still covered."""
|
|
73
|
+
device = torch.device("cpu")
|
|
74
|
+
shape = (5, 7, 3)
|
|
75
|
+
idx = get_tile_partition_indices(shape, (4, 4, 4), device)
|
|
76
|
+
assert idx.shape == (5 * 7 * 3,)
|
|
77
|
+
assert set(idx.tolist()) == set(range(5 * 7 * 3))
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class TestGetReverseTilePartitionIndices:
|
|
81
|
+
|
|
82
|
+
@pytest.mark.parametrize("dit_seq_shape", [
|
|
83
|
+
(8, 16, 16),
|
|
84
|
+
(9, 10, 7),
|
|
85
|
+
])
|
|
86
|
+
def test_inverse_of_forward(self, dit_seq_shape):
|
|
87
|
+
"""reverse[forward[i]] == i for all i."""
|
|
88
|
+
device = torch.device("cpu")
|
|
89
|
+
tile_size = (4, 4, 4)
|
|
90
|
+
fwd = get_tile_partition_indices(dit_seq_shape, tile_size, device)
|
|
91
|
+
rev = get_reverse_tile_partition_indices(dit_seq_shape, tile_size, device)
|
|
92
|
+
n = math.prod(dit_seq_shape)
|
|
93
|
+
identity = torch.arange(n, device=device)
|
|
94
|
+
assert torch.equal(rev[fwd], identity)
|
|
95
|
+
assert torch.equal(fwd[rev], identity)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class TestConstructVariableBlockSizes:
|
|
99
|
+
|
|
100
|
+
def test_sum_equals_total_tokens(self):
|
|
101
|
+
"""Sum of block sizes must equal T*H*W."""
|
|
102
|
+
device = torch.device("cpu")
|
|
103
|
+
shape = (8, 16, 16)
|
|
104
|
+
tile_size = (4, 4, 4)
|
|
105
|
+
num_tiles = tuple(math.ceil(s / t) for s, t in zip(shape, tile_size))
|
|
106
|
+
vbs = construct_variable_block_sizes(shape, num_tiles, device, tile_size)
|
|
107
|
+
assert vbs.sum().item() == math.prod(shape)
|
|
108
|
+
|
|
109
|
+
def test_max_block_size(self):
|
|
110
|
+
"""No block can exceed tile volume."""
|
|
111
|
+
device = torch.device("cpu")
|
|
112
|
+
shape = (9, 10, 7)
|
|
113
|
+
tile_size = (4, 4, 4)
|
|
114
|
+
num_tiles = tuple(math.ceil(s / t) for s, t in zip(shape, tile_size))
|
|
115
|
+
vbs = construct_variable_block_sizes(shape, num_tiles, device, tile_size)
|
|
116
|
+
assert vbs.max().item() <= math.prod(tile_size)
|
|
117
|
+
|
|
118
|
+
def test_num_blocks(self):
|
|
119
|
+
"""Number of blocks = product of num_tiles."""
|
|
120
|
+
device = torch.device("cpu")
|
|
121
|
+
shape = (8, 16, 16)
|
|
122
|
+
tile_size = (4, 4, 4)
|
|
123
|
+
num_tiles = tuple(math.ceil(s / t) for s, t in zip(shape, tile_size))
|
|
124
|
+
vbs = construct_variable_block_sizes(shape, num_tiles, device, tile_size)
|
|
125
|
+
assert vbs.shape[0] == math.prod(num_tiles)
|
|
126
|
+
|
|
127
|
+
def test_exact_divisible(self):
|
|
128
|
+
"""When perfectly divisible, all blocks have the same size."""
|
|
129
|
+
device = torch.device("cpu")
|
|
130
|
+
shape = (8, 8, 8)
|
|
131
|
+
tile_size = (4, 4, 4)
|
|
132
|
+
num_tiles = (2, 2, 2)
|
|
133
|
+
vbs = construct_variable_block_sizes(shape, num_tiles, device, tile_size)
|
|
134
|
+
assert (vbs == 64).all()
|
|
135
|
+
|
|
136
|
+
def test_non_divisible_last_tile_smaller(self):
|
|
137
|
+
"""When not divisible, at least one block is smaller than max."""
|
|
138
|
+
device = torch.device("cpu")
|
|
139
|
+
shape = (9, 8, 8)
|
|
140
|
+
tile_size = (4, 4, 4)
|
|
141
|
+
num_tiles = (3, 2, 2)
|
|
142
|
+
vbs = construct_variable_block_sizes(shape, num_tiles, device, tile_size)
|
|
143
|
+
assert vbs.min().item() < math.prod(tile_size)
|
|
144
|
+
|
|
145
|
+
def test_custom_tile_size(self):
|
|
146
|
+
"""tile_size parameter overrides default VSA_TILE_SIZE."""
|
|
147
|
+
device = torch.device("cpu")
|
|
148
|
+
shape = (6, 8, 16)
|
|
149
|
+
tile_size = (2, 4, 8)
|
|
150
|
+
num_tiles = (3, 2, 2)
|
|
151
|
+
vbs = construct_variable_block_sizes(shape, num_tiles, device, tile_size)
|
|
152
|
+
assert vbs.sum().item() == math.prod(shape)
|
|
153
|
+
assert (vbs == 64).all()
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class TestGetNonPadIndex:
|
|
157
|
+
|
|
158
|
+
def test_length_equals_sum_block_sizes(self):
|
|
159
|
+
"""Output length must equal sum of variable_block_sizes."""
|
|
160
|
+
vbs = torch.tensor([32, 48, 64], dtype=torch.long)
|
|
161
|
+
idx = get_non_pad_index(vbs, 64)
|
|
162
|
+
assert idx.shape[0] == 32 + 48 + 64
|
|
163
|
+
|
|
164
|
+
def test_indices_in_valid_range(self):
|
|
165
|
+
"""All indices must be in [0, num_blocks * max_block_size)."""
|
|
166
|
+
vbs = torch.tensor([32, 48], dtype=torch.long)
|
|
167
|
+
idx = get_non_pad_index(vbs, 64)
|
|
168
|
+
assert idx.min().item() >= 0
|
|
169
|
+
assert idx.max().item() < 2 * 64
|
|
170
|
+
|
|
171
|
+
def test_block_boundary_alignment(self):
|
|
172
|
+
"""First token of block i starts at i * max_block_size."""
|
|
173
|
+
vbs = torch.tensor([20, 40], dtype=torch.long)
|
|
174
|
+
idx = get_non_pad_index(vbs, 64)
|
|
175
|
+
assert idx[0].item() == 0
|
|
176
|
+
assert idx[20].item() == 64
|
|
177
|
+
|
|
178
|
+
def test_full_blocks(self):
|
|
179
|
+
"""When all blocks are full, output is just 0..N-1."""
|
|
180
|
+
vbs = torch.tensor([64, 64], dtype=torch.long)
|
|
181
|
+
idx = get_non_pad_index(vbs, 64)
|
|
182
|
+
assert torch.equal(idx, torch.arange(128))
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
class TestBuildVsaMetadata:
|
|
186
|
+
|
|
187
|
+
def test_all_keys_present(self):
|
|
188
|
+
"""build_vsa_metadata returns all expected keys."""
|
|
189
|
+
meta = build_vsa_metadata((8, 16, 16), device="cpu")
|
|
190
|
+
expected_keys = {
|
|
191
|
+
"tile_partition_indices", "reverse_tile_partition_indices",
|
|
192
|
+
"variable_block_sizes", "non_pad_index",
|
|
193
|
+
"num_tiles", "max_block_size",
|
|
194
|
+
}
|
|
195
|
+
assert set(meta.keys()) == expected_keys
|
|
196
|
+
|
|
197
|
+
def test_types(self):
|
|
198
|
+
"""Return types are correct."""
|
|
199
|
+
meta = build_vsa_metadata((8, 16, 16), device="cpu")
|
|
200
|
+
assert isinstance(meta["tile_partition_indices"], torch.Tensor)
|
|
201
|
+
assert isinstance(meta["reverse_tile_partition_indices"], torch.Tensor)
|
|
202
|
+
assert isinstance(meta["variable_block_sizes"], torch.Tensor)
|
|
203
|
+
assert isinstance(meta["non_pad_index"], torch.Tensor)
|
|
204
|
+
assert isinstance(meta["num_tiles"], tuple)
|
|
205
|
+
assert isinstance(meta["max_block_size"], int)
|
|
206
|
+
|
|
207
|
+
def test_num_tiles_correct(self):
|
|
208
|
+
meta = build_vsa_metadata((9, 10, 7), tile_size=(4, 4, 4), device="cpu")
|
|
209
|
+
assert meta["num_tiles"] == (3, 3, 2)
|
|
210
|
+
assert meta["max_block_size"] == 64
|
|
211
|
+
|
|
212
|
+
@pytest.mark.parametrize("tile_size,expected_num_tiles,expected_block_size", [
|
|
213
|
+
((2, 4, 8), (3, 2, 2), 64),
|
|
214
|
+
((4, 8, 8), (2, 1, 2), 256),
|
|
215
|
+
])
|
|
216
|
+
def test_supported_custom_tile_size(self, tile_size, expected_num_tiles, expected_block_size):
|
|
217
|
+
meta = build_vsa_metadata((6, 8, 16), tile_size=tile_size, device="cpu")
|
|
218
|
+
assert meta["num_tiles"] == expected_num_tiles
|
|
219
|
+
assert meta["max_block_size"] == expected_block_size
|
|
220
|
+
|
|
221
|
+
def test_unsupported_tile_volume(self):
|
|
222
|
+
with pytest.raises(ValueError, match="Unsupported VSA tile volume 27"):
|
|
223
|
+
build_vsa_metadata((6, 6, 6), tile_size=(3, 3, 3), device="cpu")
|
|
224
|
+
|
|
225
|
+
def test_consistency(self):
|
|
226
|
+
"""All components are internally consistent."""
|
|
227
|
+
shape = (8, 16, 16)
|
|
228
|
+
meta = build_vsa_metadata(shape, device="cpu")
|
|
229
|
+
n = math.prod(shape)
|
|
230
|
+
assert meta["tile_partition_indices"].shape == (n,)
|
|
231
|
+
assert meta["reverse_tile_partition_indices"].shape == (n,)
|
|
232
|
+
assert meta["variable_block_sizes"].sum().item() == n
|
|
233
|
+
assert meta["non_pad_index"].shape[0] == n
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
class TestConsistencyWithFramework:
|
|
237
|
+
"""Verify vsa_utils matches the framework-level functions exactly.
|
|
238
|
+
|
|
239
|
+
Only runs if fastvideo is importable (skip otherwise).
|
|
240
|
+
"""
|
|
241
|
+
|
|
242
|
+
@pytest.fixture(autouse=True)
|
|
243
|
+
def _skip_if_no_framework(self):
|
|
244
|
+
try:
|
|
245
|
+
from fastvideo.attention.backends.video_sparse_attn import (
|
|
246
|
+
get_tile_partition_indices as fw_get_tile,
|
|
247
|
+
)
|
|
248
|
+
except ImportError:
|
|
249
|
+
pytest.skip("fastvideo framework not installed")
|
|
250
|
+
|
|
251
|
+
@pytest.mark.parametrize("shape", [(8, 16, 16), (9, 10, 7)])
|
|
252
|
+
def test_tile_indices_match(self, shape):
|
|
253
|
+
from fastvideo.attention.backends.video_sparse_attn import (
|
|
254
|
+
get_tile_partition_indices as fw_get_tile,
|
|
255
|
+
)
|
|
256
|
+
device = torch.device("cpu")
|
|
257
|
+
tile_size = (4, 4, 4)
|
|
258
|
+
ours = get_tile_partition_indices(shape, tile_size, device)
|
|
259
|
+
theirs = fw_get_tile(shape, tile_size, device)
|
|
260
|
+
assert torch.equal(ours, theirs)
|
|
261
|
+
|
|
262
|
+
@pytest.mark.parametrize("shape", [(8, 16, 16), (9, 10, 7)])
|
|
263
|
+
def test_variable_block_sizes_match(self, shape):
|
|
264
|
+
from fastvideo.attention.backends.video_sparse_attn import (
|
|
265
|
+
construct_variable_block_sizes as fw_construct_vbs,
|
|
266
|
+
)
|
|
267
|
+
device = torch.device("cpu")
|
|
268
|
+
tile_size = (4, 4, 4)
|
|
269
|
+
num_tiles = tuple(math.ceil(s / t) for s, t in zip(shape, tile_size))
|
|
270
|
+
ours = construct_variable_block_sizes(shape, num_tiles, device, tile_size)
|
|
271
|
+
theirs = fw_construct_vbs(shape, num_tiles, device)
|
|
272
|
+
assert torch.equal(ours, theirs)
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
if __name__ == "__main__":
|
|
276
|
+
pytest.main([__file__, "-v", "-s"])
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
__version__ = "0.3.0"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/blockscaled_layout.h
RENAMED
|
File without changes
|
|
File without changes
|
{fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/epilogue_tma_ws.h
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/mainloop_tma_ws.h
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/quantization/bench/bench_quant_k.py
RENAMED
|
File without changes
|
{fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/quantization/bench/bench_quant_q.py
RENAMED
|
File without changes
|
{fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/quantization/bench/bench_quant_v.py
RENAMED
|
File without changes
|
{fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/quantization/bench/bench_utils.py
RENAMED
|
File without changes
|
|
File without changes
|
{fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/quantization/fp4_quantization_4d.cu
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/block_sparse_attn.py
RENAMED
|
File without changes
|
{fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/block_sparse_attn_256.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/triton_kernels/index.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/turbodiffusion_ops.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|