fastvideo-kernel 0.3.1__tar.gz → 0.3.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/CMakeLists.txt +72 -1
  2. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/PKG-INFO +5 -5
  3. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/README.md +4 -4
  4. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/build.sh +7 -2
  5. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/gemm/gemm.cu +7 -3
  6. fastvideo_kernel-0.3.1/dist/fastvideo_kernel-0.3.1-cp312-cp312-manylinux_2_34_aarch64.manylinux_2_35_aarch64.whl → fastvideo_kernel-0.3.2/dist/fastvideo_kernel-0.3.2-cp312-cp312-manylinux_2_34_aarch64.manylinux_2_35_aarch64.whl +0 -0
  7. fastvideo_kernel-0.3.1/dist/fastvideo_kernel-0.3.1-cp312-cp312-manylinux_2_34_x86_64.manylinux_2_35_x86_64.whl → fastvideo_kernel-0.3.2/dist/fastvideo_kernel-0.3.2-cp312-cp312-manylinux_2_34_x86_64.manylinux_2_35_x86_64.whl +0 -0
  8. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/pyproject.toml +1 -1
  9. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/block_sparse_attn_cute_fwd.py +4 -5
  10. fastvideo_kernel-0.3.2/python/fastvideo_kernel/version.py +1 -0
  11. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/tests/test_attn_qat_infer.py +5 -0
  12. fastvideo_kernel-0.3.1/python/fastvideo_kernel/version.py +0 -1
  13. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/LICENSE +0 -0
  14. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/MANIFEST.in +0 -0
  15. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/attn_qat_infer/__init__.py +0 -0
  16. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/attn_qat_infer/api.py +0 -0
  17. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/__init__.py +0 -0
  18. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/api.cu +0 -0
  19. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/block_config.h +0 -0
  20. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/block_info.h +0 -0
  21. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/blockscaled_layout.h +0 -0
  22. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/cute_extension.h +0 -0
  23. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/epilogue_tma_ws.h +0 -0
  24. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/kernel_traits.h +0 -0
  25. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/kernel_ws.h +0 -0
  26. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/launch.h +0 -0
  27. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/mainloop_tma_ws.h +0 -0
  28. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/named_barrier.h +0 -0
  29. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/params.h +0 -0
  30. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/softmax_fused.h +0 -0
  31. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/static_switch.h +0 -0
  32. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/tile_scheduler.h +0 -0
  33. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/attn_qat_infer/blackwell/utils.h +0 -0
  34. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/attn_qat_infer/quantization/__init__.py +0 -0
  35. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/attn_qat_infer/quantization/bench/bench_quant_k.py +0 -0
  36. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/attn_qat_infer/quantization/bench/bench_quant_q.py +0 -0
  37. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/attn_qat_infer/quantization/bench/bench_quant_v.py +0 -0
  38. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/attn_qat_infer/quantization/bench/bench_utils.py +0 -0
  39. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/attn_qat_infer/quantization/cuda_utils.h +0 -0
  40. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/attn_qat_infer/quantization/fp4_quantization_4d.cu +0 -0
  41. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/benchmarks/bench_fused_compress_topk.py +0 -0
  42. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/benchmarks/bench_vsa.py +0 -0
  43. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/csrc/attention/block_sparse_h100.cu +0 -0
  44. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/csrc/attention/st_attn_h100.cu +0 -0
  45. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/csrc/common_extension.cpp +0 -0
  46. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/common/common.hpp +0 -0
  47. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/common/launch.hpp +0 -0
  48. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/common/load.hpp +0 -0
  49. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/common/store.hpp +0 -0
  50. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/gemm/kernel.hpp +0 -0
  51. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/gemm/launch.hpp +0 -0
  52. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/gemm/utils.hpp +0 -0
  53. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/norm/layernorm.cu +0 -0
  54. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/norm/layernorm.hpp +0 -0
  55. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/norm/rmsnorm.cu +0 -0
  56. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/norm/rmsnorm.hpp +0 -0
  57. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/quant/quant.cu +0 -0
  58. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/csrc/turbodiffusion/quant/quant.hpp +0 -0
  59. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/__init__.py +0 -0
  60. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/block_sparse_attn.py +0 -0
  61. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/block_sparse_attn_256.py +0 -0
  62. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/block_sparse_attn_varlen.py +0 -0
  63. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/ops.py +0 -0
  64. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/triton_kernels/attn_qat_train.py +0 -0
  65. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/triton_kernels/block_sparse_attn_triton.py +0 -0
  66. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/triton_kernels/fused_attention.py +0 -0
  67. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/triton_kernels/fused_compress_topk.py +0 -0
  68. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/triton_kernels/index.py +0 -0
  69. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/triton_kernels/nvfp4_utils.py +0 -0
  70. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/triton_kernels/quant_utils.py +0 -0
  71. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/triton_kernels/sla_triton.py +0 -0
  72. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/triton_kernels/st_attn_triton.py +0 -0
  73. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/turbodiffusion_ops.py +0 -0
  74. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/vmoba.py +0 -0
  75. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/python/fastvideo_kernel/vsa_utils.py +0 -0
  76. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/tests/__init__.py +0 -0
  77. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/tests/support_flex_sta.py +0 -0
  78. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/tests/test_attn_qat_train.py +0 -0
  79. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/tests/test_fused_compress_topk.py +0 -0
  80. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/tests/test_sta.py +0 -0
  81. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/tests/test_turbodiffusion.py +0 -0
  82. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/tests/test_vmoba_correctness.py +0 -0
  83. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/tests/test_vsa.py +0 -0
  84. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/tests/test_vsa256_forward.py +0 -0
  85. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/tests/test_vsa256_forward_cross.py +0 -0
  86. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/tests/test_vsa256_forward_vbs.py +0 -0
  87. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/tests/test_vsa256_triton.py +0 -0
  88. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/tests/test_vsa_forward.py +0 -0
  89. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/tests/test_vsa_utils.py +0 -0
  90. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/tests/test_vsa_varlen.py +0 -0
  91. {fastvideo_kernel-0.3.1 → fastvideo_kernel-0.3.2}/tests/utils.py +0 -0
@@ -1,6 +1,13 @@
1
1
  cmake_minimum_required(VERSION 3.26 FATAL_ERROR)
2
2
  project(fastvideo-kernel LANGUAGES CXX)
3
3
 
4
+ # Capture any caller-provided -DCMAKE_CUDA_ARCHITECTURES *before* enable_language(CUDA)
5
+ # auto-populates it with CMake's built-in default (an old arch, e.g. sm_75 on CUDA 13).
6
+ # torch's cmake actually ignores CMAKE_CUDA_ARCHITECTURES (it drives arch selection via
7
+ # TORCH_CUDA_ARCH_LIST), so we only use this captured value to honor an explicit pin by
8
+ # translating it into TORCH_CUDA_ARCH_LIST below.
9
+ set(_FASTVIDEO_USER_CUDA_ARCH "${CMAKE_CUDA_ARCHITECTURES}")
10
+
4
11
  # Prefer environment variable (used by CI or uv pip install git+repo_addr) if CMake var is not explicitly set.
5
12
  if(NOT DEFINED GPU_BACKEND AND DEFINED ENV{GPU_BACKEND})
6
13
  set(GPU_BACKEND "$ENV{GPU_BACKEND}")
@@ -19,6 +26,71 @@ endif()
19
26
  # Find Python and Torch
20
27
  find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)
21
28
 
29
+ # ---------------------------------------------------------------------------
30
+ # Resolve the target CUDA architecture, BEFORE find_package(Torch) below.
31
+ #
32
+ # torch's cmake (Caffe2 public/cuda.cmake) takes over arch selection: it emits
33
+ # the real -gencode flags from TORCH_CUDA_ARCH_LIST and forces
34
+ # CMAKE_CUDA_ARCHITECTURES to OFF. So the *effective* arch is whatever
35
+ # TORCH_CUDA_ARCH_LIST is when find_package(Torch) runs. build.sh exports it;
36
+ # standards-based builds (pip / uv pip install, sdist) don't, and torch then
37
+ # auto-detects an arch that does not match the GPU -- the kernels build but fail
38
+ # at runtime ("no kernel image is available for execution on the device").
39
+ # Resolve it here when absent (mirrors build.sh): honor a pinned
40
+ # CMAKE_CUDA_ARCHITECTURES if given, else probe the visible GPU with torch.
41
+ # ---------------------------------------------------------------------------
42
+ if(NOT GPU_BACKEND STREQUAL "ROCM")
43
+ if(DEFINED ENV{TORCH_CUDA_ARCH_LIST})
44
+ message(STATUS "CUDA arch: TORCH_CUDA_ARCH_LIST=$ENV{TORCH_CUDA_ARCH_LIST} (from environment)")
45
+ elseif(TORCH_CUDA_ARCH_LIST)
46
+ set(ENV{TORCH_CUDA_ARCH_LIST} "${TORCH_CUDA_ARCH_LIST}")
47
+ message(STATUS "CUDA arch: TORCH_CUDA_ARCH_LIST=${TORCH_CUDA_ARCH_LIST} (from cmake)")
48
+ else()
49
+ set(_FV_ARCH_LIST "")
50
+ if(_FASTVIDEO_USER_CUDA_ARCH)
51
+ # Caller pinned -DCMAKE_CUDA_ARCHITECTURES (which torch ignores); translate it
52
+ # to the TORCH_CUDA_ARCH_LIST spelling: "121" -> "12.1", "90a" -> "9.0a".
53
+ foreach(_fv_arch IN LISTS _FASTVIDEO_USER_CUDA_ARCH)
54
+ string(REGEX MATCH "[af]$" _fv_suffix "${_fv_arch}")
55
+ string(REGEX REPLACE "[af]$" "" _fv_num "${_fv_arch}")
56
+ string(REGEX REPLACE "(.)$" ".\\1" _fv_num "${_fv_num}") # dot before the last digit
57
+ list(APPEND _FV_ARCH_LIST "${_fv_num}${_fv_suffix}")
58
+ endforeach()
59
+ message(STATUS "CUDA arch: TORCH_CUDA_ARCH_LIST=${_FV_ARCH_LIST} (from -DCMAKE_CUDA_ARCHITECTURES=${_FASTVIDEO_USER_CUDA_ARCH})")
60
+ else()
61
+ # Best-effort probe of the visible GPU (mirrors build.sh detect_with_torch).
62
+ execute_process(
63
+ COMMAND "${Python_EXECUTABLE}" -c "import torch; assert torch.cuda.is_available(); mj, mn = torch.cuda.get_device_capability(0); print(f'{mj}.{mn}a' if (mj, mn) in ((9, 0), (12, 0)) else f'{mj}.{mn}')"
64
+ OUTPUT_VARIABLE _FV_ARCH_LIST
65
+ OUTPUT_STRIP_TRAILING_WHITESPACE
66
+ RESULT_VARIABLE _fv_detect_rc
67
+ ERROR_QUIET
68
+ )
69
+ if(_fv_detect_rc EQUAL 0 AND _FV_ARCH_LIST)
70
+ message(STATUS "CUDA arch: TORCH_CUDA_ARCH_LIST=${_FV_ARCH_LIST} (detected via torch, live GPU)")
71
+ else()
72
+ set(_FV_ARCH_LIST "")
73
+ endif()
74
+ endif()
75
+
76
+ if(_FV_ARCH_LIST)
77
+ set(TORCH_CUDA_ARCH_LIST "${_FV_ARCH_LIST}")
78
+ set(ENV{TORCH_CUDA_ARCH_LIST} "${_FV_ARCH_LIST}")
79
+ else()
80
+ message(FATAL_ERROR
81
+ "fastvideo-kernel: could not determine the target CUDA architecture.\n"
82
+ "Refusing to let torch auto-detect an arch that may not run on this GPU. "
83
+ "Fix with one of:\n"
84
+ " - set TORCH_CUDA_ARCH_LIST (e.g. 12.1, or 9.0a for Hopper), or\n"
85
+ " - pass -DCMAKE_CUDA_ARCHITECTURES=<arch> (e.g. 121), or\n"
86
+ " - build where the target GPU is visible to torch.\n"
87
+ "Note: 'pip/uv pip install' builds under build isolation, which hides the "
88
+ "GPU; set TORCH_CUDA_ARCH_LIST or add --no-build-isolation. "
89
+ "fastvideo-kernel/build.sh sets all of this for you.")
90
+ endif()
91
+ endif()
92
+ endif()
93
+
22
94
  # Robustly find Torch include paths using Python
23
95
  execute_process(
24
96
  COMMAND "${Python_EXECUTABLE}" -c "import torch; from torch.utils.cpp_extension import include_paths; print(';'.join(include_paths()))"
@@ -354,4 +426,3 @@ if(ENABLE_ATTN_QAT_INFER)
354
426
  install(TARGETS fp4attn_cuda LIBRARY DESTINATION .)
355
427
  install(TARGETS fp4quant_cuda LIBRARY DESTINATION .)
356
428
  endif()
357
-
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: fastvideo-kernel
3
- Version: 0.3.1
3
+ Version: 0.3.2
4
4
  Summary: Unified CUDA kernels for FastVideo
5
5
  Author-Email: Hao AI Lab <contact@haoailab.com>
6
6
  License: Apache License
@@ -239,13 +239,13 @@ fully usable without it).
239
239
  The symbols the fastpath needs (`flash_attn.cute.block_sparsity.BlockSparseTensorsTorch`,
240
240
  `flash_attn.cute.interface._flash_attn_fwd`) are provided upstream by
241
241
  [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention). Pin to
242
- commit `c19cd20e`: the wrapper targets that revision's `_flash_attn_fwd` signature
243
- (`m_block_size` / `n_block_size`); later upstream revisions reshaped it into a
244
- `tile_mn` tuple and are not drop-in compatible.
242
+ commit `940cd9680f3315f2f06b43ab5bea2c2cf2d96806`, the revision FastVideo pins as
243
+ the `flash-attn-4` source in the repo-root `pyproject.toml`; other revisions may
244
+ have an incompatible `_flash_attn_fwd` signature.
245
245
 
246
246
  ```bash
247
247
  pip install "nvidia-cutlass-dsl>=4.5.0" torchvision
248
- pip install "git+https://github.com/Dao-AILab/flash-attention.git@c19cd20e#subdirectory=flash_attn/cute"
248
+ pip install "git+https://github.com/Dao-AILab/flash-attention.git@940cd9680f3315f2f06b43ab5bea2c2cf2d96806#subdirectory=flash_attn/cute"
249
249
  ```
250
250
 
251
251
  The CuTe kernel JIT-compiles on first use. Verified on Blackwell (sm_100) against
@@ -39,13 +39,13 @@ fully usable without it).
39
39
  The symbols the fastpath needs (`flash_attn.cute.block_sparsity.BlockSparseTensorsTorch`,
40
40
  `flash_attn.cute.interface._flash_attn_fwd`) are provided upstream by
41
41
  [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention). Pin to
42
- commit `c19cd20e`: the wrapper targets that revision's `_flash_attn_fwd` signature
43
- (`m_block_size` / `n_block_size`); later upstream revisions reshaped it into a
44
- `tile_mn` tuple and are not drop-in compatible.
42
+ commit `940cd9680f3315f2f06b43ab5bea2c2cf2d96806`, the revision FastVideo pins as
43
+ the `flash-attn-4` source in the repo-root `pyproject.toml`; other revisions may
44
+ have an incompatible `_flash_attn_fwd` signature.
45
45
 
46
46
  ```bash
47
47
  pip install "nvidia-cutlass-dsl>=4.5.0" torchvision
48
- pip install "git+https://github.com/Dao-AILab/flash-attention.git@c19cd20e#subdirectory=flash_attn/cute"
48
+ pip install "git+https://github.com/Dao-AILab/flash-attention.git@940cd9680f3315f2f06b43ab5bea2c2cf2d96806#subdirectory=flash_attn/cute"
49
49
  ```
50
50
 
51
51
  The CuTe kernel JIT-compiles on first use. Verified on Blackwell (sm_100) against
@@ -39,8 +39,13 @@ if [[ -n "${CONDA_PREFIX:-}" ]]; then
39
39
  unset _need_clean _host_arch
40
40
  fi
41
41
 
42
- # Ensure submodules are initialized if needed (tk)
43
- git submodule update --init --recursive
42
+ # Ensure only the kernel's required headers are initialized. A repository-wide
43
+ # update also clones the unrelated VBench evaluation submodule. Skip outside a
44
+ # git checkout (e.g. Docker contexts that exclude .git), where the submodule
45
+ # contents must already be present.
46
+ if git rev-parse --git-dir >/dev/null 2>&1; then
47
+ git submodule update --init --recursive include/cutlass include/tk
48
+ fi
44
49
 
45
50
  # Install build dependencies
46
51
  uv pip install scikit-build-core cmake ninja
@@ -25,12 +25,16 @@
25
25
  #include "gemm/launch.hpp"
26
26
 
27
27
  void int8_gemm(
28
- at::Tensor const& A, at::Tensor const& A_S,
29
- at::Tensor const& B, at::Tensor const& B_S,
28
+ at::Tensor const& A, at::Tensor const& A_S,
29
+ at::Tensor const& B, at::Tensor const& B_S,
30
30
  torch::Tensor& C
31
31
  ) {
32
32
 
33
-
33
+ // The kernel dereferences raw pointers; a CPU tensor here (e.g. an Int8Linear
34
+ // never moved to CUDA) would otherwise fail as an illegal memory access.
35
+ TORCH_CHECK(A.is_cuda() && A_S.is_cuda() && B.is_cuda() && B_S.is_cuda() && C.is_cuda(),
36
+ "int8_gemm: all tensors must be on CUDA (move Int8Linear to CUDA before forward)");
37
+
34
38
  static constexpr int swizzle_dir = 1;
35
39
  static constexpr int swizzle_size_log = 5;
36
40
 
@@ -9,7 +9,7 @@ build-backend = "scikit_build_core.build"
9
9
 
10
10
  [project]
11
11
  name = "fastvideo-kernel"
12
- version = "0.3.1"
12
+ version = "0.3.2"
13
13
  description = "Unified CUDA kernels for FastVideo"
14
14
  readme = "README.md"
15
15
  requires-python = ">=3.10"
@@ -46,8 +46,7 @@ def _load_fa4_cute():
46
46
  return BlockSparseTensorsTorch, _flash_attn_fwd
47
47
 
48
48
 
49
- # FA4 BSA fwd uses (m_block_size, n_block_size); m_block_size=128 is the
50
- # Q-side tile, kv_block_size comes from the caller's VSA logical KV block.
49
+ # Q-side tile size; kv_block_size comes from the caller's VSA logical KV block.
51
50
  _M_BLOCK_SIZE_DEFAULT = 128
52
51
 
53
52
 
@@ -182,18 +181,18 @@ def _cute_forward(
182
181
  block_size=(q_sparse_block_size, kv_block_size),
183
182
  )
184
183
 
184
+ # _flash_attn_fwd returns (out, lse, p, row_max); keep the first two.
185
185
  out, lse = _flash_attn_fwd(
186
186
  q_bshd,
187
187
  k_bshd,
188
188
  v_bshd,
189
- m_block_size=_M_BLOCK_SIZE_DEFAULT,
190
- n_block_size=kv_block_size,
189
+ tile_mn=(_M_BLOCK_SIZE_DEFAULT, kv_block_size),
191
190
  mask_mod=_build_vbs_mask_mod(kv_block_size),
192
191
  block_sparse_tensors=sparse_tensors,
193
192
  aux_tensors=[variable_block_sizes],
194
193
  causal=False,
195
194
  return_lse=True,
196
- )
195
+ )[:2]
197
196
  return out, lse
198
197
 
199
198
 
@@ -0,0 +1 @@
1
+ __version__ = "0.3.2"
@@ -18,10 +18,15 @@ import os
18
18
  import sys
19
19
  sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
20
20
 
21
+ import pytest
21
22
  import torch
22
23
  import torch.nn.functional as F
23
24
  from torch.nn.attention import SDPBackend, sdpa_kernel
24
25
 
26
+ # The FP4 extensions are only compiled under the sm_120a (Blackwell) arch
27
+ # gate; on other GPUs the api import below would die at collection time.
28
+ pytest.importorskip("fp4attn_cuda", reason="ATTN_QAT_INFER FP4 kernels require a sm_120a build")
29
+
25
30
  from attn_qat_infer.api import sageattn_blackwell
26
31
 
27
32
  DEVICE = torch.device("cuda")
@@ -1 +0,0 @@
1
- __version__ = "0.3.1"