fastvideo-kernel 0.3.0__tar.gz → 0.3.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/CMakeLists.txt +77 -1
  2. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/PKG-INFO +5 -5
  3. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/README.md +4 -4
  4. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/build.sh +7 -2
  5. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/gemm/gemm.cu +7 -3
  6. fastvideo_kernel-0.3.2/dist/fastvideo_kernel-0.3.2-cp312-cp312-manylinux_2_34_aarch64.manylinux_2_35_aarch64.whl +0 -0
  7. fastvideo_kernel-0.3.0/dist/fastvideo_kernel-0.3.0-cp312-cp312-manylinux_2_34_x86_64.manylinux_2_35_x86_64.whl → fastvideo_kernel-0.3.2/dist/fastvideo_kernel-0.3.2-cp312-cp312-manylinux_2_34_x86_64.manylinux_2_35_x86_64.whl +0 -0
  8. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/pyproject.toml +1 -1
  9. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/__init__.py +15 -0
  10. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/block_sparse_attn_cute_fwd.py +4 -5
  11. fastvideo_kernel-0.3.2/python/fastvideo_kernel/version.py +1 -0
  12. fastvideo_kernel-0.3.2/python/fastvideo_kernel/vsa_utils.py +160 -0
  13. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/tests/test_attn_qat_infer.py +5 -0
  14. fastvideo_kernel-0.3.2/tests/test_vsa_utils.py +276 -0
  15. fastvideo_kernel-0.3.0/python/fastvideo_kernel/version.py +0 -1
  16. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/LICENSE +0 -0
  17. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/MANIFEST.in +0 -0
  18. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/__init__.py +0 -0
  19. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/api.py +0 -0
  20. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/__init__.py +0 -0
  21. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/api.cu +0 -0
  22. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/block_config.h +0 -0
  23. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/block_info.h +0 -0
  24. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/blockscaled_layout.h +0 -0
  25. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/cute_extension.h +0 -0
  26. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/epilogue_tma_ws.h +0 -0
  27. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/kernel_traits.h +0 -0
  28. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/kernel_ws.h +0 -0
  29. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/launch.h +0 -0
  30. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/mainloop_tma_ws.h +0 -0
  31. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/named_barrier.h +0 -0
  32. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/params.h +0 -0
  33. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/softmax_fused.h +0 -0
  34. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/static_switch.h +0 -0
  35. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/tile_scheduler.h +0 -0
  36. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/utils.h +0 -0
  37. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/quantization/__init__.py +0 -0
  38. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/quantization/bench/bench_quant_k.py +0 -0
  39. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/quantization/bench/bench_quant_q.py +0 -0
  40. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/quantization/bench/bench_quant_v.py +0 -0
  41. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/quantization/bench/bench_utils.py +0 -0
  42. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/quantization/cuda_utils.h +0 -0
  43. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/attn_qat_infer/quantization/fp4_quantization_4d.cu +0 -0
  44. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/benchmarks/bench_fused_compress_topk.py +0 -0
  45. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/benchmarks/bench_vsa.py +0 -0
  46. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/csrc/attention/block_sparse_h100.cu +0 -0
  47. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/csrc/attention/st_attn_h100.cu +0 -0
  48. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/csrc/common_extension.cpp +0 -0
  49. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/common/common.hpp +0 -0
  50. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/common/launch.hpp +0 -0
  51. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/common/load.hpp +0 -0
  52. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/common/store.hpp +0 -0
  53. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/gemm/kernel.hpp +0 -0
  54. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/gemm/launch.hpp +0 -0
  55. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/gemm/utils.hpp +0 -0
  56. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/norm/layernorm.cu +0 -0
  57. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/norm/layernorm.hpp +0 -0
  58. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/norm/rmsnorm.cu +0 -0
  59. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/norm/rmsnorm.hpp +0 -0
  60. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/quant/quant.cu +0 -0
  61. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/quant/quant.hpp +0 -0
  62. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/block_sparse_attn.py +0 -0
  63. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/block_sparse_attn_256.py +0 -0
  64. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/block_sparse_attn_varlen.py +0 -0
  65. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/ops.py +0 -0
  66. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/triton_kernels/attn_qat_train.py +0 -0
  67. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/triton_kernels/block_sparse_attn_triton.py +0 -0
  68. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/triton_kernels/fused_attention.py +0 -0
  69. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/triton_kernels/fused_compress_topk.py +0 -0
  70. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/triton_kernels/index.py +0 -0
  71. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/triton_kernels/nvfp4_utils.py +0 -0
  72. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/triton_kernels/quant_utils.py +0 -0
  73. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/triton_kernels/sla_triton.py +0 -0
  74. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/triton_kernels/st_attn_triton.py +0 -0
  75. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/turbodiffusion_ops.py +0 -0
  76. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/vmoba.py +0 -0
  77. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/tests/__init__.py +0 -0
  78. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/tests/support_flex_sta.py +0 -0
  79. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/tests/test_attn_qat_train.py +0 -0
  80. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/tests/test_fused_compress_topk.py +0 -0
  81. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/tests/test_sta.py +0 -0
  82. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/tests/test_turbodiffusion.py +0 -0
  83. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/tests/test_vmoba_correctness.py +0 -0
  84. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/tests/test_vsa.py +0 -0
  85. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/tests/test_vsa256_forward.py +0 -0
  86. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/tests/test_vsa256_forward_cross.py +0 -0
  87. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/tests/test_vsa256_forward_vbs.py +0 -0
  88. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/tests/test_vsa256_triton.py +0 -0
  89. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/tests/test_vsa_forward.py +0 -0
  90. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/tests/test_vsa_varlen.py +0 -0
  91. {fastvideo_kernel-0.3.0 → fastvideo_kernel-0.3.2}/tests/utils.py +0 -0
@@ -1,6 +1,13 @@
1
1
  cmake_minimum_required(VERSION 3.26 FATAL_ERROR)
2
2
  project(fastvideo-kernel LANGUAGES CXX)
3
3
 
4
+ # Capture any caller-provided -DCMAKE_CUDA_ARCHITECTURES *before* enable_language(CUDA)
5
+ # auto-populates it with CMake's built-in default (an old arch, e.g. sm_75 on CUDA 13).
6
+ # torch's cmake actually ignores CMAKE_CUDA_ARCHITECTURES (it drives arch selection via
7
+ # TORCH_CUDA_ARCH_LIST), so we only use this captured value to honor an explicit pin by
8
+ # translating it into TORCH_CUDA_ARCH_LIST below.
9
+ set(_FASTVIDEO_USER_CUDA_ARCH "${CMAKE_CUDA_ARCHITECTURES}")
10
+
4
11
  # Prefer environment variable (used by CI or uv pip install git+repo_addr) if CMake var is not explicitly set.
5
12
  if(NOT DEFINED GPU_BACKEND AND DEFINED ENV{GPU_BACKEND})
6
13
  set(GPU_BACKEND "$ENV{GPU_BACKEND}")
@@ -19,6 +26,71 @@ endif()
19
26
  # Find Python and Torch
20
27
  find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)
21
28
 
29
+ # ---------------------------------------------------------------------------
30
+ # Resolve the target CUDA architecture, BEFORE find_package(Torch) below.
31
+ #
32
+ # torch's cmake (Caffe2 public/cuda.cmake) takes over arch selection: it emits
33
+ # the real -gencode flags from TORCH_CUDA_ARCH_LIST and forces
34
+ # CMAKE_CUDA_ARCHITECTURES to OFF. So the *effective* arch is whatever
35
+ # TORCH_CUDA_ARCH_LIST is when find_package(Torch) runs. build.sh exports it;
36
+ # standards-based builds (pip / uv pip install, sdist) don't, and torch then
37
+ # auto-detects an arch that does not match the GPU -- the kernels build but fail
38
+ # at runtime ("no kernel image is available for execution on the device").
39
+ # Resolve it here when absent (mirrors build.sh): honor a pinned
40
+ # CMAKE_CUDA_ARCHITECTURES if given, else probe the visible GPU with torch.
41
+ # ---------------------------------------------------------------------------
42
+ if(NOT GPU_BACKEND STREQUAL "ROCM")
43
+ if(DEFINED ENV{TORCH_CUDA_ARCH_LIST})
44
+ message(STATUS "CUDA arch: TORCH_CUDA_ARCH_LIST=$ENV{TORCH_CUDA_ARCH_LIST} (from environment)")
45
+ elseif(TORCH_CUDA_ARCH_LIST)
46
+ set(ENV{TORCH_CUDA_ARCH_LIST} "${TORCH_CUDA_ARCH_LIST}")
47
+ message(STATUS "CUDA arch: TORCH_CUDA_ARCH_LIST=${TORCH_CUDA_ARCH_LIST} (from cmake)")
48
+ else()
49
+ set(_FV_ARCH_LIST "")
50
+ if(_FASTVIDEO_USER_CUDA_ARCH)
51
+ # Caller pinned -DCMAKE_CUDA_ARCHITECTURES (which torch ignores); translate it
52
+ # to the TORCH_CUDA_ARCH_LIST spelling: "121" -> "12.1", "90a" -> "9.0a".
53
+ foreach(_fv_arch IN LISTS _FASTVIDEO_USER_CUDA_ARCH)
54
+ string(REGEX MATCH "[af]$" _fv_suffix "${_fv_arch}")
55
+ string(REGEX REPLACE "[af]$" "" _fv_num "${_fv_arch}")
56
+ string(REGEX REPLACE "(.)$" ".\\1" _fv_num "${_fv_num}") # dot before the last digit
57
+ list(APPEND _FV_ARCH_LIST "${_fv_num}${_fv_suffix}")
58
+ endforeach()
59
+ message(STATUS "CUDA arch: TORCH_CUDA_ARCH_LIST=${_FV_ARCH_LIST} (from -DCMAKE_CUDA_ARCHITECTURES=${_FASTVIDEO_USER_CUDA_ARCH})")
60
+ else()
61
+ # Best-effort probe of the visible GPU (mirrors build.sh detect_with_torch).
62
+ execute_process(
63
+ COMMAND "${Python_EXECUTABLE}" -c "import torch; assert torch.cuda.is_available(); mj, mn = torch.cuda.get_device_capability(0); print(f'{mj}.{mn}a' if (mj, mn) in ((9, 0), (12, 0)) else f'{mj}.{mn}')"
64
+ OUTPUT_VARIABLE _FV_ARCH_LIST
65
+ OUTPUT_STRIP_TRAILING_WHITESPACE
66
+ RESULT_VARIABLE _fv_detect_rc
67
+ ERROR_QUIET
68
+ )
69
+ if(_fv_detect_rc EQUAL 0 AND _FV_ARCH_LIST)
70
+ message(STATUS "CUDA arch: TORCH_CUDA_ARCH_LIST=${_FV_ARCH_LIST} (detected via torch, live GPU)")
71
+ else()
72
+ set(_FV_ARCH_LIST "")
73
+ endif()
74
+ endif()
75
+
76
+ if(_FV_ARCH_LIST)
77
+ set(TORCH_CUDA_ARCH_LIST "${_FV_ARCH_LIST}")
78
+ set(ENV{TORCH_CUDA_ARCH_LIST} "${_FV_ARCH_LIST}")
79
+ else()
80
+ message(FATAL_ERROR
81
+ "fastvideo-kernel: could not determine the target CUDA architecture.\n"
82
+ "Refusing to let torch auto-detect an arch that may not run on this GPU. "
83
+ "Fix with one of:\n"
84
+ " - set TORCH_CUDA_ARCH_LIST (e.g. 12.1, or 9.0a for Hopper), or\n"
85
+ " - pass -DCMAKE_CUDA_ARCHITECTURES=<arch> (e.g. 121), or\n"
86
+ " - build where the target GPU is visible to torch.\n"
87
+ "Note: 'pip/uv pip install' builds under build isolation, which hides the "
88
+ "GPU; set TORCH_CUDA_ARCH_LIST or add --no-build-isolation. "
89
+ "fastvideo-kernel/build.sh sets all of this for you.")
90
+ endif()
91
+ endif()
92
+ endif()
93
+
22
94
  # Robustly find Torch include paths using Python
23
95
  execute_process(
24
96
  COMMAND "${Python_EXECUTABLE}" -c "import torch; from torch.utils.cpp_extension import include_paths; print(';'.join(include_paths()))"
@@ -191,6 +263,11 @@ set(CUDA_FLAGS
191
263
  "--expt-relaxed-constexpr"
192
264
  "-Xcompiler=-fno-strict-aliasing"
193
265
  "-Xcompiler=-fPIC"
266
+ # ARM/aarch64 defaults `char` to unsigned, but ThunderKittens headers assume the
267
+ # x86 signed-char behavior (else base_types.cuh hits "narrowing conversion from
268
+ # char to signed char"). Force signed char so TK compiles on Grace Hopper; this
269
+ # is a no-op on x86_64, where char is already signed.
270
+ "-Xcompiler=-fsigned-char"
194
271
  "-DTORCH_COMPILE"
195
272
  "-Xnvlink=--verbose"
196
273
  "-Xptxas=--verbose"
@@ -349,4 +426,3 @@ if(ENABLE_ATTN_QAT_INFER)
349
426
  install(TARGETS fp4attn_cuda LIBRARY DESTINATION .)
350
427
  install(TARGETS fp4quant_cuda LIBRARY DESTINATION .)
351
428
  endif()
352
-
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: fastvideo-kernel
3
- Version: 0.3.0
3
+ Version: 0.3.2
4
4
  Summary: Unified CUDA kernels for FastVideo
5
5
  Author-Email: Hao AI Lab <contact@haoailab.com>
6
6
  License: Apache License
@@ -239,13 +239,13 @@ fully usable without it).
239
239
  The symbols the fastpath needs (`flash_attn.cute.block_sparsity.BlockSparseTensorsTorch`,
240
240
  `flash_attn.cute.interface._flash_attn_fwd`) are provided upstream by
241
241
  [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention). Pin to
242
- commit `c19cd20e`: the wrapper targets that revision's `_flash_attn_fwd` signature
243
- (`m_block_size` / `n_block_size`); later upstream revisions reshaped it into a
244
- `tile_mn` tuple and are not drop-in compatible.
242
+ commit `940cd9680f3315f2f06b43ab5bea2c2cf2d96806`, the revision FastVideo pins as
243
+ the `flash-attn-4` source in the repo-root `pyproject.toml`; other revisions may
244
+ have an incompatible `_flash_attn_fwd` signature.
245
245
 
246
246
  ```bash
247
247
  pip install "nvidia-cutlass-dsl>=4.5.0" torchvision
248
- pip install "git+https://github.com/Dao-AILab/flash-attention.git@c19cd20e#subdirectory=flash_attn/cute"
248
+ pip install "git+https://github.com/Dao-AILab/flash-attention.git@940cd9680f3315f2f06b43ab5bea2c2cf2d96806#subdirectory=flash_attn/cute"
249
249
  ```
250
250
 
251
251
  The CuTe kernel JIT-compiles on first use. Verified on Blackwell (sm_100) against
@@ -39,13 +39,13 @@ fully usable without it).
39
39
  The symbols the fastpath needs (`flash_attn.cute.block_sparsity.BlockSparseTensorsTorch`,
40
40
  `flash_attn.cute.interface._flash_attn_fwd`) are provided upstream by
41
41
  [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention). Pin to
42
- commit `c19cd20e`: the wrapper targets that revision's `_flash_attn_fwd` signature
43
- (`m_block_size` / `n_block_size`); later upstream revisions reshaped it into a
44
- `tile_mn` tuple and are not drop-in compatible.
42
+ commit `940cd9680f3315f2f06b43ab5bea2c2cf2d96806`, the revision FastVideo pins as
43
+ the `flash-attn-4` source in the repo-root `pyproject.toml`; other revisions may
44
+ have an incompatible `_flash_attn_fwd` signature.
45
45
 
46
46
  ```bash
47
47
  pip install "nvidia-cutlass-dsl>=4.5.0" torchvision
48
- pip install "git+https://github.com/Dao-AILab/flash-attention.git@c19cd20e#subdirectory=flash_attn/cute"
48
+ pip install "git+https://github.com/Dao-AILab/flash-attention.git@940cd9680f3315f2f06b43ab5bea2c2cf2d96806#subdirectory=flash_attn/cute"
49
49
  ```
50
50
 
51
51
  The CuTe kernel JIT-compiles on first use. Verified on Blackwell (sm_100) against
@@ -39,8 +39,13 @@ if [[ -n "${CONDA_PREFIX:-}" ]]; then
39
39
  unset _need_clean _host_arch
40
40
  fi
41
41
 
42
- # Ensure submodules are initialized if needed (tk)
43
- git submodule update --init --recursive
42
+ # Ensure only the kernel's required headers are initialized. A repository-wide
43
+ # update also clones the unrelated VBench evaluation submodule. Skip outside a
44
+ # git checkout (e.g. Docker contexts that exclude .git), where the submodule
45
+ # contents must already be present.
46
+ if git rev-parse --git-dir >/dev/null 2>&1; then
47
+ git submodule update --init --recursive include/cutlass include/tk
48
+ fi
44
49
 
45
50
  # Install build dependencies
46
51
  uv pip install scikit-build-core cmake ninja
@@ -25,12 +25,16 @@
25
25
  #include "gemm/launch.hpp"
26
26
 
27
27
  void int8_gemm(
28
- at::Tensor const& A, at::Tensor const& A_S,
29
- at::Tensor const& B, at::Tensor const& B_S,
28
+ at::Tensor const& A, at::Tensor const& A_S,
29
+ at::Tensor const& B, at::Tensor const& B_S,
30
30
  torch::Tensor& C
31
31
  ) {
32
32
 
33
-
33
+ // The kernel dereferences raw pointers; a CPU tensor here (e.g. an Int8Linear
34
+ // never moved to CUDA) would otherwise fail as an illegal memory access.
35
+ TORCH_CHECK(A.is_cuda() && A_S.is_cuda() && B.is_cuda() && B_S.is_cuda() && C.is_cuda(),
36
+ "int8_gemm: all tensors must be on CUDA (move Int8Linear to CUDA before forward)");
37
+
34
38
  static constexpr int swizzle_dir = 1;
35
39
  static constexpr int swizzle_size_log = 5;
36
40
 
@@ -9,7 +9,7 @@ build-backend = "scikit_build_core.build"
9
9
 
10
10
  [project]
11
11
  name = "fastvideo-kernel"
12
- version = "0.3.0"
12
+ version = "0.3.2"
13
13
  description = "Unified CUDA kernels for FastVideo"
14
14
  readme = "README.md"
15
15
  requires-python = ">=3.10"
@@ -29,6 +29,15 @@ from fastvideo_kernel.block_sparse_attn_varlen import (
29
29
  block_sparse_attn_varlen,
30
30
  )
31
31
 
32
+ from fastvideo_kernel.vsa_utils import (
33
+ VSA_TILE_SIZE,
34
+ get_tile_partition_indices,
35
+ get_reverse_tile_partition_indices,
36
+ construct_variable_block_sizes,
37
+ get_non_pad_index,
38
+ build_vsa_metadata,
39
+ )
40
+
32
41
  __all__ = [
33
42
  "sliding_tile_attention",
34
43
  "video_sparse_attn",
@@ -44,5 +53,11 @@ __all__ = [
44
53
  "FastLayerNorm",
45
54
  "int8_linear",
46
55
  "int8_quant",
56
+ "VSA_TILE_SIZE",
57
+ "get_tile_partition_indices",
58
+ "get_reverse_tile_partition_indices",
59
+ "construct_variable_block_sizes",
60
+ "get_non_pad_index",
61
+ "build_vsa_metadata",
47
62
  "__version__",
48
63
  ]
@@ -46,8 +46,7 @@ def _load_fa4_cute():
46
46
  return BlockSparseTensorsTorch, _flash_attn_fwd
47
47
 
48
48
 
49
- # FA4 BSA fwd uses (m_block_size, n_block_size); m_block_size=128 is the
50
- # Q-side tile, kv_block_size comes from the caller's VSA logical KV block.
49
+ # Q-side tile size; kv_block_size comes from the caller's VSA logical KV block.
51
50
  _M_BLOCK_SIZE_DEFAULT = 128
52
51
 
53
52
 
@@ -182,18 +181,18 @@ def _cute_forward(
182
181
  block_size=(q_sparse_block_size, kv_block_size),
183
182
  )
184
183
 
184
+ # _flash_attn_fwd returns (out, lse, p, row_max); keep the first two.
185
185
  out, lse = _flash_attn_fwd(
186
186
  q_bshd,
187
187
  k_bshd,
188
188
  v_bshd,
189
- m_block_size=_M_BLOCK_SIZE_DEFAULT,
190
- n_block_size=kv_block_size,
189
+ tile_mn=(_M_BLOCK_SIZE_DEFAULT, kv_block_size),
191
190
  mask_mod=_build_vbs_mask_mod(kv_block_size),
192
191
  block_sparse_tensors=sparse_tensors,
193
192
  aux_tensors=[variable_block_sizes],
194
193
  causal=False,
195
194
  return_lse=True,
196
- )
195
+ )[:2]
197
196
  return out, lse
198
197
 
199
198
 
@@ -0,0 +1 @@
1
+ __version__ = "0.3.2"
@@ -0,0 +1,160 @@
1
+ """VSA metadata utilities — standalone, no fastvideo framework dependency.
2
+
3
+ Provides the tile-partition index helpers and variable-block-size
4
+ computations that are needed to call `video_sparse_attn` or
5
+ `block_sparse_attn_from_indices` without depending on the full
6
+ fastvideo package.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import functools
12
+ import math
13
+
14
+ import torch
15
+
16
+ VSA_TILE_SIZE = (4, 4, 4)
17
+ _SUPPORTED_VSA_BLOCK_VOLUMES = (64, 256)
18
+
19
+
20
+ def _canonicalize_device(device: torch.device | str) -> torch.device:
21
+ """Resolve an indexless CUDA device before it is used as a cache key."""
22
+ device = torch.device(device)
23
+ if device.type == "cuda" and device.index is None:
24
+ return torch.device("cuda", torch.cuda.current_device())
25
+ return device
26
+
27
+
28
+ @functools.lru_cache(maxsize=10)
29
+ def get_tile_partition_indices(
30
+ dit_seq_shape: tuple[int, int, int],
31
+ tile_size: tuple[int, int, int],
32
+ device: torch.device,
33
+ ) -> torch.LongTensor:
34
+ """Map raster-order token indices to tile-contiguous order.
35
+
36
+ Groups spatially adjacent tokens into (ts_t x ts_h x ts_w) tiles
37
+ so that each tile's tokens are contiguous in the output.
38
+ """
39
+ T, H, W = dit_seq_shape
40
+ ts, hs, ws = tile_size
41
+ indices = torch.arange(T * H * W, device=device, dtype=torch.long).reshape(T, H, W)
42
+ ls = []
43
+ for t in range(math.ceil(T / ts)):
44
+ for h in range(math.ceil(H / hs)):
45
+ for w in range(math.ceil(W / ws)):
46
+ ls.append(indices[
47
+ t * ts:min(t * ts + ts, T),
48
+ h * hs:min(h * hs + hs, H),
49
+ w * ws:min(w * ws + ws, W),
50
+ ].flatten())
51
+ return torch.cat(ls, dim=0)
52
+
53
+
54
+ @functools.lru_cache(maxsize=10)
55
+ def get_reverse_tile_partition_indices(
56
+ dit_seq_shape: tuple[int, int, int],
57
+ tile_size: tuple[int, int, int],
58
+ device: torch.device,
59
+ ) -> torch.LongTensor:
60
+ """Inverse of get_tile_partition_indices: tile order back to raster."""
61
+ return torch.argsort(get_tile_partition_indices(dit_seq_shape, tile_size, device))
62
+
63
+
64
+ @functools.lru_cache(maxsize=10)
65
+ def construct_variable_block_sizes(
66
+ dit_seq_shape: tuple[int, int, int],
67
+ num_tiles: tuple[int, int, int],
68
+ device: torch.device,
69
+ tile_size: tuple[int, int, int] = VSA_TILE_SIZE,
70
+ ) -> torch.LongTensor:
71
+ """Compute the number of valid tokens in each tile.
72
+
73
+ Tiles at the boundary of each dimension may contain fewer tokens
74
+ when the video shape is not evenly divisible by tile_size.
75
+ """
76
+ t, h, w = dit_seq_shape
77
+ ts_t, ts_h, ts_w = tile_size
78
+ n_t, n_h, n_w = num_tiles
79
+
80
+ def _sizes(dim_len: int, tile: int, n: int) -> torch.LongTensor:
81
+ sizes = torch.full((n,), tile, dtype=torch.int, device=device)
82
+ remainder = dim_len - (n - 1) * tile
83
+ sizes[-1] = remainder if remainder > 0 else tile
84
+ return sizes
85
+
86
+ t_sizes = _sizes(t, ts_t, n_t)
87
+ h_sizes = _sizes(h, ts_h, n_h)
88
+ w_sizes = _sizes(w, ts_w, n_w)
89
+
90
+ return (
91
+ t_sizes[:, None, None]
92
+ * h_sizes[None, :, None]
93
+ * w_sizes[None, None, :]
94
+ ).reshape(-1)
95
+
96
+
97
+ def get_non_pad_index(
98
+ variable_block_sizes: torch.LongTensor,
99
+ max_block_size: int,
100
+ ) -> torch.LongTensor:
101
+ """Find positions of real tokens within a block-padded layout.
102
+
103
+ Each block occupies max_block_size slots. This returns the flat
104
+ indices of the valid (non-padding) positions.
105
+ """
106
+ n_win = variable_block_sizes.shape[0]
107
+ device = variable_block_sizes.device
108
+ starts_pad = torch.arange(n_win, device=device) * max_block_size
109
+ index_pad = starts_pad[:, None] + torch.arange(max_block_size, device=device)[None, :]
110
+ index_mask = torch.arange(max_block_size, device=device)[None, :] < variable_block_sizes[:, None]
111
+ return index_pad[index_mask]
112
+
113
+
114
+ def build_vsa_metadata(
115
+ dit_seq_shape: tuple[int, int, int],
116
+ tile_size: tuple[int, int, int] = VSA_TILE_SIZE,
117
+ device: torch.device | str = "cuda",
118
+ ) -> dict:
119
+ """Build all VSA metadata from a video latent shape in one call.
120
+
121
+ Args:
122
+ dit_seq_shape: (T, H, W) — temporal frames, spatial height, width.
123
+ tile_size: (ts_t, ts_h, ts_w) — tokens per tile in each dimension.
124
+ The resulting tile volume must be supported by the VSA kernels.
125
+ device: Target device for index tensors.
126
+
127
+ Returns:
128
+ Dict with keys: tile_partition_indices, reverse_tile_partition_indices,
129
+ variable_block_sizes, non_pad_index, num_tiles, max_block_size.
130
+ """
131
+ device = _canonicalize_device(device)
132
+
133
+ T, H, W = dit_seq_shape
134
+ ts_t, ts_h, ts_w = tile_size
135
+ max_block_size = math.prod(tile_size)
136
+ if max_block_size not in _SUPPORTED_VSA_BLOCK_VOLUMES:
137
+ raise ValueError(
138
+ f"Unsupported VSA tile volume {max_block_size} for tile_size={tile_size}; "
139
+ f"supported volumes are {_SUPPORTED_VSA_BLOCK_VOLUMES}."
140
+ )
141
+
142
+ num_tiles = (
143
+ math.ceil(T / ts_t),
144
+ math.ceil(H / ts_h),
145
+ math.ceil(W / ts_w),
146
+ )
147
+
148
+ tile_indices = get_tile_partition_indices(dit_seq_shape, tile_size, device)
149
+ reverse_tile_indices = get_reverse_tile_partition_indices(dit_seq_shape, tile_size, device)
150
+ vbs = construct_variable_block_sizes(dit_seq_shape, num_tiles, device, tile_size)
151
+ npi = get_non_pad_index(vbs, max_block_size)
152
+
153
+ return {
154
+ "tile_partition_indices": tile_indices,
155
+ "reverse_tile_partition_indices": reverse_tile_indices,
156
+ "variable_block_sizes": vbs,
157
+ "non_pad_index": npi,
158
+ "num_tiles": num_tiles,
159
+ "max_block_size": max_block_size,
160
+ }
@@ -18,10 +18,15 @@ import os
18
18
  import sys
19
19
  sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
20
20
 
21
+ import pytest
21
22
  import torch
22
23
  import torch.nn.functional as F
23
24
  from torch.nn.attention import SDPBackend, sdpa_kernel
24
25
 
26
+ # The FP4 extensions are only compiled under the sm_120a (Blackwell) arch
27
+ # gate; on other GPUs the api import below would die at collection time.
28
+ pytest.importorskip("fp4attn_cuda", reason="ATTN_QAT_INFER FP4 kernels require a sm_120a build")
29
+
25
30
  from attn_qat_infer.api import sageattn_blackwell
26
31
 
27
32
  DEVICE = torch.device("cuda")
@@ -0,0 +1,276 @@
1
+ """Tests for vsa_utils — standalone VSA metadata utilities.
2
+
3
+ These tests use CPU index computation except for an optional multi-GPU
4
+ regression covering cache isolation between CUDA devices.
5
+ """
6
+
7
+ import math
8
+ import pytest
9
+ import torch
10
+
11
+ from fastvideo_kernel.vsa_utils import (
12
+ VSA_TILE_SIZE,
13
+ _canonicalize_device,
14
+ get_tile_partition_indices,
15
+ get_reverse_tile_partition_indices,
16
+ construct_variable_block_sizes,
17
+ get_non_pad_index,
18
+ build_vsa_metadata,
19
+ )
20
+
21
+
22
+ class TestDeviceCanonicalization:
23
+
24
+ def test_unindexed_cuda_resolves_current_device(self, monkeypatch):
25
+ monkeypatch.setattr(torch.cuda, "current_device", lambda: 3)
26
+ assert _canonicalize_device("cuda") == torch.device("cuda:3")
27
+ assert _canonicalize_device(torch.device("cuda")) == torch.device("cuda:3")
28
+
29
+ @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="requires two CUDA devices")
30
+ def test_metadata_cache_isolated_between_cuda_devices(self):
31
+ shape = (7, 8, 8)
32
+ tensor_keys = (
33
+ "tile_partition_indices",
34
+ "reverse_tile_partition_indices",
35
+ "variable_block_sizes",
36
+ "non_pad_index",
37
+ )
38
+
39
+ with torch.cuda.device(0):
40
+ metadata_0 = build_vsa_metadata(shape, device="cuda")
41
+ with torch.cuda.device(1):
42
+ metadata_1 = build_vsa_metadata(shape, device="cuda")
43
+
44
+ for key in tensor_keys:
45
+ assert metadata_0[key].device == torch.device("cuda:0")
46
+ assert metadata_1[key].device == torch.device("cuda:1")
47
+
48
+
49
+ class TestGetTilePartitionIndices:
50
+
51
+ @pytest.mark.parametrize("dit_seq_shape,tile_size", [
52
+ ((8, 16, 16), (4, 4, 4)),
53
+ ((4, 8, 8), (4, 4, 4)),
54
+ ((9, 10, 7), (4, 4, 4)),
55
+ ])
56
+ def test_is_valid_permutation(self, dit_seq_shape, tile_size):
57
+ """Output must be a permutation of 0..N-1."""
58
+ device = torch.device("cpu")
59
+ idx = get_tile_partition_indices(dit_seq_shape, tile_size, device)
60
+ n = math.prod(dit_seq_shape)
61
+ assert idx.shape == (n,)
62
+ assert idx.dtype == torch.long
63
+ assert set(idx.tolist()) == set(range(n))
64
+
65
+ def test_exact_values_small(self):
66
+ """Manually verify a small case: (2,2,2) with tile (2,2,2) = 1 tile."""
67
+ device = torch.device("cpu")
68
+ idx = get_tile_partition_indices((2, 2, 2), (2, 2, 2), device)
69
+ assert idx.tolist() == list(range(8))
70
+
71
+ def test_non_divisible_shape(self):
72
+ """When shape doesn't divide evenly by tile_size, all tokens still covered."""
73
+ device = torch.device("cpu")
74
+ shape = (5, 7, 3)
75
+ idx = get_tile_partition_indices(shape, (4, 4, 4), device)
76
+ assert idx.shape == (5 * 7 * 3,)
77
+ assert set(idx.tolist()) == set(range(5 * 7 * 3))
78
+
79
+
80
+ class TestGetReverseTilePartitionIndices:
81
+
82
+ @pytest.mark.parametrize("dit_seq_shape", [
83
+ (8, 16, 16),
84
+ (9, 10, 7),
85
+ ])
86
+ def test_inverse_of_forward(self, dit_seq_shape):
87
+ """reverse[forward[i]] == i for all i."""
88
+ device = torch.device("cpu")
89
+ tile_size = (4, 4, 4)
90
+ fwd = get_tile_partition_indices(dit_seq_shape, tile_size, device)
91
+ rev = get_reverse_tile_partition_indices(dit_seq_shape, tile_size, device)
92
+ n = math.prod(dit_seq_shape)
93
+ identity = torch.arange(n, device=device)
94
+ assert torch.equal(rev[fwd], identity)
95
+ assert torch.equal(fwd[rev], identity)
96
+
97
+
98
+ class TestConstructVariableBlockSizes:
99
+
100
+ def test_sum_equals_total_tokens(self):
101
+ """Sum of block sizes must equal T*H*W."""
102
+ device = torch.device("cpu")
103
+ shape = (8, 16, 16)
104
+ tile_size = (4, 4, 4)
105
+ num_tiles = tuple(math.ceil(s / t) for s, t in zip(shape, tile_size))
106
+ vbs = construct_variable_block_sizes(shape, num_tiles, device, tile_size)
107
+ assert vbs.sum().item() == math.prod(shape)
108
+
109
+ def test_max_block_size(self):
110
+ """No block can exceed tile volume."""
111
+ device = torch.device("cpu")
112
+ shape = (9, 10, 7)
113
+ tile_size = (4, 4, 4)
114
+ num_tiles = tuple(math.ceil(s / t) for s, t in zip(shape, tile_size))
115
+ vbs = construct_variable_block_sizes(shape, num_tiles, device, tile_size)
116
+ assert vbs.max().item() <= math.prod(tile_size)
117
+
118
+ def test_num_blocks(self):
119
+ """Number of blocks = product of num_tiles."""
120
+ device = torch.device("cpu")
121
+ shape = (8, 16, 16)
122
+ tile_size = (4, 4, 4)
123
+ num_tiles = tuple(math.ceil(s / t) for s, t in zip(shape, tile_size))
124
+ vbs = construct_variable_block_sizes(shape, num_tiles, device, tile_size)
125
+ assert vbs.shape[0] == math.prod(num_tiles)
126
+
127
+ def test_exact_divisible(self):
128
+ """When perfectly divisible, all blocks have the same size."""
129
+ device = torch.device("cpu")
130
+ shape = (8, 8, 8)
131
+ tile_size = (4, 4, 4)
132
+ num_tiles = (2, 2, 2)
133
+ vbs = construct_variable_block_sizes(shape, num_tiles, device, tile_size)
134
+ assert (vbs == 64).all()
135
+
136
+ def test_non_divisible_last_tile_smaller(self):
137
+ """When not divisible, at least one block is smaller than max."""
138
+ device = torch.device("cpu")
139
+ shape = (9, 8, 8)
140
+ tile_size = (4, 4, 4)
141
+ num_tiles = (3, 2, 2)
142
+ vbs = construct_variable_block_sizes(shape, num_tiles, device, tile_size)
143
+ assert vbs.min().item() < math.prod(tile_size)
144
+
145
+ def test_custom_tile_size(self):
146
+ """tile_size parameter overrides default VSA_TILE_SIZE."""
147
+ device = torch.device("cpu")
148
+ shape = (6, 8, 16)
149
+ tile_size = (2, 4, 8)
150
+ num_tiles = (3, 2, 2)
151
+ vbs = construct_variable_block_sizes(shape, num_tiles, device, tile_size)
152
+ assert vbs.sum().item() == math.prod(shape)
153
+ assert (vbs == 64).all()
154
+
155
+
156
+ class TestGetNonPadIndex:
157
+
158
+ def test_length_equals_sum_block_sizes(self):
159
+ """Output length must equal sum of variable_block_sizes."""
160
+ vbs = torch.tensor([32, 48, 64], dtype=torch.long)
161
+ idx = get_non_pad_index(vbs, 64)
162
+ assert idx.shape[0] == 32 + 48 + 64
163
+
164
+ def test_indices_in_valid_range(self):
165
+ """All indices must be in [0, num_blocks * max_block_size)."""
166
+ vbs = torch.tensor([32, 48], dtype=torch.long)
167
+ idx = get_non_pad_index(vbs, 64)
168
+ assert idx.min().item() >= 0
169
+ assert idx.max().item() < 2 * 64
170
+
171
+ def test_block_boundary_alignment(self):
172
+ """First token of block i starts at i * max_block_size."""
173
+ vbs = torch.tensor([20, 40], dtype=torch.long)
174
+ idx = get_non_pad_index(vbs, 64)
175
+ assert idx[0].item() == 0
176
+ assert idx[20].item() == 64
177
+
178
+ def test_full_blocks(self):
179
+ """When all blocks are full, output is just 0..N-1."""
180
+ vbs = torch.tensor([64, 64], dtype=torch.long)
181
+ idx = get_non_pad_index(vbs, 64)
182
+ assert torch.equal(idx, torch.arange(128))
183
+
184
+
185
+ class TestBuildVsaMetadata:
186
+
187
+ def test_all_keys_present(self):
188
+ """build_vsa_metadata returns all expected keys."""
189
+ meta = build_vsa_metadata((8, 16, 16), device="cpu")
190
+ expected_keys = {
191
+ "tile_partition_indices", "reverse_tile_partition_indices",
192
+ "variable_block_sizes", "non_pad_index",
193
+ "num_tiles", "max_block_size",
194
+ }
195
+ assert set(meta.keys()) == expected_keys
196
+
197
+ def test_types(self):
198
+ """Return types are correct."""
199
+ meta = build_vsa_metadata((8, 16, 16), device="cpu")
200
+ assert isinstance(meta["tile_partition_indices"], torch.Tensor)
201
+ assert isinstance(meta["reverse_tile_partition_indices"], torch.Tensor)
202
+ assert isinstance(meta["variable_block_sizes"], torch.Tensor)
203
+ assert isinstance(meta["non_pad_index"], torch.Tensor)
204
+ assert isinstance(meta["num_tiles"], tuple)
205
+ assert isinstance(meta["max_block_size"], int)
206
+
207
+ def test_num_tiles_correct(self):
208
+ meta = build_vsa_metadata((9, 10, 7), tile_size=(4, 4, 4), device="cpu")
209
+ assert meta["num_tiles"] == (3, 3, 2)
210
+ assert meta["max_block_size"] == 64
211
+
212
+ @pytest.mark.parametrize("tile_size,expected_num_tiles,expected_block_size", [
213
+ ((2, 4, 8), (3, 2, 2), 64),
214
+ ((4, 8, 8), (2, 1, 2), 256),
215
+ ])
216
+ def test_supported_custom_tile_size(self, tile_size, expected_num_tiles, expected_block_size):
217
+ meta = build_vsa_metadata((6, 8, 16), tile_size=tile_size, device="cpu")
218
+ assert meta["num_tiles"] == expected_num_tiles
219
+ assert meta["max_block_size"] == expected_block_size
220
+
221
+ def test_unsupported_tile_volume(self):
222
+ with pytest.raises(ValueError, match="Unsupported VSA tile volume 27"):
223
+ build_vsa_metadata((6, 6, 6), tile_size=(3, 3, 3), device="cpu")
224
+
225
+ def test_consistency(self):
226
+ """All components are internally consistent."""
227
+ shape = (8, 16, 16)
228
+ meta = build_vsa_metadata(shape, device="cpu")
229
+ n = math.prod(shape)
230
+ assert meta["tile_partition_indices"].shape == (n,)
231
+ assert meta["reverse_tile_partition_indices"].shape == (n,)
232
+ assert meta["variable_block_sizes"].sum().item() == n
233
+ assert meta["non_pad_index"].shape[0] == n
234
+
235
+
236
+ class TestConsistencyWithFramework:
237
+ """Verify vsa_utils matches the framework-level functions exactly.
238
+
239
+ Only runs if fastvideo is importable (skip otherwise).
240
+ """
241
+
242
+ @pytest.fixture(autouse=True)
243
+ def _skip_if_no_framework(self):
244
+ try:
245
+ from fastvideo.attention.backends.video_sparse_attn import (
246
+ get_tile_partition_indices as fw_get_tile,
247
+ )
248
+ except ImportError:
249
+ pytest.skip("fastvideo framework not installed")
250
+
251
+ @pytest.mark.parametrize("shape", [(8, 16, 16), (9, 10, 7)])
252
+ def test_tile_indices_match(self, shape):
253
+ from fastvideo.attention.backends.video_sparse_attn import (
254
+ get_tile_partition_indices as fw_get_tile,
255
+ )
256
+ device = torch.device("cpu")
257
+ tile_size = (4, 4, 4)
258
+ ours = get_tile_partition_indices(shape, tile_size, device)
259
+ theirs = fw_get_tile(shape, tile_size, device)
260
+ assert torch.equal(ours, theirs)
261
+
262
+ @pytest.mark.parametrize("shape", [(8, 16, 16), (9, 10, 7)])
263
+ def test_variable_block_sizes_match(self, shape):
264
+ from fastvideo.attention.backends.video_sparse_attn import (
265
+ construct_variable_block_sizes as fw_construct_vbs,
266
+ )
267
+ device = torch.device("cpu")
268
+ tile_size = (4, 4, 4)
269
+ num_tiles = tuple(math.ceil(s / t) for s, t in zip(shape, tile_size))
270
+ ours = construct_variable_block_sizes(shape, num_tiles, device, tile_size)
271
+ theirs = fw_construct_vbs(shape, num_tiles, device)
272
+ assert torch.equal(ours, theirs)
273
+
274
+
275
+ if __name__ == "__main__":
276
+ pytest.main([__file__, "-v", "-s"])
@@ -1 +0,0 @@
1
- __version__ = "0.3.0"