adasplash 0.2.0__tar.gz → 0.2.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- adasplash-0.2.2/LICENSE +28 -0
- adasplash-0.2.2/PKG-INFO +305 -0
- adasplash-0.2.2/README.md +270 -0
- {adasplash-0.2.0 → adasplash-0.2.2}/adasplash/__init__.py +0 -17
- {adasplash-0.2.0 → adasplash-0.2.2}/adasplash/adasplash_block_mask.py +18 -0
- {adasplash-0.2.0 → adasplash-0.2.2}/adasplash/adasplash_no_block_mask.py +22 -0
- {adasplash-0.2.0 → adasplash-0.2.2}/adasplash/adasplash_v2.py +23 -0
- adasplash-0.2.2/adasplash.egg-info/PKG-INFO +305 -0
- {adasplash-0.2.0 → adasplash-0.2.2}/adasplash.egg-info/SOURCES.txt +0 -1
- {adasplash-0.2.0 → adasplash-0.2.2}/setup.py +2 -2
- {adasplash-0.2.0 → adasplash-0.2.2}/tests/test_attention.py +23 -4
- {adasplash-0.2.0 → adasplash-0.2.2}/tests/test_public_api.py +1 -13
- adasplash-0.2.0/LICENSE +0 -21
- adasplash-0.2.0/PKG-INFO +0 -176
- adasplash-0.2.0/README.md +0 -141
- adasplash-0.2.0/adasplash/attention.py +0 -73
- adasplash-0.2.0/adasplash.egg-info/PKG-INFO +0 -176
- {adasplash-0.2.0 → adasplash-0.2.2}/adasplash/triton_entmax.py +0 -0
- {adasplash-0.2.0 → adasplash-0.2.2}/adasplash/triton_entmax_v2.py +0 -0
- {adasplash-0.2.0 → adasplash-0.2.2}/adasplash.egg-info/dependency_links.txt +0 -0
- {adasplash-0.2.0 → adasplash-0.2.2}/adasplash.egg-info/requires.txt +0 -0
- {adasplash-0.2.0 → adasplash-0.2.2}/adasplash.egg-info/top_level.txt +0 -0
- {adasplash-0.2.0 → adasplash-0.2.2}/pyproject.toml +0 -0
- {adasplash-0.2.0 → adasplash-0.2.2}/setup.cfg +0 -0
- {adasplash-0.2.0 → adasplash-0.2.2}/tests/test_adasplash.py +0 -0
- {adasplash-0.2.0 → adasplash-0.2.2}/tests/test_adasplash_no_block_mask.py +0 -0
- {adasplash-0.2.0 → adasplash-0.2.2}/tests/test_adasplash_v2.py +0 -0
- {adasplash-0.2.0 → adasplash-0.2.2}/tests/test_triton_entmax.py +0 -0
adasplash-0.2.2/LICENSE
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
BSD 3-Clause License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025, Sardine LAB
|
|
4
|
+
|
|
5
|
+
Redistribution and use in source and binary forms, with or without
|
|
6
|
+
modification, are permitted provided that the following conditions are met:
|
|
7
|
+
|
|
8
|
+
1. Redistributions of source code must retain the above copyright notice, this
|
|
9
|
+
list of conditions and the following disclaimer.
|
|
10
|
+
|
|
11
|
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
|
12
|
+
this list of conditions and the following disclaimer in the documentation
|
|
13
|
+
and/or other materials provided with the distribution.
|
|
14
|
+
|
|
15
|
+
3. Neither the name of the copyright holder nor the names of its contributors
|
|
16
|
+
may be used to endorse or promote products derived from this software without
|
|
17
|
+
specific prior written permission.
|
|
18
|
+
|
|
19
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
20
|
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
21
|
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
22
|
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
23
|
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
24
|
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
25
|
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
26
|
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
27
|
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
28
|
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
adasplash-0.2.2/PKG-INFO
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
Metadata-Version: 2.2
|
|
2
|
+
Name: adasplash
|
|
3
|
+
Version: 0.2.2
|
|
4
|
+
Summary: AdaSplash: Efficient Adaptive Sparse Attention in Triton
|
|
5
|
+
Home-page: https://github.com/deep-spin/adasplash
|
|
6
|
+
Author: Nuno Gonçalves, Marcos Treviso
|
|
7
|
+
Author-email: marcosvtreviso@gmail.com
|
|
8
|
+
Classifier: Programming Language :: Python :: 3
|
|
9
|
+
Classifier: License :: OSI Approved :: BSD License
|
|
10
|
+
Classifier: Operating System :: OS Independent
|
|
11
|
+
Requires-Python: >=3.8
|
|
12
|
+
Description-Content-Type: text/markdown
|
|
13
|
+
License-File: LICENSE
|
|
14
|
+
Requires-Dist: numpy
|
|
15
|
+
Requires-Dist: torch
|
|
16
|
+
Requires-Dist: triton
|
|
17
|
+
Provides-Extra: dev
|
|
18
|
+
Requires-Dist: pytest; extra == "dev"
|
|
19
|
+
Requires-Dist: black; extra == "dev"
|
|
20
|
+
Requires-Dist: isort; extra == "dev"
|
|
21
|
+
Requires-Dist: flake8; extra == "dev"
|
|
22
|
+
Requires-Dist: entmax; extra == "dev"
|
|
23
|
+
Requires-Dist: build; extra == "dev"
|
|
24
|
+
Requires-Dist: twine; extra == "dev"
|
|
25
|
+
Dynamic: author
|
|
26
|
+
Dynamic: author-email
|
|
27
|
+
Dynamic: classifier
|
|
28
|
+
Dynamic: description
|
|
29
|
+
Dynamic: description-content-type
|
|
30
|
+
Dynamic: home-page
|
|
31
|
+
Dynamic: provides-extra
|
|
32
|
+
Dynamic: requires-dist
|
|
33
|
+
Dynamic: requires-python
|
|
34
|
+
Dynamic: summary
|
|
35
|
+
|
|
36
|
+
# AdaSplash
|
|
37
|
+
|
|
38
|
+
[](https://pypi.org/project/adasplash/)
|
|
39
|
+
[](https://huggingface.co/sardinelab/)
|
|
40
|
+
[](LICENSE)
|
|
41
|
+
|
|
42
|
+
AdaSplash is a Triton implementation of adaptive sparse attention based on entmax. It exposes the original **AdaSplash** kernels and the newer **🆕 AdaSplash-2** kernels through a backwards-compatible Python API.
|
|
43
|
+
|
|
44
|
+
- **AdaSplash-1**: adaptive sparse flash attention with entmax and block masking.
|
|
45
|
+
- **🆕 AdaSplash-2**: faster differentiable sparse attention based on histogram initialization, hybrid threshold solving, and a v2 sparse causal attention path.
|
|
46
|
+
|
|
47
|
+
Papers:
|
|
48
|
+
|
|
49
|
+
- AdaSplash: [Adaptive Sparse Flash Attention](https://openreview.net/forum?id=OWIPDWhUcO)
|
|
50
|
+
- AdaSplash-2: [Faster Differentiable Sparse Attention](https://openreview.net/forum?id=7qpvff2gWI)
|
|
51
|
+
|
|
52
|
+
## Installation
|
|
53
|
+
|
|
54
|
+
Install from PyPI:
|
|
55
|
+
|
|
56
|
+
```bash
|
|
57
|
+
pip install adasplash
|
|
58
|
+
```
|
|
59
|
+
|
|
60
|
+
Or install the latest development version:
|
|
61
|
+
|
|
62
|
+
```bash
|
|
63
|
+
pip install git+https://github.com/deep-spin/adasplash.git
|
|
64
|
+
```
|
|
65
|
+
|
|
66
|
+
AdaSplash requires PyTorch and Triton. CUDA is required to run the Triton kernels, but package imports are lazy, so `import adasplash` works without an active CUDA driver.
|
|
67
|
+
|
|
68
|
+
## What Is New In 🆕 AdaSplash-2?
|
|
69
|
+
|
|
70
|
+
AdaSplash-2 keeps the same sparse-entmax motivation as AdaSplash-1, but improves the core threshold and attention path:
|
|
71
|
+
|
|
72
|
+
- **Histogram initialization** for faster entmax threshold estimates.
|
|
73
|
+
- **Hybrid solver updates** that combine robust root-finding steps for the entmax threshold.
|
|
74
|
+
- **Sparse causal v2 attention** exposed as `adasplash_v2`.
|
|
75
|
+
- **Grouped-query attention support** through different query and key/value head counts.
|
|
76
|
+
- **Variable-length sequence support** with padded rows zeroed in the output.
|
|
77
|
+
- **Convenience v2 entmax APIs** for sparsemax, entmax-1.5, and generic entmax calls.
|
|
78
|
+
|
|
79
|
+
In `adasplash>=0.2.0`, bare `adasplash(q, k, v)` defaults to the AdaSplash-2 path when the call is supported. Explicit v1 entry points remain available for backwards compatibility.
|
|
80
|
+
|
|
81
|
+
## Public API
|
|
82
|
+
|
|
83
|
+
All public functions can be imported from the top-level package:
|
|
84
|
+
|
|
85
|
+
```python
|
|
86
|
+
from adasplash import (
|
|
87
|
+
adasplash,
|
|
88
|
+
adasplash_v1,
|
|
89
|
+
adasplash_v2,
|
|
90
|
+
adasplash_no_block_mask,
|
|
91
|
+
triton_entmax,
|
|
92
|
+
triton_entmax_v1,
|
|
93
|
+
triton_entmax_v2,
|
|
94
|
+
triton_sparsemax,
|
|
95
|
+
triton_entmax15,
|
|
96
|
+
)
|
|
97
|
+
```
|
|
98
|
+
|
|
99
|
+
| Function | Purpose |
|
|
100
|
+
| --- | --- |
|
|
101
|
+
| `adasplash` | Compatibility dispatcher. Uses v2 for supported causal `alpha=1.5` calls and falls back to v1 for v1-only behavior. |
|
|
102
|
+
| `adasplash_v2` | Direct AdaSplash-2 causal sparse attention. |
|
|
103
|
+
| `adasplash_v1` | Direct original AdaSplash block-mask implementation. |
|
|
104
|
+
| `adasplash_no_block_mask` | Original v1 no-block-mask implementation. |
|
|
105
|
+
| `triton_entmax` | Default v2 entmax API. |
|
|
106
|
+
| `triton_entmax_v2` | Direct v2 entmax with histogram and hybrid solver support. |
|
|
107
|
+
| `triton_entmax_v1` | Original entmax implementation. |
|
|
108
|
+
| `triton_sparsemax` | Convenience v2 sparsemax call, equivalent to entmax with `alpha=2.0`. |
|
|
109
|
+
| `triton_entmax15` | Convenience v2 entmax-1.5 call. |
|
|
110
|
+
|
|
111
|
+
## Sparse Attention Examples
|
|
112
|
+
|
|
113
|
+
### Default Dispatcher
|
|
114
|
+
|
|
115
|
+
```python
|
|
116
|
+
import torch
|
|
117
|
+
from adasplash import adasplash
|
|
118
|
+
|
|
119
|
+
q = torch.randn(1, 8, 128, 64, device="cuda")
|
|
120
|
+
k = torch.randn(1, 8, 128, 64, device="cuda")
|
|
121
|
+
v = torch.randn(1, 8, 128, 64, device="cuda")
|
|
122
|
+
|
|
123
|
+
out = adasplash(q, k, v)
|
|
124
|
+
```
|
|
125
|
+
|
|
126
|
+
For supported causal `alpha=1.5` calls, `adasplash` routes to AdaSplash-2. Calls that request v1-only behavior, such as `alpha != 1.5` or `is_causal=False`, route to the v1 implementation.
|
|
127
|
+
|
|
128
|
+
```python
|
|
129
|
+
# Uses AdaSplash-2.
|
|
130
|
+
out_v2_default = adasplash(q, k, v, is_causal=True)
|
|
131
|
+
|
|
132
|
+
# Uses the v1 compatibility path.
|
|
133
|
+
out_v1_alpha = adasplash(q, k, v, alpha=1.333)
|
|
134
|
+
out_v1_noncausal = adasplash(q, k, v, is_causal=False)
|
|
135
|
+
```
|
|
136
|
+
|
|
137
|
+
### Explicit V1 And V2 Calls
|
|
138
|
+
|
|
139
|
+
```python
|
|
140
|
+
from adasplash import adasplash_v1, adasplash_v2
|
|
141
|
+
|
|
142
|
+
out_v1 = adasplash_v1(q, k, v, alpha=1.5, is_causal=True, niter=10)
|
|
143
|
+
out_v2 = adasplash_v2(q, k, v, niter=1)
|
|
144
|
+
```
|
|
145
|
+
|
|
146
|
+
### Variable-Length Sequences
|
|
147
|
+
|
|
148
|
+
```python
|
|
149
|
+
from adasplash import adasplash
|
|
150
|
+
|
|
151
|
+
varlen = torch.tensor([96], device="cuda", dtype=torch.int32)
|
|
152
|
+
out = adasplash(q, k, v, varlen=varlen)
|
|
153
|
+
```
|
|
154
|
+
|
|
155
|
+
Rows beyond each valid sequence length are zeroed in the output.
|
|
156
|
+
|
|
157
|
+
### Grouped-Query Attention With AdaSplash-2
|
|
158
|
+
|
|
159
|
+
```python
|
|
160
|
+
from adasplash import adasplash_v2
|
|
161
|
+
|
|
162
|
+
q = torch.randn(1, 8, 256, 64, device="cuda")
|
|
163
|
+
k = torch.randn(1, 2, 256, 64, device="cuda")
|
|
164
|
+
v = torch.randn(1, 2, 256, 64, device="cuda")
|
|
165
|
+
|
|
166
|
+
out = adasplash_v2(q, k, v)
|
|
167
|
+
```
|
|
168
|
+
|
|
169
|
+
`adasplash_v2` supports grouped-query attention when the number of query heads is divisible by the number of key/value heads.
|
|
170
|
+
|
|
171
|
+
## Triton Entmax Examples
|
|
172
|
+
|
|
173
|
+
```python
|
|
174
|
+
import torch
|
|
175
|
+
from adasplash import triton_entmax, triton_entmax_v1, triton_sparsemax, triton_entmax15
|
|
176
|
+
|
|
177
|
+
x = torch.randn(128, 256, device="cuda")
|
|
178
|
+
|
|
179
|
+
y = triton_entmax(x, alpha=1.5, n_iter=2, use_histogram=True)
|
|
180
|
+
y_v1 = triton_entmax_v1(x, alpha=1.5, n_iter=10, fast_math=True)
|
|
181
|
+
y_sparsemax = triton_sparsemax(x)
|
|
182
|
+
y_entmax15 = triton_entmax15(x)
|
|
183
|
+
```
|
|
184
|
+
|
|
185
|
+
`triton_entmax` points to the v2 implementation in `0.2.0`. Strict v1 users should call `triton_entmax_v1`.
|
|
186
|
+
|
|
187
|
+
For generic alpha values other than `1.5` and `2.0`, v2 disables histogram initialization internally and uses more refinement iterations for correctness.
|
|
188
|
+
|
|
189
|
+
## Attention Examples
|
|
190
|
+
|
|
191
|
+
The `examples/attention.py` file contains two small helpers that show the difference between the fused AdaSplash kernel and a dense reference-style implementation.
|
|
192
|
+
|
|
193
|
+
### Flash Entmax Attention
|
|
194
|
+
|
|
195
|
+
```python
|
|
196
|
+
from examples.attention import flash_entmax_attention
|
|
197
|
+
|
|
198
|
+
out = flash_entmax_attention(q, k, v, is_causal=True)
|
|
199
|
+
```
|
|
200
|
+
|
|
201
|
+
`flash_entmax_attention` is a thin example wrapper around `adasplash`, the actual fused flash entmax attention path.
|
|
202
|
+
|
|
203
|
+
### Slow Dense Entmax Attention
|
|
204
|
+
|
|
205
|
+
```python
|
|
206
|
+
from examples.attention import slow_entmax_attention
|
|
207
|
+
|
|
208
|
+
out = slow_entmax_attention(q, k, v, is_causal=True, padding="right")
|
|
209
|
+
```
|
|
210
|
+
|
|
211
|
+
`slow_entmax_attention` materializes dense attention scores and applies `triton_entmax`. It is useful for examples and small correctness checks, but it is not the AdaSplash flash kernel and should not be used for long contexts.
|
|
212
|
+
|
|
213
|
+
## Backwards Compatibility
|
|
214
|
+
|
|
215
|
+
AdaSplash `0.2.0` changes the default direction of the top-level APIs:
|
|
216
|
+
|
|
217
|
+
- `adasplash(q, k, v)` uses AdaSplash-2 when the call is supported.
|
|
218
|
+
- `triton_entmax(x)` uses the v2 entmax implementation.
|
|
219
|
+
- Explicit v1 names are preserved: `adasplash_v1`, `adasplash_no_block_mask`, and `triton_entmax_v1`.
|
|
220
|
+
|
|
221
|
+
If you need strict AdaSplash-1 behavior, use the `_v1` names directly.
|
|
222
|
+
|
|
223
|
+
## Testing
|
|
224
|
+
|
|
225
|
+
Install development dependencies:
|
|
226
|
+
|
|
227
|
+
```bash
|
|
228
|
+
pip install -r requirements-dev.txt
|
|
229
|
+
```
|
|
230
|
+
|
|
231
|
+
No-GPU import and public API checks:
|
|
232
|
+
|
|
233
|
+
```bash
|
|
234
|
+
TRITON_INTERPRET=1 pytest -q
|
|
235
|
+
```
|
|
236
|
+
|
|
237
|
+
Fast CUDA suite:
|
|
238
|
+
|
|
239
|
+
```bash
|
|
240
|
+
pytest -q -m "not slow and not stress"
|
|
241
|
+
```
|
|
242
|
+
|
|
243
|
+
Slow CUDA correctness suite:
|
|
244
|
+
|
|
245
|
+
```bash
|
|
246
|
+
pytest -q -m "slow"
|
|
247
|
+
```
|
|
248
|
+
|
|
249
|
+
Stress tests run dense reference checks for long contexts and may skip cases when the current GPU does not have enough memory or shared memory:
|
|
250
|
+
|
|
251
|
+
```bash
|
|
252
|
+
pytest -q -m "stress"
|
|
253
|
+
```
|
|
254
|
+
|
|
255
|
+
## Benchmarks And Models
|
|
256
|
+
|
|
257
|
+
### Efficiency AdaSplash-1
|
|
258
|
+

|
|
259
|
+
|
|
260
|
+
### Efficiency AdaSplash-2
|
|
261
|
+

|
|
262
|
+
|
|
263
|
+
### Sparse Models
|
|
264
|
+
|
|
265
|
+
Sparse models and related artifacts are hosted under the [SARDINE Lab Hugging Face organization](https://huggingface.co/sardinelab/).
|
|
266
|
+
|
|
267
|
+
For single-vector retrieval experiments, see the [Sparse ModernBERT repository](https://github.com/deep-spin/SparseModernBERT).
|
|
268
|
+
|
|
269
|
+
## Citation
|
|
270
|
+
|
|
271
|
+
If you use AdaSplash in your research, please cite the relevant paper.
|
|
272
|
+
|
|
273
|
+
AdaSplash-1:
|
|
274
|
+
|
|
275
|
+
```bibtex
|
|
276
|
+
@inproceedings{
|
|
277
|
+
goncalves2025adasplash,
|
|
278
|
+
title={AdaSplash: Adaptive Sparse Flash Attention},
|
|
279
|
+
author={Nuno Gon{\c{c}}alves and Marcos V Treviso and Andre Martins},
|
|
280
|
+
booktitle={Forty-second International Conference on Machine Learning},
|
|
281
|
+
year={2025},
|
|
282
|
+
url={https://openreview.net/forum?id=OWIPDWhUcO}
|
|
283
|
+
}
|
|
284
|
+
```
|
|
285
|
+
|
|
286
|
+
AdaSplash-2:
|
|
287
|
+
|
|
288
|
+
```bibtex
|
|
289
|
+
@inproceedings{
|
|
290
|
+
goncalves2026adasplash,
|
|
291
|
+
title={AdaSplash-2: Faster Differentiable Sparse Attention},
|
|
292
|
+
author={Nuno Gon{\c{c}}alves and Hugo Pitorro and Vlad Niculae and Edoardo Ponti and Lei Li and Andre Martins and Marcos V Treviso},
|
|
293
|
+
booktitle={Forty-third International Conference on Machine Learning},
|
|
294
|
+
year={2026},
|
|
295
|
+
url={https://openreview.net/forum?id=7qpvff2gWI}
|
|
296
|
+
}
|
|
297
|
+
```
|
|
298
|
+
|
|
299
|
+
## License
|
|
300
|
+
|
|
301
|
+
AdaSplash is licensed under the BSD-3-Clause License. See the [LICENSE](LICENSE) file for details.
|
|
302
|
+
|
|
303
|
+
## Acknowledgements
|
|
304
|
+
|
|
305
|
+
> We would like to the SARDINE lab team for the helpful discussions. This work was supported by the project DECOLLAGE (ERC-2022-CoG 101088763), by the Portuguese Recovery and Resilience Plan through project C64500888200000055 (Center for Responsible AI), and by FCT/MECI through national funds and when applicable co-funded EU funds under UID/50008: Instituto de Telecomunicações. Vlad Niculae is supported by the Dutch Research Council (NWO) via VI.Veni.212.228. Edoardo M. Ponti is supported by the ERC Starting Grant AToM-FM (101222956).
|
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
# AdaSplash
|
|
2
|
+
|
|
3
|
+
[](https://pypi.org/project/adasplash/)
|
|
4
|
+
[](https://huggingface.co/sardinelab/)
|
|
5
|
+
[](LICENSE)
|
|
6
|
+
|
|
7
|
+
AdaSplash is a Triton implementation of adaptive sparse attention based on entmax. It exposes the original **AdaSplash** kernels and the newer **🆕 AdaSplash-2** kernels through a backwards-compatible Python API.
|
|
8
|
+
|
|
9
|
+
- **AdaSplash-1**: adaptive sparse flash attention with entmax and block masking.
|
|
10
|
+
- **🆕 AdaSplash-2**: faster differentiable sparse attention based on histogram initialization, hybrid threshold solving, and a v2 sparse causal attention path.
|
|
11
|
+
|
|
12
|
+
Papers:
|
|
13
|
+
|
|
14
|
+
- AdaSplash: [Adaptive Sparse Flash Attention](https://openreview.net/forum?id=OWIPDWhUcO)
|
|
15
|
+
- AdaSplash-2: [Faster Differentiable Sparse Attention](https://openreview.net/forum?id=7qpvff2gWI)
|
|
16
|
+
|
|
17
|
+
## Installation
|
|
18
|
+
|
|
19
|
+
Install from PyPI:
|
|
20
|
+
|
|
21
|
+
```bash
|
|
22
|
+
pip install adasplash
|
|
23
|
+
```
|
|
24
|
+
|
|
25
|
+
Or install the latest development version:
|
|
26
|
+
|
|
27
|
+
```bash
|
|
28
|
+
pip install git+https://github.com/deep-spin/adasplash.git
|
|
29
|
+
```
|
|
30
|
+
|
|
31
|
+
AdaSplash requires PyTorch and Triton. CUDA is required to run the Triton kernels, but package imports are lazy, so `import adasplash` works without an active CUDA driver.
|
|
32
|
+
|
|
33
|
+
## What Is New In 🆕 AdaSplash-2?
|
|
34
|
+
|
|
35
|
+
AdaSplash-2 keeps the same sparse-entmax motivation as AdaSplash-1, but improves the core threshold and attention path:
|
|
36
|
+
|
|
37
|
+
- **Histogram initialization** for faster entmax threshold estimates.
|
|
38
|
+
- **Hybrid solver updates** that combine robust root-finding steps for the entmax threshold.
|
|
39
|
+
- **Sparse causal v2 attention** exposed as `adasplash_v2`.
|
|
40
|
+
- **Grouped-query attention support** through different query and key/value head counts.
|
|
41
|
+
- **Variable-length sequence support** with padded rows zeroed in the output.
|
|
42
|
+
- **Convenience v2 entmax APIs** for sparsemax, entmax-1.5, and generic entmax calls.
|
|
43
|
+
|
|
44
|
+
In `adasplash>=0.2.0`, bare `adasplash(q, k, v)` defaults to the AdaSplash-2 path when the call is supported. Explicit v1 entry points remain available for backwards compatibility.
|
|
45
|
+
|
|
46
|
+
## Public API
|
|
47
|
+
|
|
48
|
+
All public functions can be imported from the top-level package:
|
|
49
|
+
|
|
50
|
+
```python
|
|
51
|
+
from adasplash import (
|
|
52
|
+
adasplash,
|
|
53
|
+
adasplash_v1,
|
|
54
|
+
adasplash_v2,
|
|
55
|
+
adasplash_no_block_mask,
|
|
56
|
+
triton_entmax,
|
|
57
|
+
triton_entmax_v1,
|
|
58
|
+
triton_entmax_v2,
|
|
59
|
+
triton_sparsemax,
|
|
60
|
+
triton_entmax15,
|
|
61
|
+
)
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
| Function | Purpose |
|
|
65
|
+
| --- | --- |
|
|
66
|
+
| `adasplash` | Compatibility dispatcher. Uses v2 for supported causal `alpha=1.5` calls and falls back to v1 for v1-only behavior. |
|
|
67
|
+
| `adasplash_v2` | Direct AdaSplash-2 causal sparse attention. |
|
|
68
|
+
| `adasplash_v1` | Direct original AdaSplash block-mask implementation. |
|
|
69
|
+
| `adasplash_no_block_mask` | Original v1 no-block-mask implementation. |
|
|
70
|
+
| `triton_entmax` | Default v2 entmax API. |
|
|
71
|
+
| `triton_entmax_v2` | Direct v2 entmax with histogram and hybrid solver support. |
|
|
72
|
+
| `triton_entmax_v1` | Original entmax implementation. |
|
|
73
|
+
| `triton_sparsemax` | Convenience v2 sparsemax call, equivalent to entmax with `alpha=2.0`. |
|
|
74
|
+
| `triton_entmax15` | Convenience v2 entmax-1.5 call. |
|
|
75
|
+
|
|
76
|
+
## Sparse Attention Examples
|
|
77
|
+
|
|
78
|
+
### Default Dispatcher
|
|
79
|
+
|
|
80
|
+
```python
|
|
81
|
+
import torch
|
|
82
|
+
from adasplash import adasplash
|
|
83
|
+
|
|
84
|
+
q = torch.randn(1, 8, 128, 64, device="cuda")
|
|
85
|
+
k = torch.randn(1, 8, 128, 64, device="cuda")
|
|
86
|
+
v = torch.randn(1, 8, 128, 64, device="cuda")
|
|
87
|
+
|
|
88
|
+
out = adasplash(q, k, v)
|
|
89
|
+
```
|
|
90
|
+
|
|
91
|
+
For supported causal `alpha=1.5` calls, `adasplash` routes to AdaSplash-2. Calls that request v1-only behavior, such as `alpha != 1.5` or `is_causal=False`, route to the v1 implementation.
|
|
92
|
+
|
|
93
|
+
```python
|
|
94
|
+
# Uses AdaSplash-2.
|
|
95
|
+
out_v2_default = adasplash(q, k, v, is_causal=True)
|
|
96
|
+
|
|
97
|
+
# Uses the v1 compatibility path.
|
|
98
|
+
out_v1_alpha = adasplash(q, k, v, alpha=1.333)
|
|
99
|
+
out_v1_noncausal = adasplash(q, k, v, is_causal=False)
|
|
100
|
+
```
|
|
101
|
+
|
|
102
|
+
### Explicit V1 And V2 Calls
|
|
103
|
+
|
|
104
|
+
```python
|
|
105
|
+
from adasplash import adasplash_v1, adasplash_v2
|
|
106
|
+
|
|
107
|
+
out_v1 = adasplash_v1(q, k, v, alpha=1.5, is_causal=True, niter=10)
|
|
108
|
+
out_v2 = adasplash_v2(q, k, v, niter=1)
|
|
109
|
+
```
|
|
110
|
+
|
|
111
|
+
### Variable-Length Sequences
|
|
112
|
+
|
|
113
|
+
```python
|
|
114
|
+
from adasplash import adasplash
|
|
115
|
+
|
|
116
|
+
varlen = torch.tensor([96], device="cuda", dtype=torch.int32)
|
|
117
|
+
out = adasplash(q, k, v, varlen=varlen)
|
|
118
|
+
```
|
|
119
|
+
|
|
120
|
+
Rows beyond each valid sequence length are zeroed in the output.
|
|
121
|
+
|
|
122
|
+
### Grouped-Query Attention With AdaSplash-2
|
|
123
|
+
|
|
124
|
+
```python
|
|
125
|
+
from adasplash import adasplash_v2
|
|
126
|
+
|
|
127
|
+
q = torch.randn(1, 8, 256, 64, device="cuda")
|
|
128
|
+
k = torch.randn(1, 2, 256, 64, device="cuda")
|
|
129
|
+
v = torch.randn(1, 2, 256, 64, device="cuda")
|
|
130
|
+
|
|
131
|
+
out = adasplash_v2(q, k, v)
|
|
132
|
+
```
|
|
133
|
+
|
|
134
|
+
`adasplash_v2` supports grouped-query attention when the number of query heads is divisible by the number of key/value heads.
|
|
135
|
+
|
|
136
|
+
## Triton Entmax Examples
|
|
137
|
+
|
|
138
|
+
```python
|
|
139
|
+
import torch
|
|
140
|
+
from adasplash import triton_entmax, triton_entmax_v1, triton_sparsemax, triton_entmax15
|
|
141
|
+
|
|
142
|
+
x = torch.randn(128, 256, device="cuda")
|
|
143
|
+
|
|
144
|
+
y = triton_entmax(x, alpha=1.5, n_iter=2, use_histogram=True)
|
|
145
|
+
y_v1 = triton_entmax_v1(x, alpha=1.5, n_iter=10, fast_math=True)
|
|
146
|
+
y_sparsemax = triton_sparsemax(x)
|
|
147
|
+
y_entmax15 = triton_entmax15(x)
|
|
148
|
+
```
|
|
149
|
+
|
|
150
|
+
`triton_entmax` points to the v2 implementation in `0.2.0`. Strict v1 users should call `triton_entmax_v1`.
|
|
151
|
+
|
|
152
|
+
For generic alpha values other than `1.5` and `2.0`, v2 disables histogram initialization internally and uses more refinement iterations for correctness.
|
|
153
|
+
|
|
154
|
+
## Attention Examples
|
|
155
|
+
|
|
156
|
+
The `examples/attention.py` file contains two small helpers that show the difference between the fused AdaSplash kernel and a dense reference-style implementation.
|
|
157
|
+
|
|
158
|
+
### Flash Entmax Attention
|
|
159
|
+
|
|
160
|
+
```python
|
|
161
|
+
from examples.attention import flash_entmax_attention
|
|
162
|
+
|
|
163
|
+
out = flash_entmax_attention(q, k, v, is_causal=True)
|
|
164
|
+
```
|
|
165
|
+
|
|
166
|
+
`flash_entmax_attention` is a thin example wrapper around `adasplash`, the actual fused flash entmax attention path.
|
|
167
|
+
|
|
168
|
+
### Slow Dense Entmax Attention
|
|
169
|
+
|
|
170
|
+
```python
|
|
171
|
+
from examples.attention import slow_entmax_attention
|
|
172
|
+
|
|
173
|
+
out = slow_entmax_attention(q, k, v, is_causal=True, padding="right")
|
|
174
|
+
```
|
|
175
|
+
|
|
176
|
+
`slow_entmax_attention` materializes dense attention scores and applies `triton_entmax`. It is useful for examples and small correctness checks, but it is not the AdaSplash flash kernel and should not be used for long contexts.
|
|
177
|
+
|
|
178
|
+
## Backwards Compatibility
|
|
179
|
+
|
|
180
|
+
AdaSplash `0.2.0` changes the default direction of the top-level APIs:
|
|
181
|
+
|
|
182
|
+
- `adasplash(q, k, v)` uses AdaSplash-2 when the call is supported.
|
|
183
|
+
- `triton_entmax(x)` uses the v2 entmax implementation.
|
|
184
|
+
- Explicit v1 names are preserved: `adasplash_v1`, `adasplash_no_block_mask`, and `triton_entmax_v1`.
|
|
185
|
+
|
|
186
|
+
If you need strict AdaSplash-1 behavior, use the `_v1` names directly.
|
|
187
|
+
|
|
188
|
+
## Testing
|
|
189
|
+
|
|
190
|
+
Install development dependencies:
|
|
191
|
+
|
|
192
|
+
```bash
|
|
193
|
+
pip install -r requirements-dev.txt
|
|
194
|
+
```
|
|
195
|
+
|
|
196
|
+
No-GPU import and public API checks:
|
|
197
|
+
|
|
198
|
+
```bash
|
|
199
|
+
TRITON_INTERPRET=1 pytest -q
|
|
200
|
+
```
|
|
201
|
+
|
|
202
|
+
Fast CUDA suite:
|
|
203
|
+
|
|
204
|
+
```bash
|
|
205
|
+
pytest -q -m "not slow and not stress"
|
|
206
|
+
```
|
|
207
|
+
|
|
208
|
+
Slow CUDA correctness suite:
|
|
209
|
+
|
|
210
|
+
```bash
|
|
211
|
+
pytest -q -m "slow"
|
|
212
|
+
```
|
|
213
|
+
|
|
214
|
+
Stress tests run dense reference checks for long contexts and may skip cases when the current GPU does not have enough memory or shared memory:
|
|
215
|
+
|
|
216
|
+
```bash
|
|
217
|
+
pytest -q -m "stress"
|
|
218
|
+
```
|
|
219
|
+
|
|
220
|
+
## Benchmarks And Models
|
|
221
|
+
|
|
222
|
+
### Efficiency AdaSplash-1
|
|
223
|
+

|
|
224
|
+
|
|
225
|
+
### Efficiency AdaSplash-2
|
|
226
|
+

|
|
227
|
+
|
|
228
|
+
### Sparse Models
|
|
229
|
+
|
|
230
|
+
Sparse models and related artifacts are hosted under the [SARDINE Lab Hugging Face organization](https://huggingface.co/sardinelab/).
|
|
231
|
+
|
|
232
|
+
For single-vector retrieval experiments, see the [Sparse ModernBERT repository](https://github.com/deep-spin/SparseModernBERT).
|
|
233
|
+
|
|
234
|
+
## Citation
|
|
235
|
+
|
|
236
|
+
If you use AdaSplash in your research, please cite the relevant paper.
|
|
237
|
+
|
|
238
|
+
AdaSplash-1:
|
|
239
|
+
|
|
240
|
+
```bibtex
|
|
241
|
+
@inproceedings{
|
|
242
|
+
goncalves2025adasplash,
|
|
243
|
+
title={AdaSplash: Adaptive Sparse Flash Attention},
|
|
244
|
+
author={Nuno Gon{\c{c}}alves and Marcos V Treviso and Andre Martins},
|
|
245
|
+
booktitle={Forty-second International Conference on Machine Learning},
|
|
246
|
+
year={2025},
|
|
247
|
+
url={https://openreview.net/forum?id=OWIPDWhUcO}
|
|
248
|
+
}
|
|
249
|
+
```
|
|
250
|
+
|
|
251
|
+
AdaSplash-2:
|
|
252
|
+
|
|
253
|
+
```bibtex
|
|
254
|
+
@inproceedings{
|
|
255
|
+
goncalves2026adasplash,
|
|
256
|
+
title={AdaSplash-2: Faster Differentiable Sparse Attention},
|
|
257
|
+
author={Nuno Gon{\c{c}}alves and Hugo Pitorro and Vlad Niculae and Edoardo Ponti and Lei Li and Andre Martins and Marcos V Treviso},
|
|
258
|
+
booktitle={Forty-third International Conference on Machine Learning},
|
|
259
|
+
year={2026},
|
|
260
|
+
url={https://openreview.net/forum?id=7qpvff2gWI}
|
|
261
|
+
}
|
|
262
|
+
```
|
|
263
|
+
|
|
264
|
+
## License
|
|
265
|
+
|
|
266
|
+
AdaSplash is licensed under the BSD-3-Clause License. See the [LICENSE](LICENSE) file for details.
|
|
267
|
+
|
|
268
|
+
## Acknowledgements
|
|
269
|
+
|
|
270
|
+
> We would like to the SARDINE lab team for the helpful discussions. This work was supported by the project DECOLLAGE (ERC-2022-CoG 101088763), by the Portuguese Recovery and Resilience Plan through project C64500888200000055 (Center for Responsible AI), and by FCT/MECI through national funds and when applicable co-funded EU funds under UID/50008: Instituto de Telecomunicações. Vlad Niculae is supported by the Dutch Research Council (NWO) via VI.Veni.212.228. Edoardo M. Ponti is supported by the ERC Starting Grant AToM-FM (101222956).
|
|
@@ -68,22 +68,6 @@ def triton_entmax15(x, **kwargs):
|
|
|
68
68
|
return triton_entmax15(x, **kwargs)
|
|
69
69
|
|
|
70
70
|
|
|
71
|
-
def entmax_attention(q, k, v, alpha=1.5, varlen=None, is_causal=False, padding="right", niter=2, alibi_slopes=None):
|
|
72
|
-
from .attention import entmax_attention as _entmax_attention
|
|
73
|
-
|
|
74
|
-
return _entmax_attention(
|
|
75
|
-
q,
|
|
76
|
-
k,
|
|
77
|
-
v,
|
|
78
|
-
alpha=alpha,
|
|
79
|
-
varlen=varlen,
|
|
80
|
-
is_causal=is_causal,
|
|
81
|
-
padding=padding,
|
|
82
|
-
niter=niter,
|
|
83
|
-
alibi_slopes=alibi_slopes,
|
|
84
|
-
)
|
|
85
|
-
|
|
86
|
-
|
|
87
71
|
adasplash2 = _adasplash_v2
|
|
88
72
|
|
|
89
73
|
__all__ = [
|
|
@@ -98,5 +82,4 @@ __all__ = [
|
|
|
98
82
|
"triton_entmax_v2",
|
|
99
83
|
"triton_sparsemax",
|
|
100
84
|
"triton_entmax15",
|
|
101
|
-
"entmax_attention",
|
|
102
85
|
]
|
|
@@ -1211,4 +1211,22 @@ class _sparse_attention(torch.autograd.Function):
|
|
|
1211
1211
|
|
|
1212
1212
|
# TODO: Support the post_niter parameter.
|
|
1213
1213
|
def sparse_attn(q, k, v, alpha=1.5, is_causal=False, varlen=None, niter=10):
|
|
1214
|
+
"""Run the original AdaSplash sparse attention kernel with a block mask.
|
|
1215
|
+
|
|
1216
|
+
Args:
|
|
1217
|
+
q: Query tensor with shape ``(batch, n_heads, seq_len, head_dim)``.
|
|
1218
|
+
k: Key tensor with the same shape as ``q``.
|
|
1219
|
+
v: Value tensor with the same shape as ``q``.
|
|
1220
|
+
alpha: Entmax alpha. Values in ``(1, 2]`` are supported by the v1
|
|
1221
|
+
solver; ``1.5`` is the default.
|
|
1222
|
+
is_causal: If true, apply a lower-triangular causal attention mask.
|
|
1223
|
+
varlen: Optional integer tensor of shape ``(batch,)`` containing the
|
|
1224
|
+
valid sequence length for each batch item. Padded output rows are
|
|
1225
|
+
zeroed.
|
|
1226
|
+
niter: Number of iterations used by the entmax threshold solver.
|
|
1227
|
+
|
|
1228
|
+
Returns:
|
|
1229
|
+
Tensor with the same shape and dtype as ``q`` containing the attended
|
|
1230
|
+
values.
|
|
1231
|
+
"""
|
|
1214
1232
|
return _sparse_attention.apply(q, k, v, alpha, is_causal, varlen, niter)
|
|
@@ -1079,6 +1079,28 @@ class _sparse_attention(torch.autograd.Function):
|
|
|
1079
1079
|
|
|
1080
1080
|
|
|
1081
1081
|
def sparse_attn(q, k, v, alpha=1.5, is_causal=False, varlen=None, niter=10):
|
|
1082
|
+
"""Run the v1 AdaSplash attention kernel without block-mask pruning.
|
|
1083
|
+
|
|
1084
|
+
This variant preserves the original v1 semantics and entmax alpha support,
|
|
1085
|
+
but skips the block-mask sparsity prepass. It is useful as a compatibility
|
|
1086
|
+
path and as a simpler reference-style Triton implementation.
|
|
1087
|
+
|
|
1088
|
+
Args:
|
|
1089
|
+
q: Query tensor with shape ``(batch, n_heads, seq_len, head_dim)``.
|
|
1090
|
+
k: Key tensor with the same shape as ``q``.
|
|
1091
|
+
v: Value tensor with the same shape as ``q``.
|
|
1092
|
+
alpha: Entmax alpha. Values in ``(1, 2]`` are supported by the v1
|
|
1093
|
+
solver; ``1.5`` is the default.
|
|
1094
|
+
is_causal: If true, apply a lower-triangular causal attention mask.
|
|
1095
|
+
varlen: Optional integer tensor of shape ``(batch,)`` containing the
|
|
1096
|
+
valid sequence length for each batch item. Padded output rows are
|
|
1097
|
+
zeroed.
|
|
1098
|
+
niter: Number of iterations used by the entmax threshold solver.
|
|
1099
|
+
|
|
1100
|
+
Returns:
|
|
1101
|
+
Tensor with the same shape and dtype as ``q`` containing the attended
|
|
1102
|
+
values.
|
|
1103
|
+
"""
|
|
1082
1104
|
return _sparse_attention.apply(q, k, v, alpha, is_causal, varlen, niter)
|
|
1083
1105
|
|
|
1084
1106
|
|