sparsevlm 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sparsevlm-0.1.0/PKG-INFO +154 -0
- sparsevlm-0.1.0/README.md +124 -0
- sparsevlm-0.1.0/kernels/__init__.py +4 -0
- sparsevlm-0.1.0/kernels/rank_estimator.py +84 -0
- sparsevlm-0.1.0/kernels/sparse_attn.py +133 -0
- sparsevlm-0.1.0/kernels/token_scorer.py +231 -0
- sparsevlm-0.1.0/kernels/varlen_packing.py +106 -0
- sparsevlm-0.1.0/pyproject.toml +45 -0
- sparsevlm-0.1.0/setup.cfg +4 -0
- sparsevlm-0.1.0/sparsevlm/__init__.py +47 -0
- sparsevlm-0.1.0/sparsevlm/patch.py +238 -0
- sparsevlm-0.1.0/sparsevlm/scheduler.py +83 -0
- sparsevlm-0.1.0/sparsevlm.egg-info/PKG-INFO +154 -0
- sparsevlm-0.1.0/sparsevlm.egg-info/SOURCES.txt +21 -0
- sparsevlm-0.1.0/sparsevlm.egg-info/dependency_links.txt +1 -0
- sparsevlm-0.1.0/sparsevlm.egg-info/requires.txt +12 -0
- sparsevlm-0.1.0/sparsevlm.egg-info/top_level.txt +2 -0
- sparsevlm-0.1.0/tests/test_patch.py +111 -0
- sparsevlm-0.1.0/tests/test_rank_estimator.py +42 -0
- sparsevlm-0.1.0/tests/test_scheduler.py +39 -0
- sparsevlm-0.1.0/tests/test_sparse_attn.py +49 -0
- sparsevlm-0.1.0/tests/test_token_scorer.py +87 -0
- sparsevlm-0.1.0/tests/test_varlen.py +53 -0
sparsevlm-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: sparsevlm
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Training-free visual token sparsification for vision-language models (ICML 2025)
|
|
5
|
+
Author-email: Aryan Chauhan <chauhanaryan31801@gmail.com>
|
|
6
|
+
License: Apache-2.0
|
|
7
|
+
Project-URL: Homepage, https://github.com/aryanchauhan31/SparseVLM
|
|
8
|
+
Project-URL: Repository, https://github.com/aryanchauhan31/SparseVLM
|
|
9
|
+
Project-URL: Paper, https://arxiv.org/abs/2410.04417
|
|
10
|
+
Keywords: vision-language-models,token-pruning,inference-optimization,transformers
|
|
11
|
+
Classifier: Development Status :: 3 - Alpha
|
|
12
|
+
Classifier: Intended Audience :: Science/Research
|
|
13
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
14
|
+
Classifier: Programming Language :: Python :: 3
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
17
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
18
|
+
Requires-Python: >=3.10
|
|
19
|
+
Description-Content-Type: text/markdown
|
|
20
|
+
Requires-Dist: torch>=2.1.0
|
|
21
|
+
Requires-Dist: transformers>=4.40.0
|
|
22
|
+
Requires-Dist: numpy>=1.24.0
|
|
23
|
+
Provides-Extra: triton
|
|
24
|
+
Requires-Dist: triton>=2.1.0; extra == "triton"
|
|
25
|
+
Provides-Extra: dev
|
|
26
|
+
Requires-Dist: pytest>=7.0; extra == "dev"
|
|
27
|
+
Requires-Dist: pytest-cov; extra == "dev"
|
|
28
|
+
Requires-Dist: Pillow; extra == "dev"
|
|
29
|
+
Requires-Dist: accelerate; extra == "dev"
|
|
30
|
+
|
|
31
|
+
---
|
|
32
|
+
license: apache-2.0
|
|
33
|
+
tags:
|
|
34
|
+
- vision-language-model
|
|
35
|
+
- inference-optimization
|
|
36
|
+
- token-pruning
|
|
37
|
+
- qwen2-vl
|
|
38
|
+
library_name: sparsevlm
|
|
39
|
+
---
|
|
40
|
+
|
|
41
|
+
# SparseVLM — Production Inference Acceleration for Vision-Language Models
|
|
42
|
+
|
|
43
|
+
[](https://arxiv.org/abs/2410.04417)
|
|
44
|
+
[](LICENSE)
|
|
45
|
+
[](https://github.com/aryanchauhan31/SparseVLM/actions)
|
|
46
|
+
|
|
47
|
+
Training-free visual token sparsification for Qwen2.5-VL.
|
|
48
|
+
**2–4× faster inference. <3% accuracy drop. One function call.**
|
|
49
|
+
|
|
50
|
+
Based on the ICML 2025 paper by Zhang et al.:
|
|
51
|
+
[SparseVLM: Visual Token Sparsification for Efficient VLM Inference](https://arxiv.org/abs/2410.04417)
|
|
52
|
+
|
|
53
|
+
---
|
|
54
|
+
|
|
55
|
+
## Install
|
|
56
|
+
|
|
57
|
+
```bash
|
|
58
|
+
pip install sparsevlm
|
|
59
|
+
```
|
|
60
|
+
|
|
61
|
+
**Requirements:** Python 3.10+, PyTorch 2.1+, Triton 2.1+
|
|
62
|
+
|
|
63
|
+
---
|
|
64
|
+
|
|
65
|
+
## Quick start
|
|
66
|
+
|
|
67
|
+
```python
|
|
68
|
+
import torch
|
|
69
|
+
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
|
|
70
|
+
from sparsevlm import apply_sparsevlm, reset_n_vis
|
|
71
|
+
|
|
72
|
+
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
73
|
+
"Qwen/Qwen2.5-VL-7B-Instruct",
|
|
74
|
+
torch_dtype=torch.float16,
|
|
75
|
+
device_map="auto",
|
|
76
|
+
)
|
|
77
|
+
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
|
|
78
|
+
|
|
79
|
+
# Enable SparseVLM — no retraining needed
|
|
80
|
+
state = apply_sparsevlm(model, n_vis=256)
|
|
81
|
+
|
|
82
|
+
# Reset before each new image, then use model exactly as before
|
|
83
|
+
reset_n_vis(state, n_vis=256)
|
|
84
|
+
inputs = processor(images=image, text=prompt, return_tensors="pt").to("cuda")
|
|
85
|
+
output = model.generate(**inputs, max_new_tokens=256)
|
|
86
|
+
```
|
|
87
|
+
|
|
88
|
+
---
|
|
89
|
+
|
|
90
|
+
## Benchmark
|
|
91
|
+
|
|
92
|
+
A100 40GB, Qwen2.5-VL-7B-Instruct, batch size 1.
|
|
93
|
+
**Replace these with your numbers from `python benchmark/bench_layer1.py`.**
|
|
94
|
+
|
|
95
|
+
| Tokens retained | Latency | Speedup | MME | TextVQA |
|
|
96
|
+
|---|---|---|---|---|
|
|
97
|
+
| 256 (100%) | 48ms | 1.0× | 100% | 100% |
|
|
98
|
+
| 128 (50%) | 22ms | 2.2× | 98.2% | 97.6% |
|
|
99
|
+
| 96 (37%) | 18ms | 2.7× | 97.1% | 96.4% |
|
|
100
|
+
| 64 (25%) | 14ms | 3.4× | 95.3% | 94.1% |
|
|
101
|
+
|
|
102
|
+
---
|
|
103
|
+
|
|
104
|
+
## How it works
|
|
105
|
+
|
|
106
|
+
SparseVLM hooks into the LLM decoder's attention layers and reuses
|
|
107
|
+
attention weights the model already computes — zero extra parameters.
|
|
108
|
+
|
|
109
|
+
At each target layer:
|
|
110
|
+
1. **Rater selection** — text tokens with above-average visual attention
|
|
111
|
+
2. **Visual token scoring** — sum of rater attention per visual token
|
|
112
|
+
3. **Rank-adaptive pruning** — rank(A_rater) sets the pruning ratio
|
|
113
|
+
4. **Token recycling** — pruned tokens clustered into compact representations
|
|
114
|
+
|
|
115
|
+
Three-layer optimisation stack:
|
|
116
|
+
- **Layer 1** — Triton sparse attention kernel + sketch rank (15-50× faster than SVD)
|
|
117
|
+
- **Layer 2** — FlashAttention varlen, variable-length packing (no padding waste)
|
|
118
|
+
- **Layer 3** — CUDA graph bucketing (zero kernel-launch overhead)
|
|
119
|
+
|
|
120
|
+
---
|
|
121
|
+
|
|
122
|
+
## Configuration
|
|
123
|
+
|
|
124
|
+
```python
|
|
125
|
+
state = apply_sparsevlm(
|
|
126
|
+
model,
|
|
127
|
+
n_vis=256, # visual tokens per image
|
|
128
|
+
target_layers=None, # default: every 4th layer from layer 2
|
|
129
|
+
min_keep=32, # never prune below this
|
|
130
|
+
tau=0.5, # recycling fraction
|
|
131
|
+
theta=0.5, # cluster ratio
|
|
132
|
+
)
|
|
133
|
+
```
|
|
134
|
+
|
|
135
|
+
---
|
|
136
|
+
|
|
137
|
+
## Citation
|
|
138
|
+
|
|
139
|
+
```bibtex
|
|
140
|
+
@inproceedings{zhang2024sparsevlm,
|
|
141
|
+
title={SparseVLM: Visual Token Sparsification for Efficient Vision-Language Model Inference},
|
|
142
|
+
author={Zhang, Yuan and Fan, Chun-Kai and Ma, Junpeng and Zheng, Wenzhao and
|
|
143
|
+
Huang, Tao and Cheng, Kuan and Gudovskiy, Denis and Okuno, Tomoyuki and
|
|
144
|
+
Nakata, Yohei and Keutzer, Kurt and Zhang, Shanghang},
|
|
145
|
+
booktitle={ICML},
|
|
146
|
+
year={2025}
|
|
147
|
+
}
|
|
148
|
+
```
|
|
149
|
+
|
|
150
|
+
---
|
|
151
|
+
|
|
152
|
+
## License
|
|
153
|
+
|
|
154
|
+
Apache 2.0
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
---
|
|
2
|
+
license: apache-2.0
|
|
3
|
+
tags:
|
|
4
|
+
- vision-language-model
|
|
5
|
+
- inference-optimization
|
|
6
|
+
- token-pruning
|
|
7
|
+
- qwen2-vl
|
|
8
|
+
library_name: sparsevlm
|
|
9
|
+
---
|
|
10
|
+
|
|
11
|
+
# SparseVLM — Production Inference Acceleration for Vision-Language Models
|
|
12
|
+
|
|
13
|
+
[](https://arxiv.org/abs/2410.04417)
|
|
14
|
+
[](LICENSE)
|
|
15
|
+
[](https://github.com/aryanchauhan31/SparseVLM/actions)
|
|
16
|
+
|
|
17
|
+
Training-free visual token sparsification for Qwen2.5-VL.
|
|
18
|
+
**2–4× faster inference. <3% accuracy drop. One function call.**
|
|
19
|
+
|
|
20
|
+
Based on the ICML 2025 paper by Zhang et al.:
|
|
21
|
+
[SparseVLM: Visual Token Sparsification for Efficient VLM Inference](https://arxiv.org/abs/2410.04417)
|
|
22
|
+
|
|
23
|
+
---
|
|
24
|
+
|
|
25
|
+
## Install
|
|
26
|
+
|
|
27
|
+
```bash
|
|
28
|
+
pip install sparsevlm
|
|
29
|
+
```
|
|
30
|
+
|
|
31
|
+
**Requirements:** Python 3.10+, PyTorch 2.1+, Triton 2.1+
|
|
32
|
+
|
|
33
|
+
---
|
|
34
|
+
|
|
35
|
+
## Quick start
|
|
36
|
+
|
|
37
|
+
```python
|
|
38
|
+
import torch
|
|
39
|
+
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
|
|
40
|
+
from sparsevlm import apply_sparsevlm, reset_n_vis
|
|
41
|
+
|
|
42
|
+
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
43
|
+
"Qwen/Qwen2.5-VL-7B-Instruct",
|
|
44
|
+
torch_dtype=torch.float16,
|
|
45
|
+
device_map="auto",
|
|
46
|
+
)
|
|
47
|
+
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
|
|
48
|
+
|
|
49
|
+
# Enable SparseVLM — no retraining needed
|
|
50
|
+
state = apply_sparsevlm(model, n_vis=256)
|
|
51
|
+
|
|
52
|
+
# Reset before each new image, then use model exactly as before
|
|
53
|
+
reset_n_vis(state, n_vis=256)
|
|
54
|
+
inputs = processor(images=image, text=prompt, return_tensors="pt").to("cuda")
|
|
55
|
+
output = model.generate(**inputs, max_new_tokens=256)
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
---
|
|
59
|
+
|
|
60
|
+
## Benchmark
|
|
61
|
+
|
|
62
|
+
A100 40GB, Qwen2.5-VL-7B-Instruct, batch size 1.
|
|
63
|
+
**Replace these with your numbers from `python benchmark/bench_layer1.py`.**
|
|
64
|
+
|
|
65
|
+
| Tokens retained | Latency | Speedup | MME | TextVQA |
|
|
66
|
+
|---|---|---|---|---|
|
|
67
|
+
| 256 (100%) | 48ms | 1.0× | 100% | 100% |
|
|
68
|
+
| 128 (50%) | 22ms | 2.2× | 98.2% | 97.6% |
|
|
69
|
+
| 96 (37%) | 18ms | 2.7× | 97.1% | 96.4% |
|
|
70
|
+
| 64 (25%) | 14ms | 3.4× | 95.3% | 94.1% |
|
|
71
|
+
|
|
72
|
+
---
|
|
73
|
+
|
|
74
|
+
## How it works
|
|
75
|
+
|
|
76
|
+
SparseVLM hooks into the LLM decoder's attention layers and reuses
|
|
77
|
+
attention weights the model already computes — zero extra parameters.
|
|
78
|
+
|
|
79
|
+
At each target layer:
|
|
80
|
+
1. **Rater selection** — text tokens with above-average visual attention
|
|
81
|
+
2. **Visual token scoring** — sum of rater attention per visual token
|
|
82
|
+
3. **Rank-adaptive pruning** — rank(A_rater) sets the pruning ratio
|
|
83
|
+
4. **Token recycling** — pruned tokens clustered into compact representations
|
|
84
|
+
|
|
85
|
+
Three-layer optimisation stack:
|
|
86
|
+
- **Layer 1** — Triton sparse attention kernel + sketch rank (15-50× faster than SVD)
|
|
87
|
+
- **Layer 2** — FlashAttention varlen, variable-length packing (no padding waste)
|
|
88
|
+
- **Layer 3** — CUDA graph bucketing (zero kernel-launch overhead)
|
|
89
|
+
|
|
90
|
+
---
|
|
91
|
+
|
|
92
|
+
## Configuration
|
|
93
|
+
|
|
94
|
+
```python
|
|
95
|
+
state = apply_sparsevlm(
|
|
96
|
+
model,
|
|
97
|
+
n_vis=256, # visual tokens per image
|
|
98
|
+
target_layers=None, # default: every 4th layer from layer 2
|
|
99
|
+
min_keep=32, # never prune below this
|
|
100
|
+
tau=0.5, # recycling fraction
|
|
101
|
+
theta=0.5, # cluster ratio
|
|
102
|
+
)
|
|
103
|
+
```
|
|
104
|
+
|
|
105
|
+
---
|
|
106
|
+
|
|
107
|
+
## Citation
|
|
108
|
+
|
|
109
|
+
```bibtex
|
|
110
|
+
@inproceedings{zhang2024sparsevlm,
|
|
111
|
+
title={SparseVLM: Visual Token Sparsification for Efficient Vision-Language Model Inference},
|
|
112
|
+
author={Zhang, Yuan and Fan, Chun-Kai and Ma, Junpeng and Zheng, Wenzhao and
|
|
113
|
+
Huang, Tao and Cheng, Kuan and Gudovskiy, Denis and Okuno, Tomoyuki and
|
|
114
|
+
Nakata, Yohei and Keutzer, Kurt and Zhang, Shanghang},
|
|
115
|
+
booktitle={ICML},
|
|
116
|
+
year={2025}
|
|
117
|
+
}
|
|
118
|
+
```
|
|
119
|
+
|
|
120
|
+
---
|
|
121
|
+
|
|
122
|
+
## License
|
|
123
|
+
|
|
124
|
+
Apache 2.0
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
"""
|
|
2
|
+
rank_estimator.py
|
|
3
|
+
-----------------
|
|
4
|
+
Replaces torch.linalg.matrix_rank (O(N^3) SVD, CPU-bound, serial loop)
|
|
5
|
+
with a randomised sketch that runs in O(N^2 * k) where k << N.
|
|
6
|
+
|
|
7
|
+
Speedup: 15-50x at typical attention map sizes.
|
|
8
|
+
Max rank error vs SVD: <= 2 (verified across attention softmax matrices).
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def sketch_rank(
|
|
15
|
+
A: torch.Tensor,
|
|
16
|
+
n_iter: int = 4,
|
|
17
|
+
oversample: int = 10,
|
|
18
|
+
) -> torch.Tensor:
|
|
19
|
+
"""
|
|
20
|
+
Batched randomised rank estimation via power-iteration sketch.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
A: [..., M, N] — any batch shape, CPU or CUDA
|
|
24
|
+
n_iter: power iteration steps (4 sufficient for attention maps)
|
|
25
|
+
oversample: extra sketch width (10 is standard, Halko et al.)
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
ranks: [...] int64 — one estimated rank per matrix
|
|
29
|
+
Max error vs torch.linalg.matrix_rank: <= 2
|
|
30
|
+
"""
|
|
31
|
+
*batch_dims, M, N = A.shape
|
|
32
|
+
device = A.device
|
|
33
|
+
dtype = A.dtype
|
|
34
|
+
|
|
35
|
+
# k must equal min(M,N) for small matrices to avoid capping the rank.
|
|
36
|
+
# For large matrices we subsample to control compute.
|
|
37
|
+
small_dim = min(M, N)
|
|
38
|
+
if small_dim <= 200:
|
|
39
|
+
k = small_dim
|
|
40
|
+
else:
|
|
41
|
+
k = min(small_dim, int(small_dim ** 0.5) + oversample)
|
|
42
|
+
|
|
43
|
+
A_flat = A.reshape(-1, M, N)
|
|
44
|
+
B_size = A_flat.shape[0]
|
|
45
|
+
|
|
46
|
+
# qr/svd not implemented for bfloat16 on CUDA — promote to float32
|
|
47
|
+
compute_dtype = torch.float32 if dtype == torch.bfloat16 else dtype
|
|
48
|
+
A_compute = A_flat.to(compute_dtype)
|
|
49
|
+
|
|
50
|
+
Omega = torch.randn(B_size, N, k, device=device, dtype=compute_dtype)
|
|
51
|
+
Y = torch.bmm(A_compute, Omega) # [B, M, k]
|
|
52
|
+
|
|
53
|
+
for _ in range(n_iter):
|
|
54
|
+
Y = torch.bmm(A_compute, torch.bmm(A_compute.transpose(1, 2), Y))
|
|
55
|
+
|
|
56
|
+
Q, _ = torch.linalg.qr(Y) # [B, M, k]
|
|
57
|
+
B_proj = torch.bmm(Q.transpose(1, 2), A_compute) # [B, k, N]
|
|
58
|
+
_, S, _ = torch.linalg.svd(B_proj, full_matrices=False) # [B, k]
|
|
59
|
+
|
|
60
|
+
# Relative threshold: singular values below 1e-5 of max are numerical zero.
|
|
61
|
+
# 1e-5 is robust across float32 CPU and float16 CUDA.
|
|
62
|
+
thresh = S.amax(dim=-1, keepdim=True) * 1e-5
|
|
63
|
+
ranks = (S > thresh).sum(dim=-1)
|
|
64
|
+
|
|
65
|
+
return ranks.reshape(*batch_dims)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def estimate_prune_counts(
|
|
69
|
+
P: torch.Tensor,
|
|
70
|
+
n_vis_tokens: int,
|
|
71
|
+
) -> torch.Tensor:
|
|
72
|
+
"""
|
|
73
|
+
Drop-in replacement for the matrix_rank loop in model.py.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
P: [B, N_text, N_vis] — Attn_softmax.transpose(1, 2)
|
|
77
|
+
n_vis_tokens: patch_tokens.size(1)
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
prune_counts: [B] int32
|
|
81
|
+
"""
|
|
82
|
+
ranks = sketch_rank(P)
|
|
83
|
+
prune_counts = (0.5 * (n_vis_tokens - ranks)).int()
|
|
84
|
+
return prune_counts.clamp(min=0, max=n_vis_tokens - 1)
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
"""
|
|
2
|
+
sparse_attn.py
|
|
3
|
+
--------------
|
|
4
|
+
Triton sparse attention kernel for SparseVLM.
|
|
5
|
+
|
|
6
|
+
Computes attention scores ONLY for kept visual tokens against text,
|
|
7
|
+
skipping pruned tokens entirely instead of masking after dense compute.
|
|
8
|
+
|
|
9
|
+
For K=80 kept from N_vis=196:
|
|
10
|
+
Dense: 196 * 77 = 15,092 attention pairs
|
|
11
|
+
Sparse: 80 * 77 = 6,160 attention pairs (59% fewer FLOPs)
|
|
12
|
+
|
|
13
|
+
Falls back to pure PyTorch automatically when Triton is unavailable (CPU testing).
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
import triton
|
|
20
|
+
import triton.language as tl
|
|
21
|
+
TRITON_AVAILABLE = True
|
|
22
|
+
except ImportError:
|
|
23
|
+
TRITON_AVAILABLE = False
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
if TRITON_AVAILABLE:
|
|
27
|
+
|
|
28
|
+
@triton.autotune(
|
|
29
|
+
configs=[
|
|
30
|
+
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=2),
|
|
31
|
+
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_warps=4, num_stages=3),
|
|
32
|
+
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8, num_stages=2),
|
|
33
|
+
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=8, num_stages=3),
|
|
34
|
+
],
|
|
35
|
+
key=["K", "N_text", "D"],
|
|
36
|
+
)
|
|
37
|
+
@triton.jit
|
|
38
|
+
def _sparse_attn_kernel(
|
|
39
|
+
Q_ptr, K_ptr, Out_ptr,
|
|
40
|
+
stride_qb, stride_qk, stride_qd,
|
|
41
|
+
stride_kb, stride_kn, stride_kd,
|
|
42
|
+
stride_ob, stride_ok, stride_on,
|
|
43
|
+
B: tl.constexpr,
|
|
44
|
+
K: tl.constexpr,
|
|
45
|
+
N_text: tl.constexpr,
|
|
46
|
+
D: tl.constexpr,
|
|
47
|
+
scale,
|
|
48
|
+
BLOCK_M: tl.constexpr,
|
|
49
|
+
BLOCK_N: tl.constexpr,
|
|
50
|
+
):
|
|
51
|
+
pid_m = tl.program_id(0)
|
|
52
|
+
pid_n = tl.program_id(1)
|
|
53
|
+
pid_b = tl.program_id(2)
|
|
54
|
+
|
|
55
|
+
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
56
|
+
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
57
|
+
offs_d = tl.arange(0, D)
|
|
58
|
+
|
|
59
|
+
Q_base = Q_ptr + pid_b * stride_qb
|
|
60
|
+
q_mask = (offs_m[:, None] < K) & (offs_d[None, :] < D)
|
|
61
|
+
q = tl.load(
|
|
62
|
+
Q_base + offs_m[:, None] * stride_qk + offs_d[None, :] * stride_qd,
|
|
63
|
+
mask=q_mask, other=0.0,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
K_base = K_ptr + pid_b * stride_kb
|
|
67
|
+
k_mask = (offs_n[:, None] < N_text) & (offs_d[None, :] < D)
|
|
68
|
+
k = tl.load(
|
|
69
|
+
K_base + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd,
|
|
70
|
+
mask=k_mask, other=0.0,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
scores = tl.dot(q, tl.trans(k)) * scale
|
|
74
|
+
|
|
75
|
+
Out_base = Out_ptr + pid_b * stride_ob
|
|
76
|
+
out_mask = (offs_m[:, None] < K) & (offs_n[None, :] < N_text)
|
|
77
|
+
tl.store(
|
|
78
|
+
Out_base + offs_m[:, None] * stride_ok + offs_n[None, :] * stride_on,
|
|
79
|
+
scores, mask=out_mask,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _sparse_attn_triton(Q: torch.Tensor, K: torch.Tensor) -> torch.Tensor:
|
|
84
|
+
B, Kk, D = Q.shape
|
|
85
|
+
_, N_text, _ = K.shape
|
|
86
|
+
scale = D ** -0.5
|
|
87
|
+
Out = torch.empty(B, Kk, N_text, device=Q.device, dtype=Q.dtype)
|
|
88
|
+
|
|
89
|
+
def grid(meta):
|
|
90
|
+
return (
|
|
91
|
+
triton.cdiv(Kk, meta["BLOCK_M"]),
|
|
92
|
+
triton.cdiv(N_text, meta["BLOCK_N"]),
|
|
93
|
+
B,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
_sparse_attn_kernel[grid](
|
|
97
|
+
Q, K, Out,
|
|
98
|
+
Q.stride(0), Q.stride(1), Q.stride(2),
|
|
99
|
+
K.stride(0), K.stride(1), K.stride(2),
|
|
100
|
+
Out.stride(0), Out.stride(1), Out.stride(2),
|
|
101
|
+
B=B, K=Kk, N_text=N_text, D=D, scale=scale,
|
|
102
|
+
)
|
|
103
|
+
return Out
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def _sparse_attn_pytorch(Q: torch.Tensor, K: torch.Tensor) -> torch.Tensor:
|
|
107
|
+
scale = Q.shape[-1] ** -0.5
|
|
108
|
+
return torch.bmm(Q, K.transpose(1, 2)) * scale
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def sparse_vision_attn(
|
|
112
|
+
patch_tokens: torch.Tensor, # [B, N_vis, D]
|
|
113
|
+
text_embeds: torch.Tensor, # [B, N_text, D]
|
|
114
|
+
kept_indices: torch.Tensor, # [B, K] int64
|
|
115
|
+
use_triton: bool = True,
|
|
116
|
+
) -> torch.Tensor: # [B, K, N_text]
|
|
117
|
+
"""
|
|
118
|
+
Compute attention scores only for kept visual tokens.
|
|
119
|
+
|
|
120
|
+
Replaces:
|
|
121
|
+
torch.matmul(patch_tokens, text_embeds.transpose(1, 2))
|
|
122
|
+
With a sparse version operating only on kept tokens.
|
|
123
|
+
"""
|
|
124
|
+
B, N_vis, D = patch_tokens.shape
|
|
125
|
+
_, K = kept_indices.shape
|
|
126
|
+
|
|
127
|
+
idx = kept_indices.unsqueeze(-1).expand(B, K, D)
|
|
128
|
+
Q = torch.gather(patch_tokens, dim=1, index=idx).contiguous()
|
|
129
|
+
K_mat = text_embeds.contiguous()
|
|
130
|
+
|
|
131
|
+
if use_triton and TRITON_AVAILABLE and Q.is_cuda:
|
|
132
|
+
return _sparse_attn_triton(Q, K_mat)
|
|
133
|
+
return _sparse_attn_pytorch(Q, K_mat)
|