adasplash 0.2.0__tar.gz → 0.2.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. adasplash-0.2.2/LICENSE +28 -0
  2. adasplash-0.2.2/PKG-INFO +305 -0
  3. adasplash-0.2.2/README.md +270 -0
  4. {adasplash-0.2.0 → adasplash-0.2.2}/adasplash/__init__.py +0 -17
  5. {adasplash-0.2.0 → adasplash-0.2.2}/adasplash/adasplash_block_mask.py +18 -0
  6. {adasplash-0.2.0 → adasplash-0.2.2}/adasplash/adasplash_no_block_mask.py +22 -0
  7. {adasplash-0.2.0 → adasplash-0.2.2}/adasplash/adasplash_v2.py +23 -0
  8. adasplash-0.2.2/adasplash.egg-info/PKG-INFO +305 -0
  9. {adasplash-0.2.0 → adasplash-0.2.2}/adasplash.egg-info/SOURCES.txt +0 -1
  10. {adasplash-0.2.0 → adasplash-0.2.2}/setup.py +2 -2
  11. {adasplash-0.2.0 → adasplash-0.2.2}/tests/test_attention.py +23 -4
  12. {adasplash-0.2.0 → adasplash-0.2.2}/tests/test_public_api.py +1 -13
  13. adasplash-0.2.0/LICENSE +0 -21
  14. adasplash-0.2.0/PKG-INFO +0 -176
  15. adasplash-0.2.0/README.md +0 -141
  16. adasplash-0.2.0/adasplash/attention.py +0 -73
  17. adasplash-0.2.0/adasplash.egg-info/PKG-INFO +0 -176
  18. {adasplash-0.2.0 → adasplash-0.2.2}/adasplash/triton_entmax.py +0 -0
  19. {adasplash-0.2.0 → adasplash-0.2.2}/adasplash/triton_entmax_v2.py +0 -0
  20. {adasplash-0.2.0 → adasplash-0.2.2}/adasplash.egg-info/dependency_links.txt +0 -0
  21. {adasplash-0.2.0 → adasplash-0.2.2}/adasplash.egg-info/requires.txt +0 -0
  22. {adasplash-0.2.0 → adasplash-0.2.2}/adasplash.egg-info/top_level.txt +0 -0
  23. {adasplash-0.2.0 → adasplash-0.2.2}/pyproject.toml +0 -0
  24. {adasplash-0.2.0 → adasplash-0.2.2}/setup.cfg +0 -0
  25. {adasplash-0.2.0 → adasplash-0.2.2}/tests/test_adasplash.py +0 -0
  26. {adasplash-0.2.0 → adasplash-0.2.2}/tests/test_adasplash_no_block_mask.py +0 -0
  27. {adasplash-0.2.0 → adasplash-0.2.2}/tests/test_adasplash_v2.py +0 -0
  28. {adasplash-0.2.0 → adasplash-0.2.2}/tests/test_triton_entmax.py +0 -0
@@ -0,0 +1,28 @@
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2025, Sardine LAB
4
+
5
+ Redistribution and use in source and binary forms, with or without
6
+ modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation
13
+ and/or other materials provided with the distribution.
14
+
15
+ 3. Neither the name of the copyright holder nor the names of its contributors
16
+ may be used to endorse or promote products derived from this software without
17
+ specific prior written permission.
18
+
19
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
@@ -0,0 +1,305 @@
1
+ Metadata-Version: 2.2
2
+ Name: adasplash
3
+ Version: 0.2.2
4
+ Summary: AdaSplash: Efficient Adaptive Sparse Attention in Triton
5
+ Home-page: https://github.com/deep-spin/adasplash
6
+ Author: Nuno Gonçalves, Marcos Treviso
7
+ Author-email: marcosvtreviso@gmail.com
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: License :: OSI Approved :: BSD License
10
+ Classifier: Operating System :: OS Independent
11
+ Requires-Python: >=3.8
12
+ Description-Content-Type: text/markdown
13
+ License-File: LICENSE
14
+ Requires-Dist: numpy
15
+ Requires-Dist: torch
16
+ Requires-Dist: triton
17
+ Provides-Extra: dev
18
+ Requires-Dist: pytest; extra == "dev"
19
+ Requires-Dist: black; extra == "dev"
20
+ Requires-Dist: isort; extra == "dev"
21
+ Requires-Dist: flake8; extra == "dev"
22
+ Requires-Dist: entmax; extra == "dev"
23
+ Requires-Dist: build; extra == "dev"
24
+ Requires-Dist: twine; extra == "dev"
25
+ Dynamic: author
26
+ Dynamic: author-email
27
+ Dynamic: classifier
28
+ Dynamic: description
29
+ Dynamic: description-content-type
30
+ Dynamic: home-page
31
+ Dynamic: provides-extra
32
+ Dynamic: requires-dist
33
+ Dynamic: requires-python
34
+ Dynamic: summary
35
+
36
+ # AdaSplash
37
+
38
+ [![PyPI version](https://img.shields.io/pypi/v/adasplash.svg)](https://pypi.org/project/adasplash/)
39
+ [![Hugging Face models](https://img.shields.io/badge/Hugging%20Face-sardinelab-ffcc4d?logo=huggingface&logoColor=black)](https://huggingface.co/sardinelab/)
40
+ [![License: BSD-3-Clause](https://img.shields.io/badge/License-BSD--3--Clause-blue.svg)](LICENSE)
41
+
42
+ AdaSplash is a Triton implementation of adaptive sparse attention based on entmax. It exposes the original **AdaSplash** kernels and the newer **🆕 AdaSplash-2** kernels through a backwards-compatible Python API.
43
+
44
+ - **AdaSplash-1**: adaptive sparse flash attention with entmax and block masking.
45
+ - **🆕 AdaSplash-2**: faster differentiable sparse attention based on histogram initialization, hybrid threshold solving, and a v2 sparse causal attention path.
46
+
47
+ Papers:
48
+
49
+ - AdaSplash: [Adaptive Sparse Flash Attention](https://openreview.net/forum?id=OWIPDWhUcO)
50
+ - AdaSplash-2: [Faster Differentiable Sparse Attention](https://openreview.net/forum?id=7qpvff2gWI)
51
+
52
+ ## Installation
53
+
54
+ Install from PyPI:
55
+
56
+ ```bash
57
+ pip install adasplash
58
+ ```
59
+
60
+ Or install the latest development version:
61
+
62
+ ```bash
63
+ pip install git+https://github.com/deep-spin/adasplash.git
64
+ ```
65
+
66
+ AdaSplash requires PyTorch and Triton. CUDA is required to run the Triton kernels, but package imports are lazy, so `import adasplash` works without an active CUDA driver.
67
+
68
+ ## What Is New In 🆕 AdaSplash-2?
69
+
70
+ AdaSplash-2 keeps the same sparse-entmax motivation as AdaSplash-1, but improves the core threshold and attention path:
71
+
72
+ - **Histogram initialization** for faster entmax threshold estimates.
73
+ - **Hybrid solver updates** that combine robust root-finding steps for the entmax threshold.
74
+ - **Sparse causal v2 attention** exposed as `adasplash_v2`.
75
+ - **Grouped-query attention support** through different query and key/value head counts.
76
+ - **Variable-length sequence support** with padded rows zeroed in the output.
77
+ - **Convenience v2 entmax APIs** for sparsemax, entmax-1.5, and generic entmax calls.
78
+
79
+ In `adasplash>=0.2.0`, bare `adasplash(q, k, v)` defaults to the AdaSplash-2 path when the call is supported. Explicit v1 entry points remain available for backwards compatibility.
80
+
81
+ ## Public API
82
+
83
+ All public functions can be imported from the top-level package:
84
+
85
+ ```python
86
+ from adasplash import (
87
+ adasplash,
88
+ adasplash_v1,
89
+ adasplash_v2,
90
+ adasplash_no_block_mask,
91
+ triton_entmax,
92
+ triton_entmax_v1,
93
+ triton_entmax_v2,
94
+ triton_sparsemax,
95
+ triton_entmax15,
96
+ )
97
+ ```
98
+
99
+ | Function | Purpose |
100
+ | --- | --- |
101
+ | `adasplash` | Compatibility dispatcher. Uses v2 for supported causal `alpha=1.5` calls and falls back to v1 for v1-only behavior. |
102
+ | `adasplash_v2` | Direct AdaSplash-2 causal sparse attention. |
103
+ | `adasplash_v1` | Direct original AdaSplash block-mask implementation. |
104
+ | `adasplash_no_block_mask` | Original v1 no-block-mask implementation. |
105
+ | `triton_entmax` | Default v2 entmax API. |
106
+ | `triton_entmax_v2` | Direct v2 entmax with histogram and hybrid solver support. |
107
+ | `triton_entmax_v1` | Original entmax implementation. |
108
+ | `triton_sparsemax` | Convenience v2 sparsemax call, equivalent to entmax with `alpha=2.0`. |
109
+ | `triton_entmax15` | Convenience v2 entmax-1.5 call. |
110
+
111
+ ## Sparse Attention Examples
112
+
113
+ ### Default Dispatcher
114
+
115
+ ```python
116
+ import torch
117
+ from adasplash import adasplash
118
+
119
+ q = torch.randn(1, 8, 128, 64, device="cuda")
120
+ k = torch.randn(1, 8, 128, 64, device="cuda")
121
+ v = torch.randn(1, 8, 128, 64, device="cuda")
122
+
123
+ out = adasplash(q, k, v)
124
+ ```
125
+
126
+ For supported causal `alpha=1.5` calls, `adasplash` routes to AdaSplash-2. Calls that request v1-only behavior, such as `alpha != 1.5` or `is_causal=False`, route to the v1 implementation.
127
+
128
+ ```python
129
+ # Uses AdaSplash-2.
130
+ out_v2_default = adasplash(q, k, v, is_causal=True)
131
+
132
+ # Uses the v1 compatibility path.
133
+ out_v1_alpha = adasplash(q, k, v, alpha=1.333)
134
+ out_v1_noncausal = adasplash(q, k, v, is_causal=False)
135
+ ```
136
+
137
+ ### Explicit V1 And V2 Calls
138
+
139
+ ```python
140
+ from adasplash import adasplash_v1, adasplash_v2
141
+
142
+ out_v1 = adasplash_v1(q, k, v, alpha=1.5, is_causal=True, niter=10)
143
+ out_v2 = adasplash_v2(q, k, v, niter=1)
144
+ ```
145
+
146
+ ### Variable-Length Sequences
147
+
148
+ ```python
149
+ from adasplash import adasplash
150
+
151
+ varlen = torch.tensor([96], device="cuda", dtype=torch.int32)
152
+ out = adasplash(q, k, v, varlen=varlen)
153
+ ```
154
+
155
+ Rows beyond each valid sequence length are zeroed in the output.
156
+
157
+ ### Grouped-Query Attention With AdaSplash-2
158
+
159
+ ```python
160
+ from adasplash import adasplash_v2
161
+
162
+ q = torch.randn(1, 8, 256, 64, device="cuda")
163
+ k = torch.randn(1, 2, 256, 64, device="cuda")
164
+ v = torch.randn(1, 2, 256, 64, device="cuda")
165
+
166
+ out = adasplash_v2(q, k, v)
167
+ ```
168
+
169
+ `adasplash_v2` supports grouped-query attention when the number of query heads is divisible by the number of key/value heads.
170
+
171
+ ## Triton Entmax Examples
172
+
173
+ ```python
174
+ import torch
175
+ from adasplash import triton_entmax, triton_entmax_v1, triton_sparsemax, triton_entmax15
176
+
177
+ x = torch.randn(128, 256, device="cuda")
178
+
179
+ y = triton_entmax(x, alpha=1.5, n_iter=2, use_histogram=True)
180
+ y_v1 = triton_entmax_v1(x, alpha=1.5, n_iter=10, fast_math=True)
181
+ y_sparsemax = triton_sparsemax(x)
182
+ y_entmax15 = triton_entmax15(x)
183
+ ```
184
+
185
+ `triton_entmax` points to the v2 implementation in `0.2.0`. Strict v1 users should call `triton_entmax_v1`.
186
+
187
+ For generic alpha values other than `1.5` and `2.0`, v2 disables histogram initialization internally and uses more refinement iterations for correctness.
188
+
189
+ ## Attention Examples
190
+
191
+ The `examples/attention.py` file contains two small helpers that show the difference between the fused AdaSplash kernel and a dense reference-style implementation.
192
+
193
+ ### Flash Entmax Attention
194
+
195
+ ```python
196
+ from examples.attention import flash_entmax_attention
197
+
198
+ out = flash_entmax_attention(q, k, v, is_causal=True)
199
+ ```
200
+
201
+ `flash_entmax_attention` is a thin example wrapper around `adasplash`, the actual fused flash entmax attention path.
202
+
203
+ ### Slow Dense Entmax Attention
204
+
205
+ ```python
206
+ from examples.attention import slow_entmax_attention
207
+
208
+ out = slow_entmax_attention(q, k, v, is_causal=True, padding="right")
209
+ ```
210
+
211
+ `slow_entmax_attention` materializes dense attention scores and applies `triton_entmax`. It is useful for examples and small correctness checks, but it is not the AdaSplash flash kernel and should not be used for long contexts.
212
+
213
+ ## Backwards Compatibility
214
+
215
+ AdaSplash `0.2.0` changes the default direction of the top-level APIs:
216
+
217
+ - `adasplash(q, k, v)` uses AdaSplash-2 when the call is supported.
218
+ - `triton_entmax(x)` uses the v2 entmax implementation.
219
+ - Explicit v1 names are preserved: `adasplash_v1`, `adasplash_no_block_mask`, and `triton_entmax_v1`.
220
+
221
+ If you need strict AdaSplash-1 behavior, use the `_v1` names directly.
222
+
223
+ ## Testing
224
+
225
+ Install development dependencies:
226
+
227
+ ```bash
228
+ pip install -r requirements-dev.txt
229
+ ```
230
+
231
+ No-GPU import and public API checks:
232
+
233
+ ```bash
234
+ TRITON_INTERPRET=1 pytest -q
235
+ ```
236
+
237
+ Fast CUDA suite:
238
+
239
+ ```bash
240
+ pytest -q -m "not slow and not stress"
241
+ ```
242
+
243
+ Slow CUDA correctness suite:
244
+
245
+ ```bash
246
+ pytest -q -m "slow"
247
+ ```
248
+
249
+ Stress tests run dense reference checks for long contexts and may skip cases when the current GPU does not have enough memory or shared memory:
250
+
251
+ ```bash
252
+ pytest -q -m "stress"
253
+ ```
254
+
255
+ ## Benchmarks And Models
256
+
257
+ ### Efficiency AdaSplash-1
258
+ ![Benchmark AdaSplash-1](benchmark.png)
259
+
260
+ ### Efficiency AdaSplash-2
261
+ ![Benchmark AdaSplash-2](benchmark2.png)
262
+
263
+ ### Sparse Models
264
+
265
+ Sparse models and related artifacts are hosted under the [SARDINE Lab Hugging Face organization](https://huggingface.co/sardinelab/).
266
+
267
+ For single-vector retrieval experiments, see the [Sparse ModernBERT repository](https://github.com/deep-spin/SparseModernBERT).
268
+
269
+ ## Citation
270
+
271
+ If you use AdaSplash in your research, please cite the relevant paper.
272
+
273
+ AdaSplash-1:
274
+
275
+ ```bibtex
276
+ @inproceedings{
277
+ goncalves2025adasplash,
278
+ title={AdaSplash: Adaptive Sparse Flash Attention},
279
+ author={Nuno Gon{\c{c}}alves and Marcos V Treviso and Andre Martins},
280
+ booktitle={Forty-second International Conference on Machine Learning},
281
+ year={2025},
282
+ url={https://openreview.net/forum?id=OWIPDWhUcO}
283
+ }
284
+ ```
285
+
286
+ AdaSplash-2:
287
+
288
+ ```bibtex
289
+ @inproceedings{
290
+ goncalves2026adasplash,
291
+ title={AdaSplash-2: Faster Differentiable Sparse Attention},
292
+ author={Nuno Gon{\c{c}}alves and Hugo Pitorro and Vlad Niculae and Edoardo Ponti and Lei Li and Andre Martins and Marcos V Treviso},
293
+ booktitle={Forty-third International Conference on Machine Learning},
294
+ year={2026},
295
+ url={https://openreview.net/forum?id=7qpvff2gWI}
296
+ }
297
+ ```
298
+
299
+ ## License
300
+
301
+ AdaSplash is licensed under the BSD-3-Clause License. See the [LICENSE](LICENSE) file for details.
302
+
303
+ ## Acknowledgements
304
+
305
+ > We would like to the SARDINE lab team for the helpful discussions. This work was supported by the project DECOLLAGE (ERC-2022-CoG 101088763), by the Portuguese Recovery and Resilience Plan through project C64500888200000055 (Center for Responsible AI), and by FCT/MECI through national funds and when applicable co-funded EU funds under UID/50008: Instituto de Telecomunicações. Vlad Niculae is supported by the Dutch Research Council (NWO) via VI.Veni.212.228. Edoardo M. Ponti is supported by the ERC Starting Grant AToM-FM (101222956).
@@ -0,0 +1,270 @@
1
+ # AdaSplash
2
+
3
+ [![PyPI version](https://img.shields.io/pypi/v/adasplash.svg)](https://pypi.org/project/adasplash/)
4
+ [![Hugging Face models](https://img.shields.io/badge/Hugging%20Face-sardinelab-ffcc4d?logo=huggingface&logoColor=black)](https://huggingface.co/sardinelab/)
5
+ [![License: BSD-3-Clause](https://img.shields.io/badge/License-BSD--3--Clause-blue.svg)](LICENSE)
6
+
7
+ AdaSplash is a Triton implementation of adaptive sparse attention based on entmax. It exposes the original **AdaSplash** kernels and the newer **🆕 AdaSplash-2** kernels through a backwards-compatible Python API.
8
+
9
+ - **AdaSplash-1**: adaptive sparse flash attention with entmax and block masking.
10
+ - **🆕 AdaSplash-2**: faster differentiable sparse attention based on histogram initialization, hybrid threshold solving, and a v2 sparse causal attention path.
11
+
12
+ Papers:
13
+
14
+ - AdaSplash: [Adaptive Sparse Flash Attention](https://openreview.net/forum?id=OWIPDWhUcO)
15
+ - AdaSplash-2: [Faster Differentiable Sparse Attention](https://openreview.net/forum?id=7qpvff2gWI)
16
+
17
+ ## Installation
18
+
19
+ Install from PyPI:
20
+
21
+ ```bash
22
+ pip install adasplash
23
+ ```
24
+
25
+ Or install the latest development version:
26
+
27
+ ```bash
28
+ pip install git+https://github.com/deep-spin/adasplash.git
29
+ ```
30
+
31
+ AdaSplash requires PyTorch and Triton. CUDA is required to run the Triton kernels, but package imports are lazy, so `import adasplash` works without an active CUDA driver.
32
+
33
+ ## What Is New In 🆕 AdaSplash-2?
34
+
35
+ AdaSplash-2 keeps the same sparse-entmax motivation as AdaSplash-1, but improves the core threshold and attention path:
36
+
37
+ - **Histogram initialization** for faster entmax threshold estimates.
38
+ - **Hybrid solver updates** that combine robust root-finding steps for the entmax threshold.
39
+ - **Sparse causal v2 attention** exposed as `adasplash_v2`.
40
+ - **Grouped-query attention support** through different query and key/value head counts.
41
+ - **Variable-length sequence support** with padded rows zeroed in the output.
42
+ - **Convenience v2 entmax APIs** for sparsemax, entmax-1.5, and generic entmax calls.
43
+
44
+ In `adasplash>=0.2.0`, bare `adasplash(q, k, v)` defaults to the AdaSplash-2 path when the call is supported. Explicit v1 entry points remain available for backwards compatibility.
45
+
46
+ ## Public API
47
+
48
+ All public functions can be imported from the top-level package:
49
+
50
+ ```python
51
+ from adasplash import (
52
+ adasplash,
53
+ adasplash_v1,
54
+ adasplash_v2,
55
+ adasplash_no_block_mask,
56
+ triton_entmax,
57
+ triton_entmax_v1,
58
+ triton_entmax_v2,
59
+ triton_sparsemax,
60
+ triton_entmax15,
61
+ )
62
+ ```
63
+
64
+ | Function | Purpose |
65
+ | --- | --- |
66
+ | `adasplash` | Compatibility dispatcher. Uses v2 for supported causal `alpha=1.5` calls and falls back to v1 for v1-only behavior. |
67
+ | `adasplash_v2` | Direct AdaSplash-2 causal sparse attention. |
68
+ | `adasplash_v1` | Direct original AdaSplash block-mask implementation. |
69
+ | `adasplash_no_block_mask` | Original v1 no-block-mask implementation. |
70
+ | `triton_entmax` | Default v2 entmax API. |
71
+ | `triton_entmax_v2` | Direct v2 entmax with histogram and hybrid solver support. |
72
+ | `triton_entmax_v1` | Original entmax implementation. |
73
+ | `triton_sparsemax` | Convenience v2 sparsemax call, equivalent to entmax with `alpha=2.0`. |
74
+ | `triton_entmax15` | Convenience v2 entmax-1.5 call. |
75
+
76
+ ## Sparse Attention Examples
77
+
78
+ ### Default Dispatcher
79
+
80
+ ```python
81
+ import torch
82
+ from adasplash import adasplash
83
+
84
+ q = torch.randn(1, 8, 128, 64, device="cuda")
85
+ k = torch.randn(1, 8, 128, 64, device="cuda")
86
+ v = torch.randn(1, 8, 128, 64, device="cuda")
87
+
88
+ out = adasplash(q, k, v)
89
+ ```
90
+
91
+ For supported causal `alpha=1.5` calls, `adasplash` routes to AdaSplash-2. Calls that request v1-only behavior, such as `alpha != 1.5` or `is_causal=False`, route to the v1 implementation.
92
+
93
+ ```python
94
+ # Uses AdaSplash-2.
95
+ out_v2_default = adasplash(q, k, v, is_causal=True)
96
+
97
+ # Uses the v1 compatibility path.
98
+ out_v1_alpha = adasplash(q, k, v, alpha=1.333)
99
+ out_v1_noncausal = adasplash(q, k, v, is_causal=False)
100
+ ```
101
+
102
+ ### Explicit V1 And V2 Calls
103
+
104
+ ```python
105
+ from adasplash import adasplash_v1, adasplash_v2
106
+
107
+ out_v1 = adasplash_v1(q, k, v, alpha=1.5, is_causal=True, niter=10)
108
+ out_v2 = adasplash_v2(q, k, v, niter=1)
109
+ ```
110
+
111
+ ### Variable-Length Sequences
112
+
113
+ ```python
114
+ from adasplash import adasplash
115
+
116
+ varlen = torch.tensor([96], device="cuda", dtype=torch.int32)
117
+ out = adasplash(q, k, v, varlen=varlen)
118
+ ```
119
+
120
+ Rows beyond each valid sequence length are zeroed in the output.
121
+
122
+ ### Grouped-Query Attention With AdaSplash-2
123
+
124
+ ```python
125
+ from adasplash import adasplash_v2
126
+
127
+ q = torch.randn(1, 8, 256, 64, device="cuda")
128
+ k = torch.randn(1, 2, 256, 64, device="cuda")
129
+ v = torch.randn(1, 2, 256, 64, device="cuda")
130
+
131
+ out = adasplash_v2(q, k, v)
132
+ ```
133
+
134
+ `adasplash_v2` supports grouped-query attention when the number of query heads is divisible by the number of key/value heads.
135
+
136
+ ## Triton Entmax Examples
137
+
138
+ ```python
139
+ import torch
140
+ from adasplash import triton_entmax, triton_entmax_v1, triton_sparsemax, triton_entmax15
141
+
142
+ x = torch.randn(128, 256, device="cuda")
143
+
144
+ y = triton_entmax(x, alpha=1.5, n_iter=2, use_histogram=True)
145
+ y_v1 = triton_entmax_v1(x, alpha=1.5, n_iter=10, fast_math=True)
146
+ y_sparsemax = triton_sparsemax(x)
147
+ y_entmax15 = triton_entmax15(x)
148
+ ```
149
+
150
+ `triton_entmax` points to the v2 implementation in `0.2.0`. Strict v1 users should call `triton_entmax_v1`.
151
+
152
+ For generic alpha values other than `1.5` and `2.0`, v2 disables histogram initialization internally and uses more refinement iterations for correctness.
153
+
154
+ ## Attention Examples
155
+
156
+ The `examples/attention.py` file contains two small helpers that show the difference between the fused AdaSplash kernel and a dense reference-style implementation.
157
+
158
+ ### Flash Entmax Attention
159
+
160
+ ```python
161
+ from examples.attention import flash_entmax_attention
162
+
163
+ out = flash_entmax_attention(q, k, v, is_causal=True)
164
+ ```
165
+
166
+ `flash_entmax_attention` is a thin example wrapper around `adasplash`, the actual fused flash entmax attention path.
167
+
168
+ ### Slow Dense Entmax Attention
169
+
170
+ ```python
171
+ from examples.attention import slow_entmax_attention
172
+
173
+ out = slow_entmax_attention(q, k, v, is_causal=True, padding="right")
174
+ ```
175
+
176
+ `slow_entmax_attention` materializes dense attention scores and applies `triton_entmax`. It is useful for examples and small correctness checks, but it is not the AdaSplash flash kernel and should not be used for long contexts.
177
+
178
+ ## Backwards Compatibility
179
+
180
+ AdaSplash `0.2.0` changes the default direction of the top-level APIs:
181
+
182
+ - `adasplash(q, k, v)` uses AdaSplash-2 when the call is supported.
183
+ - `triton_entmax(x)` uses the v2 entmax implementation.
184
+ - Explicit v1 names are preserved: `adasplash_v1`, `adasplash_no_block_mask`, and `triton_entmax_v1`.
185
+
186
+ If you need strict AdaSplash-1 behavior, use the `_v1` names directly.
187
+
188
+ ## Testing
189
+
190
+ Install development dependencies:
191
+
192
+ ```bash
193
+ pip install -r requirements-dev.txt
194
+ ```
195
+
196
+ No-GPU import and public API checks:
197
+
198
+ ```bash
199
+ TRITON_INTERPRET=1 pytest -q
200
+ ```
201
+
202
+ Fast CUDA suite:
203
+
204
+ ```bash
205
+ pytest -q -m "not slow and not stress"
206
+ ```
207
+
208
+ Slow CUDA correctness suite:
209
+
210
+ ```bash
211
+ pytest -q -m "slow"
212
+ ```
213
+
214
+ Stress tests run dense reference checks for long contexts and may skip cases when the current GPU does not have enough memory or shared memory:
215
+
216
+ ```bash
217
+ pytest -q -m "stress"
218
+ ```
219
+
220
+ ## Benchmarks And Models
221
+
222
+ ### Efficiency AdaSplash-1
223
+ ![Benchmark AdaSplash-1](benchmark.png)
224
+
225
+ ### Efficiency AdaSplash-2
226
+ ![Benchmark AdaSplash-2](benchmark2.png)
227
+
228
+ ### Sparse Models
229
+
230
+ Sparse models and related artifacts are hosted under the [SARDINE Lab Hugging Face organization](https://huggingface.co/sardinelab/).
231
+
232
+ For single-vector retrieval experiments, see the [Sparse ModernBERT repository](https://github.com/deep-spin/SparseModernBERT).
233
+
234
+ ## Citation
235
+
236
+ If you use AdaSplash in your research, please cite the relevant paper.
237
+
238
+ AdaSplash-1:
239
+
240
+ ```bibtex
241
+ @inproceedings{
242
+ goncalves2025adasplash,
243
+ title={AdaSplash: Adaptive Sparse Flash Attention},
244
+ author={Nuno Gon{\c{c}}alves and Marcos V Treviso and Andre Martins},
245
+ booktitle={Forty-second International Conference on Machine Learning},
246
+ year={2025},
247
+ url={https://openreview.net/forum?id=OWIPDWhUcO}
248
+ }
249
+ ```
250
+
251
+ AdaSplash-2:
252
+
253
+ ```bibtex
254
+ @inproceedings{
255
+ goncalves2026adasplash,
256
+ title={AdaSplash-2: Faster Differentiable Sparse Attention},
257
+ author={Nuno Gon{\c{c}}alves and Hugo Pitorro and Vlad Niculae and Edoardo Ponti and Lei Li and Andre Martins and Marcos V Treviso},
258
+ booktitle={Forty-third International Conference on Machine Learning},
259
+ year={2026},
260
+ url={https://openreview.net/forum?id=7qpvff2gWI}
261
+ }
262
+ ```
263
+
264
+ ## License
265
+
266
+ AdaSplash is licensed under the BSD-3-Clause License. See the [LICENSE](LICENSE) file for details.
267
+
268
+ ## Acknowledgements
269
+
270
+ > We would like to the SARDINE lab team for the helpful discussions. This work was supported by the project DECOLLAGE (ERC-2022-CoG 101088763), by the Portuguese Recovery and Resilience Plan through project C64500888200000055 (Center for Responsible AI), and by FCT/MECI through national funds and when applicable co-funded EU funds under UID/50008: Instituto de Telecomunicações. Vlad Niculae is supported by the Dutch Research Council (NWO) via VI.Veni.212.228. Edoardo M. Ponti is supported by the ERC Starting Grant AToM-FM (101222956).
@@ -68,22 +68,6 @@ def triton_entmax15(x, **kwargs):
68
68
  return triton_entmax15(x, **kwargs)
69
69
 
70
70
 
71
- def entmax_attention(q, k, v, alpha=1.5, varlen=None, is_causal=False, padding="right", niter=2, alibi_slopes=None):
72
- from .attention import entmax_attention as _entmax_attention
73
-
74
- return _entmax_attention(
75
- q,
76
- k,
77
- v,
78
- alpha=alpha,
79
- varlen=varlen,
80
- is_causal=is_causal,
81
- padding=padding,
82
- niter=niter,
83
- alibi_slopes=alibi_slopes,
84
- )
85
-
86
-
87
71
  adasplash2 = _adasplash_v2
88
72
 
89
73
  __all__ = [
@@ -98,5 +82,4 @@ __all__ = [
98
82
  "triton_entmax_v2",
99
83
  "triton_sparsemax",
100
84
  "triton_entmax15",
101
- "entmax_attention",
102
85
  ]
@@ -1211,4 +1211,22 @@ class _sparse_attention(torch.autograd.Function):
1211
1211
 
1212
1212
  # TODO: Support the post_niter parameter.
1213
1213
  def sparse_attn(q, k, v, alpha=1.5, is_causal=False, varlen=None, niter=10):
1214
+ """Run the original AdaSplash sparse attention kernel with a block mask.
1215
+
1216
+ Args:
1217
+ q: Query tensor with shape ``(batch, n_heads, seq_len, head_dim)``.
1218
+ k: Key tensor with the same shape as ``q``.
1219
+ v: Value tensor with the same shape as ``q``.
1220
+ alpha: Entmax alpha. Values in ``(1, 2]`` are supported by the v1
1221
+ solver; ``1.5`` is the default.
1222
+ is_causal: If true, apply a lower-triangular causal attention mask.
1223
+ varlen: Optional integer tensor of shape ``(batch,)`` containing the
1224
+ valid sequence length for each batch item. Padded output rows are
1225
+ zeroed.
1226
+ niter: Number of iterations used by the entmax threshold solver.
1227
+
1228
+ Returns:
1229
+ Tensor with the same shape and dtype as ``q`` containing the attended
1230
+ values.
1231
+ """
1214
1232
  return _sparse_attention.apply(q, k, v, alpha, is_causal, varlen, niter)
@@ -1079,6 +1079,28 @@ class _sparse_attention(torch.autograd.Function):
1079
1079
 
1080
1080
 
1081
1081
  def sparse_attn(q, k, v, alpha=1.5, is_causal=False, varlen=None, niter=10):
1082
+ """Run the v1 AdaSplash attention kernel without block-mask pruning.
1083
+
1084
+ This variant preserves the original v1 semantics and entmax alpha support,
1085
+ but skips the block-mask sparsity prepass. It is useful as a compatibility
1086
+ path and as a simpler reference-style Triton implementation.
1087
+
1088
+ Args:
1089
+ q: Query tensor with shape ``(batch, n_heads, seq_len, head_dim)``.
1090
+ k: Key tensor with the same shape as ``q``.
1091
+ v: Value tensor with the same shape as ``q``.
1092
+ alpha: Entmax alpha. Values in ``(1, 2]`` are supported by the v1
1093
+ solver; ``1.5`` is the default.
1094
+ is_causal: If true, apply a lower-triangular causal attention mask.
1095
+ varlen: Optional integer tensor of shape ``(batch,)`` containing the
1096
+ valid sequence length for each batch item. Padded output rows are
1097
+ zeroed.
1098
+ niter: Number of iterations used by the entmax threshold solver.
1099
+
1100
+ Returns:
1101
+ Tensor with the same shape and dtype as ``q`` containing the attended
1102
+ values.
1103
+ """
1082
1104
  return _sparse_attention.apply(q, k, v, alpha, is_causal, varlen, niter)
1083
1105
 
1084
1106