fastvideo-kernel 0.2.4__tar.gz → 0.2.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/PKG-INFO +12 -1
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/README.md +11 -0
- fastvideo_kernel-0.2.6/benchmarks/bench_vsa.py +166 -0
- fastvideo_kernel-0.2.4/dist/fastvideo_kernel-0.2.4-cp311-cp311-manylinux_2_34_x86_64.manylinux_2_35_x86_64.whl → fastvideo_kernel-0.2.6/dist/fastvideo_kernel-0.2.6-cp310-cp310-manylinux_2_34_x86_64.manylinux_2_35_x86_64.whl +0 -0
- fastvideo_kernel-0.2.4/dist/fastvideo_kernel-0.2.4-cp312-cp312-manylinux_2_34_x86_64.manylinux_2_35_x86_64.whl → fastvideo_kernel-0.2.6/dist/fastvideo_kernel-0.2.6-cp311-cp311-manylinux_2_34_x86_64.manylinux_2_35_x86_64.whl +0 -0
- fastvideo_kernel-0.2.4/dist/fastvideo_kernel-0.2.4-cp310-cp310-manylinux_2_34_x86_64.manylinux_2_35_x86_64.whl → fastvideo_kernel-0.2.6/dist/fastvideo_kernel-0.2.6-cp312-cp312-manylinux_2_34_x86_64.manylinux_2_35_x86_64.whl +0 -0
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/pyproject.toml +1 -1
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/python/fastvideo_kernel/block_sparse_attn.py +23 -27
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/python/fastvideo_kernel/ops.py +0 -6
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/python/fastvideo_kernel/triton_kernels/block_sparse_attn_triton.py +314 -47
- fastvideo_kernel-0.2.6/python/fastvideo_kernel/version.py +1 -0
- fastvideo_kernel-0.2.4/python/fastvideo_kernel/version.py +0 -1
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/CMakeLists.txt +0 -0
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/LICENSE +0 -0
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/MANIFEST.in +0 -0
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/build.sh +0 -0
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/csrc/attention/block_sparse_h100.cu +0 -0
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/csrc/attention/st_attn_h100.cu +0 -0
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/csrc/common_extension.cpp +0 -0
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/common/common.hpp +0 -0
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/common/launch.hpp +0 -0
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/common/load.hpp +0 -0
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/common/store.hpp +0 -0
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/gemm/gemm.cu +0 -0
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/gemm/kernel.hpp +0 -0
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/gemm/launch.hpp +0 -0
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/gemm/utils.hpp +0 -0
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/norm/layernorm.cu +0 -0
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/norm/layernorm.hpp +0 -0
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/norm/rmsnorm.cu +0 -0
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/norm/rmsnorm.hpp +0 -0
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/quant/quant.cu +0 -0
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/quant/quant.hpp +0 -0
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/python/fastvideo_kernel/__init__.py +0 -0
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/python/fastvideo_kernel/triton_kernels/index.py +0 -0
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/python/fastvideo_kernel/triton_kernels/sla_triton.py +0 -0
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/python/fastvideo_kernel/triton_kernels/st_attn_triton.py +0 -0
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/python/fastvideo_kernel/turbodiffusion_ops.py +0 -0
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/python/fastvideo_kernel/vmoba.py +0 -0
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/tests/__init__.py +0 -0
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/tests/support_flex_sta.py +0 -0
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/tests/test_sta.py +0 -0
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/tests/test_turbodiffusion.py +0 -0
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/tests/test_vmoba_correctness.py +0 -0
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/tests/test_vsa.py +0 -0
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/tests/test_vsa_forward.py +0 -0
- {fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/tests/utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: fastvideo-kernel
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.6
|
|
4
4
|
Summary: Unified CUDA kernels for FastVideo
|
|
5
5
|
Author-Email: Hao AI Lab <contact@haoailab.com>
|
|
6
6
|
License: Apache License
|
|
@@ -240,6 +240,17 @@ out = video_sparse_attn(q, k, v, block_sizes, block_sizes, topk=5)
|
|
|
240
240
|
out = moba_attn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_k, ...)
|
|
241
241
|
```
|
|
242
242
|
|
|
243
|
+
## Benchmark
|
|
244
|
+
|
|
245
|
+
### VSA (block-sparse) TFLOPs
|
|
246
|
+
|
|
247
|
+
After building/installing `fastvideo-kernel`, run:
|
|
248
|
+
|
|
249
|
+
```bash
|
|
250
|
+
cd fastvideo-kernel
|
|
251
|
+
python benchmarks/bench_vsa.py --batch_size 1 --num_heads 16 --head_dim 128 --q_seq_lens 49152 --topk 64
|
|
252
|
+
```
|
|
253
|
+
|
|
243
254
|
### TurboDiffusion Kernels
|
|
244
255
|
|
|
245
256
|
This package also includes kernels from [TurboDiffusion](https://github.com/thu-ml/TurboDiffusion), including INT8 GEMM, Quantization, RMSNorm and LayerNorm.
|
|
@@ -40,6 +40,17 @@ out = video_sparse_attn(q, k, v, block_sizes, block_sizes, topk=5)
|
|
|
40
40
|
out = moba_attn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_k, ...)
|
|
41
41
|
```
|
|
42
42
|
|
|
43
|
+
## Benchmark
|
|
44
|
+
|
|
45
|
+
### VSA (block-sparse) TFLOPs
|
|
46
|
+
|
|
47
|
+
After building/installing `fastvideo-kernel`, run:
|
|
48
|
+
|
|
49
|
+
```bash
|
|
50
|
+
cd fastvideo-kernel
|
|
51
|
+
python benchmarks/bench_vsa.py --batch_size 1 --num_heads 16 --head_dim 128 --q_seq_lens 49152 --topk 64
|
|
52
|
+
```
|
|
53
|
+
|
|
43
54
|
### TurboDiffusion Kernels
|
|
44
55
|
|
|
45
56
|
This package also includes kernels from [TurboDiffusion](https://github.com/thu-ml/TurboDiffusion), including INT8 GEMM, Quantization, RMSNorm and LayerNorm.
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Benchmark VSA *wrapper* performance (forward + backward) and report TFLOPs.
|
|
4
|
+
|
|
5
|
+
This script benchmarks the autograd-enabled wrapper:
|
|
6
|
+
- fastvideo_kernel.block_sparse_attn.block_sparse_attn
|
|
7
|
+
|
|
8
|
+
So measured time includes wrapper overhead (map->index conversion, dispatch) plus kernel time.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import argparse
|
|
14
|
+
import os
|
|
15
|
+
import random
|
|
16
|
+
from typing import Tuple, Callable
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
import torch
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
from triton.testing import do_bench
|
|
23
|
+
except Exception as e: # pragma: no cover
|
|
24
|
+
raise ImportError("This benchmark requires triton (for triton.testing.do_bench).") from e
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
BLOCK_M = 64
|
|
28
|
+
BLOCK_N = 64
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def set_seed(seed: int = 42) -> None:
|
|
32
|
+
random.seed(seed)
|
|
33
|
+
np.random.seed(seed)
|
|
34
|
+
torch.manual_seed(seed)
|
|
35
|
+
torch.cuda.manual_seed(seed)
|
|
36
|
+
torch.cuda.manual_seed_all(seed)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def parse_arguments() -> argparse.Namespace:
|
|
40
|
+
p = argparse.ArgumentParser(description="Benchmark FastVideo VSA block-sparse attention")
|
|
41
|
+
p.add_argument("--batch_size", type=int, default=1)
|
|
42
|
+
p.add_argument("--num_heads", type=int, default=12)
|
|
43
|
+
p.add_argument("--head_dim", type=int, default=128, choices=[64, 128])
|
|
44
|
+
p.add_argument("--topk", type=int, default=None, help="KV blocks per Q block (default: ~90%% sparsity)")
|
|
45
|
+
p.add_argument("--q_seq_lens", type=int, nargs="+", default=[49152], help="Q sequence lengths (must be /64)")
|
|
46
|
+
p.add_argument("--kv_seq_lens", type=int, nargs="+", default=None, help="KV sequence lengths (defaults to q_seq_len)")
|
|
47
|
+
p.add_argument("--warmup", type=int, default=5)
|
|
48
|
+
p.add_argument("--rep", type=int, default=20)
|
|
49
|
+
p.add_argument("--seed", type=int, default=42)
|
|
50
|
+
p.add_argument("--dtype", type=str, default="bf16", choices=["bf16", "fp16"])
|
|
51
|
+
p.add_argument("--force_triton", action="store_true", help="Force wrapper to use Triton path (if supported by shapes).")
|
|
52
|
+
return p.parse_args()
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def create_qkv(batch: int, heads: int, q_len: int, kv_len: int, d: int, dtype: torch.dtype) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
56
|
+
q = torch.randn(batch, heads, q_len, d, dtype=dtype, device="cuda")
|
|
57
|
+
k = torch.randn(batch, heads, kv_len, d, dtype=dtype, device="cuda")
|
|
58
|
+
v = torch.randn(batch, heads, kv_len, d, dtype=dtype, device="cuda")
|
|
59
|
+
return q, k, v
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def make_block_map(bs: int, h: int, num_q_blocks: int, num_kv_blocks: int, topk: int) -> torch.Tensor:
|
|
63
|
+
# block_map: [bs, h, num_q_blocks, num_kv_blocks] bool
|
|
64
|
+
scores = torch.rand(bs, h, num_q_blocks, num_kv_blocks, device="cuda")
|
|
65
|
+
topk = min(max(1, topk), num_kv_blocks)
|
|
66
|
+
idx = torch.topk(scores, topk, dim=-1).indices
|
|
67
|
+
block_map = torch.zeros(bs, h, num_q_blocks, num_kv_blocks, dtype=torch.bool, device="cuda")
|
|
68
|
+
block_map.scatter_(-1, idx, True)
|
|
69
|
+
return block_map
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def flops_sparse_attention(bs: int, h: int, d: int, q_len: int, topk_blocks: int, block_n: int) -> float:
|
|
73
|
+
# Approx: QK^T + PV, each is ~2*bs*h*q_len*(topk_blocks*block_n)*d
|
|
74
|
+
return 4.0 * bs * h * d * q_len * (topk_blocks * block_n)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def bench_ms(fn: Callable[[], object], warmup: int, rep: int) -> float:
|
|
78
|
+
return do_bench(fn, warmup=warmup, rep=rep, quantiles=None)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def main() -> None:
|
|
82
|
+
args = parse_arguments()
|
|
83
|
+
set_seed(args.seed)
|
|
84
|
+
|
|
85
|
+
dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16
|
|
86
|
+
|
|
87
|
+
if args.force_triton:
|
|
88
|
+
os.environ["FASTVIDEO_KERNEL_VSA_FORCE_TRITON"] = "1"
|
|
89
|
+
|
|
90
|
+
from fastvideo_kernel.block_sparse_attn import block_sparse_attn
|
|
91
|
+
|
|
92
|
+
bs, h, d = args.batch_size, args.num_heads, args.head_dim
|
|
93
|
+
kv_seq_lens = args.kv_seq_lens
|
|
94
|
+
if kv_seq_lens is None:
|
|
95
|
+
kv_seq_lens = args.q_seq_lens
|
|
96
|
+
if len(kv_seq_lens) != len(args.q_seq_lens):
|
|
97
|
+
raise ValueError("kv_seq_lens must have the same number of entries as q_seq_lens (or be omitted).")
|
|
98
|
+
|
|
99
|
+
print("VSA Block-Sparse Attention Benchmark (WRAPPER)")
|
|
100
|
+
print(f"device: {torch.cuda.get_device_name(0)}")
|
|
101
|
+
print(f"batch={bs}, heads={h}, head_dim={d}, dtype={args.dtype}")
|
|
102
|
+
print(f"BLOCK_M={BLOCK_M}, BLOCK_N={BLOCK_N}")
|
|
103
|
+
print("NOTE: timings include wrapper overhead (map->index + dispatch).")
|
|
104
|
+
if args.force_triton:
|
|
105
|
+
print("dispatch: forced Triton (FASTVIDEO_KERNEL_VSA_FORCE_TRITON=1)")
|
|
106
|
+
else:
|
|
107
|
+
print("dispatch: SM90 if available, else Triton")
|
|
108
|
+
|
|
109
|
+
for q_len, kv_len in zip(args.q_seq_lens, kv_seq_lens):
|
|
110
|
+
if q_len % BLOCK_M != 0 or kv_len % BLOCK_N != 0:
|
|
111
|
+
print(f"[skip] q_len={q_len}, kv_len={kv_len} must be divisible by 64")
|
|
112
|
+
continue
|
|
113
|
+
|
|
114
|
+
num_q_blocks = q_len // BLOCK_M
|
|
115
|
+
num_kv_blocks = kv_len // BLOCK_N
|
|
116
|
+
topk = args.topk if args.topk is not None else max(1, num_kv_blocks // 10)
|
|
117
|
+
topk = min(topk, num_kv_blocks)
|
|
118
|
+
|
|
119
|
+
print("\n" + "=" * 80)
|
|
120
|
+
print(f"q_len={q_len}, kv_len={kv_len}, num_q_blocks={num_q_blocks}, num_kv_blocks={num_kv_blocks}, topk={topk}")
|
|
121
|
+
|
|
122
|
+
q, k, v = create_qkv(bs, h, q_len, kv_len, d, dtype)
|
|
123
|
+
block_map = make_block_map(bs, h, num_q_blocks, num_kv_blocks, topk)
|
|
124
|
+
|
|
125
|
+
# Variable block sizes: default full blocks (64 tokens per KV block)
|
|
126
|
+
variable_block_sizes = torch.full((num_kv_blocks,), BLOCK_N, dtype=torch.int32, device="cuda")
|
|
127
|
+
|
|
128
|
+
def _fwd():
|
|
129
|
+
return block_sparse_attn(q, k, v, block_map, variable_block_sizes)
|
|
130
|
+
|
|
131
|
+
fwd_ms = bench_ms(_fwd, warmup=args.warmup, rep=args.rep)
|
|
132
|
+
|
|
133
|
+
# Backward benchmark (wrapper autograd). We build the graph once, then repeatedly run backward
|
|
134
|
+
# on the retained graph so bwd timing excludes the forward compute.
|
|
135
|
+
q_ = q.detach().requires_grad_(True)
|
|
136
|
+
k_ = k.detach().requires_grad_(True)
|
|
137
|
+
v_ = v.detach().requires_grad_(True)
|
|
138
|
+
o_, _aux_ = block_sparse_attn(q_, k_, v_, block_map, variable_block_sizes)
|
|
139
|
+
og = torch.randn_like(o_)
|
|
140
|
+
loss = (o_ * og).sum()
|
|
141
|
+
|
|
142
|
+
for _ in range(max(1, args.warmup // 2)):
|
|
143
|
+
torch.autograd.grad(loss, (q_, k_, v_), retain_graph=True)
|
|
144
|
+
torch.cuda.synchronize()
|
|
145
|
+
|
|
146
|
+
bwd_ms = bench_ms(
|
|
147
|
+
lambda: torch.autograd.grad(loss, (q_, k_, v_), retain_graph=True),
|
|
148
|
+
warmup=0,
|
|
149
|
+
rep=max(5, args.rep // 2),
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
flops = flops_sparse_attention(bs, h, d, q_len, topk, BLOCK_N)
|
|
153
|
+
fwd_tflops = flops / fwd_ms * 1e-12 * 1e3
|
|
154
|
+
# Rough backward multiplier (attention backward typically ~2-3x forward)
|
|
155
|
+
bwd_tflops = (2.5 * flops) / bwd_ms * 1e-12 * 1e3
|
|
156
|
+
|
|
157
|
+
print(f"fwd(wrapper): {fwd_ms:.3f} ms | {fwd_tflops:.2f} TFLOPs (approx)")
|
|
158
|
+
print(f"bwd(wrapper): {bwd_ms:.3f} ms | {bwd_tflops:.2f} TFLOPs (approx)")
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
if __name__ == "__main__":
|
|
162
|
+
if not torch.cuda.is_available():
|
|
163
|
+
raise RuntimeError("CUDA is required for this benchmark.")
|
|
164
|
+
main()
|
|
165
|
+
|
|
166
|
+
|
{fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/python/fastvideo_kernel/block_sparse_attn.py
RENAMED
|
@@ -30,13 +30,12 @@ def _force_triton() -> bool:
|
|
|
30
30
|
return os.environ.get("FASTVIDEO_KERNEL_VSA_FORCE_TRITON", "0") == "1"
|
|
31
31
|
|
|
32
32
|
|
|
33
|
-
def
|
|
33
|
+
def _map_to_index(block_map: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
34
34
|
"""
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
num: [B, H, Q] int32 (#kv blocks per q block)
|
|
35
|
+
Preferred map->index conversion used by the wrapper.
|
|
36
|
+
|
|
37
|
+
This wrapper **requires** the Triton implementation.
|
|
38
|
+
If Triton (or the Triton map_to_index module) is not available, it raises.
|
|
40
39
|
"""
|
|
41
40
|
if block_map.dim() == 3:
|
|
42
41
|
block_map = block_map.unsqueeze(0)
|
|
@@ -45,20 +44,17 @@ def _map_to_index_torch(block_map: torch.Tensor) -> Tuple[torch.Tensor, torch.Te
|
|
|
45
44
|
if block_map.dtype != torch.bool:
|
|
46
45
|
block_map = block_map.to(torch.bool)
|
|
47
46
|
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
num = torch.zeros((B, H, Q), dtype=torch.int32, device=block_map.device)
|
|
47
|
+
if not block_map.is_cuda:
|
|
48
|
+
raise RuntimeError("block_map must be a CUDA tensor (Triton map_to_index required).")
|
|
51
49
|
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
num[b, h, q] = n
|
|
61
|
-
return index, num
|
|
50
|
+
try:
|
|
51
|
+
from fastvideo_kernel.triton_kernels.index import map_to_index as triton_map_to_index # local import
|
|
52
|
+
except Exception as e:
|
|
53
|
+
raise ImportError(
|
|
54
|
+
"Triton map_to_index is required but not available. "
|
|
55
|
+
"Ensure Triton is installed and fastvideo_kernel.triton_kernels.index is importable."
|
|
56
|
+
) from e
|
|
57
|
+
return triton_map_to_index(block_map)
|
|
62
58
|
|
|
63
59
|
|
|
64
60
|
@torch.library.custom_op(
|
|
@@ -77,7 +73,7 @@ def block_sparse_attn_triton(
|
|
|
77
73
|
k = k.contiguous()
|
|
78
74
|
v = v.contiguous()
|
|
79
75
|
block_map = block_map.to(torch.bool)
|
|
80
|
-
q2k_idx, q2k_num =
|
|
76
|
+
q2k_idx, q2k_num = _map_to_index(block_map)
|
|
81
77
|
|
|
82
78
|
from fastvideo_kernel.triton_kernels.block_sparse_attn_triton import ( # local import
|
|
83
79
|
triton_block_sparse_attn_forward,
|
|
@@ -87,6 +83,7 @@ def block_sparse_attn_triton(
|
|
|
87
83
|
return o, M
|
|
88
84
|
|
|
89
85
|
|
|
86
|
+
|
|
90
87
|
@torch.library.register_fake("fastvideo_kernel::block_sparse_attn_triton")
|
|
91
88
|
def _block_sparse_attn_triton_fake(
|
|
92
89
|
q: torch.Tensor,
|
|
@@ -117,8 +114,8 @@ def block_sparse_attn_backward_triton(
|
|
|
117
114
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
118
115
|
grad_output = grad_output.contiguous()
|
|
119
116
|
block_map = block_map.to(torch.bool)
|
|
120
|
-
q2k_idx, q2k_num =
|
|
121
|
-
k2q_idx, k2q_num =
|
|
117
|
+
q2k_idx, q2k_num = _map_to_index(block_map)
|
|
118
|
+
k2q_idx, k2q_num = _map_to_index(block_map.transpose(-1, -2).contiguous())
|
|
122
119
|
|
|
123
120
|
from fastvideo_kernel.triton_kernels.block_sparse_attn_triton import ( # local import
|
|
124
121
|
triton_block_sparse_attn_backward,
|
|
@@ -182,7 +179,7 @@ def block_sparse_attn_sm90(
|
|
|
182
179
|
k_padded = k_padded.contiguous()
|
|
183
180
|
v_padded = v_padded.contiguous()
|
|
184
181
|
block_map = block_map.to(torch.bool)
|
|
185
|
-
q2k_idx, q2k_num =
|
|
182
|
+
q2k_idx, q2k_num = _map_to_index(block_map)
|
|
186
183
|
|
|
187
184
|
o_padded, lse_padded = block_sparse_fwd(
|
|
188
185
|
q_padded, k_padded, v_padded, q2k_idx, q2k_num, variable_block_sizes.int()
|
|
@@ -224,7 +221,7 @@ def block_sparse_attn_backward_sm90(
|
|
|
224
221
|
|
|
225
222
|
grad_output_padded = grad_output_padded.contiguous()
|
|
226
223
|
block_map = block_map.to(torch.bool)
|
|
227
|
-
k2q_idx, k2q_num =
|
|
224
|
+
k2q_idx, k2q_num = _map_to_index(block_map.transpose(-1, -2).contiguous())
|
|
228
225
|
|
|
229
226
|
dq, dk, dv = block_sparse_bwd(
|
|
230
227
|
q_padded,
|
|
@@ -290,9 +287,8 @@ def block_sparse_attn(
|
|
|
290
287
|
block_sparse_fwd, block_sparse_bwd = _get_sm90_ops()
|
|
291
288
|
if (not _force_triton()) and _is_sm90() and (block_sparse_fwd is not None) and (block_sparse_bwd is not None):
|
|
292
289
|
return block_sparse_attn_sm90(q, k, v, block_map, variable_block_sizes)
|
|
293
|
-
# Triton path:
|
|
294
|
-
|
|
295
|
-
raise RuntimeError("Triton fallback requires q/k/v to have the same padded length.")
|
|
290
|
+
# Triton path: supports q_seq_len != kv_seq_len as long as both are padded
|
|
291
|
+
# to a multiple of the block size (64 tokens).
|
|
296
292
|
return block_sparse_attn_triton(q, k, v, block_map, variable_block_sizes)
|
|
297
293
|
|
|
298
294
|
|
|
@@ -141,12 +141,6 @@ def video_sparse_attn(
|
|
|
141
141
|
# Use autograd-enabled wrapper so backward works (and still uses SM90 kernel when available)
|
|
142
142
|
out_s = block_sparse_attn(q, k, v, mask, variable_block_sizes)[0]
|
|
143
143
|
else:
|
|
144
|
-
if q_seq_len != kv_seq_len:
|
|
145
|
-
raise RuntimeError(
|
|
146
|
-
"q/k have different lengths, but the compiled CUDA kernel (block_sparse_fwd) "
|
|
147
|
-
"is not available. The Triton fallback currently requires q and k/v to have "
|
|
148
|
-
"the same padded length."
|
|
149
|
-
)
|
|
150
144
|
# Triton-only forward (kept for environments without the wrapper deps)
|
|
151
145
|
out_s, _ = triton_block_sparse_attn_forward(q, k, v, idx, num, variable_block_sizes)
|
|
152
146
|
|
|
@@ -29,7 +29,7 @@ configs = [
|
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
# ──────────────────────────── SPARSE ADDITION BEGIN ───────────────────────────
|
|
32
|
-
@triton.autotune(configs, key=["
|
|
32
|
+
@triton.autotune(configs, key=["N_CTX_Q", "HEAD_DIM"])
|
|
33
33
|
@triton.jit
|
|
34
34
|
def _attn_fwd_sparse(
|
|
35
35
|
Q,
|
|
@@ -60,7 +60,8 @@ def _attn_fwd_sparse(
|
|
|
60
60
|
stride_on,
|
|
61
61
|
Z,
|
|
62
62
|
H,
|
|
63
|
-
|
|
63
|
+
N_CTX_Q, #
|
|
64
|
+
N_CTX_KV, #
|
|
64
65
|
HEAD_DIM: tl.constexpr, #
|
|
65
66
|
BLOCK_M: tl.constexpr,
|
|
66
67
|
BLOCK_N: tl.constexpr,
|
|
@@ -75,24 +76,29 @@ def _attn_fwd_sparse(
|
|
|
75
76
|
off_hz = tl.program_id(1) # fused (batch, head)
|
|
76
77
|
b = off_hz // H
|
|
77
78
|
h = off_hz % H
|
|
78
|
-
q_tiles =
|
|
79
|
+
q_tiles = N_CTX_Q // BLOCK_M
|
|
79
80
|
meta_base = ((b * H + h) * q_tiles + q_blk)
|
|
80
81
|
|
|
81
82
|
kv_blocks = tl.load(q2k_num + meta_base) # int32
|
|
82
83
|
kv_ptr = q2k_index + meta_base * max_kv_blks # ptr to list
|
|
83
84
|
|
|
84
85
|
# ----- base pointers -----
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
86
|
+
# Note: when q and kv have different sequence lengths, their per-(batch,head)
|
|
87
|
+
# strides differ, so we must compute separate base offsets.
|
|
88
|
+
q_off = (b.to(tl.int64) * stride_qz + h.to(tl.int64) * stride_qh)
|
|
89
|
+
k_off = (b.to(tl.int64) * stride_kz + h.to(tl.int64) * stride_kh)
|
|
90
|
+
v_off = (b.to(tl.int64) * stride_vz + h.to(tl.int64) * stride_vh)
|
|
91
|
+
o_off = (b.to(tl.int64) * stride_oz + h.to(tl.int64) * stride_oh)
|
|
92
|
+
|
|
93
|
+
Q_ptr = tl.make_block_ptr(base=Q + q_off,
|
|
94
|
+
shape=(N_CTX_Q, HEAD_DIM),
|
|
89
95
|
strides=(stride_qm, stride_qk),
|
|
90
96
|
offsets=(q_blk * BLOCK_M, 0),
|
|
91
97
|
block_shape=(BLOCK_M, HEAD_DIM),
|
|
92
98
|
order=(1, 0))
|
|
93
99
|
|
|
94
|
-
K_base = tl.make_block_ptr(base=K +
|
|
95
|
-
shape=(HEAD_DIM,
|
|
100
|
+
K_base = tl.make_block_ptr(base=K + k_off,
|
|
101
|
+
shape=(HEAD_DIM, N_CTX_KV),
|
|
96
102
|
strides=(stride_kk, stride_kn),
|
|
97
103
|
offsets=(0, 0),
|
|
98
104
|
block_shape=(HEAD_DIM, BLOCK_N),
|
|
@@ -100,15 +106,15 @@ def _attn_fwd_sparse(
|
|
|
100
106
|
|
|
101
107
|
v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1,
|
|
102
108
|
0)
|
|
103
|
-
V_base = tl.make_block_ptr(base=V +
|
|
104
|
-
shape=(
|
|
109
|
+
V_base = tl.make_block_ptr(base=V + v_off,
|
|
110
|
+
shape=(N_CTX_KV, HEAD_DIM),
|
|
105
111
|
strides=(stride_vk, stride_vn),
|
|
106
112
|
offsets=(0, 0),
|
|
107
113
|
block_shape=(BLOCK_N, HEAD_DIM),
|
|
108
114
|
order=v_order)
|
|
109
115
|
|
|
110
|
-
O_ptr = tl.make_block_ptr(base=Out +
|
|
111
|
-
shape=(
|
|
116
|
+
O_ptr = tl.make_block_ptr(base=Out + o_off,
|
|
117
|
+
shape=(N_CTX_Q, HEAD_DIM),
|
|
112
118
|
strides=(stride_om, stride_on),
|
|
113
119
|
offsets=(q_blk * BLOCK_M, 0),
|
|
114
120
|
block_shape=(BLOCK_M, HEAD_DIM),
|
|
@@ -150,7 +156,7 @@ def _attn_fwd_sparse(
|
|
|
150
156
|
# ----- epilogue -----
|
|
151
157
|
m_i += tl.math.log2(l_i)
|
|
152
158
|
acc = acc / l_i[:, None]
|
|
153
|
-
tl.store(M + off_hz *
|
|
159
|
+
tl.store(M + off_hz * N_CTX_Q + offs_m, m_i)
|
|
154
160
|
tl.store(O_ptr, acc.to(Out.type.element_ty))
|
|
155
161
|
|
|
156
162
|
|
|
@@ -201,7 +207,7 @@ def _attn_bwd_dkdv(
|
|
|
201
207
|
stride_tok,
|
|
202
208
|
stride_d, #
|
|
203
209
|
H,
|
|
204
|
-
|
|
210
|
+
N_CTX_KV,
|
|
205
211
|
BLOCK_M1: tl.constexpr, #
|
|
206
212
|
BLOCK_N1: tl.constexpr, #
|
|
207
213
|
HEAD_DIM: tl.constexpr, #
|
|
@@ -221,8 +227,8 @@ def _attn_bwd_dkdv(
|
|
|
221
227
|
off_hz = tl.program_id(2) # fused (batch, head)
|
|
222
228
|
b = off_hz // H
|
|
223
229
|
h = off_hz % H
|
|
224
|
-
|
|
225
|
-
meta_base = ((b * H + h) *
|
|
230
|
+
kv_tiles = N_CTX_KV // BLOCK_N1
|
|
231
|
+
meta_base = ((b * H + h) * kv_tiles + kv_blk)
|
|
226
232
|
|
|
227
233
|
q_blocks = tl.load(k2q_num + meta_base) # int32
|
|
228
234
|
q_ptr = k2q_index + meta_base * max_q_blks # ptr to list
|
|
@@ -302,16 +308,21 @@ def _attn_bwd_dq(
|
|
|
302
308
|
|
|
303
309
|
kv_blocks = tl.load(q2k_num + meta_base) # int32
|
|
304
310
|
kv_ptr = q2k_index + meta_base * max_kv_blks # ptr to list
|
|
305
|
-
block_size = tl.load(variable_block_sizes + q_blk)
|
|
306
311
|
|
|
307
312
|
for blk_idx in range(kv_blocks * 2):
|
|
308
|
-
|
|
309
|
-
|
|
313
|
+
kv_idx = tl.load(kv_ptr + blk_idx // 2).to(tl.int32)
|
|
314
|
+
# variable_block_sizes is defined per KV block (tile). Mask must therefore
|
|
315
|
+
# use kv_idx (not q_blk). Also, because we split each 64-token block into
|
|
316
|
+
# two 32-token halves, the mask must account for the half-block offset.
|
|
317
|
+
block_size = tl.load(variable_block_sizes + kv_idx).to(tl.int32)
|
|
318
|
+
half = (blk_idx % 2).to(tl.int32)
|
|
319
|
+
block_sparse_offset = (kv_idx * 2 + half) * step_n * stride_tok
|
|
310
320
|
kT = tl.load(kT_ptrs + block_sparse_offset)
|
|
311
321
|
vT = tl.load(vT_ptrs + block_sparse_offset)
|
|
312
322
|
qk = tl.dot(q, kT)
|
|
313
323
|
p = tl.math.exp2(qk - m)
|
|
314
|
-
|
|
324
|
+
offs_in_block = half * step_n + tl.arange(0, BLOCK_N2)
|
|
325
|
+
mask = offs_in_block < block_size
|
|
315
326
|
p = tl.where(mask[None, :], p, 0.0)
|
|
316
327
|
# Compute dP and dS.
|
|
317
328
|
dp = tl.dot(do, vT).to(tl.float32)
|
|
@@ -467,19 +478,235 @@ def _attn_bwd(
|
|
|
467
478
|
tl.store(dq_ptrs, dq)
|
|
468
479
|
|
|
469
480
|
|
|
481
|
+
@triton.jit
|
|
482
|
+
def _attn_bwd_dkdv_kernel(
|
|
483
|
+
Q,
|
|
484
|
+
K,
|
|
485
|
+
V,
|
|
486
|
+
sm_scale, #
|
|
487
|
+
DO, #
|
|
488
|
+
DK,
|
|
489
|
+
DV, #
|
|
490
|
+
M,
|
|
491
|
+
D,
|
|
492
|
+
k2q_index,
|
|
493
|
+
k2q_num,
|
|
494
|
+
max_q_blks,
|
|
495
|
+
variable_block_sizes,
|
|
496
|
+
# shared token/dim strides (assumed contiguous along token and dim)
|
|
497
|
+
stride_tok,
|
|
498
|
+
stride_d, #
|
|
499
|
+
# batch/head strides (may differ between Q and KV)
|
|
500
|
+
stride_qz,
|
|
501
|
+
stride_qh,
|
|
502
|
+
stride_kz,
|
|
503
|
+
stride_kh,
|
|
504
|
+
stride_vz,
|
|
505
|
+
stride_vh,
|
|
506
|
+
stride_doz,
|
|
507
|
+
stride_doh,
|
|
508
|
+
stride_dkz,
|
|
509
|
+
stride_dkh,
|
|
510
|
+
stride_dvz,
|
|
511
|
+
stride_dvh,
|
|
512
|
+
H,
|
|
513
|
+
N_CTX_Q,
|
|
514
|
+
N_CTX_KV,
|
|
515
|
+
BLOCK_M1: tl.constexpr, #
|
|
516
|
+
BLOCK_N1: tl.constexpr, #
|
|
517
|
+
HEAD_DIM: tl.constexpr):
|
|
518
|
+
"""
|
|
519
|
+
Backward kernel that computes dK and dV for each KV block (64 tokens).
|
|
520
|
+
Grid:
|
|
521
|
+
pid0: kv_blk in [0, N_CTX_KV/BLOCK_N1)
|
|
522
|
+
pid2: fused (batch, head) in [0, B*H)
|
|
523
|
+
"""
|
|
524
|
+
bhid = tl.program_id(2)
|
|
525
|
+
b = bhid // H
|
|
526
|
+
h = bhid % H
|
|
527
|
+
kv_blk = tl.program_id(0)
|
|
528
|
+
|
|
529
|
+
q_adj = (b.to(tl.int64) * stride_qz + h.to(tl.int64) * stride_qh)
|
|
530
|
+
kv_adj_k = (b.to(tl.int64) * stride_kz + h.to(tl.int64) * stride_kh)
|
|
531
|
+
kv_adj_v = (b.to(tl.int64) * stride_vz + h.to(tl.int64) * stride_vh)
|
|
532
|
+
do_adj = (b.to(tl.int64) * stride_doz + h.to(tl.int64) * stride_doh)
|
|
533
|
+
dk_adj = (b.to(tl.int64) * stride_dkz + h.to(tl.int64) * stride_dkh)
|
|
534
|
+
dv_adj = (b.to(tl.int64) * stride_dvz + h.to(tl.int64) * stride_dvh)
|
|
535
|
+
|
|
536
|
+
Q = Q + q_adj
|
|
537
|
+
K = K + kv_adj_k
|
|
538
|
+
V = V + kv_adj_v
|
|
539
|
+
DO = DO + do_adj
|
|
540
|
+
DK = DK + dk_adj
|
|
541
|
+
DV = DV + dv_adj
|
|
542
|
+
|
|
543
|
+
# M and D (delta) are always sized by Q length.
|
|
544
|
+
M = M + (bhid * N_CTX_Q).to(tl.int64)
|
|
545
|
+
D = D + (bhid * N_CTX_Q).to(tl.int64)
|
|
546
|
+
|
|
547
|
+
offs_k = tl.arange(0, HEAD_DIM)
|
|
548
|
+
start_n = kv_blk * BLOCK_N1
|
|
549
|
+
offs_n = start_n + tl.arange(0, BLOCK_N1)
|
|
550
|
+
|
|
551
|
+
# load K and V: they stay in SRAM throughout the inner loop.
|
|
552
|
+
k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
|
553
|
+
v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
|
554
|
+
|
|
555
|
+
dv_acc = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
|
|
556
|
+
dk_acc = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
|
|
557
|
+
|
|
558
|
+
num_steps = N_CTX_Q // BLOCK_M1
|
|
559
|
+
dk_acc, dv_acc = _attn_bwd_dkdv(
|
|
560
|
+
dk_acc,
|
|
561
|
+
dv_acc,
|
|
562
|
+
Q,
|
|
563
|
+
k,
|
|
564
|
+
v,
|
|
565
|
+
sm_scale,
|
|
566
|
+
DO,
|
|
567
|
+
M,
|
|
568
|
+
D,
|
|
569
|
+
k2q_index,
|
|
570
|
+
k2q_num,
|
|
571
|
+
max_q_blks,
|
|
572
|
+
variable_block_sizes,
|
|
573
|
+
stride_tok,
|
|
574
|
+
stride_d,
|
|
575
|
+
H,
|
|
576
|
+
N_CTX_KV,
|
|
577
|
+
BLOCK_M1=BLOCK_M1,
|
|
578
|
+
BLOCK_N1=BLOCK_N1,
|
|
579
|
+
HEAD_DIM=HEAD_DIM,
|
|
580
|
+
start_n=start_n,
|
|
581
|
+
start_m=0,
|
|
582
|
+
num_steps=num_steps,
|
|
583
|
+
)
|
|
584
|
+
|
|
585
|
+
dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
|
|
586
|
+
tl.store(dv_ptrs, dv_acc)
|
|
587
|
+
|
|
588
|
+
dk_acc *= sm_scale
|
|
589
|
+
dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
|
|
590
|
+
tl.store(dk_ptrs, dk_acc)
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
@triton.jit
|
|
594
|
+
def _attn_bwd_dq_kernel(
|
|
595
|
+
Q,
|
|
596
|
+
K,
|
|
597
|
+
V,
|
|
598
|
+
DO, #
|
|
599
|
+
DQ,
|
|
600
|
+
M,
|
|
601
|
+
D,
|
|
602
|
+
q2k_index,
|
|
603
|
+
q2k_num,
|
|
604
|
+
max_kv_blks,
|
|
605
|
+
variable_block_sizes,
|
|
606
|
+
# shared token/dim strides (assumed contiguous along token and dim)
|
|
607
|
+
stride_tok,
|
|
608
|
+
stride_d, #
|
|
609
|
+
# batch/head strides (may differ between Q and KV)
|
|
610
|
+
stride_qz,
|
|
611
|
+
stride_qh,
|
|
612
|
+
stride_kz,
|
|
613
|
+
stride_kh,
|
|
614
|
+
stride_vz,
|
|
615
|
+
stride_vh,
|
|
616
|
+
stride_doz,
|
|
617
|
+
stride_doh,
|
|
618
|
+
stride_dqz,
|
|
619
|
+
stride_dqh,
|
|
620
|
+
H,
|
|
621
|
+
N_CTX_Q,
|
|
622
|
+
BLOCK_M2: tl.constexpr, #
|
|
623
|
+
BLOCK_N2: tl.constexpr, #
|
|
624
|
+
HEAD_DIM: tl.constexpr):
|
|
625
|
+
"""
|
|
626
|
+
Backward kernel that computes dQ for each Q block (64 tokens).
|
|
627
|
+
Grid:
|
|
628
|
+
pid0: q_blk in [0, N_CTX_Q/BLOCK_M2)
|
|
629
|
+
pid2: fused (batch, head) in [0, B*H)
|
|
630
|
+
"""
|
|
631
|
+
LN2 = 0.6931471824645996 # = ln(2)
|
|
632
|
+
bhid = tl.program_id(2)
|
|
633
|
+
b = bhid // H
|
|
634
|
+
h = bhid % H
|
|
635
|
+
q_blk = tl.program_id(0)
|
|
636
|
+
|
|
637
|
+
q_adj = (b.to(tl.int64) * stride_qz + h.to(tl.int64) * stride_qh)
|
|
638
|
+
kv_adj_k = (b.to(tl.int64) * stride_kz + h.to(tl.int64) * stride_kh)
|
|
639
|
+
kv_adj_v = (b.to(tl.int64) * stride_vz + h.to(tl.int64) * stride_vh)
|
|
640
|
+
do_adj = (b.to(tl.int64) * stride_doz + h.to(tl.int64) * stride_doh)
|
|
641
|
+
dq_adj = (b.to(tl.int64) * stride_dqz + h.to(tl.int64) * stride_dqh)
|
|
642
|
+
|
|
643
|
+
Q = Q + q_adj
|
|
644
|
+
K = K + kv_adj_k
|
|
645
|
+
V = V + kv_adj_v
|
|
646
|
+
DO = DO + do_adj
|
|
647
|
+
DQ = DQ + dq_adj
|
|
648
|
+
|
|
649
|
+
M = M + (bhid * N_CTX_Q).to(tl.int64)
|
|
650
|
+
D = D + (bhid * N_CTX_Q).to(tl.int64)
|
|
651
|
+
|
|
652
|
+
offs_k = tl.arange(0, HEAD_DIM)
|
|
653
|
+
start_m = q_blk * BLOCK_M2
|
|
654
|
+
offs_m = start_m + tl.arange(0, BLOCK_M2)
|
|
655
|
+
|
|
656
|
+
q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
|
657
|
+
do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
|
658
|
+
m = tl.load(M + offs_m)[:, None]
|
|
659
|
+
|
|
660
|
+
dq_acc = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32)
|
|
661
|
+
num_steps = 0 # unused in _attn_bwd_dq
|
|
662
|
+
dq_acc = _attn_bwd_dq(
|
|
663
|
+
dq_acc,
|
|
664
|
+
q,
|
|
665
|
+
K,
|
|
666
|
+
V,
|
|
667
|
+
do,
|
|
668
|
+
m,
|
|
669
|
+
D,
|
|
670
|
+
q2k_index,
|
|
671
|
+
q2k_num,
|
|
672
|
+
max_kv_blks,
|
|
673
|
+
variable_block_sizes,
|
|
674
|
+
stride_tok,
|
|
675
|
+
stride_d,
|
|
676
|
+
H,
|
|
677
|
+
N_CTX_Q,
|
|
678
|
+
BLOCK_M2=BLOCK_M2,
|
|
679
|
+
BLOCK_N2=BLOCK_N2,
|
|
680
|
+
HEAD_DIM=HEAD_DIM,
|
|
681
|
+
start_m=start_m,
|
|
682
|
+
start_n=0,
|
|
683
|
+
num_steps=num_steps,
|
|
684
|
+
)
|
|
685
|
+
|
|
686
|
+
dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
|
|
687
|
+
dq_acc *= LN2
|
|
688
|
+
tl.store(dq_ptrs, dq_acc)
|
|
689
|
+
|
|
690
|
+
|
|
470
691
|
# ──────────────────────────── SPARSE ADDITION BEGIN ───────────────────────────
|
|
471
692
|
def triton_block_sparse_attn_forward(q, k, v, q2k_index, q2k_num,
|
|
472
693
|
variable_block_sizes):
|
|
473
|
-
B, H,
|
|
694
|
+
B, H, Tq, D = q.shape
|
|
695
|
+
Tkv = k.shape[2]
|
|
474
696
|
sm_scale = 1.0 / math.sqrt(D)
|
|
475
697
|
max_kv_blks = q2k_index.shape[-1]
|
|
476
|
-
assert
|
|
698
|
+
assert Tq % 64 == 0, f"q length must be a multiple of 64, but got {Tq}"
|
|
699
|
+
assert Tkv % 64 == 0, f"kv length must be a multiple of 64, but got {Tkv}"
|
|
477
700
|
assert q2k_num.shape[
|
|
478
|
-
-1] ==
|
|
701
|
+
-1] == Tq // 64, f"shape mismatch, Tq // 64 = {Tq // 64}, q2k_num.shape[-2] = {q2k_num.shape[-2]}"
|
|
702
|
+
assert variable_block_sizes.numel() == Tkv // 64, (
|
|
703
|
+
f"shape mismatch, variable_block_sizes must have length {Tkv // 64}, "
|
|
704
|
+
f"got {variable_block_sizes.numel()}"
|
|
705
|
+
)
|
|
479
706
|
o = torch.empty_like(q)
|
|
480
|
-
M = torch.empty((B, H,
|
|
707
|
+
M = torch.empty((B, H, Tq), dtype=torch.float32, device=q.device)
|
|
481
708
|
|
|
482
|
-
grid = lambda _: (triton.cdiv(
|
|
709
|
+
grid = lambda _: (triton.cdiv(Tq, 64), B * H, 1)
|
|
483
710
|
_attn_fwd_sparse[grid](q,
|
|
484
711
|
k,
|
|
485
712
|
v,
|
|
@@ -508,7 +735,8 @@ def triton_block_sparse_attn_forward(q, k, v, q2k_index, q2k_num,
|
|
|
508
735
|
o.stride(3),
|
|
509
736
|
B,
|
|
510
737
|
H,
|
|
511
|
-
|
|
738
|
+
Tq,
|
|
739
|
+
Tkv,
|
|
512
740
|
HEAD_DIM=D,
|
|
513
741
|
STAGE=3)
|
|
514
742
|
|
|
@@ -518,21 +746,21 @@ def triton_block_sparse_attn_forward(q, k, v, q2k_index, q2k_num,
|
|
|
518
746
|
def triton_block_sparse_attn_backward(do, q, k, v, o, M, q2k_index, q2k_num,
|
|
519
747
|
k2q_index, k2q_num, variable_block_sizes):
|
|
520
748
|
assert do.is_contiguous()
|
|
521
|
-
assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
|
|
522
749
|
|
|
523
|
-
B, H,
|
|
750
|
+
B, H, Tq, D = q.shape
|
|
751
|
+
Tkv = k.shape[2]
|
|
524
752
|
sm_scale = 1.0 / math.sqrt(D)
|
|
525
753
|
dq = torch.empty_like(q)
|
|
526
754
|
dk = torch.empty_like(k)
|
|
527
755
|
dv = torch.empty_like(v)
|
|
528
|
-
BATCH, N_HEAD
|
|
756
|
+
BATCH, N_HEAD = q.shape[:2]
|
|
529
757
|
BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 64, 64, 32
|
|
530
758
|
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
|
|
531
759
|
arg_k = k
|
|
532
760
|
arg_k = arg_k * (sm_scale * RCP_LN2)
|
|
533
761
|
PRE_BLOCK = 64
|
|
534
|
-
assert
|
|
535
|
-
pre_grid = (
|
|
762
|
+
assert Tq % PRE_BLOCK == 0
|
|
763
|
+
pre_grid = (Tq // PRE_BLOCK, BATCH * N_HEAD)
|
|
536
764
|
delta = torch.empty_like(M)
|
|
537
765
|
_attn_bwd_preprocess[pre_grid](
|
|
538
766
|
o,
|
|
@@ -540,7 +768,7 @@ def triton_block_sparse_attn_backward(do, q, k, v, o, M, q2k_index, q2k_num,
|
|
|
540
768
|
delta, #
|
|
541
769
|
BATCH,
|
|
542
770
|
N_HEAD,
|
|
543
|
-
|
|
771
|
+
Tq, #
|
|
544
772
|
BLOCK_M=PRE_BLOCK,
|
|
545
773
|
HEAD_DIM=D #
|
|
546
774
|
)
|
|
@@ -548,36 +776,75 @@ def triton_block_sparse_attn_backward(do, q, k, v, o, M, q2k_index, q2k_num,
|
|
|
548
776
|
max_q_blks = k2q_index.shape[-1]
|
|
549
777
|
max_kv_blks = q2k_index.shape[-1]
|
|
550
778
|
|
|
551
|
-
|
|
552
|
-
|
|
779
|
+
# dK/dV kernel: grid over KV blocks
|
|
780
|
+
grid_kv = (Tkv // BLOCK_N1, 1, BATCH * N_HEAD)
|
|
781
|
+
_attn_bwd_dkdv_kernel[grid_kv](
|
|
553
782
|
q,
|
|
554
783
|
arg_k,
|
|
555
784
|
v,
|
|
556
785
|
sm_scale,
|
|
557
786
|
do,
|
|
558
|
-
dq,
|
|
559
787
|
dk,
|
|
560
|
-
dv,
|
|
788
|
+
dv,
|
|
561
789
|
M,
|
|
562
|
-
delta,
|
|
563
|
-
q2k_index,
|
|
564
|
-
q2k_num,
|
|
565
|
-
max_kv_blks,
|
|
790
|
+
delta,
|
|
566
791
|
k2q_index,
|
|
567
792
|
k2q_num,
|
|
568
793
|
max_q_blks,
|
|
569
794
|
variable_block_sizes,
|
|
795
|
+
q.stride(2),
|
|
796
|
+
q.stride(3),
|
|
570
797
|
q.stride(0),
|
|
571
798
|
q.stride(1),
|
|
572
|
-
|
|
573
|
-
|
|
799
|
+
arg_k.stride(0),
|
|
800
|
+
arg_k.stride(1),
|
|
801
|
+
v.stride(0),
|
|
802
|
+
v.stride(1),
|
|
803
|
+
do.stride(0),
|
|
804
|
+
do.stride(1),
|
|
805
|
+
dk.stride(0),
|
|
806
|
+
dk.stride(1),
|
|
807
|
+
dv.stride(0),
|
|
808
|
+
dv.stride(1),
|
|
574
809
|
N_HEAD,
|
|
575
|
-
|
|
810
|
+
Tq,
|
|
811
|
+
Tkv,
|
|
576
812
|
BLOCK_M1=BLOCK_M1,
|
|
577
|
-
BLOCK_N1=BLOCK_N1,
|
|
813
|
+
BLOCK_N1=BLOCK_N1,
|
|
814
|
+
HEAD_DIM=D,
|
|
815
|
+
)
|
|
816
|
+
|
|
817
|
+
# dQ kernel: grid over Q blocks
|
|
818
|
+
grid_q = (Tq // BLOCK_M2, 1, BATCH * N_HEAD)
|
|
819
|
+
_attn_bwd_dq_kernel[grid_q](
|
|
820
|
+
q,
|
|
821
|
+
arg_k,
|
|
822
|
+
v,
|
|
823
|
+
do,
|
|
824
|
+
dq,
|
|
825
|
+
M,
|
|
826
|
+
delta,
|
|
827
|
+
q2k_index,
|
|
828
|
+
q2k_num,
|
|
829
|
+
max_kv_blks,
|
|
830
|
+
variable_block_sizes,
|
|
831
|
+
q.stride(2),
|
|
832
|
+
q.stride(3),
|
|
833
|
+
q.stride(0),
|
|
834
|
+
q.stride(1),
|
|
835
|
+
arg_k.stride(0),
|
|
836
|
+
arg_k.stride(1),
|
|
837
|
+
v.stride(0),
|
|
838
|
+
v.stride(1),
|
|
839
|
+
do.stride(0),
|
|
840
|
+
do.stride(1),
|
|
841
|
+
dq.stride(0),
|
|
842
|
+
dq.stride(1),
|
|
843
|
+
N_HEAD,
|
|
844
|
+
Tq,
|
|
578
845
|
BLOCK_M2=BLOCK_M2,
|
|
579
|
-
BLOCK_N2=BLOCK_N2,
|
|
580
|
-
HEAD_DIM=D
|
|
846
|
+
BLOCK_N2=BLOCK_N2,
|
|
847
|
+
HEAD_DIM=D,
|
|
581
848
|
)
|
|
582
849
|
|
|
583
850
|
return dq, dk, dv
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.2.6"
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
__version__ = "0.2.4"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/python/fastvideo_kernel/triton_kernels/index.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{fastvideo_kernel-0.2.4 → fastvideo_kernel-0.2.6}/python/fastvideo_kernel/turbodiffusion_ops.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|