blksprs 2.1.2__py3-none-any.whl → 2.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +1 -1
- blksprs/ops/conversion.py +1 -1
- blksprs/ops/flow.py +1 -1
- blksprs/ops/matmul.py +1 -1
- blksprs/ops/softmax.py +15 -13
- {blksprs-2.1.2.dist-info → blksprs-2.1.4.dist-info}/METADATA +9 -13
- {blksprs-2.1.2.dist-info → blksprs-2.1.4.dist-info}/RECORD +9 -9
- {blksprs-2.1.2.dist-info → blksprs-2.1.4.dist-info}/WHEEL +0 -0
- {blksprs-2.1.2.dist-info → blksprs-2.1.4.dist-info}/top_level.txt +0 -0
blksprs/__init__.py
CHANGED
blksprs/ops/conversion.py
CHANGED
|
@@ -56,7 +56,7 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor,
|
|
|
56
56
|
def to_sparse_forward(x: Tensor, _: Tensor,
|
|
57
57
|
sparsity_lut: Tensor, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
58
58
|
with torch.no_grad():
|
|
59
|
-
output = torch.
|
|
59
|
+
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
60
60
|
dtype=x.dtype, device=x.device)
|
|
61
61
|
|
|
62
62
|
x_b, x_r, x_c = x.size()
|
blksprs/ops/flow.py
CHANGED
|
@@ -14,7 +14,7 @@ def flow_pull_forward(x: Tensor, sparsity_layout_o: Tensor,
|
|
|
14
14
|
sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
15
15
|
sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
16
16
|
with torch.no_grad():
|
|
17
|
-
output = torch.
|
|
17
|
+
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
18
18
|
dtype=x.dtype, device=x.device)
|
|
19
19
|
|
|
20
20
|
x_b, x_r, x_c = x.size()
|
blksprs/ops/matmul.py
CHANGED
|
@@ -62,7 +62,7 @@ def matmul_forward(x: Tensor, y: Tensor,
|
|
|
62
62
|
_: Tensor, sparsity_lut_o: Tensor,
|
|
63
63
|
sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
64
64
|
with torch.no_grad():
|
|
65
|
-
output = torch.
|
|
65
|
+
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
66
66
|
dtype=x.dtype, device=x.device)
|
|
67
67
|
|
|
68
68
|
x_b, x_r, x_c = x.size()
|
blksprs/ops/softmax.py
CHANGED
|
@@ -66,7 +66,7 @@ def softmax_forward(x: Tensor, sparsity_layout: Tensor,
|
|
|
66
66
|
sparsity_lut: Tensor,
|
|
67
67
|
sparsity_reverse_lut_rws: Tensor,
|
|
68
68
|
sparsity_block_size: int) -> Tensor:
|
|
69
|
-
output = torch.
|
|
69
|
+
output = torch.empty_like(x)
|
|
70
70
|
|
|
71
71
|
x_row_wise_max, sparsity_layout_rwm = row_wise_max(x, sparsity_layout, sparsity_block_size,
|
|
72
72
|
flag_slice_only=True)
|
|
@@ -114,7 +114,7 @@ def softmax_backward_wrapper(ctx, grad_output):
|
|
|
114
114
|
def softmax_backward(grad_output: Tensor, o: Tensor, sparsity_lut: Tensor, sparsity_layout: Tensor,
|
|
115
115
|
sparsity_block_size: int) -> Tensor:
|
|
116
116
|
with torch.no_grad():
|
|
117
|
-
grad_x = torch.
|
|
117
|
+
grad_x = torch.empty_like(o, dtype=torch.float)
|
|
118
118
|
|
|
119
119
|
s, sparsity_layout_s = row_wise_sum(grad_output * o, sparsity_layout, sparsity_block_size, flag_slice_only=True)
|
|
120
120
|
|
|
@@ -359,7 +359,7 @@ def softmax_fused_forward(x: Tensor, sparsity_layout: Tensor,
|
|
|
359
359
|
sparsity_reverse_lut_sorted: Tensor,
|
|
360
360
|
max_blocks_line: int,
|
|
361
361
|
sparsity_block_size: int) -> Tensor:
|
|
362
|
-
output = torch.
|
|
362
|
+
output = torch.empty_like(x)
|
|
363
363
|
|
|
364
364
|
x_b, x_r, x_c = x.size()
|
|
365
365
|
x_b_s, x_r_s, x_c_s = stride(x)
|
|
@@ -374,7 +374,7 @@ def softmax_fused_forward(x: Tensor, sparsity_layout: Tensor,
|
|
|
374
374
|
(x,
|
|
375
375
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
376
376
|
output,
|
|
377
|
-
s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
|
|
377
|
+
s_l_b, s_l_b_s, s_l_r_s, s_l_c, s_l_c_s,
|
|
378
378
|
sparsity_reverse_lut_sorted,
|
|
379
379
|
max_blocks_line,
|
|
380
380
|
sparsity_block_size))
|
|
@@ -399,7 +399,7 @@ def softmax_fused_backward(grad_output: Tensor,
|
|
|
399
399
|
max_blocks_line: int,
|
|
400
400
|
sparsity_block_size: int) -> Tensor:
|
|
401
401
|
with torch.no_grad():
|
|
402
|
-
grad_x = torch.
|
|
402
|
+
grad_x = torch.empty_like(o)
|
|
403
403
|
|
|
404
404
|
g_b, g_r, g_c = grad_output.size()
|
|
405
405
|
g_b_s, g_r_s, g_c_s = stride(grad_output)
|
|
@@ -417,7 +417,7 @@ def softmax_fused_backward(grad_output: Tensor,
|
|
|
417
417
|
g_b, g_b_s, g_r_s, g_c_s,
|
|
418
418
|
o,
|
|
419
419
|
o_b, o_b_s, o_r_s, o_c_s,
|
|
420
|
-
s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
|
|
420
|
+
s_l_b, s_l_b_s, s_l_r_s, s_l_c, s_l_c_s,
|
|
421
421
|
sparsity_reverse_lut_sorted,
|
|
422
422
|
grad_x,
|
|
423
423
|
max_blocks_line,
|
|
@@ -437,7 +437,7 @@ def softmax_fused_backward(grad_output: Tensor,
|
|
|
437
437
|
def softmax_fused_kernel(x,
|
|
438
438
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
439
439
|
o,
|
|
440
|
-
s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
|
|
440
|
+
s_l_b, s_l_b_s, s_l_r_s, s_l_c, s_l_c_s,
|
|
441
441
|
r_lut_s,
|
|
442
442
|
mbs: tl.constexpr,
|
|
443
443
|
sparsity_block_size: tl.constexpr,
|
|
@@ -451,8 +451,9 @@ def softmax_fused_kernel(x,
|
|
|
451
451
|
blk_rev_idx = (pid_bat * s_l_b_s +
|
|
452
452
|
pid_row * s_l_r_s +
|
|
453
453
|
(tl.arange(0, mbs) * s_l_c_s))
|
|
454
|
-
blk_rev_msk = (blk_rev_idx >= 0 and blk_rev_idx < s_l_b * s_l_b_s)
|
|
455
|
-
|
|
454
|
+
blk_rev_msk = ((blk_rev_idx >= 0 and blk_rev_idx < s_l_b * s_l_b_s) and
|
|
455
|
+
(tl.arange(0, mbs) < s_l_c))
|
|
456
|
+
blk_rev = tl.load(r_lut_s + blk_rev_idx, mask=blk_rev_msk, other=-1).to(tl.int32)
|
|
456
457
|
|
|
457
458
|
if (not (tl.min(blk_rev) == -1 and
|
|
458
459
|
tl.max(blk_rev) == -1)):
|
|
@@ -488,7 +489,7 @@ def softmax_fused_kernel_grad(g,
|
|
|
488
489
|
g_b, g_b_s, g_r_s, g_c_s,
|
|
489
490
|
x,
|
|
490
491
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
491
|
-
s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
|
|
492
|
+
s_l_b, s_l_b_s, s_l_r_s, s_l_c, s_l_c_s,
|
|
492
493
|
r_lut_s,
|
|
493
494
|
o,
|
|
494
495
|
mbs: tl.constexpr,
|
|
@@ -503,8 +504,9 @@ def softmax_fused_kernel_grad(g,
|
|
|
503
504
|
blk_rev_idx = (pid_bat * s_l_b_s +
|
|
504
505
|
pid_row * s_l_r_s +
|
|
505
506
|
(tl.arange(0, mbs) * s_l_c_s))
|
|
506
|
-
blk_rev_msk = (blk_rev_idx >= 0 and blk_rev_idx < s_l_b * s_l_b_s)
|
|
507
|
-
|
|
507
|
+
blk_rev_msk = ((blk_rev_idx >= 0 and blk_rev_idx < s_l_b * s_l_b_s) and
|
|
508
|
+
(tl.arange(0, mbs) < s_l_c))
|
|
509
|
+
blk_rev = tl.load(r_lut_s + blk_rev_idx, mask=blk_rev_msk, other=-1).to(tl.int32)
|
|
508
510
|
|
|
509
511
|
if (not (tl.min(blk_rev) == -1 and
|
|
510
512
|
tl.max(blk_rev) == -1)):
|
|
@@ -557,7 +559,7 @@ def softmax_fused_build_lut(lut: dict, sparsity_layout: Tensor):
|
|
|
557
559
|
.sum(dim=-1)
|
|
558
560
|
.max()
|
|
559
561
|
.item())
|
|
560
|
-
lut["max_blocks_line"] =
|
|
562
|
+
lut["max_blocks_line"] = ceil_pow2(max(max_blocks_line, 2))
|
|
561
563
|
|
|
562
564
|
validate_contiguous(sparsity_layout, lut["sparsity_reverse_lut_sorted"])
|
|
563
565
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 2.1.
|
|
3
|
+
Version: 2.1.4
|
|
4
4
|
Summary: A lightweight library for operations on block-sparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -20,7 +20,8 @@ Requires-Dist: matplotlib; extra == "test"
|
|
|
20
20
|
# blksprs
|
|
21
21
|
|
|
22
22
|
[](https://github.com/FelixSchoen/blksprs/releases)
|
|
23
|
-
[](https://www.python.org/downloads/release/python-3119/)
|
|
24
|
+
[](https://www.python.org/downloads/release/python-31210/)
|
|
24
25
|
|
|
25
26
|
## Overview
|
|
26
27
|
|
|
@@ -75,9 +76,7 @@ _* see the [Roadmap](#roadmap) section for more information_
|
|
|
75
76
|
|
|
76
77
|
## Installation
|
|
77
78
|
|
|
78
|
-
Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible
|
|
79
|
-
with
|
|
80
|
-
the Linux platform**.
|
|
79
|
+
Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible with the Linux platform**.
|
|
81
80
|
Keep track of this [issue](https://github.com/triton-lang/triton/issues/1640) for updates.
|
|
82
81
|
|
|
83
82
|
We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) using pip:
|
|
@@ -86,8 +85,8 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
|
|
|
86
85
|
|
|
87
86
|
### Dependencies
|
|
88
87
|
|
|
89
|
-
- [PyTorch](https://pytorch.org/) (built with v2.
|
|
90
|
-
- _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.
|
|
88
|
+
- [PyTorch](https://pytorch.org/) (built with v2.7.1)
|
|
89
|
+
- _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.3.1)_
|
|
91
90
|
- _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
|
|
92
91
|
|
|
93
92
|
## Changelog
|
|
@@ -103,7 +102,7 @@ We will continue to maintain the library and fix any issues that arise.
|
|
|
103
102
|
Should you find any bugs please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
|
|
104
103
|
We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
|
|
105
104
|
|
|
106
|
-
It might be that this changes with future projects, but as of
|
|
105
|
+
It might be that this changes with future projects, but as of June 2025, we are content with the current state of the
|
|
107
106
|
library.
|
|
108
107
|
|
|
109
108
|
## Known Limitations and Issues
|
|
@@ -112,9 +111,6 @@ library.
|
|
|
112
111
|
In order to work around this bug a manual conversion of some values is needed, (slightly) negatively impacting
|
|
113
112
|
performance.
|
|
114
113
|
Watch the [issue](https://github.com/triton-lang/triton/issues/6376) on Triton's issue tracker for more information.
|
|
115
|
-
- PyTorch's `wrap_triton()` currently does not support config pruning. It thus cannot be used for some of the kernels,
|
|
116
|
-
which could impact graph compilation.
|
|
117
|
-
- There seem to be some issues with autocasting, forcing some operations to manually cast.
|
|
118
114
|
- There will be some slight numerical differences between vanilla and blksprs operations.
|
|
119
115
|
These instabilities are due to Triton and thus cannot be fixed by this library alone.
|
|
120
116
|
However, for all intents and purposes, these very minor differences should not matter and can safely be ignored.
|
|
@@ -196,8 +192,8 @@ def test_readme():
|
|
|
196
192
|
|
|
197
193
|
# Other available functions
|
|
198
194
|
bs.ops.transpose(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
199
|
-
bs.ops.softmax(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
200
|
-
bs.ops.softmax_fused(o_sparse, sparsity_layout_o, sparsity_block_size) # Significantly faster version that requires that rows of matrix fit into memory
|
|
195
|
+
bs.ops.softmax(o_sparse, sparsity_layout_o, sparsity_block_size, flag_fused=False)
|
|
196
|
+
bs.ops.softmax_fused(o_sparse, sparsity_layout_o, sparsity_block_size) # Significantly faster version that requires that rows of matrix fit into memory (default if flag is not set)
|
|
201
197
|
bs.ops.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
202
198
|
bs.ops.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
203
199
|
|
|
@@ -1,13 +1,13 @@
|
|
|
1
|
-
blksprs/__init__.py,sha256=
|
|
1
|
+
blksprs/__init__.py,sha256=XERzTtkiElDeBppOO8rNrF6OktUQf_yozDiA4DUXqTY,1615
|
|
2
2
|
blksprs/layouting/distribution_layout.py,sha256=ur1ty_2U-Hfj78hMWsLZvu7ZuGhzW3qGLKMc72DfTZM,5861
|
|
3
3
|
blksprs/layouting/sparsity_layout.py,sha256=eXHmu2h7K5Q-YUpfOxocJoeP_5ZoQFZf_eHLxRZQbYU,11207
|
|
4
|
-
blksprs/ops/conversion.py,sha256=
|
|
4
|
+
blksprs/ops/conversion.py,sha256=nv5gXiyZkUtk1kCIlPr0Vpaj4G8G6dJdW7StlbV3nDw,21914
|
|
5
5
|
blksprs/ops/distribution.py,sha256=0tPldv0ARzmCV1CU2jvfqpHBgOuHPrDFiCtqsLs7CZc,20789
|
|
6
|
-
blksprs/ops/flow.py,sha256=
|
|
7
|
-
blksprs/ops/matmul.py,sha256=
|
|
6
|
+
blksprs/ops/flow.py,sha256=oUn_xDT74220-EmnBnB8bRNtbS1mjbxWpm76PFsK22o,8246
|
|
7
|
+
blksprs/ops/matmul.py,sha256=ES9bpiCIRBxaynNIL5ftDP0c9LSArbj8YJqkPEzBaIU,11879
|
|
8
8
|
blksprs/ops/partitioning.py,sha256=cfQmY9BZqGTvvJorIhtb-EyuGRJGPraWR-wTKdb47aI,9954
|
|
9
9
|
blksprs/ops/repeat.py,sha256=TLYNxwPuT9y5K9xyM41WK5gnggAJF3lI61Q2K7zWjns,9035
|
|
10
|
-
blksprs/ops/softmax.py,sha256=
|
|
10
|
+
blksprs/ops/softmax.py,sha256=tfC_jaAKrA956rxGeb57klMuYRKTiyMCd5Zg5DIH3fc,23649
|
|
11
11
|
blksprs/ops/transpose.py,sha256=U-VAyLRT6_NDv9qYSFzBqfVlDeIpTqAMEXkqto0VF6w,4072
|
|
12
12
|
blksprs/ops/misc/broadcast_ops.py,sha256=-PrHiSJikZh8nXUmXxSCtFEP27TTxFr4wcrNxBjnimk,5987
|
|
13
13
|
blksprs/ops/misc/row_wise.py,sha256=n5FJjAuOd8BHBJQx4bsQwr-HmXkR9PYVAqfk77wjOFU,19653
|
|
@@ -17,7 +17,7 @@ blksprs/utils/blksprs_tensor.py,sha256=pfoz59aJixj_fIoFx76ySiygwRQUemmgjMKepZ2c4
|
|
|
17
17
|
blksprs/utils/processing.py,sha256=RNkEDc0g-sNHRuMPkRzNWU13d3_lIkXMJdoqES4yQTM,3738
|
|
18
18
|
blksprs/utils/tools.py,sha256=TKygEKge4wJtJnXXDg8BTL8vzBpqIJsQ_A3_5FmLpcE,859
|
|
19
19
|
blksprs/utils/validation.py,sha256=G8eQlvJVMKfEX3k2AwBD0A6Ck-gFoRLpLNY6HXsB3fA,4348
|
|
20
|
-
blksprs-2.1.
|
|
21
|
-
blksprs-2.1.
|
|
22
|
-
blksprs-2.1.
|
|
23
|
-
blksprs-2.1.
|
|
20
|
+
blksprs-2.1.4.dist-info/METADATA,sha256=qGLQunHEIoHlmRvFnM0TVDjOSApwGzBglpZezmfhHLU,9590
|
|
21
|
+
blksprs-2.1.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
22
|
+
blksprs-2.1.4.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
|
|
23
|
+
blksprs-2.1.4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|