blksprs 2.1.10__tar.gz → 2.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {blksprs-2.1.10 → blksprs-2.2}/PKG-INFO +32 -21
  2. {blksprs-2.1.10 → blksprs-2.2}/README.md +31 -20
  3. {blksprs-2.1.10 → blksprs-2.2}/blksprs/__init__.py +2 -1
  4. {blksprs-2.1.10 → blksprs-2.2}/blksprs/ops/distribution.py +3 -3
  5. blksprs-2.2/blksprs/ops/flash_attention.py +612 -0
  6. {blksprs-2.1.10 → blksprs-2.2}/blksprs/utils/autotuning.py +0 -1
  7. {blksprs-2.1.10 → blksprs-2.2}/blksprs.egg-info/PKG-INFO +32 -21
  8. {blksprs-2.1.10 → blksprs-2.2}/blksprs.egg-info/SOURCES.txt +1 -0
  9. {blksprs-2.1.10 → blksprs-2.2}/pyproject.toml +7 -7
  10. {blksprs-2.1.10 → blksprs-2.2}/blksprs/layouting/distribution_layout.py +0 -0
  11. {blksprs-2.1.10 → blksprs-2.2}/blksprs/layouting/sparsity_layout.py +0 -0
  12. {blksprs-2.1.10 → blksprs-2.2}/blksprs/ops/conversion.py +0 -0
  13. {blksprs-2.1.10 → blksprs-2.2}/blksprs/ops/flow.py +0 -0
  14. {blksprs-2.1.10 → blksprs-2.2}/blksprs/ops/matmul.py +0 -0
  15. {blksprs-2.1.10 → blksprs-2.2}/blksprs/ops/misc/broadcast_ops.py +0 -0
  16. {blksprs-2.1.10 → blksprs-2.2}/blksprs/ops/misc/row_wise.py +0 -0
  17. {blksprs-2.1.10 → blksprs-2.2}/blksprs/ops/partitioning.py +0 -0
  18. {blksprs-2.1.10 → blksprs-2.2}/blksprs/ops/repeat.py +0 -0
  19. {blksprs-2.1.10 → blksprs-2.2}/blksprs/ops/softmax.py +0 -0
  20. {blksprs-2.1.10 → blksprs-2.2}/blksprs/ops/transpose.py +0 -0
  21. {blksprs-2.1.10 → blksprs-2.2}/blksprs/utils/benchmarking.py +0 -0
  22. {blksprs-2.1.10 → blksprs-2.2}/blksprs/utils/blksprs_tensor.py +0 -0
  23. {blksprs-2.1.10 → blksprs-2.2}/blksprs/utils/processing.py +0 -0
  24. {blksprs-2.1.10 → blksprs-2.2}/blksprs/utils/tools.py +0 -0
  25. {blksprs-2.1.10 → blksprs-2.2}/blksprs/utils/validation.py +0 -0
  26. {blksprs-2.1.10 → blksprs-2.2}/blksprs.egg-info/dependency_links.txt +0 -0
  27. {blksprs-2.1.10 → blksprs-2.2}/blksprs.egg-info/requires.txt +0 -0
  28. {blksprs-2.1.10 → blksprs-2.2}/blksprs.egg-info/top_level.txt +0 -0
  29. {blksprs-2.1.10 → blksprs-2.2}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: blksprs
3
- Version: 2.1.10
3
+ Version: 2.2
4
4
  Summary: A lightweight library for operations on block-sparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -17,20 +17,13 @@ Requires-Dist: coverage; extra == "test"
17
17
  Requires-Dist: build; extra == "test"
18
18
  Requires-Dist: matplotlib; extra == "test"
19
19
 
20
- # blksprs
20
+ # 🧊 blksprs
21
21
 
22
22
  [![GitHub Release](https://img.shields.io/github/v/release/FelixSchoen/blksprs?include_prereleases&label=Latest%20Release)](https://github.com/FelixSchoen/blksprs/releases)
23
23
  [![Python 3.11](https://img.shields.io/badge/Python%20Version-3.11-blue)](https://www.python.org/downloads/release/python-3119/)
24
24
  [![Python 3.12](https://img.shields.io/badge/Python%20Version-3.12-blue)](https://www.python.org/downloads/release/python-31210/)
25
25
 
26
- ## Overview
27
-
28
- ### News
29
-
30
- 🎉 ***Version 2.0 released***. blksprs now supports kernel auto-tuning, JIT compilation, specification of pre-calculated
31
- LUTs, autocasting, and makes use of `torch.library.triton_op()`!
32
-
33
- ---
26
+ ## 📖 Overview
34
27
 
35
28
  A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
36
29
 
@@ -46,6 +39,7 @@ Currently supported operations (includes gradient calculation):
46
39
  - Splitting and merging of matrices (_currently* only supports splitting and merging along the last dimension_)
47
40
  - Conversion to and from sparse form
48
41
  - Conversion to different sparsity layouts and different sparsity block sizes
42
+ - Flash Attention (_supports custom masks and cross-attention_)
49
43
 
50
44
  As with this library sparse matrices are represented using a tuple of `(matrix, sparsity_layout, sparsity_block_size)`,
51
45
  any element-wise operations can be applied in regular torch-like fashion.
@@ -74,7 +68,7 @@ Furthermore, the library provides a set of utility functions
74
68
 
75
69
  _* see the [Roadmap](#roadmap) section for more information_
76
70
 
77
- ## Installation
71
+ ## 🛠️ Installation
78
72
 
79
73
  Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible with the Linux platform**.
80
74
  Keep track of this [issue](https://github.com/triton-lang/triton/issues/1640) for updates.
@@ -89,11 +83,11 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
89
83
  - _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.3.1)_
90
84
  - _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
91
85
 
92
- ## Changelog
86
+ ## 📝 Changelog
93
87
 
94
88
  See [`CHANGELOG.md`](https://github.com/FelixSchoen/blksprs/blob/main/CHANGELOG.md) for a detailed changelog.
95
89
 
96
- ## Roadmap
90
+ ## 🗺️ Roadmap
97
91
 
98
92
  Note that since this library covers all our current needs it is in a **bugfix-only** state.
99
93
  This means that there are no plans to add new features, e.g., support for dimension specification of the ``split`` and
@@ -105,17 +99,15 @@ We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
105
99
  It might be that this changes with future projects, but as of August 2025, we are content with the current state of the
106
100
  library.
107
101
 
108
- ## Known Limitations and Issues
102
+ ## ⚠️ Known Limitations and Issues
109
103
 
110
- - Triton has a bug with `tl.atomix_max()` used for the row-wise max operation.
111
- In order to work around this bug a manual conversion of some values is needed, (slightly) negatively impacting
112
- performance.
113
- Watch the [issue](https://github.com/triton-lang/triton/issues/6376) on Triton's issue tracker for more information.
114
104
  - There will be some slight numerical differences between vanilla and blksprs operations.
115
105
  These instabilities are due to Triton and thus cannot be fixed by this library alone.
116
106
  However, for all intents and purposes, these very minor differences should not matter and can safely be ignored.
117
107
 
118
- ## Usage
108
+ - Flash Attention is a recent addition. While it has been tested and appears stable, please report any issues you encounter.
109
+
110
+ ## 💻 Usage
119
111
 
120
112
  We provide an example below to demonstrate the usage of the library.
121
113
  For more detailed examples, please refer to
@@ -128,7 +120,6 @@ the [test cases](https://github.com/FelixSchoen/blksprs/blob/main/test/cases/tes
128
120
  import torch
129
121
  import blksprs as bs
130
122
 
131
-
132
123
  def test_readme():
133
124
  # Set up parameters (batch size, number of heads, dimensions for matrices (m, k) and (n, k))
134
125
  b, h, m, n, k = 2, 4, 64, 64, 16
@@ -193,10 +184,30 @@ def test_readme():
193
184
  # Other available functions
194
185
  bs.ops.transpose(o_sparse, sparsity_layout_o, sparsity_block_size)
195
186
  bs.ops.softmax(o_sparse, sparsity_layout_o, sparsity_block_size, flag_fused=False)
196
- bs.ops.softmax_fused(o_sparse, sparsity_layout_o, sparsity_block_size) # Significantly faster version that requires that rows of matrix fit into memory (default if flag is not set)
187
+ bs.ops.softmax_fused(o_sparse, sparsity_layout_o,
188
+ sparsity_block_size) # Significantly faster version that requires that rows of matrix fit into memory (default if flag is not set)
197
189
  bs.ops.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size)
198
190
  bs.ops.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size)
199
191
 
192
+ # Flash Attention
193
+ seq_len, head_dim = 512, 64
194
+ sparsity_block_size_attn = 128
195
+
196
+ q = torch.randn(b, seq_len, h, head_dim, device="cuda")
197
+ k = torch.randn(b, seq_len, h, head_dim, device="cuda")
198
+ v = torch.randn(b, seq_len, h, head_dim, device="cuda")
199
+
200
+ n_batches_attn = b * h
201
+ n_seq_blocks = seq_len // sparsity_block_size_attn
202
+ attention_layout = torch.tril(torch.ones(n_batches_attn, n_seq_blocks, n_seq_blocks, device="cuda", dtype=torch.bool))
203
+
204
+ lut = bs.ops.flash_attention_build_lut(attention_layout, n_seq_blocks, n_seq_blocks)
205
+
206
+ attn_out = bs.ops.flash_attention(q, k, v, attention_layout, sparsity_block_size_attn, lut=lut)
207
+
208
+ assert attn_out.shape == (b, seq_len, h, head_dim)
209
+
210
+
200
211
 
201
212
  def _get_random_sparsity_layout(b, m, n, sparsity_block_size, sparsity_percentage):
202
213
  """Helper function, creates a random sparsity layout for a given shape with a given percentage of blocks marked as sparse.
@@ -1,17 +1,10 @@
1
- # blksprs
1
+ # 🧊 blksprs
2
2
 
3
3
  [![GitHub Release](https://img.shields.io/github/v/release/FelixSchoen/blksprs?include_prereleases&label=Latest%20Release)](https://github.com/FelixSchoen/blksprs/releases)
4
4
  [![Python 3.11](https://img.shields.io/badge/Python%20Version-3.11-blue)](https://www.python.org/downloads/release/python-3119/)
5
5
  [![Python 3.12](https://img.shields.io/badge/Python%20Version-3.12-blue)](https://www.python.org/downloads/release/python-31210/)
6
6
 
7
- ## Overview
8
-
9
- ### News
10
-
11
- 🎉 ***Version 2.0 released***. blksprs now supports kernel auto-tuning, JIT compilation, specification of pre-calculated
12
- LUTs, autocasting, and makes use of `torch.library.triton_op()`!
13
-
14
- ---
7
+ ## 📖 Overview
15
8
 
16
9
  A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
17
10
 
@@ -27,6 +20,7 @@ Currently supported operations (includes gradient calculation):
27
20
  - Splitting and merging of matrices (_currently* only supports splitting and merging along the last dimension_)
28
21
  - Conversion to and from sparse form
29
22
  - Conversion to different sparsity layouts and different sparsity block sizes
23
+ - Flash Attention (_supports custom masks and cross-attention_)
30
24
 
31
25
  As with this library sparse matrices are represented using a tuple of `(matrix, sparsity_layout, sparsity_block_size)`,
32
26
  any element-wise operations can be applied in regular torch-like fashion.
@@ -55,7 +49,7 @@ Furthermore, the library provides a set of utility functions
55
49
 
56
50
  _* see the [Roadmap](#roadmap) section for more information_
57
51
 
58
- ## Installation
52
+ ## 🛠️ Installation
59
53
 
60
54
  Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible with the Linux platform**.
61
55
  Keep track of this [issue](https://github.com/triton-lang/triton/issues/1640) for updates.
@@ -70,11 +64,11 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
70
64
  - _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.3.1)_
71
65
  - _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
72
66
 
73
- ## Changelog
67
+ ## 📝 Changelog
74
68
 
75
69
  See [`CHANGELOG.md`](https://github.com/FelixSchoen/blksprs/blob/main/CHANGELOG.md) for a detailed changelog.
76
70
 
77
- ## Roadmap
71
+ ## 🗺️ Roadmap
78
72
 
79
73
  Note that since this library covers all our current needs it is in a **bugfix-only** state.
80
74
  This means that there are no plans to add new features, e.g., support for dimension specification of the ``split`` and
@@ -86,17 +80,15 @@ We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
86
80
  It might be that this changes with future projects, but as of August 2025, we are content with the current state of the
87
81
  library.
88
82
 
89
- ## Known Limitations and Issues
83
+ ## ⚠️ Known Limitations and Issues
90
84
 
91
- - Triton has a bug with `tl.atomix_max()` used for the row-wise max operation.
92
- In order to work around this bug a manual conversion of some values is needed, (slightly) negatively impacting
93
- performance.
94
- Watch the [issue](https://github.com/triton-lang/triton/issues/6376) on Triton's issue tracker for more information.
95
85
  - There will be some slight numerical differences between vanilla and blksprs operations.
96
86
  These instabilities are due to Triton and thus cannot be fixed by this library alone.
97
87
  However, for all intents and purposes, these very minor differences should not matter and can safely be ignored.
98
88
 
99
- ## Usage
89
+ - Flash Attention is a recent addition. While it has been tested and appears stable, please report any issues you encounter.
90
+
91
+ ## 💻 Usage
100
92
 
101
93
  We provide an example below to demonstrate the usage of the library.
102
94
  For more detailed examples, please refer to
@@ -109,7 +101,6 @@ the [test cases](https://github.com/FelixSchoen/blksprs/blob/main/test/cases/tes
109
101
  import torch
110
102
  import blksprs as bs
111
103
 
112
-
113
104
  def test_readme():
114
105
  # Set up parameters (batch size, number of heads, dimensions for matrices (m, k) and (n, k))
115
106
  b, h, m, n, k = 2, 4, 64, 64, 16
@@ -174,10 +165,30 @@ def test_readme():
174
165
  # Other available functions
175
166
  bs.ops.transpose(o_sparse, sparsity_layout_o, sparsity_block_size)
176
167
  bs.ops.softmax(o_sparse, sparsity_layout_o, sparsity_block_size, flag_fused=False)
177
- bs.ops.softmax_fused(o_sparse, sparsity_layout_o, sparsity_block_size) # Significantly faster version that requires that rows of matrix fit into memory (default if flag is not set)
168
+ bs.ops.softmax_fused(o_sparse, sparsity_layout_o,
169
+ sparsity_block_size) # Significantly faster version that requires that rows of matrix fit into memory (default if flag is not set)
178
170
  bs.ops.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size)
179
171
  bs.ops.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size)
180
172
 
173
+ # Flash Attention
174
+ seq_len, head_dim = 512, 64
175
+ sparsity_block_size_attn = 128
176
+
177
+ q = torch.randn(b, seq_len, h, head_dim, device="cuda")
178
+ k = torch.randn(b, seq_len, h, head_dim, device="cuda")
179
+ v = torch.randn(b, seq_len, h, head_dim, device="cuda")
180
+
181
+ n_batches_attn = b * h
182
+ n_seq_blocks = seq_len // sparsity_block_size_attn
183
+ attention_layout = torch.tril(torch.ones(n_batches_attn, n_seq_blocks, n_seq_blocks, device="cuda", dtype=torch.bool))
184
+
185
+ lut = bs.ops.flash_attention_build_lut(attention_layout, n_seq_blocks, n_seq_blocks)
186
+
187
+ attn_out = bs.ops.flash_attention(q, k, v, attention_layout, sparsity_block_size_attn, lut=lut)
188
+
189
+ assert attn_out.shape == (b, seq_len, h, head_dim)
190
+
191
+
181
192
 
182
193
  def _get_random_sparsity_layout(b, m, n, sparsity_block_size, sparsity_percentage):
183
194
  """Helper function, creates a random sparsity layout for a given shape with a given percentage of blocks marked as sparse.
@@ -4,7 +4,7 @@ import torch
4
4
  # Capture scalar outputs for JIT compilation
5
5
  torch._dynamo.config.capture_scalar_outputs = True
6
6
  # Set version
7
- __version__ = "2.1.10"
7
+ __version__ = "2.2"
8
8
 
9
9
  # Imports
10
10
 
@@ -14,6 +14,7 @@ from blksprs.utils.blksprs_tensor import BlksprsTensor
14
14
  class ops:
15
15
  from blksprs.ops.conversion import to_dense, to_sparse, from_blksprs, to_blksprs, adapt_layout
16
16
  from blksprs.ops.distribution import gather, scatter, scatter_reduce
17
+ from blksprs.ops.flash_attention import flash_attention, flash_attention_build_lut
17
18
  from blksprs.ops.matmul import matmul
18
19
  from blksprs.ops.softmax import softmax, softmax_fused
19
20
  from blksprs.ops.transpose import transpose
@@ -174,7 +174,7 @@ def gather_kernel(x,
174
174
  dst_col_x)
175
175
  blk_x_msk = (((blk_x_idx >= 0) &
176
176
  (blk_x_idx < x_b * x_b_s)) &
177
- (rev_idx_spa_x_msk != -1))
177
+ (rev_idx_spa_x >= 0))
178
178
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
179
179
 
180
180
  # Store output
@@ -183,7 +183,7 @@ def gather_kernel(x,
183
183
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
184
184
  blk_o_msk = (((blk_o_idx >= 0) &
185
185
  (blk_o_idx < o_b * o_b_s)) &
186
- (rev_idx_spa_x_msk != -1))
186
+ (rev_idx_spa_x >= 0))
187
187
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
188
188
 
189
189
 
@@ -426,7 +426,7 @@ def scatter_reduce_kernel(x,
426
426
  dst_col_o)
427
427
  blk_o_msk = (((blk_o_idx >= 0) &
428
428
  (blk_o_idx < o_b * o_b_s)) &
429
- (rev_idx_spa_o_msk != -1))
429
+ (rev_idx_spa_o >= 0))
430
430
 
431
431
  if reduce_op_ind == 0:
432
432
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
@@ -0,0 +1,612 @@
1
+ """Block-sparse Flash Attention implementation for blksprs.
2
+
3
+ This module implements Flash Attention 2 algorithm with block-sparse support,
4
+ including cross-attention (seq_q != seq_k) and custom attention masks.
5
+
6
+ Note: This implementation was developed with AI assistance.
7
+ """
8
+
9
+ import math
10
+ from typing import Tuple
11
+
12
+ import torch
13
+ import triton
14
+ from torch import Tensor
15
+ from triton import language as tl
16
+
17
+ from blksprs.utils.validation import validate_contiguous, validate_device, validate_dtype_float, ensure_contiguous
18
+
19
+
20
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
21
+ def flash_attention(
22
+ q: Tensor,
23
+ k: Tensor,
24
+ v: Tensor,
25
+ attention_layout: Tensor,
26
+ sparsity_block_size: int,
27
+ scale: float = None,
28
+ attention_mask: Tensor = None,
29
+ lut: dict = None,
30
+ ) -> Tensor:
31
+ """Block-sparse flash attention with optional attention mask.
32
+
33
+ Args:
34
+ q: Query tensor [batch, seq_q, n_heads, head_dim]
35
+ k: Key tensor [batch, seq_k, n_heads, head_dim]
36
+ v: Value tensor [batch, seq_k, n_heads, head_dim]
37
+ attention_layout: Block attention pattern [batch*heads, n_seq_blocks_q, n_seq_blocks_k]
38
+ sparsity_block_size: Block size for sparsity pattern
39
+ scale: Attention scale (default: 1/sqrt(head_dim))
40
+ attention_mask: Boolean mask [batch*heads, seq_q, seq_k] where True=masked (default None)
41
+ lut: Optional pre-computed LUT dictionary
42
+
43
+ Returns:
44
+ Output tensor [batch, seq_q, n_heads, head_dim]
45
+ """
46
+ q, k, v = ensure_contiguous(q, k, v)
47
+
48
+ validate_contiguous(q, k, v)
49
+ validate_dtype_float(q, k, v)
50
+ validate_device(q, k, v)
51
+
52
+ batch, seq_q, n_heads, head_dim = q.shape
53
+ _, seq_k, _, _ = k.shape
54
+
55
+ if k.shape[0] != batch or k.shape[2] != n_heads or k.shape[3] != head_dim:
56
+ raise ValueError("K must have compatible shape with Q")
57
+ if v.shape != k.shape:
58
+ raise ValueError("V must have same shape as K")
59
+ if not (sparsity_block_size >= 16 and (sparsity_block_size & (sparsity_block_size - 1)) == 0):
60
+ raise ValueError(f"sparsity_block_size must be power of 2 >= 16, got {sparsity_block_size}")
61
+ if seq_q % sparsity_block_size != 0:
62
+ raise ValueError(f"seq_q ({seq_q}) must be divisible by sparsity_block_size")
63
+ if seq_k % sparsity_block_size != 0:
64
+ raise ValueError(f"seq_k ({seq_k}) must be divisible by sparsity_block_size")
65
+
66
+ n_batches = batch * n_heads
67
+ n_seq_blocks_q = seq_q // sparsity_block_size
68
+ n_seq_blocks_k = seq_k // sparsity_block_size
69
+
70
+ expected_layout_shape = (n_batches, n_seq_blocks_q, n_seq_blocks_k)
71
+ if attention_layout.shape != expected_layout_shape:
72
+ raise ValueError(f"attention_layout shape {tuple(attention_layout.shape)} doesn't match expected {expected_layout_shape}")
73
+
74
+ if scale is None:
75
+ scale = 1.0 / math.sqrt(head_dim)
76
+
77
+ if lut is None:
78
+ lut = flash_attention_build_lut(attention_layout, n_seq_blocks_q, n_seq_blocks_k)
79
+
80
+ has_mask = attention_mask is not None
81
+ if has_mask:
82
+ if attention_mask.shape != (n_batches, seq_q, seq_k):
83
+ raise ValueError(f"attention_mask shape {tuple(attention_mask.shape)} doesn't match expected ({n_batches}, {seq_q}, {seq_k})")
84
+ attention_mask_additive = torch.where(
85
+ attention_mask,
86
+ torch.tensor(float("-inf"), device=attention_mask.device, dtype=q.dtype),
87
+ torch.tensor(0.0, device=attention_mask.device, dtype=q.dtype)
88
+ ).contiguous()
89
+ else:
90
+ attention_mask_additive = torch.empty(0, device=q.device, dtype=q.dtype)
91
+
92
+ return BlockSparseFlashAttention.apply(
93
+ q, k, v,
94
+ attention_mask_additive,
95
+ lut["attn_lut"], lut["attn_offsets"],
96
+ lut["rev_attn_lut"], lut["rev_attn_offsets"],
97
+ sparsity_block_size, n_seq_blocks_q, n_seq_blocks_k,
98
+ lut["max_kv_blocks"], lut["max_q_per_k"],
99
+ scale, has_mask,
100
+ )
101
+
102
+
103
+ class BlockSparseFlashAttention(torch.autograd.Function):
104
+ """Block-sparse Flash Attention with autograd support."""
105
+
106
+ @staticmethod
107
+ def forward(ctx, q, k, v, attention_mask, attn_lut, attn_offsets, rev_attn_lut, rev_attn_offsets,
108
+ sparsity_block_size, n_seq_blocks_q, n_seq_blocks_k, max_kv_blocks, max_q_per_k, scale, has_mask):
109
+ batch, seq_q, n_heads, head_dim = q.shape
110
+ _, seq_k, _, _ = k.shape
111
+ n_batches = batch * n_heads
112
+
113
+ q_flat = q.permute(0, 2, 1, 3).reshape(n_batches, seq_q, head_dim).contiguous()
114
+ k_flat = k.permute(0, 2, 1, 3).reshape(n_batches, seq_k, head_dim).contiguous()
115
+ v_flat = v.permute(0, 2, 1, 3).reshape(n_batches, seq_k, head_dim).contiguous()
116
+
117
+ o_flat = torch.empty_like(q_flat)
118
+ lse = torch.empty(n_batches, seq_q, device=q.device, dtype=torch.float32)
119
+ l = torch.empty(n_batches, seq_q, device=q.device, dtype=torch.float32)
120
+
121
+ if head_dim <= 64:
122
+ BLOCK_M = min(128, sparsity_block_size)
123
+ elif head_dim <= 128:
124
+ BLOCK_M = min(64, sparsity_block_size)
125
+ else:
126
+ BLOCK_M = min(32, sparsity_block_size)
127
+ BLOCK_N = sparsity_block_size
128
+
129
+ n_m_tiles = seq_q // BLOCK_M
130
+ grid = (n_m_tiles, n_batches)
131
+
132
+ if has_mask:
133
+ mask_stride_batch = attention_mask.stride(0)
134
+ mask_stride_row = attention_mask.stride(1)
135
+ mask_stride_col = attention_mask.stride(2)
136
+ else:
137
+ mask_stride_batch = 0
138
+ mask_stride_row = 0
139
+ mask_stride_col = 0
140
+
141
+ flash_attention_fwd_kernel[grid](
142
+ q_flat, k_flat, v_flat, o_flat,
143
+ attention_mask if has_mask else q_flat,
144
+ attn_lut, attn_offsets,
145
+ lse, l,
146
+ q_flat.stride(0), q_flat.stride(1), q_flat.stride(2),
147
+ k_flat.stride(0), k_flat.stride(1), k_flat.stride(2),
148
+ mask_stride_batch, mask_stride_row, mask_stride_col,
149
+ n_batches, seq_q, seq_k, head_dim, sparsity_block_size, n_seq_blocks_q, max_kv_blocks,
150
+ scale,
151
+ has_mask,
152
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
153
+ num_stages=4, num_warps=4,
154
+ )
155
+
156
+ o = o_flat.reshape(batch, n_heads, seq_q, head_dim).permute(0, 2, 1, 3).contiguous()
157
+
158
+ ctx.save_for_backward(q_flat, k_flat, v_flat, o_flat, lse,
159
+ attn_lut, attn_offsets, rev_attn_lut, rev_attn_offsets,
160
+ attention_mask if has_mask else torch.empty(0, device=q.device))
161
+ ctx.sparsity_block_size = sparsity_block_size
162
+ ctx.n_seq_blocks_q = n_seq_blocks_q
163
+ ctx.n_seq_blocks_k = n_seq_blocks_k
164
+ ctx.max_kv_blocks = max_kv_blocks
165
+ ctx.max_q_per_k = max_q_per_k
166
+ ctx.scale = scale
167
+ ctx.has_mask = has_mask
168
+ ctx.batch = batch
169
+ ctx.n_heads = n_heads
170
+ ctx.seq_q = seq_q
171
+ ctx.seq_k = seq_k
172
+ ctx.head_dim = head_dim
173
+ ctx.BLOCK_M = BLOCK_M
174
+ ctx.BLOCK_N = BLOCK_N
175
+
176
+ return o
177
+
178
+ @staticmethod
179
+ def backward(ctx, grad_output):
180
+ (q_flat, k_flat, v_flat, o_flat, lse,
181
+ attn_lut, attn_offsets, rev_attn_lut, rev_attn_offsets, attention_mask) = ctx.saved_tensors
182
+
183
+ batch = ctx.batch
184
+ n_heads = ctx.n_heads
185
+ seq_q = ctx.seq_q
186
+ seq_k = ctx.seq_k
187
+ head_dim = ctx.head_dim
188
+ n_batches = batch * n_heads
189
+ sparsity_block_size = ctx.sparsity_block_size
190
+ BLOCK_M = ctx.BLOCK_M
191
+ BLOCK_N = ctx.BLOCK_N
192
+ has_mask = ctx.has_mask
193
+
194
+ do_flat = grad_output.permute(0, 2, 1, 3).reshape(n_batches, seq_q, head_dim).contiguous()
195
+
196
+ dq_flat = torch.zeros_like(q_flat)
197
+ dk_flat = torch.zeros_like(k_flat)
198
+ dv_flat = torch.zeros_like(v_flat)
199
+ delta = torch.empty(n_batches, seq_q, device=q_flat.device, dtype=torch.float32)
200
+
201
+ if has_mask:
202
+ mask_stride_batch = attention_mask.stride(0)
203
+ mask_stride_row = attention_mask.stride(1)
204
+ mask_stride_col = attention_mask.stride(2)
205
+ else:
206
+ mask_stride_batch = 0
207
+ mask_stride_row = 0
208
+ mask_stride_col = 0
209
+
210
+ n_m_tiles_q = seq_q // BLOCK_M
211
+ flash_attention_bwd_preprocess_kernel[(n_m_tiles_q, n_batches)](
212
+ o_flat, do_flat, delta,
213
+ o_flat.stride(0), o_flat.stride(1), o_flat.stride(2),
214
+ seq_q, head_dim,
215
+ BLOCK_M=BLOCK_M,
216
+ )
217
+
218
+ n_n_tiles_k = seq_k // BLOCK_N
219
+ flash_attention_bwd_dkdv_kernel[(n_n_tiles_k, n_batches)](
220
+ q_flat, k_flat, v_flat, do_flat,
221
+ dk_flat, dv_flat,
222
+ lse, delta,
223
+ attention_mask if has_mask else q_flat,
224
+ rev_attn_lut, rev_attn_offsets,
225
+ q_flat.stride(0), q_flat.stride(1),
226
+ k_flat.stride(0), k_flat.stride(1),
227
+ q_flat.stride(2),
228
+ mask_stride_batch, mask_stride_row, mask_stride_col,
229
+ n_batches, seq_q, seq_k, head_dim, sparsity_block_size, ctx.n_seq_blocks_k, ctx.max_q_per_k,
230
+ ctx.scale,
231
+ has_mask,
232
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
233
+ )
234
+
235
+ flash_attention_bwd_dq_kernel[(n_m_tiles_q, n_batches)](
236
+ q_flat, k_flat, v_flat, do_flat,
237
+ dq_flat,
238
+ lse, delta,
239
+ attention_mask if has_mask else q_flat,
240
+ attn_lut, attn_offsets,
241
+ q_flat.stride(0), q_flat.stride(1),
242
+ k_flat.stride(0), k_flat.stride(1),
243
+ q_flat.stride(2),
244
+ mask_stride_batch, mask_stride_row, mask_stride_col,
245
+ n_batches, seq_q, seq_k, head_dim, sparsity_block_size, ctx.n_seq_blocks_q, ctx.max_kv_blocks,
246
+ ctx.scale,
247
+ has_mask,
248
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
249
+ )
250
+
251
+ dq = dq_flat.reshape(batch, n_heads, seq_q, head_dim).permute(0, 2, 1, 3).contiguous()
252
+ dk = dk_flat.reshape(batch, n_heads, seq_k, head_dim).permute(0, 2, 1, 3).contiguous()
253
+ dv = dv_flat.reshape(batch, n_heads, seq_k, head_dim).permute(0, 2, 1, 3).contiguous()
254
+
255
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None
256
+
257
+
258
+ @triton.jit
259
+ def flash_attention_fwd_kernel(
260
+ q_ptr, k_ptr, v_ptr, o_ptr,
261
+ mask_ptr,
262
+ attn_lut_ptr, attn_offsets_ptr,
263
+ m_ptr, l_ptr,
264
+ stride_q_batch, stride_q_seq, stride_q_dim,
265
+ stride_kv_batch, stride_kv_seq, stride_kv_dim,
266
+ stride_mask_batch, stride_mask_row, stride_mask_col,
267
+ n_batches: tl.constexpr,
268
+ seq_q: tl.constexpr,
269
+ seq_k: tl.constexpr,
270
+ head_dim: tl.constexpr,
271
+ sparsity_block_size: tl.constexpr,
272
+ n_seq_blocks_q: tl.constexpr,
273
+ max_kv_blocks: tl.constexpr,
274
+ scale,
275
+ has_mask: tl.constexpr,
276
+ BLOCK_M: tl.constexpr,
277
+ BLOCK_N: tl.constexpr,
278
+ ):
279
+ """Flash attention forward kernel with block-sparse mask support."""
280
+ pid_m = tl.program_id(0)
281
+ pid_batch = tl.program_id(1)
282
+
283
+ n_m_tiles: tl.constexpr = sparsity_block_size // BLOCK_M
284
+ n_n_tiles: tl.constexpr = sparsity_block_size // BLOCK_N
285
+
286
+ q_seq_block = pid_m // n_m_tiles
287
+ m_tile_idx = pid_m % n_m_tiles
288
+
289
+ q_row_start = q_seq_block * sparsity_block_size + m_tile_idx * BLOCK_M
290
+ offs_m = q_row_start + tl.arange(0, BLOCK_M)
291
+ offs_d = tl.arange(0, head_dim)
292
+
293
+ q_ptrs = q_ptr + pid_batch * stride_q_batch + offs_m[:, None] * stride_q_seq + offs_d[None, :]
294
+ q_mask = offs_m[:, None] < seq_q
295
+ q = tl.load(q_ptrs, mask=q_mask, other=0.0)
296
+
297
+ m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
298
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
299
+ acc = tl.zeros([BLOCK_M, head_dim], dtype=tl.float32)
300
+
301
+ qk_scale = scale * 1.44269504
302
+
303
+ attn_offset_idx = pid_batch * n_seq_blocks_q + q_seq_block
304
+ attn_start = tl.load(attn_offsets_ptr + attn_offset_idx)
305
+ attn_end = tl.load(attn_offsets_ptr + attn_offset_idx + 1)
306
+ n_kv_blocks = attn_end - attn_start
307
+
308
+ for kv_idx in range(max_kv_blocks):
309
+ if kv_idx < n_kv_blocks:
310
+ k_seq_block = tl.load(attn_lut_ptr + attn_start + kv_idx)
311
+
312
+ k_row_start = k_seq_block * sparsity_block_size
313
+ offs_n = k_row_start + tl.arange(0, BLOCK_N)
314
+
315
+ k_ptrs = k_ptr + pid_batch * stride_kv_batch + offs_n[:, None] * stride_kv_seq + offs_d[None, :]
316
+ k_mask = offs_n[:, None] < seq_k
317
+ k = tl.load(k_ptrs, mask=k_mask, other=0.0)
318
+
319
+ qk = tl.dot(q, tl.trans(k)) * qk_scale
320
+
321
+ if has_mask:
322
+ mask_ptrs = mask_ptr + pid_batch * stride_mask_batch + offs_m[:, None] * stride_mask_row + offs_n[None, :] * stride_mask_col
323
+ mask_vals = tl.load(mask_ptrs, mask=(offs_m[:, None] < seq_q) & (offs_n[None, :] < seq_k), other=0.0)
324
+ qk = qk + mask_vals
325
+
326
+ m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
327
+ alpha = tl.math.exp2(m_i - m_ij)
328
+ p = tl.math.exp2(qk - m_ij[:, None])
329
+ l_i = l_i * alpha + tl.sum(p, axis=1)
330
+ acc = acc * alpha[:, None]
331
+
332
+ v_ptrs = v_ptr + pid_batch * stride_kv_batch + offs_n[:, None] * stride_kv_seq + offs_d[None, :]
333
+ v = tl.load(v_ptrs, mask=k_mask, other=0.0)
334
+ acc = tl.dot(p.to(v.dtype), v, acc)
335
+
336
+ m_i = m_ij
337
+
338
+ has_attention = l_i > 0
339
+ l_safe = tl.where(has_attention, l_i, 1.0)
340
+ acc = acc / l_safe[:, None]
341
+ acc = tl.where(has_attention[:, None], acc, 0.0)
342
+
343
+ o_ptrs = o_ptr + pid_batch * stride_q_batch + offs_m[:, None] * stride_q_seq + offs_d[None, :]
344
+ tl.store(o_ptrs, acc.to(o_ptr.dtype.element_ty), mask=offs_m[:, None] < seq_q)
345
+
346
+ lse = tl.where(has_attention, m_i + tl.math.log2(l_safe), float("-inf"))
347
+ tl.store(m_ptr + pid_batch * seq_q + offs_m, lse, mask=offs_m < seq_q)
348
+ tl.store(l_ptr + pid_batch * seq_q + offs_m, l_i, mask=offs_m < seq_q)
349
+
350
+
351
+ @triton.jit
352
+ def flash_attention_bwd_preprocess_kernel(
353
+ o_ptr, do_ptr, delta_ptr,
354
+ stride_batch, stride_seq, stride_dim,
355
+ seq_len: tl.constexpr,
356
+ head_dim: tl.constexpr,
357
+ BLOCK_M: tl.constexpr,
358
+ ):
359
+ """Compute delta = (O * dO).sum(dim=-1)."""
360
+ pid_m = tl.program_id(0)
361
+ pid_batch = tl.program_id(1)
362
+
363
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
364
+ offs_d = tl.arange(0, head_dim)
365
+
366
+ o_ptrs = o_ptr + pid_batch * stride_batch + offs_m[:, None] * stride_seq + offs_d[None, :]
367
+ do_ptrs = do_ptr + pid_batch * stride_batch + offs_m[:, None] * stride_seq + offs_d[None, :]
368
+ mask = offs_m[:, None] < seq_len
369
+
370
+ o = tl.load(o_ptrs, mask=mask, other=0.0).to(tl.float32)
371
+ do = tl.load(do_ptrs, mask=mask, other=0.0).to(tl.float32)
372
+ delta = tl.sum(o * do, axis=1)
373
+
374
+ tl.store(delta_ptr + pid_batch * seq_len + offs_m, delta, mask=offs_m < seq_len)
375
+
376
+
377
+ @triton.jit
378
+ def flash_attention_bwd_dkdv_kernel(
379
+ q_ptr, k_ptr, v_ptr, do_ptr,
380
+ dk_ptr, dv_ptr,
381
+ lse_ptr, delta_ptr,
382
+ mask_ptr,
383
+ rev_attn_lut_ptr, rev_attn_offsets_ptr,
384
+ stride_q_batch, stride_q_seq,
385
+ stride_kv_batch, stride_kv_seq,
386
+ stride_dim,
387
+ stride_mask_batch, stride_mask_row, stride_mask_col,
388
+ n_batches: tl.constexpr,
389
+ seq_q: tl.constexpr,
390
+ seq_k: tl.constexpr,
391
+ head_dim: tl.constexpr,
392
+ sparsity_block_size: tl.constexpr,
393
+ n_seq_blocks_k: tl.constexpr,
394
+ max_q_per_k: tl.constexpr,
395
+ scale,
396
+ has_mask: tl.constexpr,
397
+ BLOCK_M: tl.constexpr,
398
+ BLOCK_N: tl.constexpr,
399
+ ):
400
+ """Compute dK and dV gradients."""
401
+ pid_n = tl.program_id(0)
402
+ pid_batch = tl.program_id(1)
403
+
404
+ n_n_tiles = sparsity_block_size // BLOCK_N
405
+ n_m_tiles = sparsity_block_size // BLOCK_M
406
+
407
+ k_seq_block = pid_n // n_n_tiles
408
+ n_tile_idx = pid_n % n_n_tiles
409
+
410
+ k_row_start = k_seq_block * sparsity_block_size + n_tile_idx * BLOCK_N
411
+ offs_n = k_row_start + tl.arange(0, BLOCK_N)
412
+ offs_d = tl.arange(0, head_dim)
413
+
414
+ qk_scale = scale * 1.44269504
415
+
416
+ k_ptrs = k_ptr + pid_batch * stride_kv_batch + offs_n[:, None] * stride_kv_seq + offs_d[None, :]
417
+ v_ptrs = v_ptr + pid_batch * stride_kv_batch + offs_n[:, None] * stride_kv_seq + offs_d[None, :]
418
+ k_mask = offs_n[:, None] < seq_k
419
+ k = tl.load(k_ptrs, mask=k_mask, other=0.0)
420
+ v = tl.load(v_ptrs, mask=k_mask, other=0.0)
421
+
422
+ dk = tl.zeros([BLOCK_N, head_dim], dtype=tl.float32)
423
+ dv = tl.zeros([BLOCK_N, head_dim], dtype=tl.float32)
424
+
425
+ rev_offset_idx = pid_batch * n_seq_blocks_k + k_seq_block
426
+ rev_start = tl.load(rev_attn_offsets_ptr + rev_offset_idx)
427
+ rev_end = tl.load(rev_attn_offsets_ptr + rev_offset_idx + 1)
428
+ n_q_blocks = rev_end - rev_start
429
+
430
+ for q_idx in range(max_q_per_k):
431
+ if q_idx < n_q_blocks:
432
+ q_seq_block = tl.load(rev_attn_lut_ptr + rev_start + q_idx)
433
+
434
+ for m_tile_idx in range(n_m_tiles):
435
+ q_row_start = q_seq_block * sparsity_block_size + m_tile_idx * BLOCK_M
436
+ offs_m = q_row_start + tl.arange(0, BLOCK_M)
437
+
438
+ q_ptrs = q_ptr + pid_batch * stride_q_batch + offs_m[:, None] * stride_q_seq + offs_d[None, :]
439
+ do_ptrs = do_ptr + pid_batch * stride_q_batch + offs_m[:, None] * stride_q_seq + offs_d[None, :]
440
+ q_mask = offs_m[:, None] < seq_q
441
+ q = tl.load(q_ptrs, mask=q_mask, other=0.0)
442
+ do = tl.load(do_ptrs, mask=q_mask, other=0.0)
443
+
444
+ m = tl.load(lse_ptr + pid_batch * seq_q + offs_m, mask=offs_m < seq_q, other=0.0)
445
+ Di = tl.load(delta_ptr + pid_batch * seq_q + offs_m, mask=offs_m < seq_q, other=0.0)
446
+
447
+ qk = tl.dot(q, tl.trans(k)) * qk_scale
448
+
449
+ if has_mask:
450
+ mask_ptrs = mask_ptr + pid_batch * stride_mask_batch + offs_m[:, None] * stride_mask_row + offs_n[None, :] * stride_mask_col
451
+ mask_vals = tl.load(mask_ptrs, mask=(offs_m[:, None] < seq_q) & (offs_n[None, :] < seq_k), other=0.0)
452
+ qk = qk + mask_vals
453
+
454
+ valid_lse = m > float("-inf")
455
+ safe_m = tl.where(valid_lse, m, 0.0)
456
+ p = tl.math.exp2(qk - safe_m[:, None])
457
+ p = tl.where(valid_lse[:, None], p, 0.0)
458
+
459
+ dv += tl.dot(tl.trans(p.to(do.dtype)), do)
460
+ dp = tl.dot(do, tl.trans(v))
461
+ ds = p * (dp - Di[:, None])
462
+ dk += tl.dot(tl.trans(ds.to(q.dtype)), q)
463
+
464
+ dk = dk * scale
465
+ tl.store(dk_ptr + pid_batch * stride_kv_batch + offs_n[:, None] * stride_kv_seq + offs_d[None, :], dk.to(dk_ptr.dtype.element_ty), mask=k_mask)
466
+ tl.store(dv_ptr + pid_batch * stride_kv_batch + offs_n[:, None] * stride_kv_seq + offs_d[None, :], dv.to(dv_ptr.dtype.element_ty), mask=k_mask)
467
+
468
+
469
+ @triton.jit
470
+ def flash_attention_bwd_dq_kernel(
471
+ q_ptr, k_ptr, v_ptr, do_ptr,
472
+ dq_ptr,
473
+ lse_ptr, delta_ptr,
474
+ mask_ptr,
475
+ attn_lut_ptr, attn_offsets_ptr,
476
+ stride_q_batch, stride_q_seq,
477
+ stride_kv_batch, stride_kv_seq,
478
+ stride_dim,
479
+ stride_mask_batch, stride_mask_row, stride_mask_col,
480
+ n_batches: tl.constexpr,
481
+ seq_q: tl.constexpr,
482
+ seq_k: tl.constexpr,
483
+ head_dim: tl.constexpr,
484
+ sparsity_block_size: tl.constexpr,
485
+ n_seq_blocks_q: tl.constexpr,
486
+ max_kv_blocks: tl.constexpr,
487
+ scale,
488
+ has_mask: tl.constexpr,
489
+ BLOCK_M: tl.constexpr,
490
+ BLOCK_N: tl.constexpr,
491
+ ):
492
+ """Compute dQ gradients."""
493
+ pid_m = tl.program_id(0)
494
+ pid_batch = tl.program_id(1)
495
+
496
+ n_m_tiles = sparsity_block_size // BLOCK_M
497
+ n_n_tiles = sparsity_block_size // BLOCK_N
498
+
499
+ q_seq_block = pid_m // n_m_tiles
500
+ m_tile_idx = pid_m % n_m_tiles
501
+
502
+ q_row_start = q_seq_block * sparsity_block_size + m_tile_idx * BLOCK_M
503
+ offs_m = q_row_start + tl.arange(0, BLOCK_M)
504
+ offs_d = tl.arange(0, head_dim)
505
+
506
+ qk_scale = scale * 1.44269504
507
+
508
+ q_ptrs = q_ptr + pid_batch * stride_q_batch + offs_m[:, None] * stride_q_seq + offs_d[None, :]
509
+ do_ptrs = do_ptr + pid_batch * stride_q_batch + offs_m[:, None] * stride_q_seq + offs_d[None, :]
510
+ q_mask = offs_m[:, None] < seq_q
511
+ q = tl.load(q_ptrs, mask=q_mask, other=0.0)
512
+ do = tl.load(do_ptrs, mask=q_mask, other=0.0)
513
+
514
+ m = tl.load(lse_ptr + pid_batch * seq_q + offs_m, mask=offs_m < seq_q, other=0.0)
515
+ Di = tl.load(delta_ptr + pid_batch * seq_q + offs_m, mask=offs_m < seq_q, other=0.0)
516
+
517
+ dq = tl.zeros([BLOCK_M, head_dim], dtype=tl.float32)
518
+
519
+ attn_offset_idx = pid_batch * n_seq_blocks_q + q_seq_block
520
+ attn_start = tl.load(attn_offsets_ptr + attn_offset_idx)
521
+ attn_end = tl.load(attn_offsets_ptr + attn_offset_idx + 1)
522
+ n_kv_blocks = attn_end - attn_start
523
+
524
+ for kv_idx in range(max_kv_blocks):
525
+ if kv_idx < n_kv_blocks:
526
+ k_seq_block = tl.load(attn_lut_ptr + attn_start + kv_idx)
527
+
528
+ for n_tile_idx in range(n_n_tiles):
529
+ k_row_start = k_seq_block * sparsity_block_size + n_tile_idx * BLOCK_N
530
+ offs_n = k_row_start + tl.arange(0, BLOCK_N)
531
+
532
+ k_ptrs = k_ptr + pid_batch * stride_kv_batch + offs_n[:, None] * stride_kv_seq + offs_d[None, :]
533
+ v_ptrs = v_ptr + pid_batch * stride_kv_batch + offs_n[:, None] * stride_kv_seq + offs_d[None, :]
534
+ k_mask = offs_n[:, None] < seq_k
535
+ k = tl.load(k_ptrs, mask=k_mask, other=0.0)
536
+ v = tl.load(v_ptrs, mask=k_mask, other=0.0)
537
+
538
+ qk = tl.dot(q, tl.trans(k)) * qk_scale
539
+
540
+ if has_mask:
541
+ mask_ptrs = mask_ptr + pid_batch * stride_mask_batch + offs_m[:, None] * stride_mask_row + offs_n[None, :] * stride_mask_col
542
+ mask_vals = tl.load(mask_ptrs, mask=(offs_m[:, None] < seq_q) & (offs_n[None, :] < seq_k), other=0.0)
543
+ qk = qk + mask_vals
544
+
545
+ valid_lse = m > float("-inf")
546
+ safe_m = tl.where(valid_lse, m, 0.0)
547
+ p = tl.math.exp2(qk - safe_m[:, None])
548
+ p = tl.where(valid_lse[:, None], p, 0.0)
549
+
550
+ dp = tl.dot(do, tl.trans(v))
551
+ ds = p * (dp - Di[:, None])
552
+ dq += tl.dot(ds.to(k.dtype), k)
553
+
554
+ dq = dq * scale
555
+ tl.store(dq_ptr + pid_batch * stride_q_batch + offs_m[:, None] * stride_q_seq + offs_d[None, :], dq.to(dq_ptr.dtype.element_ty), mask=q_mask)
556
+
557
+
558
+ def flash_attention_build_lut(
559
+ attention_layout: Tensor,
560
+ n_seq_blocks_q: int = None,
561
+ n_seq_blocks_k: int = None,
562
+ ) -> dict:
563
+ """Build attention LUTs for reuse across multiple calls."""
564
+ n_batches = attention_layout.shape[0]
565
+ if n_seq_blocks_q is None:
566
+ n_seq_blocks_q = attention_layout.shape[1]
567
+ if n_seq_blocks_k is None:
568
+ n_seq_blocks_k = attention_layout.shape[2]
569
+
570
+ attn_lut, attn_offsets, max_kv_blocks = _build_attention_lut_fast(
571
+ attention_layout, n_batches, n_seq_blocks_q, n_seq_blocks_k
572
+ )
573
+
574
+ attention_layout_t = attention_layout.transpose(1, 2).contiguous()
575
+ rev_attn_lut, rev_attn_offsets, max_q_per_k = _build_attention_lut_fast(
576
+ attention_layout_t, n_batches, n_seq_blocks_k, n_seq_blocks_q
577
+ )
578
+
579
+ return {
580
+ "attn_lut": attn_lut,
581
+ "attn_offsets": attn_offsets,
582
+ "max_kv_blocks": max_kv_blocks,
583
+ "rev_attn_lut": rev_attn_lut,
584
+ "rev_attn_offsets": rev_attn_offsets,
585
+ "max_q_per_k": max_q_per_k,
586
+ }
587
+
588
+
589
+ def _build_attention_lut_fast(
590
+ attention_layout: Tensor,
591
+ n_batches: int,
592
+ n_blocks_row: int,
593
+ n_blocks_col: int,
594
+ ) -> Tuple[Tensor, Tensor, int]:
595
+ """Build attention LUT efficiently."""
596
+ device = attention_layout.device
597
+
598
+ counts = attention_layout.sum(dim=2).flatten()
599
+ max_blocks_per_row = int(counts.max().item())
600
+
601
+ if max_blocks_per_row == 0:
602
+ offsets = torch.zeros(n_batches * n_blocks_row + 1, dtype=torch.int32, device=device)
603
+ lut = torch.empty(0, dtype=torch.int32, device=device)
604
+ return lut, offsets, 1
605
+
606
+ offsets = torch.zeros(n_batches * n_blocks_row + 1, dtype=torch.int32, device=device)
607
+ offsets[1:] = counts.cumsum(0).to(torch.int32)
608
+
609
+ indices = attention_layout.reshape(n_batches * n_blocks_row, n_blocks_col).nonzero(as_tuple=False)
610
+ lut = indices[:, 1].to(torch.int32)
611
+
612
+ return lut, offsets, max_blocks_per_row
@@ -65,7 +65,6 @@ def prune_autotune_configs_conversion(autotune_configs, kernel_args, **kwargs):
65
65
  return pruned_configs
66
66
 
67
67
 
68
- @torch.compile
69
68
  def get_autotune_configs():
70
69
  global autotune_parameters
71
70
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: blksprs
3
- Version: 2.1.10
3
+ Version: 2.2
4
4
  Summary: A lightweight library for operations on block-sparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -17,20 +17,13 @@ Requires-Dist: coverage; extra == "test"
17
17
  Requires-Dist: build; extra == "test"
18
18
  Requires-Dist: matplotlib; extra == "test"
19
19
 
20
- # blksprs
20
+ # 🧊 blksprs
21
21
 
22
22
  [![GitHub Release](https://img.shields.io/github/v/release/FelixSchoen/blksprs?include_prereleases&label=Latest%20Release)](https://github.com/FelixSchoen/blksprs/releases)
23
23
  [![Python 3.11](https://img.shields.io/badge/Python%20Version-3.11-blue)](https://www.python.org/downloads/release/python-3119/)
24
24
  [![Python 3.12](https://img.shields.io/badge/Python%20Version-3.12-blue)](https://www.python.org/downloads/release/python-31210/)
25
25
 
26
- ## Overview
27
-
28
- ### News
29
-
30
- 🎉 ***Version 2.0 released***. blksprs now supports kernel auto-tuning, JIT compilation, specification of pre-calculated
31
- LUTs, autocasting, and makes use of `torch.library.triton_op()`!
32
-
33
- ---
26
+ ## 📖 Overview
34
27
 
35
28
  A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
36
29
 
@@ -46,6 +39,7 @@ Currently supported operations (includes gradient calculation):
46
39
  - Splitting and merging of matrices (_currently* only supports splitting and merging along the last dimension_)
47
40
  - Conversion to and from sparse form
48
41
  - Conversion to different sparsity layouts and different sparsity block sizes
42
+ - Flash Attention (_supports custom masks and cross-attention_)
49
43
 
50
44
  As with this library sparse matrices are represented using a tuple of `(matrix, sparsity_layout, sparsity_block_size)`,
51
45
  any element-wise operations can be applied in regular torch-like fashion.
@@ -74,7 +68,7 @@ Furthermore, the library provides a set of utility functions
74
68
 
75
69
  _* see the [Roadmap](#roadmap) section for more information_
76
70
 
77
- ## Installation
71
+ ## 🛠️ Installation
78
72
 
79
73
  Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible with the Linux platform**.
80
74
  Keep track of this [issue](https://github.com/triton-lang/triton/issues/1640) for updates.
@@ -89,11 +83,11 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
89
83
  - _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.3.1)_
90
84
  - _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
91
85
 
92
- ## Changelog
86
+ ## 📝 Changelog
93
87
 
94
88
  See [`CHANGELOG.md`](https://github.com/FelixSchoen/blksprs/blob/main/CHANGELOG.md) for a detailed changelog.
95
89
 
96
- ## Roadmap
90
+ ## 🗺️ Roadmap
97
91
 
98
92
  Note that since this library covers all our current needs it is in a **bugfix-only** state.
99
93
  This means that there are no plans to add new features, e.g., support for dimension specification of the ``split`` and
@@ -105,17 +99,15 @@ We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
105
99
  It might be that this changes with future projects, but as of August 2025, we are content with the current state of the
106
100
  library.
107
101
 
108
- ## Known Limitations and Issues
102
+ ## ⚠️ Known Limitations and Issues
109
103
 
110
- - Triton has a bug with `tl.atomix_max()` used for the row-wise max operation.
111
- In order to work around this bug a manual conversion of some values is needed, (slightly) negatively impacting
112
- performance.
113
- Watch the [issue](https://github.com/triton-lang/triton/issues/6376) on Triton's issue tracker for more information.
114
104
  - There will be some slight numerical differences between vanilla and blksprs operations.
115
105
  These instabilities are due to Triton and thus cannot be fixed by this library alone.
116
106
  However, for all intents and purposes, these very minor differences should not matter and can safely be ignored.
117
107
 
118
- ## Usage
108
+ - Flash Attention is a recent addition. While it has been tested and appears stable, please report any issues you encounter.
109
+
110
+ ## 💻 Usage
119
111
 
120
112
  We provide an example below to demonstrate the usage of the library.
121
113
  For more detailed examples, please refer to
@@ -128,7 +120,6 @@ the [test cases](https://github.com/FelixSchoen/blksprs/blob/main/test/cases/tes
128
120
  import torch
129
121
  import blksprs as bs
130
122
 
131
-
132
123
  def test_readme():
133
124
  # Set up parameters (batch size, number of heads, dimensions for matrices (m, k) and (n, k))
134
125
  b, h, m, n, k = 2, 4, 64, 64, 16
@@ -193,10 +184,30 @@ def test_readme():
193
184
  # Other available functions
194
185
  bs.ops.transpose(o_sparse, sparsity_layout_o, sparsity_block_size)
195
186
  bs.ops.softmax(o_sparse, sparsity_layout_o, sparsity_block_size, flag_fused=False)
196
- bs.ops.softmax_fused(o_sparse, sparsity_layout_o, sparsity_block_size) # Significantly faster version that requires that rows of matrix fit into memory (default if flag is not set)
187
+ bs.ops.softmax_fused(o_sparse, sparsity_layout_o,
188
+ sparsity_block_size) # Significantly faster version that requires that rows of matrix fit into memory (default if flag is not set)
197
189
  bs.ops.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size)
198
190
  bs.ops.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size)
199
191
 
192
+ # Flash Attention
193
+ seq_len, head_dim = 512, 64
194
+ sparsity_block_size_attn = 128
195
+
196
+ q = torch.randn(b, seq_len, h, head_dim, device="cuda")
197
+ k = torch.randn(b, seq_len, h, head_dim, device="cuda")
198
+ v = torch.randn(b, seq_len, h, head_dim, device="cuda")
199
+
200
+ n_batches_attn = b * h
201
+ n_seq_blocks = seq_len // sparsity_block_size_attn
202
+ attention_layout = torch.tril(torch.ones(n_batches_attn, n_seq_blocks, n_seq_blocks, device="cuda", dtype=torch.bool))
203
+
204
+ lut = bs.ops.flash_attention_build_lut(attention_layout, n_seq_blocks, n_seq_blocks)
205
+
206
+ attn_out = bs.ops.flash_attention(q, k, v, attention_layout, sparsity_block_size_attn, lut=lut)
207
+
208
+ assert attn_out.shape == (b, seq_len, h, head_dim)
209
+
210
+
200
211
 
201
212
  def _get_random_sparsity_layout(b, m, n, sparsity_block_size, sparsity_percentage):
202
213
  """Helper function, creates a random sparsity layout for a given shape with a given percentage of blocks marked as sparse.
@@ -10,6 +10,7 @@ blksprs/layouting/distribution_layout.py
10
10
  blksprs/layouting/sparsity_layout.py
11
11
  blksprs/ops/conversion.py
12
12
  blksprs/ops/distribution.py
13
+ blksprs/ops/flash_attention.py
13
14
  blksprs/ops/flow.py
14
15
  blksprs/ops/matmul.py
15
16
  blksprs/ops/partitioning.py
@@ -1,15 +1,12 @@
1
1
  [project]
2
2
  name = "blksprs"
3
- version = "2.1.10"
3
+ version = "2.2"
4
4
  authors = [{ name = "Felix Schön", email = "schoen@kr.tuwien.ac.at" }]
5
5
  description = "A lightweight library for operations on block-sparse matrices in PyTorch."
6
6
  readme = "README.md"
7
7
  requires-python = ">=3.11"
8
8
  license = { file = "LICENSE.md" }
9
- dependencies = [
10
- "torch >= 2.8.0",
11
- "numpy"
12
- ]
9
+ dependencies = ["torch >= 2.8.0", "numpy"]
13
10
 
14
11
  [project.urls]
15
12
  "Homepage" = "https://github.com/FelixSchoen/blksprs"
@@ -22,12 +19,15 @@ test = [
22
19
  "pytest-cov",
23
20
  "coverage",
24
21
  "build",
25
- "matplotlib"
22
+ "matplotlib",
26
23
  ]
27
24
 
28
25
  [build-system]
29
26
  requires = ["setuptools", "wheel"]
30
27
  build-backend = "setuptools.build_meta"
31
28
 
29
+ [tool.setuptools.packages.find]
30
+ include = ["blksprs*"]
31
+
32
32
  [tool.setuptools.package-data]
33
- "*" = ["*.json", "*.conf"]
33
+ "*" = ["*.json", "*.conf"]
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes