adasplash 0.2.1__tar.gz → 0.2.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. {adasplash-0.2.1 → adasplash-0.2.2}/PKG-INFO +20 -16
  2. {adasplash-0.2.1 → adasplash-0.2.2}/README.md +19 -15
  3. {adasplash-0.2.1 → adasplash-0.2.2}/adasplash/__init__.py +0 -17
  4. {adasplash-0.2.1 → adasplash-0.2.2}/adasplash.egg-info/PKG-INFO +20 -16
  5. {adasplash-0.2.1 → adasplash-0.2.2}/adasplash.egg-info/SOURCES.txt +0 -1
  6. {adasplash-0.2.1 → adasplash-0.2.2}/setup.py +1 -1
  7. {adasplash-0.2.1 → adasplash-0.2.2}/tests/test_attention.py +23 -4
  8. {adasplash-0.2.1 → adasplash-0.2.2}/tests/test_public_api.py +1 -13
  9. adasplash-0.2.1/adasplash/attention.py +0 -73
  10. {adasplash-0.2.1 → adasplash-0.2.2}/LICENSE +0 -0
  11. {adasplash-0.2.1 → adasplash-0.2.2}/adasplash/adasplash_block_mask.py +0 -0
  12. {adasplash-0.2.1 → adasplash-0.2.2}/adasplash/adasplash_no_block_mask.py +0 -0
  13. {adasplash-0.2.1 → adasplash-0.2.2}/adasplash/adasplash_v2.py +0 -0
  14. {adasplash-0.2.1 → adasplash-0.2.2}/adasplash/triton_entmax.py +0 -0
  15. {adasplash-0.2.1 → adasplash-0.2.2}/adasplash/triton_entmax_v2.py +0 -0
  16. {adasplash-0.2.1 → adasplash-0.2.2}/adasplash.egg-info/dependency_links.txt +0 -0
  17. {adasplash-0.2.1 → adasplash-0.2.2}/adasplash.egg-info/requires.txt +0 -0
  18. {adasplash-0.2.1 → adasplash-0.2.2}/adasplash.egg-info/top_level.txt +0 -0
  19. {adasplash-0.2.1 → adasplash-0.2.2}/pyproject.toml +0 -0
  20. {adasplash-0.2.1 → adasplash-0.2.2}/setup.cfg +0 -0
  21. {adasplash-0.2.1 → adasplash-0.2.2}/tests/test_adasplash.py +0 -0
  22. {adasplash-0.2.1 → adasplash-0.2.2}/tests/test_adasplash_no_block_mask.py +0 -0
  23. {adasplash-0.2.1 → adasplash-0.2.2}/tests/test_adasplash_v2.py +0 -0
  24. {adasplash-0.2.1 → adasplash-0.2.2}/tests/test_triton_entmax.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: adasplash
3
- Version: 0.2.1
3
+ Version: 0.2.2
4
4
  Summary: AdaSplash: Efficient Adaptive Sparse Attention in Triton
5
5
  Home-page: https://github.com/deep-spin/adasplash
6
6
  Author: Nuno Gonçalves, Marcos Treviso
@@ -93,7 +93,6 @@ from adasplash import (
93
93
  triton_entmax_v2,
94
94
  triton_sparsemax,
95
95
  triton_entmax15,
96
- entmax_attention,
97
96
  )
98
97
  ```
99
98
 
@@ -108,7 +107,6 @@ from adasplash import (
108
107
  | `triton_entmax_v1` | Original entmax implementation. |
109
108
  | `triton_sparsemax` | Convenience v2 sparsemax call, equivalent to entmax with `alpha=2.0`. |
110
109
  | `triton_entmax15` | Convenience v2 entmax-1.5 call. |
111
- | `entmax_attention` | Dense attention utility using v2 `triton_entmax`. |
112
110
 
113
111
  ## Sparse Attention Examples
114
112
 
@@ -188,23 +186,29 @@ y_entmax15 = triton_entmax15(x)
188
186
 
189
187
  For generic alpha values other than `1.5` and `2.0`, v2 disables histogram initialization internally and uses more refinement iterations for correctness.
190
188
 
191
- ## Dense Entmax Attention Utility
189
+ ## Attention Examples
190
+
191
+ The `examples/attention.py` file contains two small helpers that show the difference between the fused AdaSplash kernel and a dense reference-style implementation.
192
+
193
+ ### Flash Entmax Attention
192
194
 
193
195
  ```python
194
- from adasplash import entmax_attention
195
-
196
- out = entmax_attention(
197
- q,
198
- k,
199
- v,
200
- alpha=1.5,
201
- is_causal=True,
202
- varlen=None,
203
- padding="right",
204
- )
196
+ from examples.attention import flash_entmax_attention
197
+
198
+ out = flash_entmax_attention(q, k, v, is_causal=True)
199
+ ```
200
+
201
+ `flash_entmax_attention` is a thin example wrapper around `adasplash`, the actual fused flash entmax attention path.
202
+
203
+ ### Slow Dense Entmax Attention
204
+
205
+ ```python
206
+ from examples.attention import slow_entmax_attention
207
+
208
+ out = slow_entmax_attention(q, k, v, is_causal=True, padding="right")
205
209
  ```
206
210
 
207
- `entmax_attention` is a dense utility built on top of v2 `triton_entmax`. It supports causal masking, non-causal masking, variable lengths, left/right padding, ALiBi slopes, and gradients through `q`, `k`, and `v`.
211
+ `slow_entmax_attention` materializes dense attention scores and applies `triton_entmax`. It is useful for examples and small correctness checks, but it is not the AdaSplash flash kernel and should not be used for long contexts.
208
212
 
209
213
  ## Backwards Compatibility
210
214
 
@@ -58,7 +58,6 @@ from adasplash import (
58
58
  triton_entmax_v2,
59
59
  triton_sparsemax,
60
60
  triton_entmax15,
61
- entmax_attention,
62
61
  )
63
62
  ```
64
63
 
@@ -73,7 +72,6 @@ from adasplash import (
73
72
  | `triton_entmax_v1` | Original entmax implementation. |
74
73
  | `triton_sparsemax` | Convenience v2 sparsemax call, equivalent to entmax with `alpha=2.0`. |
75
74
  | `triton_entmax15` | Convenience v2 entmax-1.5 call. |
76
- | `entmax_attention` | Dense attention utility using v2 `triton_entmax`. |
77
75
 
78
76
  ## Sparse Attention Examples
79
77
 
@@ -153,23 +151,29 @@ y_entmax15 = triton_entmax15(x)
153
151
 
154
152
  For generic alpha values other than `1.5` and `2.0`, v2 disables histogram initialization internally and uses more refinement iterations for correctness.
155
153
 
156
- ## Dense Entmax Attention Utility
154
+ ## Attention Examples
155
+
156
+ The `examples/attention.py` file contains two small helpers that show the difference between the fused AdaSplash kernel and a dense reference-style implementation.
157
+
158
+ ### Flash Entmax Attention
157
159
 
158
160
  ```python
159
- from adasplash import entmax_attention
160
-
161
- out = entmax_attention(
162
- q,
163
- k,
164
- v,
165
- alpha=1.5,
166
- is_causal=True,
167
- varlen=None,
168
- padding="right",
169
- )
161
+ from examples.attention import flash_entmax_attention
162
+
163
+ out = flash_entmax_attention(q, k, v, is_causal=True)
164
+ ```
165
+
166
+ `flash_entmax_attention` is a thin example wrapper around `adasplash`, the actual fused flash entmax attention path.
167
+
168
+ ### Slow Dense Entmax Attention
169
+
170
+ ```python
171
+ from examples.attention import slow_entmax_attention
172
+
173
+ out = slow_entmax_attention(q, k, v, is_causal=True, padding="right")
170
174
  ```
171
175
 
172
- `entmax_attention` is a dense utility built on top of v2 `triton_entmax`. It supports causal masking, non-causal masking, variable lengths, left/right padding, ALiBi slopes, and gradients through `q`, `k`, and `v`.
176
+ `slow_entmax_attention` materializes dense attention scores and applies `triton_entmax`. It is useful for examples and small correctness checks, but it is not the AdaSplash flash kernel and should not be used for long contexts.
173
177
 
174
178
  ## Backwards Compatibility
175
179
 
@@ -68,22 +68,6 @@ def triton_entmax15(x, **kwargs):
68
68
  return triton_entmax15(x, **kwargs)
69
69
 
70
70
 
71
- def entmax_attention(q, k, v, alpha=1.5, varlen=None, is_causal=False, padding="right", niter=2, alibi_slopes=None):
72
- from .attention import entmax_attention as _entmax_attention
73
-
74
- return _entmax_attention(
75
- q,
76
- k,
77
- v,
78
- alpha=alpha,
79
- varlen=varlen,
80
- is_causal=is_causal,
81
- padding=padding,
82
- niter=niter,
83
- alibi_slopes=alibi_slopes,
84
- )
85
-
86
-
87
71
  adasplash2 = _adasplash_v2
88
72
 
89
73
  __all__ = [
@@ -98,5 +82,4 @@ __all__ = [
98
82
  "triton_entmax_v2",
99
83
  "triton_sparsemax",
100
84
  "triton_entmax15",
101
- "entmax_attention",
102
85
  ]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: adasplash
3
- Version: 0.2.1
3
+ Version: 0.2.2
4
4
  Summary: AdaSplash: Efficient Adaptive Sparse Attention in Triton
5
5
  Home-page: https://github.com/deep-spin/adasplash
6
6
  Author: Nuno Gonçalves, Marcos Treviso
@@ -93,7 +93,6 @@ from adasplash import (
93
93
  triton_entmax_v2,
94
94
  triton_sparsemax,
95
95
  triton_entmax15,
96
- entmax_attention,
97
96
  )
98
97
  ```
99
98
 
@@ -108,7 +107,6 @@ from adasplash import (
108
107
  | `triton_entmax_v1` | Original entmax implementation. |
109
108
  | `triton_sparsemax` | Convenience v2 sparsemax call, equivalent to entmax with `alpha=2.0`. |
110
109
  | `triton_entmax15` | Convenience v2 entmax-1.5 call. |
111
- | `entmax_attention` | Dense attention utility using v2 `triton_entmax`. |
112
110
 
113
111
  ## Sparse Attention Examples
114
112
 
@@ -188,23 +186,29 @@ y_entmax15 = triton_entmax15(x)
188
186
 
189
187
  For generic alpha values other than `1.5` and `2.0`, v2 disables histogram initialization internally and uses more refinement iterations for correctness.
190
188
 
191
- ## Dense Entmax Attention Utility
189
+ ## Attention Examples
190
+
191
+ The `examples/attention.py` file contains two small helpers that show the difference between the fused AdaSplash kernel and a dense reference-style implementation.
192
+
193
+ ### Flash Entmax Attention
192
194
 
193
195
  ```python
194
- from adasplash import entmax_attention
195
-
196
- out = entmax_attention(
197
- q,
198
- k,
199
- v,
200
- alpha=1.5,
201
- is_causal=True,
202
- varlen=None,
203
- padding="right",
204
- )
196
+ from examples.attention import flash_entmax_attention
197
+
198
+ out = flash_entmax_attention(q, k, v, is_causal=True)
199
+ ```
200
+
201
+ `flash_entmax_attention` is a thin example wrapper around `adasplash`, the actual fused flash entmax attention path.
202
+
203
+ ### Slow Dense Entmax Attention
204
+
205
+ ```python
206
+ from examples.attention import slow_entmax_attention
207
+
208
+ out = slow_entmax_attention(q, k, v, is_causal=True, padding="right")
205
209
  ```
206
210
 
207
- `entmax_attention` is a dense utility built on top of v2 `triton_entmax`. It supports causal masking, non-causal masking, variable lengths, left/right padding, ALiBi slopes, and gradients through `q`, `k`, and `v`.
211
+ `slow_entmax_attention` materializes dense attention scores and applies `triton_entmax`. It is useful for examples and small correctness checks, but it is not the AdaSplash flash kernel and should not be used for long contexts.
208
212
 
209
213
  ## Backwards Compatibility
210
214
 
@@ -6,7 +6,6 @@ adasplash/__init__.py
6
6
  adasplash/adasplash_block_mask.py
7
7
  adasplash/adasplash_no_block_mask.py
8
8
  adasplash/adasplash_v2.py
9
- adasplash/attention.py
10
9
  adasplash/triton_entmax.py
11
10
  adasplash/triton_entmax_v2.py
12
11
  adasplash.egg-info/PKG-INFO
@@ -2,7 +2,7 @@ from setuptools import setup, find_packages
2
2
 
3
3
  setup(
4
4
  name="adasplash",
5
- version="0.2.1",
5
+ version="0.2.2",
6
6
  author="Nuno Gonçalves, Marcos Treviso",
7
7
  author_email="marcosvtreviso@gmail.com",
8
8
  description="AdaSplash: Efficient Adaptive Sparse Attention in Triton",
@@ -4,7 +4,7 @@ import pytest
4
4
  import torch
5
5
  from entmax import entmax_bisect
6
6
 
7
- import adasplash
7
+ from examples.attention import flash_entmax_attention, slow_entmax_attention
8
8
 
9
9
 
10
10
  pytestmark = pytest.mark.gpu
@@ -59,7 +59,7 @@ def _run_attention_case(padding, is_causal):
59
59
  ref = reference_attention(q, k, v, varlen=varlen, is_causal=is_causal, padding=padding, alibi_slopes=alibi)
60
60
  ref_dq, ref_dk, ref_dv = torch.autograd.grad(ref, (q, k, v), do)
61
61
 
62
- out = adasplash.entmax_attention(
62
+ out = slow_entmax_attention(
63
63
  q,
64
64
  k,
65
65
  v,
@@ -77,12 +77,31 @@ def _run_attention_case(padding, is_causal):
77
77
  assert torch.allclose(tri_dv, ref_dv, atol=1e-3, rtol=1e-3)
78
78
 
79
79
 
80
- def test_entmax_attention_fast_forward_backward_smoke():
80
+ def test_slow_entmax_attention_fast_forward_backward_smoke():
81
81
  _run_attention_case(padding="right", is_causal=True)
82
82
 
83
83
 
84
+ def test_flash_entmax_attention_example_smoke():
85
+ torch.manual_seed(42)
86
+ q = torch.randn(1, 1, 128, 32, device="cuda", dtype=torch.float32, requires_grad=True).contiguous()
87
+ k = torch.randn_like(q, requires_grad=True).contiguous()
88
+ v = torch.randn_like(q, requires_grad=True).contiguous()
89
+ do = torch.randn_like(q)
90
+
91
+ ref = reference_attention(q, k, v, is_causal=True)
92
+ ref_dq, ref_dk, ref_dv = torch.autograd.grad(ref, (q, k, v), do)
93
+
94
+ out = flash_entmax_attention(q, k, v, is_causal=True, niter=10)
95
+ tri_dq, tri_dk, tri_dv = torch.autograd.grad(out, (q, k, v), do)
96
+
97
+ assert torch.allclose(out, ref, atol=1e-4, rtol=1e-4)
98
+ assert torch.allclose(tri_dq, ref_dq, atol=1e-4, rtol=1e-4)
99
+ assert torch.allclose(tri_dk, ref_dk, atol=1e-4, rtol=1e-4)
100
+ assert torch.allclose(tri_dv, ref_dv, atol=1e-4, rtol=1e-4)
101
+
102
+
84
103
  @pytest.mark.slow
85
104
  @pytest.mark.parametrize("padding", ["left", "right"])
86
105
  @pytest.mark.parametrize("is_causal", [False, True])
87
- def test_entmax_attention_forward_backward_matches_reference(padding, is_causal):
106
+ def test_slow_entmax_attention_forward_backward_matches_reference(padding, is_causal):
88
107
  _run_attention_case(padding=padding, is_causal=is_causal)
@@ -16,7 +16,6 @@ def test_public_api_exports_are_lazy_and_versioned():
16
16
  "triton_entmax_v2",
17
17
  "triton_sparsemax",
18
18
  "triton_entmax15",
19
- "entmax_attention",
20
19
  ]:
21
20
  assert name in adasplash.__all__
22
21
  assert callable(getattr(adasplash, name))
@@ -49,24 +48,13 @@ def test_dispatcher_signatures_are_stable():
49
48
  "use_histogram",
50
49
  "fast_math",
51
50
  ]
52
- assert list(inspect.signature(adasplash.entmax_attention).parameters) == [
53
- "q",
54
- "k",
55
- "v",
56
- "alpha",
57
- "varlen",
58
- "is_causal",
59
- "padding",
60
- "niter",
61
- "alibi_slopes",
62
- ]
51
+ assert not hasattr(adasplash, "entmax_attention")
63
52
 
64
53
 
65
54
  def test_package_source_allowlist():
66
55
  package_dir = Path(adasplash.__file__).resolve().parent
67
56
  allowed = {
68
57
  "__init__.py",
69
- "attention.py",
70
58
  "adasplash_block_mask.py",
71
59
  "adasplash_no_block_mask.py",
72
60
  "adasplash_v2.py",
@@ -1,73 +0,0 @@
1
- import math
2
-
3
- import torch
4
-
5
-
6
- def _varlen_mask(varlen, size, padding):
7
- positions = torch.arange(size, device=varlen.device)
8
- if padding == "right":
9
- return positions[None, :] < varlen[:, None]
10
- if padding == "left":
11
- return positions[None, :] >= size - varlen[:, None]
12
- raise ValueError("padding must be either 'right' or 'left'.")
13
-
14
-
15
- def _alibi_bias(q, k, alibi_slopes):
16
- _, n_heads, q_len, _ = q.shape
17
- k_len = k.shape[-2]
18
- if alibi_slopes.shape != (n_heads,):
19
- raise ValueError(f"alibi_slopes must have shape ({n_heads},); got {tuple(alibi_slopes.shape)}.")
20
-
21
- if q_len == 1 and k_len > 1:
22
- rel_pos = torch.arange(k_len, device=q.device) - (k_len - 1)
23
- rel_pos = rel_pos.view(1, 1, 1, k_len)
24
- else:
25
- q_pos = torch.arange(q_len, device=q.device)
26
- k_pos = torch.arange(k_len, device=q.device)
27
- rel_pos = k_pos[None, :] - q_pos[:, None]
28
- rel_pos = rel_pos.view(1, 1, q_len, k_len)
29
- return alibi_slopes.to(q.device).view(1, n_heads, 1, 1) * rel_pos
30
-
31
-
32
- def entmax_attention(q, k, v, alpha=1.5, varlen=None, is_causal=False, padding="right", niter=2, alibi_slopes=None):
33
- """Dense QK attention using the public v2 Triton entmax activation."""
34
- if q.dim() != 4 or k.dim() != 4 or v.dim() != 4:
35
- raise ValueError("q, k and v must have shape (batch, heads, seq_len, head_dim).")
36
- if k.shape != v.shape:
37
- raise ValueError(f"k and v must have the same shape; got {tuple(k.shape)} and {tuple(v.shape)}.")
38
- if q.shape[0] != k.shape[0] or q.shape[1] != k.shape[1] or q.shape[3] != k.shape[3]:
39
- raise ValueError("q, k and v must agree on batch, heads and head_dim.")
40
-
41
- _, _, q_len, head_dim = q.shape
42
- k_len = k.shape[-2]
43
- scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(head_dim)
44
-
45
- if alibi_slopes is not None:
46
- scores = scores + _alibi_bias(q, k, alibi_slopes)
47
-
48
- if is_causal:
49
- if q_len == k_len:
50
- causal = torch.tril(torch.ones(q_len, k_len, device=q.device, dtype=torch.bool))
51
- else:
52
- q_pos = torch.arange(q_len, device=q.device) + (k_len - q_len)
53
- k_pos = torch.arange(k_len, device=q.device)
54
- causal = q_pos[:, None] >= k_pos[None, :]
55
- scores = scores.masked_fill(~causal.view(1, 1, q_len, k_len), float("-inf"))
56
-
57
- output_mask = None
58
- if varlen is not None:
59
- if varlen.dim() != 1 or varlen.shape[0] != q.shape[0]:
60
- raise ValueError(f"varlen must be a 1-D tensor of shape ({q.shape[0]},).")
61
- key_mask = _varlen_mask(varlen.to(q.device), k_len, padding)
62
- scores = scores.masked_fill(~key_mask[:, None, None, :], float("-inf"))
63
- if q_len == k_len:
64
- output_mask = _varlen_mask(varlen.to(q.device), q_len, padding)[:, None, :, None]
65
- scores = scores.masked_fill(~output_mask, 0.0)
66
-
67
- from .triton_entmax_v2 import triton_entmax
68
-
69
- probs = triton_entmax(scores.contiguous(), alpha=alpha, n_iter=niter, fast_math=False)
70
- out = torch.matmul(probs, v)
71
- if output_mask is not None:
72
- out = out.masked_fill(~output_mask, 0)
73
- return out
File without changes
File without changes
File without changes