blksprs 2.1.9__py3-none-any.whl → 2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +2 -1
- blksprs/ops/distribution.py +3 -3
- blksprs/ops/flash_attention.py +612 -0
- blksprs/utils/autotuning.py +0 -1
- blksprs/utils/tools.py +3 -1
- {blksprs-2.1.9.dist-info → blksprs-2.2.dist-info}/METADATA +32 -21
- {blksprs-2.1.9.dist-info → blksprs-2.2.dist-info}/RECORD +9 -8
- {blksprs-2.1.9.dist-info → blksprs-2.2.dist-info}/WHEEL +1 -1
- {blksprs-2.1.9.dist-info → blksprs-2.2.dist-info}/top_level.txt +0 -0
blksprs/__init__.py
CHANGED
|
@@ -4,7 +4,7 @@ import torch
|
|
|
4
4
|
# Capture scalar outputs for JIT compilation
|
|
5
5
|
torch._dynamo.config.capture_scalar_outputs = True
|
|
6
6
|
# Set version
|
|
7
|
-
__version__ = "2.
|
|
7
|
+
__version__ = "2.2"
|
|
8
8
|
|
|
9
9
|
# Imports
|
|
10
10
|
|
|
@@ -14,6 +14,7 @@ from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
|
14
14
|
class ops:
|
|
15
15
|
from blksprs.ops.conversion import to_dense, to_sparse, from_blksprs, to_blksprs, adapt_layout
|
|
16
16
|
from blksprs.ops.distribution import gather, scatter, scatter_reduce
|
|
17
|
+
from blksprs.ops.flash_attention import flash_attention, flash_attention_build_lut
|
|
17
18
|
from blksprs.ops.matmul import matmul
|
|
18
19
|
from blksprs.ops.softmax import softmax, softmax_fused
|
|
19
20
|
from blksprs.ops.transpose import transpose
|
blksprs/ops/distribution.py
CHANGED
|
@@ -174,7 +174,7 @@ def gather_kernel(x,
|
|
|
174
174
|
dst_col_x)
|
|
175
175
|
blk_x_msk = (((blk_x_idx >= 0) &
|
|
176
176
|
(blk_x_idx < x_b * x_b_s)) &
|
|
177
|
-
(
|
|
177
|
+
(rev_idx_spa_x >= 0))
|
|
178
178
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
179
179
|
|
|
180
180
|
# Store output
|
|
@@ -183,7 +183,7 @@ def gather_kernel(x,
|
|
|
183
183
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
184
184
|
blk_o_msk = (((blk_o_idx >= 0) &
|
|
185
185
|
(blk_o_idx < o_b * o_b_s)) &
|
|
186
|
-
(
|
|
186
|
+
(rev_idx_spa_x >= 0))
|
|
187
187
|
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
188
188
|
|
|
189
189
|
|
|
@@ -426,7 +426,7 @@ def scatter_reduce_kernel(x,
|
|
|
426
426
|
dst_col_o)
|
|
427
427
|
blk_o_msk = (((blk_o_idx >= 0) &
|
|
428
428
|
(blk_o_idx < o_b * o_b_s)) &
|
|
429
|
-
(
|
|
429
|
+
(rev_idx_spa_o >= 0))
|
|
430
430
|
|
|
431
431
|
if reduce_op_ind == 0:
|
|
432
432
|
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
@@ -0,0 +1,612 @@
|
|
|
1
|
+
"""Block-sparse Flash Attention implementation for blksprs.
|
|
2
|
+
|
|
3
|
+
This module implements Flash Attention 2 algorithm with block-sparse support,
|
|
4
|
+
including cross-attention (seq_q != seq_k) and custom attention masks.
|
|
5
|
+
|
|
6
|
+
Note: This implementation was developed with AI assistance.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import math
|
|
10
|
+
from typing import Tuple
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
import triton
|
|
14
|
+
from torch import Tensor
|
|
15
|
+
from triton import language as tl
|
|
16
|
+
|
|
17
|
+
from blksprs.utils.validation import validate_contiguous, validate_device, validate_dtype_float, ensure_contiguous
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
21
|
+
def flash_attention(
|
|
22
|
+
q: Tensor,
|
|
23
|
+
k: Tensor,
|
|
24
|
+
v: Tensor,
|
|
25
|
+
attention_layout: Tensor,
|
|
26
|
+
sparsity_block_size: int,
|
|
27
|
+
scale: float = None,
|
|
28
|
+
attention_mask: Tensor = None,
|
|
29
|
+
lut: dict = None,
|
|
30
|
+
) -> Tensor:
|
|
31
|
+
"""Block-sparse flash attention with optional attention mask.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
q: Query tensor [batch, seq_q, n_heads, head_dim]
|
|
35
|
+
k: Key tensor [batch, seq_k, n_heads, head_dim]
|
|
36
|
+
v: Value tensor [batch, seq_k, n_heads, head_dim]
|
|
37
|
+
attention_layout: Block attention pattern [batch*heads, n_seq_blocks_q, n_seq_blocks_k]
|
|
38
|
+
sparsity_block_size: Block size for sparsity pattern
|
|
39
|
+
scale: Attention scale (default: 1/sqrt(head_dim))
|
|
40
|
+
attention_mask: Boolean mask [batch*heads, seq_q, seq_k] where True=masked (default None)
|
|
41
|
+
lut: Optional pre-computed LUT dictionary
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
Output tensor [batch, seq_q, n_heads, head_dim]
|
|
45
|
+
"""
|
|
46
|
+
q, k, v = ensure_contiguous(q, k, v)
|
|
47
|
+
|
|
48
|
+
validate_contiguous(q, k, v)
|
|
49
|
+
validate_dtype_float(q, k, v)
|
|
50
|
+
validate_device(q, k, v)
|
|
51
|
+
|
|
52
|
+
batch, seq_q, n_heads, head_dim = q.shape
|
|
53
|
+
_, seq_k, _, _ = k.shape
|
|
54
|
+
|
|
55
|
+
if k.shape[0] != batch or k.shape[2] != n_heads or k.shape[3] != head_dim:
|
|
56
|
+
raise ValueError("K must have compatible shape with Q")
|
|
57
|
+
if v.shape != k.shape:
|
|
58
|
+
raise ValueError("V must have same shape as K")
|
|
59
|
+
if not (sparsity_block_size >= 16 and (sparsity_block_size & (sparsity_block_size - 1)) == 0):
|
|
60
|
+
raise ValueError(f"sparsity_block_size must be power of 2 >= 16, got {sparsity_block_size}")
|
|
61
|
+
if seq_q % sparsity_block_size != 0:
|
|
62
|
+
raise ValueError(f"seq_q ({seq_q}) must be divisible by sparsity_block_size")
|
|
63
|
+
if seq_k % sparsity_block_size != 0:
|
|
64
|
+
raise ValueError(f"seq_k ({seq_k}) must be divisible by sparsity_block_size")
|
|
65
|
+
|
|
66
|
+
n_batches = batch * n_heads
|
|
67
|
+
n_seq_blocks_q = seq_q // sparsity_block_size
|
|
68
|
+
n_seq_blocks_k = seq_k // sparsity_block_size
|
|
69
|
+
|
|
70
|
+
expected_layout_shape = (n_batches, n_seq_blocks_q, n_seq_blocks_k)
|
|
71
|
+
if attention_layout.shape != expected_layout_shape:
|
|
72
|
+
raise ValueError(f"attention_layout shape {tuple(attention_layout.shape)} doesn't match expected {expected_layout_shape}")
|
|
73
|
+
|
|
74
|
+
if scale is None:
|
|
75
|
+
scale = 1.0 / math.sqrt(head_dim)
|
|
76
|
+
|
|
77
|
+
if lut is None:
|
|
78
|
+
lut = flash_attention_build_lut(attention_layout, n_seq_blocks_q, n_seq_blocks_k)
|
|
79
|
+
|
|
80
|
+
has_mask = attention_mask is not None
|
|
81
|
+
if has_mask:
|
|
82
|
+
if attention_mask.shape != (n_batches, seq_q, seq_k):
|
|
83
|
+
raise ValueError(f"attention_mask shape {tuple(attention_mask.shape)} doesn't match expected ({n_batches}, {seq_q}, {seq_k})")
|
|
84
|
+
attention_mask_additive = torch.where(
|
|
85
|
+
attention_mask,
|
|
86
|
+
torch.tensor(float("-inf"), device=attention_mask.device, dtype=q.dtype),
|
|
87
|
+
torch.tensor(0.0, device=attention_mask.device, dtype=q.dtype)
|
|
88
|
+
).contiguous()
|
|
89
|
+
else:
|
|
90
|
+
attention_mask_additive = torch.empty(0, device=q.device, dtype=q.dtype)
|
|
91
|
+
|
|
92
|
+
return BlockSparseFlashAttention.apply(
|
|
93
|
+
q, k, v,
|
|
94
|
+
attention_mask_additive,
|
|
95
|
+
lut["attn_lut"], lut["attn_offsets"],
|
|
96
|
+
lut["rev_attn_lut"], lut["rev_attn_offsets"],
|
|
97
|
+
sparsity_block_size, n_seq_blocks_q, n_seq_blocks_k,
|
|
98
|
+
lut["max_kv_blocks"], lut["max_q_per_k"],
|
|
99
|
+
scale, has_mask,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class BlockSparseFlashAttention(torch.autograd.Function):
|
|
104
|
+
"""Block-sparse Flash Attention with autograd support."""
|
|
105
|
+
|
|
106
|
+
@staticmethod
|
|
107
|
+
def forward(ctx, q, k, v, attention_mask, attn_lut, attn_offsets, rev_attn_lut, rev_attn_offsets,
|
|
108
|
+
sparsity_block_size, n_seq_blocks_q, n_seq_blocks_k, max_kv_blocks, max_q_per_k, scale, has_mask):
|
|
109
|
+
batch, seq_q, n_heads, head_dim = q.shape
|
|
110
|
+
_, seq_k, _, _ = k.shape
|
|
111
|
+
n_batches = batch * n_heads
|
|
112
|
+
|
|
113
|
+
q_flat = q.permute(0, 2, 1, 3).reshape(n_batches, seq_q, head_dim).contiguous()
|
|
114
|
+
k_flat = k.permute(0, 2, 1, 3).reshape(n_batches, seq_k, head_dim).contiguous()
|
|
115
|
+
v_flat = v.permute(0, 2, 1, 3).reshape(n_batches, seq_k, head_dim).contiguous()
|
|
116
|
+
|
|
117
|
+
o_flat = torch.empty_like(q_flat)
|
|
118
|
+
lse = torch.empty(n_batches, seq_q, device=q.device, dtype=torch.float32)
|
|
119
|
+
l = torch.empty(n_batches, seq_q, device=q.device, dtype=torch.float32)
|
|
120
|
+
|
|
121
|
+
if head_dim <= 64:
|
|
122
|
+
BLOCK_M = min(128, sparsity_block_size)
|
|
123
|
+
elif head_dim <= 128:
|
|
124
|
+
BLOCK_M = min(64, sparsity_block_size)
|
|
125
|
+
else:
|
|
126
|
+
BLOCK_M = min(32, sparsity_block_size)
|
|
127
|
+
BLOCK_N = sparsity_block_size
|
|
128
|
+
|
|
129
|
+
n_m_tiles = seq_q // BLOCK_M
|
|
130
|
+
grid = (n_m_tiles, n_batches)
|
|
131
|
+
|
|
132
|
+
if has_mask:
|
|
133
|
+
mask_stride_batch = attention_mask.stride(0)
|
|
134
|
+
mask_stride_row = attention_mask.stride(1)
|
|
135
|
+
mask_stride_col = attention_mask.stride(2)
|
|
136
|
+
else:
|
|
137
|
+
mask_stride_batch = 0
|
|
138
|
+
mask_stride_row = 0
|
|
139
|
+
mask_stride_col = 0
|
|
140
|
+
|
|
141
|
+
flash_attention_fwd_kernel[grid](
|
|
142
|
+
q_flat, k_flat, v_flat, o_flat,
|
|
143
|
+
attention_mask if has_mask else q_flat,
|
|
144
|
+
attn_lut, attn_offsets,
|
|
145
|
+
lse, l,
|
|
146
|
+
q_flat.stride(0), q_flat.stride(1), q_flat.stride(2),
|
|
147
|
+
k_flat.stride(0), k_flat.stride(1), k_flat.stride(2),
|
|
148
|
+
mask_stride_batch, mask_stride_row, mask_stride_col,
|
|
149
|
+
n_batches, seq_q, seq_k, head_dim, sparsity_block_size, n_seq_blocks_q, max_kv_blocks,
|
|
150
|
+
scale,
|
|
151
|
+
has_mask,
|
|
152
|
+
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
|
|
153
|
+
num_stages=4, num_warps=4,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
o = o_flat.reshape(batch, n_heads, seq_q, head_dim).permute(0, 2, 1, 3).contiguous()
|
|
157
|
+
|
|
158
|
+
ctx.save_for_backward(q_flat, k_flat, v_flat, o_flat, lse,
|
|
159
|
+
attn_lut, attn_offsets, rev_attn_lut, rev_attn_offsets,
|
|
160
|
+
attention_mask if has_mask else torch.empty(0, device=q.device))
|
|
161
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
162
|
+
ctx.n_seq_blocks_q = n_seq_blocks_q
|
|
163
|
+
ctx.n_seq_blocks_k = n_seq_blocks_k
|
|
164
|
+
ctx.max_kv_blocks = max_kv_blocks
|
|
165
|
+
ctx.max_q_per_k = max_q_per_k
|
|
166
|
+
ctx.scale = scale
|
|
167
|
+
ctx.has_mask = has_mask
|
|
168
|
+
ctx.batch = batch
|
|
169
|
+
ctx.n_heads = n_heads
|
|
170
|
+
ctx.seq_q = seq_q
|
|
171
|
+
ctx.seq_k = seq_k
|
|
172
|
+
ctx.head_dim = head_dim
|
|
173
|
+
ctx.BLOCK_M = BLOCK_M
|
|
174
|
+
ctx.BLOCK_N = BLOCK_N
|
|
175
|
+
|
|
176
|
+
return o
|
|
177
|
+
|
|
178
|
+
@staticmethod
|
|
179
|
+
def backward(ctx, grad_output):
|
|
180
|
+
(q_flat, k_flat, v_flat, o_flat, lse,
|
|
181
|
+
attn_lut, attn_offsets, rev_attn_lut, rev_attn_offsets, attention_mask) = ctx.saved_tensors
|
|
182
|
+
|
|
183
|
+
batch = ctx.batch
|
|
184
|
+
n_heads = ctx.n_heads
|
|
185
|
+
seq_q = ctx.seq_q
|
|
186
|
+
seq_k = ctx.seq_k
|
|
187
|
+
head_dim = ctx.head_dim
|
|
188
|
+
n_batches = batch * n_heads
|
|
189
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
190
|
+
BLOCK_M = ctx.BLOCK_M
|
|
191
|
+
BLOCK_N = ctx.BLOCK_N
|
|
192
|
+
has_mask = ctx.has_mask
|
|
193
|
+
|
|
194
|
+
do_flat = grad_output.permute(0, 2, 1, 3).reshape(n_batches, seq_q, head_dim).contiguous()
|
|
195
|
+
|
|
196
|
+
dq_flat = torch.zeros_like(q_flat)
|
|
197
|
+
dk_flat = torch.zeros_like(k_flat)
|
|
198
|
+
dv_flat = torch.zeros_like(v_flat)
|
|
199
|
+
delta = torch.empty(n_batches, seq_q, device=q_flat.device, dtype=torch.float32)
|
|
200
|
+
|
|
201
|
+
if has_mask:
|
|
202
|
+
mask_stride_batch = attention_mask.stride(0)
|
|
203
|
+
mask_stride_row = attention_mask.stride(1)
|
|
204
|
+
mask_stride_col = attention_mask.stride(2)
|
|
205
|
+
else:
|
|
206
|
+
mask_stride_batch = 0
|
|
207
|
+
mask_stride_row = 0
|
|
208
|
+
mask_stride_col = 0
|
|
209
|
+
|
|
210
|
+
n_m_tiles_q = seq_q // BLOCK_M
|
|
211
|
+
flash_attention_bwd_preprocess_kernel[(n_m_tiles_q, n_batches)](
|
|
212
|
+
o_flat, do_flat, delta,
|
|
213
|
+
o_flat.stride(0), o_flat.stride(1), o_flat.stride(2),
|
|
214
|
+
seq_q, head_dim,
|
|
215
|
+
BLOCK_M=BLOCK_M,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
n_n_tiles_k = seq_k // BLOCK_N
|
|
219
|
+
flash_attention_bwd_dkdv_kernel[(n_n_tiles_k, n_batches)](
|
|
220
|
+
q_flat, k_flat, v_flat, do_flat,
|
|
221
|
+
dk_flat, dv_flat,
|
|
222
|
+
lse, delta,
|
|
223
|
+
attention_mask if has_mask else q_flat,
|
|
224
|
+
rev_attn_lut, rev_attn_offsets,
|
|
225
|
+
q_flat.stride(0), q_flat.stride(1),
|
|
226
|
+
k_flat.stride(0), k_flat.stride(1),
|
|
227
|
+
q_flat.stride(2),
|
|
228
|
+
mask_stride_batch, mask_stride_row, mask_stride_col,
|
|
229
|
+
n_batches, seq_q, seq_k, head_dim, sparsity_block_size, ctx.n_seq_blocks_k, ctx.max_q_per_k,
|
|
230
|
+
ctx.scale,
|
|
231
|
+
has_mask,
|
|
232
|
+
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
flash_attention_bwd_dq_kernel[(n_m_tiles_q, n_batches)](
|
|
236
|
+
q_flat, k_flat, v_flat, do_flat,
|
|
237
|
+
dq_flat,
|
|
238
|
+
lse, delta,
|
|
239
|
+
attention_mask if has_mask else q_flat,
|
|
240
|
+
attn_lut, attn_offsets,
|
|
241
|
+
q_flat.stride(0), q_flat.stride(1),
|
|
242
|
+
k_flat.stride(0), k_flat.stride(1),
|
|
243
|
+
q_flat.stride(2),
|
|
244
|
+
mask_stride_batch, mask_stride_row, mask_stride_col,
|
|
245
|
+
n_batches, seq_q, seq_k, head_dim, sparsity_block_size, ctx.n_seq_blocks_q, ctx.max_kv_blocks,
|
|
246
|
+
ctx.scale,
|
|
247
|
+
has_mask,
|
|
248
|
+
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
dq = dq_flat.reshape(batch, n_heads, seq_q, head_dim).permute(0, 2, 1, 3).contiguous()
|
|
252
|
+
dk = dk_flat.reshape(batch, n_heads, seq_k, head_dim).permute(0, 2, 1, 3).contiguous()
|
|
253
|
+
dv = dv_flat.reshape(batch, n_heads, seq_k, head_dim).permute(0, 2, 1, 3).contiguous()
|
|
254
|
+
|
|
255
|
+
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
@triton.jit
|
|
259
|
+
def flash_attention_fwd_kernel(
|
|
260
|
+
q_ptr, k_ptr, v_ptr, o_ptr,
|
|
261
|
+
mask_ptr,
|
|
262
|
+
attn_lut_ptr, attn_offsets_ptr,
|
|
263
|
+
m_ptr, l_ptr,
|
|
264
|
+
stride_q_batch, stride_q_seq, stride_q_dim,
|
|
265
|
+
stride_kv_batch, stride_kv_seq, stride_kv_dim,
|
|
266
|
+
stride_mask_batch, stride_mask_row, stride_mask_col,
|
|
267
|
+
n_batches: tl.constexpr,
|
|
268
|
+
seq_q: tl.constexpr,
|
|
269
|
+
seq_k: tl.constexpr,
|
|
270
|
+
head_dim: tl.constexpr,
|
|
271
|
+
sparsity_block_size: tl.constexpr,
|
|
272
|
+
n_seq_blocks_q: tl.constexpr,
|
|
273
|
+
max_kv_blocks: tl.constexpr,
|
|
274
|
+
scale,
|
|
275
|
+
has_mask: tl.constexpr,
|
|
276
|
+
BLOCK_M: tl.constexpr,
|
|
277
|
+
BLOCK_N: tl.constexpr,
|
|
278
|
+
):
|
|
279
|
+
"""Flash attention forward kernel with block-sparse mask support."""
|
|
280
|
+
pid_m = tl.program_id(0)
|
|
281
|
+
pid_batch = tl.program_id(1)
|
|
282
|
+
|
|
283
|
+
n_m_tiles: tl.constexpr = sparsity_block_size // BLOCK_M
|
|
284
|
+
n_n_tiles: tl.constexpr = sparsity_block_size // BLOCK_N
|
|
285
|
+
|
|
286
|
+
q_seq_block = pid_m // n_m_tiles
|
|
287
|
+
m_tile_idx = pid_m % n_m_tiles
|
|
288
|
+
|
|
289
|
+
q_row_start = q_seq_block * sparsity_block_size + m_tile_idx * BLOCK_M
|
|
290
|
+
offs_m = q_row_start + tl.arange(0, BLOCK_M)
|
|
291
|
+
offs_d = tl.arange(0, head_dim)
|
|
292
|
+
|
|
293
|
+
q_ptrs = q_ptr + pid_batch * stride_q_batch + offs_m[:, None] * stride_q_seq + offs_d[None, :]
|
|
294
|
+
q_mask = offs_m[:, None] < seq_q
|
|
295
|
+
q = tl.load(q_ptrs, mask=q_mask, other=0.0)
|
|
296
|
+
|
|
297
|
+
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
|
298
|
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
|
299
|
+
acc = tl.zeros([BLOCK_M, head_dim], dtype=tl.float32)
|
|
300
|
+
|
|
301
|
+
qk_scale = scale * 1.44269504
|
|
302
|
+
|
|
303
|
+
attn_offset_idx = pid_batch * n_seq_blocks_q + q_seq_block
|
|
304
|
+
attn_start = tl.load(attn_offsets_ptr + attn_offset_idx)
|
|
305
|
+
attn_end = tl.load(attn_offsets_ptr + attn_offset_idx + 1)
|
|
306
|
+
n_kv_blocks = attn_end - attn_start
|
|
307
|
+
|
|
308
|
+
for kv_idx in range(max_kv_blocks):
|
|
309
|
+
if kv_idx < n_kv_blocks:
|
|
310
|
+
k_seq_block = tl.load(attn_lut_ptr + attn_start + kv_idx)
|
|
311
|
+
|
|
312
|
+
k_row_start = k_seq_block * sparsity_block_size
|
|
313
|
+
offs_n = k_row_start + tl.arange(0, BLOCK_N)
|
|
314
|
+
|
|
315
|
+
k_ptrs = k_ptr + pid_batch * stride_kv_batch + offs_n[:, None] * stride_kv_seq + offs_d[None, :]
|
|
316
|
+
k_mask = offs_n[:, None] < seq_k
|
|
317
|
+
k = tl.load(k_ptrs, mask=k_mask, other=0.0)
|
|
318
|
+
|
|
319
|
+
qk = tl.dot(q, tl.trans(k)) * qk_scale
|
|
320
|
+
|
|
321
|
+
if has_mask:
|
|
322
|
+
mask_ptrs = mask_ptr + pid_batch * stride_mask_batch + offs_m[:, None] * stride_mask_row + offs_n[None, :] * stride_mask_col
|
|
323
|
+
mask_vals = tl.load(mask_ptrs, mask=(offs_m[:, None] < seq_q) & (offs_n[None, :] < seq_k), other=0.0)
|
|
324
|
+
qk = qk + mask_vals
|
|
325
|
+
|
|
326
|
+
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
|
|
327
|
+
alpha = tl.math.exp2(m_i - m_ij)
|
|
328
|
+
p = tl.math.exp2(qk - m_ij[:, None])
|
|
329
|
+
l_i = l_i * alpha + tl.sum(p, axis=1)
|
|
330
|
+
acc = acc * alpha[:, None]
|
|
331
|
+
|
|
332
|
+
v_ptrs = v_ptr + pid_batch * stride_kv_batch + offs_n[:, None] * stride_kv_seq + offs_d[None, :]
|
|
333
|
+
v = tl.load(v_ptrs, mask=k_mask, other=0.0)
|
|
334
|
+
acc = tl.dot(p.to(v.dtype), v, acc)
|
|
335
|
+
|
|
336
|
+
m_i = m_ij
|
|
337
|
+
|
|
338
|
+
has_attention = l_i > 0
|
|
339
|
+
l_safe = tl.where(has_attention, l_i, 1.0)
|
|
340
|
+
acc = acc / l_safe[:, None]
|
|
341
|
+
acc = tl.where(has_attention[:, None], acc, 0.0)
|
|
342
|
+
|
|
343
|
+
o_ptrs = o_ptr + pid_batch * stride_q_batch + offs_m[:, None] * stride_q_seq + offs_d[None, :]
|
|
344
|
+
tl.store(o_ptrs, acc.to(o_ptr.dtype.element_ty), mask=offs_m[:, None] < seq_q)
|
|
345
|
+
|
|
346
|
+
lse = tl.where(has_attention, m_i + tl.math.log2(l_safe), float("-inf"))
|
|
347
|
+
tl.store(m_ptr + pid_batch * seq_q + offs_m, lse, mask=offs_m < seq_q)
|
|
348
|
+
tl.store(l_ptr + pid_batch * seq_q + offs_m, l_i, mask=offs_m < seq_q)
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
@triton.jit
|
|
352
|
+
def flash_attention_bwd_preprocess_kernel(
|
|
353
|
+
o_ptr, do_ptr, delta_ptr,
|
|
354
|
+
stride_batch, stride_seq, stride_dim,
|
|
355
|
+
seq_len: tl.constexpr,
|
|
356
|
+
head_dim: tl.constexpr,
|
|
357
|
+
BLOCK_M: tl.constexpr,
|
|
358
|
+
):
|
|
359
|
+
"""Compute delta = (O * dO).sum(dim=-1)."""
|
|
360
|
+
pid_m = tl.program_id(0)
|
|
361
|
+
pid_batch = tl.program_id(1)
|
|
362
|
+
|
|
363
|
+
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
364
|
+
offs_d = tl.arange(0, head_dim)
|
|
365
|
+
|
|
366
|
+
o_ptrs = o_ptr + pid_batch * stride_batch + offs_m[:, None] * stride_seq + offs_d[None, :]
|
|
367
|
+
do_ptrs = do_ptr + pid_batch * stride_batch + offs_m[:, None] * stride_seq + offs_d[None, :]
|
|
368
|
+
mask = offs_m[:, None] < seq_len
|
|
369
|
+
|
|
370
|
+
o = tl.load(o_ptrs, mask=mask, other=0.0).to(tl.float32)
|
|
371
|
+
do = tl.load(do_ptrs, mask=mask, other=0.0).to(tl.float32)
|
|
372
|
+
delta = tl.sum(o * do, axis=1)
|
|
373
|
+
|
|
374
|
+
tl.store(delta_ptr + pid_batch * seq_len + offs_m, delta, mask=offs_m < seq_len)
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
@triton.jit
|
|
378
|
+
def flash_attention_bwd_dkdv_kernel(
|
|
379
|
+
q_ptr, k_ptr, v_ptr, do_ptr,
|
|
380
|
+
dk_ptr, dv_ptr,
|
|
381
|
+
lse_ptr, delta_ptr,
|
|
382
|
+
mask_ptr,
|
|
383
|
+
rev_attn_lut_ptr, rev_attn_offsets_ptr,
|
|
384
|
+
stride_q_batch, stride_q_seq,
|
|
385
|
+
stride_kv_batch, stride_kv_seq,
|
|
386
|
+
stride_dim,
|
|
387
|
+
stride_mask_batch, stride_mask_row, stride_mask_col,
|
|
388
|
+
n_batches: tl.constexpr,
|
|
389
|
+
seq_q: tl.constexpr,
|
|
390
|
+
seq_k: tl.constexpr,
|
|
391
|
+
head_dim: tl.constexpr,
|
|
392
|
+
sparsity_block_size: tl.constexpr,
|
|
393
|
+
n_seq_blocks_k: tl.constexpr,
|
|
394
|
+
max_q_per_k: tl.constexpr,
|
|
395
|
+
scale,
|
|
396
|
+
has_mask: tl.constexpr,
|
|
397
|
+
BLOCK_M: tl.constexpr,
|
|
398
|
+
BLOCK_N: tl.constexpr,
|
|
399
|
+
):
|
|
400
|
+
"""Compute dK and dV gradients."""
|
|
401
|
+
pid_n = tl.program_id(0)
|
|
402
|
+
pid_batch = tl.program_id(1)
|
|
403
|
+
|
|
404
|
+
n_n_tiles = sparsity_block_size // BLOCK_N
|
|
405
|
+
n_m_tiles = sparsity_block_size // BLOCK_M
|
|
406
|
+
|
|
407
|
+
k_seq_block = pid_n // n_n_tiles
|
|
408
|
+
n_tile_idx = pid_n % n_n_tiles
|
|
409
|
+
|
|
410
|
+
k_row_start = k_seq_block * sparsity_block_size + n_tile_idx * BLOCK_N
|
|
411
|
+
offs_n = k_row_start + tl.arange(0, BLOCK_N)
|
|
412
|
+
offs_d = tl.arange(0, head_dim)
|
|
413
|
+
|
|
414
|
+
qk_scale = scale * 1.44269504
|
|
415
|
+
|
|
416
|
+
k_ptrs = k_ptr + pid_batch * stride_kv_batch + offs_n[:, None] * stride_kv_seq + offs_d[None, :]
|
|
417
|
+
v_ptrs = v_ptr + pid_batch * stride_kv_batch + offs_n[:, None] * stride_kv_seq + offs_d[None, :]
|
|
418
|
+
k_mask = offs_n[:, None] < seq_k
|
|
419
|
+
k = tl.load(k_ptrs, mask=k_mask, other=0.0)
|
|
420
|
+
v = tl.load(v_ptrs, mask=k_mask, other=0.0)
|
|
421
|
+
|
|
422
|
+
dk = tl.zeros([BLOCK_N, head_dim], dtype=tl.float32)
|
|
423
|
+
dv = tl.zeros([BLOCK_N, head_dim], dtype=tl.float32)
|
|
424
|
+
|
|
425
|
+
rev_offset_idx = pid_batch * n_seq_blocks_k + k_seq_block
|
|
426
|
+
rev_start = tl.load(rev_attn_offsets_ptr + rev_offset_idx)
|
|
427
|
+
rev_end = tl.load(rev_attn_offsets_ptr + rev_offset_idx + 1)
|
|
428
|
+
n_q_blocks = rev_end - rev_start
|
|
429
|
+
|
|
430
|
+
for q_idx in range(max_q_per_k):
|
|
431
|
+
if q_idx < n_q_blocks:
|
|
432
|
+
q_seq_block = tl.load(rev_attn_lut_ptr + rev_start + q_idx)
|
|
433
|
+
|
|
434
|
+
for m_tile_idx in range(n_m_tiles):
|
|
435
|
+
q_row_start = q_seq_block * sparsity_block_size + m_tile_idx * BLOCK_M
|
|
436
|
+
offs_m = q_row_start + tl.arange(0, BLOCK_M)
|
|
437
|
+
|
|
438
|
+
q_ptrs = q_ptr + pid_batch * stride_q_batch + offs_m[:, None] * stride_q_seq + offs_d[None, :]
|
|
439
|
+
do_ptrs = do_ptr + pid_batch * stride_q_batch + offs_m[:, None] * stride_q_seq + offs_d[None, :]
|
|
440
|
+
q_mask = offs_m[:, None] < seq_q
|
|
441
|
+
q = tl.load(q_ptrs, mask=q_mask, other=0.0)
|
|
442
|
+
do = tl.load(do_ptrs, mask=q_mask, other=0.0)
|
|
443
|
+
|
|
444
|
+
m = tl.load(lse_ptr + pid_batch * seq_q + offs_m, mask=offs_m < seq_q, other=0.0)
|
|
445
|
+
Di = tl.load(delta_ptr + pid_batch * seq_q + offs_m, mask=offs_m < seq_q, other=0.0)
|
|
446
|
+
|
|
447
|
+
qk = tl.dot(q, tl.trans(k)) * qk_scale
|
|
448
|
+
|
|
449
|
+
if has_mask:
|
|
450
|
+
mask_ptrs = mask_ptr + pid_batch * stride_mask_batch + offs_m[:, None] * stride_mask_row + offs_n[None, :] * stride_mask_col
|
|
451
|
+
mask_vals = tl.load(mask_ptrs, mask=(offs_m[:, None] < seq_q) & (offs_n[None, :] < seq_k), other=0.0)
|
|
452
|
+
qk = qk + mask_vals
|
|
453
|
+
|
|
454
|
+
valid_lse = m > float("-inf")
|
|
455
|
+
safe_m = tl.where(valid_lse, m, 0.0)
|
|
456
|
+
p = tl.math.exp2(qk - safe_m[:, None])
|
|
457
|
+
p = tl.where(valid_lse[:, None], p, 0.0)
|
|
458
|
+
|
|
459
|
+
dv += tl.dot(tl.trans(p.to(do.dtype)), do)
|
|
460
|
+
dp = tl.dot(do, tl.trans(v))
|
|
461
|
+
ds = p * (dp - Di[:, None])
|
|
462
|
+
dk += tl.dot(tl.trans(ds.to(q.dtype)), q)
|
|
463
|
+
|
|
464
|
+
dk = dk * scale
|
|
465
|
+
tl.store(dk_ptr + pid_batch * stride_kv_batch + offs_n[:, None] * stride_kv_seq + offs_d[None, :], dk.to(dk_ptr.dtype.element_ty), mask=k_mask)
|
|
466
|
+
tl.store(dv_ptr + pid_batch * stride_kv_batch + offs_n[:, None] * stride_kv_seq + offs_d[None, :], dv.to(dv_ptr.dtype.element_ty), mask=k_mask)
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
@triton.jit
|
|
470
|
+
def flash_attention_bwd_dq_kernel(
|
|
471
|
+
q_ptr, k_ptr, v_ptr, do_ptr,
|
|
472
|
+
dq_ptr,
|
|
473
|
+
lse_ptr, delta_ptr,
|
|
474
|
+
mask_ptr,
|
|
475
|
+
attn_lut_ptr, attn_offsets_ptr,
|
|
476
|
+
stride_q_batch, stride_q_seq,
|
|
477
|
+
stride_kv_batch, stride_kv_seq,
|
|
478
|
+
stride_dim,
|
|
479
|
+
stride_mask_batch, stride_mask_row, stride_mask_col,
|
|
480
|
+
n_batches: tl.constexpr,
|
|
481
|
+
seq_q: tl.constexpr,
|
|
482
|
+
seq_k: tl.constexpr,
|
|
483
|
+
head_dim: tl.constexpr,
|
|
484
|
+
sparsity_block_size: tl.constexpr,
|
|
485
|
+
n_seq_blocks_q: tl.constexpr,
|
|
486
|
+
max_kv_blocks: tl.constexpr,
|
|
487
|
+
scale,
|
|
488
|
+
has_mask: tl.constexpr,
|
|
489
|
+
BLOCK_M: tl.constexpr,
|
|
490
|
+
BLOCK_N: tl.constexpr,
|
|
491
|
+
):
|
|
492
|
+
"""Compute dQ gradients."""
|
|
493
|
+
pid_m = tl.program_id(0)
|
|
494
|
+
pid_batch = tl.program_id(1)
|
|
495
|
+
|
|
496
|
+
n_m_tiles = sparsity_block_size // BLOCK_M
|
|
497
|
+
n_n_tiles = sparsity_block_size // BLOCK_N
|
|
498
|
+
|
|
499
|
+
q_seq_block = pid_m // n_m_tiles
|
|
500
|
+
m_tile_idx = pid_m % n_m_tiles
|
|
501
|
+
|
|
502
|
+
q_row_start = q_seq_block * sparsity_block_size + m_tile_idx * BLOCK_M
|
|
503
|
+
offs_m = q_row_start + tl.arange(0, BLOCK_M)
|
|
504
|
+
offs_d = tl.arange(0, head_dim)
|
|
505
|
+
|
|
506
|
+
qk_scale = scale * 1.44269504
|
|
507
|
+
|
|
508
|
+
q_ptrs = q_ptr + pid_batch * stride_q_batch + offs_m[:, None] * stride_q_seq + offs_d[None, :]
|
|
509
|
+
do_ptrs = do_ptr + pid_batch * stride_q_batch + offs_m[:, None] * stride_q_seq + offs_d[None, :]
|
|
510
|
+
q_mask = offs_m[:, None] < seq_q
|
|
511
|
+
q = tl.load(q_ptrs, mask=q_mask, other=0.0)
|
|
512
|
+
do = tl.load(do_ptrs, mask=q_mask, other=0.0)
|
|
513
|
+
|
|
514
|
+
m = tl.load(lse_ptr + pid_batch * seq_q + offs_m, mask=offs_m < seq_q, other=0.0)
|
|
515
|
+
Di = tl.load(delta_ptr + pid_batch * seq_q + offs_m, mask=offs_m < seq_q, other=0.0)
|
|
516
|
+
|
|
517
|
+
dq = tl.zeros([BLOCK_M, head_dim], dtype=tl.float32)
|
|
518
|
+
|
|
519
|
+
attn_offset_idx = pid_batch * n_seq_blocks_q + q_seq_block
|
|
520
|
+
attn_start = tl.load(attn_offsets_ptr + attn_offset_idx)
|
|
521
|
+
attn_end = tl.load(attn_offsets_ptr + attn_offset_idx + 1)
|
|
522
|
+
n_kv_blocks = attn_end - attn_start
|
|
523
|
+
|
|
524
|
+
for kv_idx in range(max_kv_blocks):
|
|
525
|
+
if kv_idx < n_kv_blocks:
|
|
526
|
+
k_seq_block = tl.load(attn_lut_ptr + attn_start + kv_idx)
|
|
527
|
+
|
|
528
|
+
for n_tile_idx in range(n_n_tiles):
|
|
529
|
+
k_row_start = k_seq_block * sparsity_block_size + n_tile_idx * BLOCK_N
|
|
530
|
+
offs_n = k_row_start + tl.arange(0, BLOCK_N)
|
|
531
|
+
|
|
532
|
+
k_ptrs = k_ptr + pid_batch * stride_kv_batch + offs_n[:, None] * stride_kv_seq + offs_d[None, :]
|
|
533
|
+
v_ptrs = v_ptr + pid_batch * stride_kv_batch + offs_n[:, None] * stride_kv_seq + offs_d[None, :]
|
|
534
|
+
k_mask = offs_n[:, None] < seq_k
|
|
535
|
+
k = tl.load(k_ptrs, mask=k_mask, other=0.0)
|
|
536
|
+
v = tl.load(v_ptrs, mask=k_mask, other=0.0)
|
|
537
|
+
|
|
538
|
+
qk = tl.dot(q, tl.trans(k)) * qk_scale
|
|
539
|
+
|
|
540
|
+
if has_mask:
|
|
541
|
+
mask_ptrs = mask_ptr + pid_batch * stride_mask_batch + offs_m[:, None] * stride_mask_row + offs_n[None, :] * stride_mask_col
|
|
542
|
+
mask_vals = tl.load(mask_ptrs, mask=(offs_m[:, None] < seq_q) & (offs_n[None, :] < seq_k), other=0.0)
|
|
543
|
+
qk = qk + mask_vals
|
|
544
|
+
|
|
545
|
+
valid_lse = m > float("-inf")
|
|
546
|
+
safe_m = tl.where(valid_lse, m, 0.0)
|
|
547
|
+
p = tl.math.exp2(qk - safe_m[:, None])
|
|
548
|
+
p = tl.where(valid_lse[:, None], p, 0.0)
|
|
549
|
+
|
|
550
|
+
dp = tl.dot(do, tl.trans(v))
|
|
551
|
+
ds = p * (dp - Di[:, None])
|
|
552
|
+
dq += tl.dot(ds.to(k.dtype), k)
|
|
553
|
+
|
|
554
|
+
dq = dq * scale
|
|
555
|
+
tl.store(dq_ptr + pid_batch * stride_q_batch + offs_m[:, None] * stride_q_seq + offs_d[None, :], dq.to(dq_ptr.dtype.element_ty), mask=q_mask)
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
def flash_attention_build_lut(
|
|
559
|
+
attention_layout: Tensor,
|
|
560
|
+
n_seq_blocks_q: int = None,
|
|
561
|
+
n_seq_blocks_k: int = None,
|
|
562
|
+
) -> dict:
|
|
563
|
+
"""Build attention LUTs for reuse across multiple calls."""
|
|
564
|
+
n_batches = attention_layout.shape[0]
|
|
565
|
+
if n_seq_blocks_q is None:
|
|
566
|
+
n_seq_blocks_q = attention_layout.shape[1]
|
|
567
|
+
if n_seq_blocks_k is None:
|
|
568
|
+
n_seq_blocks_k = attention_layout.shape[2]
|
|
569
|
+
|
|
570
|
+
attn_lut, attn_offsets, max_kv_blocks = _build_attention_lut_fast(
|
|
571
|
+
attention_layout, n_batches, n_seq_blocks_q, n_seq_blocks_k
|
|
572
|
+
)
|
|
573
|
+
|
|
574
|
+
attention_layout_t = attention_layout.transpose(1, 2).contiguous()
|
|
575
|
+
rev_attn_lut, rev_attn_offsets, max_q_per_k = _build_attention_lut_fast(
|
|
576
|
+
attention_layout_t, n_batches, n_seq_blocks_k, n_seq_blocks_q
|
|
577
|
+
)
|
|
578
|
+
|
|
579
|
+
return {
|
|
580
|
+
"attn_lut": attn_lut,
|
|
581
|
+
"attn_offsets": attn_offsets,
|
|
582
|
+
"max_kv_blocks": max_kv_blocks,
|
|
583
|
+
"rev_attn_lut": rev_attn_lut,
|
|
584
|
+
"rev_attn_offsets": rev_attn_offsets,
|
|
585
|
+
"max_q_per_k": max_q_per_k,
|
|
586
|
+
}
|
|
587
|
+
|
|
588
|
+
|
|
589
|
+
def _build_attention_lut_fast(
|
|
590
|
+
attention_layout: Tensor,
|
|
591
|
+
n_batches: int,
|
|
592
|
+
n_blocks_row: int,
|
|
593
|
+
n_blocks_col: int,
|
|
594
|
+
) -> Tuple[Tensor, Tensor, int]:
|
|
595
|
+
"""Build attention LUT efficiently."""
|
|
596
|
+
device = attention_layout.device
|
|
597
|
+
|
|
598
|
+
counts = attention_layout.sum(dim=2).flatten()
|
|
599
|
+
max_blocks_per_row = int(counts.max().item())
|
|
600
|
+
|
|
601
|
+
if max_blocks_per_row == 0:
|
|
602
|
+
offsets = torch.zeros(n_batches * n_blocks_row + 1, dtype=torch.int32, device=device)
|
|
603
|
+
lut = torch.empty(0, dtype=torch.int32, device=device)
|
|
604
|
+
return lut, offsets, 1
|
|
605
|
+
|
|
606
|
+
offsets = torch.zeros(n_batches * n_blocks_row + 1, dtype=torch.int32, device=device)
|
|
607
|
+
offsets[1:] = counts.cumsum(0).to(torch.int32)
|
|
608
|
+
|
|
609
|
+
indices = attention_layout.reshape(n_batches * n_blocks_row, n_blocks_col).nonzero(as_tuple=False)
|
|
610
|
+
lut = indices[:, 1].to(torch.int32)
|
|
611
|
+
|
|
612
|
+
return lut, offsets, max_blocks_per_row
|
blksprs/utils/autotuning.py
CHANGED
blksprs/utils/tools.py
CHANGED
|
@@ -16,7 +16,9 @@ def undo_shape_blocksparse(x: Tensor, shape: Size | tuple[int, ...]) -> Tensor:
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
def stride(x: Tensor):
|
|
19
|
-
if x.dim() ==
|
|
19
|
+
if x.dim() == 1:
|
|
20
|
+
return 1
|
|
21
|
+
elif x.dim() == 2:
|
|
20
22
|
return x.size(1), 1
|
|
21
23
|
elif x.dim() == 3:
|
|
22
24
|
return x.size(1) * x.size(2), x.size(2), 1
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.2
|
|
4
4
|
Summary: A lightweight library for operations on block-sparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -17,20 +17,13 @@ Requires-Dist: coverage; extra == "test"
|
|
|
17
17
|
Requires-Dist: build; extra == "test"
|
|
18
18
|
Requires-Dist: matplotlib; extra == "test"
|
|
19
19
|
|
|
20
|
-
# blksprs
|
|
20
|
+
# 🧊 blksprs
|
|
21
21
|
|
|
22
22
|
[](https://github.com/FelixSchoen/blksprs/releases)
|
|
23
23
|
[](https://www.python.org/downloads/release/python-3119/)
|
|
24
24
|
[](https://www.python.org/downloads/release/python-31210/)
|
|
25
25
|
|
|
26
|
-
## Overview
|
|
27
|
-
|
|
28
|
-
### News
|
|
29
|
-
|
|
30
|
-
🎉 ***Version 2.0 released***. blksprs now supports kernel auto-tuning, JIT compilation, specification of pre-calculated
|
|
31
|
-
LUTs, autocasting, and makes use of `torch.library.triton_op()`!
|
|
32
|
-
|
|
33
|
-
---
|
|
26
|
+
## 📖 Overview
|
|
34
27
|
|
|
35
28
|
A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
|
|
36
29
|
|
|
@@ -46,6 +39,7 @@ Currently supported operations (includes gradient calculation):
|
|
|
46
39
|
- Splitting and merging of matrices (_currently* only supports splitting and merging along the last dimension_)
|
|
47
40
|
- Conversion to and from sparse form
|
|
48
41
|
- Conversion to different sparsity layouts and different sparsity block sizes
|
|
42
|
+
- Flash Attention (_supports custom masks and cross-attention_)
|
|
49
43
|
|
|
50
44
|
As with this library sparse matrices are represented using a tuple of `(matrix, sparsity_layout, sparsity_block_size)`,
|
|
51
45
|
any element-wise operations can be applied in regular torch-like fashion.
|
|
@@ -74,7 +68,7 @@ Furthermore, the library provides a set of utility functions
|
|
|
74
68
|
|
|
75
69
|
_* see the [Roadmap](#roadmap) section for more information_
|
|
76
70
|
|
|
77
|
-
## Installation
|
|
71
|
+
## 🛠️ Installation
|
|
78
72
|
|
|
79
73
|
Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible with the Linux platform**.
|
|
80
74
|
Keep track of this [issue](https://github.com/triton-lang/triton/issues/1640) for updates.
|
|
@@ -89,11 +83,11 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
|
|
|
89
83
|
- _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.3.1)_
|
|
90
84
|
- _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
|
|
91
85
|
|
|
92
|
-
## Changelog
|
|
86
|
+
## 📝 Changelog
|
|
93
87
|
|
|
94
88
|
See [`CHANGELOG.md`](https://github.com/FelixSchoen/blksprs/blob/main/CHANGELOG.md) for a detailed changelog.
|
|
95
89
|
|
|
96
|
-
## Roadmap
|
|
90
|
+
## 🗺️ Roadmap
|
|
97
91
|
|
|
98
92
|
Note that since this library covers all our current needs it is in a **bugfix-only** state.
|
|
99
93
|
This means that there are no plans to add new features, e.g., support for dimension specification of the ``split`` and
|
|
@@ -105,17 +99,15 @@ We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
|
|
|
105
99
|
It might be that this changes with future projects, but as of August 2025, we are content with the current state of the
|
|
106
100
|
library.
|
|
107
101
|
|
|
108
|
-
## Known Limitations and Issues
|
|
102
|
+
## ⚠️ Known Limitations and Issues
|
|
109
103
|
|
|
110
|
-
- Triton has a bug with `tl.atomix_max()` used for the row-wise max operation.
|
|
111
|
-
In order to work around this bug a manual conversion of some values is needed, (slightly) negatively impacting
|
|
112
|
-
performance.
|
|
113
|
-
Watch the [issue](https://github.com/triton-lang/triton/issues/6376) on Triton's issue tracker for more information.
|
|
114
104
|
- There will be some slight numerical differences between vanilla and blksprs operations.
|
|
115
105
|
These instabilities are due to Triton and thus cannot be fixed by this library alone.
|
|
116
106
|
However, for all intents and purposes, these very minor differences should not matter and can safely be ignored.
|
|
117
107
|
|
|
118
|
-
|
|
108
|
+
- Flash Attention is a recent addition. While it has been tested and appears stable, please report any issues you encounter.
|
|
109
|
+
|
|
110
|
+
## 💻 Usage
|
|
119
111
|
|
|
120
112
|
We provide an example below to demonstrate the usage of the library.
|
|
121
113
|
For more detailed examples, please refer to
|
|
@@ -128,7 +120,6 @@ the [test cases](https://github.com/FelixSchoen/blksprs/blob/main/test/cases/tes
|
|
|
128
120
|
import torch
|
|
129
121
|
import blksprs as bs
|
|
130
122
|
|
|
131
|
-
|
|
132
123
|
def test_readme():
|
|
133
124
|
# Set up parameters (batch size, number of heads, dimensions for matrices (m, k) and (n, k))
|
|
134
125
|
b, h, m, n, k = 2, 4, 64, 64, 16
|
|
@@ -193,10 +184,30 @@ def test_readme():
|
|
|
193
184
|
# Other available functions
|
|
194
185
|
bs.ops.transpose(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
195
186
|
bs.ops.softmax(o_sparse, sparsity_layout_o, sparsity_block_size, flag_fused=False)
|
|
196
|
-
bs.ops.softmax_fused(o_sparse, sparsity_layout_o,
|
|
187
|
+
bs.ops.softmax_fused(o_sparse, sparsity_layout_o,
|
|
188
|
+
sparsity_block_size) # Significantly faster version that requires that rows of matrix fit into memory (default if flag is not set)
|
|
197
189
|
bs.ops.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
198
190
|
bs.ops.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
199
191
|
|
|
192
|
+
# Flash Attention
|
|
193
|
+
seq_len, head_dim = 512, 64
|
|
194
|
+
sparsity_block_size_attn = 128
|
|
195
|
+
|
|
196
|
+
q = torch.randn(b, seq_len, h, head_dim, device="cuda")
|
|
197
|
+
k = torch.randn(b, seq_len, h, head_dim, device="cuda")
|
|
198
|
+
v = torch.randn(b, seq_len, h, head_dim, device="cuda")
|
|
199
|
+
|
|
200
|
+
n_batches_attn = b * h
|
|
201
|
+
n_seq_blocks = seq_len // sparsity_block_size_attn
|
|
202
|
+
attention_layout = torch.tril(torch.ones(n_batches_attn, n_seq_blocks, n_seq_blocks, device="cuda", dtype=torch.bool))
|
|
203
|
+
|
|
204
|
+
lut = bs.ops.flash_attention_build_lut(attention_layout, n_seq_blocks, n_seq_blocks)
|
|
205
|
+
|
|
206
|
+
attn_out = bs.ops.flash_attention(q, k, v, attention_layout, sparsity_block_size_attn, lut=lut)
|
|
207
|
+
|
|
208
|
+
assert attn_out.shape == (b, seq_len, h, head_dim)
|
|
209
|
+
|
|
210
|
+
|
|
200
211
|
|
|
201
212
|
def _get_random_sparsity_layout(b, m, n, sparsity_block_size, sparsity_percentage):
|
|
202
213
|
"""Helper function, creates a random sparsity layout for a given shape with a given percentage of blocks marked as sparse.
|
|
@@ -1,8 +1,9 @@
|
|
|
1
|
-
blksprs/__init__.py,sha256=
|
|
1
|
+
blksprs/__init__.py,sha256=x6jBdOoukS032NnaO5zR-rJjdnQavBv8bA1E9C0wv7Y,1862
|
|
2
2
|
blksprs/layouting/distribution_layout.py,sha256=a2C3DG3pYhAaPpOEgSMCRqDK1RYuFenSHqp0JdWMWmQ,5934
|
|
3
3
|
blksprs/layouting/sparsity_layout.py,sha256=nl4qAJxtteZ6cx4td8FktbPiIfNEZl6zWUmMahv9Wac,11320
|
|
4
4
|
blksprs/ops/conversion.py,sha256=PEgXwN-UZilr7OUBlOI1NzT8902Baxa3ie9f6K1mGQc,21543
|
|
5
|
-
blksprs/ops/distribution.py,sha256=
|
|
5
|
+
blksprs/ops/distribution.py,sha256=na_bBldK8MXuu8u7MMMoZwl2css7cplhjkqgA3e1NPg,20221
|
|
6
|
+
blksprs/ops/flash_attention.py,sha256=ktdwdyUxgqlmTGvo-sB5hdx1yy9m4SDoYrKlAc3lkG8,24571
|
|
6
7
|
blksprs/ops/flow.py,sha256=e1SKZUNMWTRgG16aK7BjYNdxWuDnLl2s0ozSkUYDBYs,7818
|
|
7
8
|
blksprs/ops/matmul.py,sha256=Q_mcSfHpziZYrasB1_TbH8FmFtaf-lfoigg8H0POK64,11677
|
|
8
9
|
blksprs/ops/partitioning.py,sha256=88TU77uDbvZTcYdTah9oChJrbgqZdkj4tNPylf9IS1c,9995
|
|
@@ -11,13 +12,13 @@ blksprs/ops/softmax.py,sha256=SrWZaLxk0rGbyKCxH4np97mL7k10Oqg2VP2-qZFQ8ec,23679
|
|
|
11
12
|
blksprs/ops/transpose.py,sha256=IaNdqWDZ2rNSaO8kwpQyoSUpVpsoxMREgEXzhVBTsaY,4112
|
|
12
13
|
blksprs/ops/misc/broadcast_ops.py,sha256=RmLSFFugRcRn70CU5ahrTRTplk8_At-5XkaF0UFiCQs,5703
|
|
13
14
|
blksprs/ops/misc/row_wise.py,sha256=UYrgteIDp7NFqbV85hEmdzXxiJ-wQPuFGJV88rnEjdg,19344
|
|
14
|
-
blksprs/utils/autotuning.py,sha256=
|
|
15
|
+
blksprs/utils/autotuning.py,sha256=dWFYY_xoGCFxmX9qIyul37f62Bra1R9MY_turMHxYS8,2038
|
|
15
16
|
blksprs/utils/benchmarking.py,sha256=dLabDscTFn5NkmOI1g7DnKeTneUYW3RIVv9MDF-8BKc,1271
|
|
16
17
|
blksprs/utils/blksprs_tensor.py,sha256=Y8YnsFPifvdCf5Khsm8bDVv-589U0N8IsCFlnDETfzE,476
|
|
17
18
|
blksprs/utils/processing.py,sha256=GcsUl54DDrEoZ0iuWZV5Q0BR2ZML3jWOhypOMxDCsrs,3759
|
|
18
|
-
blksprs/utils/tools.py,sha256=
|
|
19
|
+
blksprs/utils/tools.py,sha256=3puJ7S-Pfb1ILnzco09pz7RQOt7Vrkj-LpPpnj3zZHY,791
|
|
19
20
|
blksprs/utils/validation.py,sha256=P98sCk6PZCQB0wO3scGTJIXfkv5EpHFM_uNHBXr42n4,4844
|
|
20
|
-
blksprs-2.
|
|
21
|
-
blksprs-2.
|
|
22
|
-
blksprs-2.
|
|
23
|
-
blksprs-2.
|
|
21
|
+
blksprs-2.2.dist-info/METADATA,sha256=9OCvQ0g7nMoNmazKUJUe7izBYc3d335rmO02zmY7iqc,10050
|
|
22
|
+
blksprs-2.2.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
23
|
+
blksprs-2.2.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
|
|
24
|
+
blksprs-2.2.dist-info/RECORD,,
|
|
File without changes
|