blksprs 2.1.3__tar.gz → 2.1.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {blksprs-2.1.3 → blksprs-2.1.5}/PKG-INFO +7 -11
- {blksprs-2.1.3 → blksprs-2.1.5}/README.md +6 -10
- {blksprs-2.1.3 → blksprs-2.1.5}/blksprs/__init__.py +2 -2
- {blksprs-2.1.3 → blksprs-2.1.5}/blksprs/ops/conversion.py +12 -20
- {blksprs-2.1.3 → blksprs-2.1.5}/blksprs/ops/distribution.py +12 -20
- {blksprs-2.1.3 → blksprs-2.1.5}/blksprs/ops/flow.py +12 -20
- {blksprs-2.1.3 → blksprs-2.1.5}/blksprs/ops/matmul.py +6 -10
- {blksprs-2.1.3 → blksprs-2.1.5}/blksprs/ops/misc/broadcast_ops.py +6 -10
- {blksprs-2.1.3 → blksprs-2.1.5}/blksprs/ops/misc/row_wise.py +35 -35
- {blksprs-2.1.3 → blksprs-2.1.5}/blksprs/ops/repeat.py +2 -2
- {blksprs-2.1.3 → blksprs-2.1.5}/blksprs/ops/softmax.py +10 -12
- {blksprs-2.1.3 → blksprs-2.1.5}/blksprs/utils/autotuning.py +2 -2
- {blksprs-2.1.3 → blksprs-2.1.5}/blksprs/utils/validation.py +21 -0
- {blksprs-2.1.3 → blksprs-2.1.5}/blksprs.egg-info/PKG-INFO +7 -11
- {blksprs-2.1.3 → blksprs-2.1.5}/pyproject.toml +1 -1
- {blksprs-2.1.3 → blksprs-2.1.5}/blksprs/layouting/distribution_layout.py +0 -0
- {blksprs-2.1.3 → blksprs-2.1.5}/blksprs/layouting/sparsity_layout.py +0 -0
- {blksprs-2.1.3 → blksprs-2.1.5}/blksprs/ops/partitioning.py +0 -0
- {blksprs-2.1.3 → blksprs-2.1.5}/blksprs/ops/transpose.py +0 -0
- {blksprs-2.1.3 → blksprs-2.1.5}/blksprs/utils/benchmarking.py +0 -0
- {blksprs-2.1.3 → blksprs-2.1.5}/blksprs/utils/blksprs_tensor.py +0 -0
- {blksprs-2.1.3 → blksprs-2.1.5}/blksprs/utils/processing.py +0 -0
- {blksprs-2.1.3 → blksprs-2.1.5}/blksprs/utils/tools.py +0 -0
- {blksprs-2.1.3 → blksprs-2.1.5}/blksprs.egg-info/SOURCES.txt +0 -0
- {blksprs-2.1.3 → blksprs-2.1.5}/blksprs.egg-info/dependency_links.txt +0 -0
- {blksprs-2.1.3 → blksprs-2.1.5}/blksprs.egg-info/requires.txt +0 -0
- {blksprs-2.1.3 → blksprs-2.1.5}/blksprs.egg-info/top_level.txt +0 -0
- {blksprs-2.1.3 → blksprs-2.1.5}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 2.1.
|
|
3
|
+
Version: 2.1.5
|
|
4
4
|
Summary: A lightweight library for operations on block-sparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -20,7 +20,8 @@ Requires-Dist: matplotlib; extra == "test"
|
|
|
20
20
|
# blksprs
|
|
21
21
|
|
|
22
22
|
[](https://github.com/FelixSchoen/blksprs/releases)
|
|
23
|
-
[](https://www.python.org/downloads/release/python-3119/)
|
|
24
|
+
[](https://www.python.org/downloads/release/python-31210/)
|
|
24
25
|
|
|
25
26
|
## Overview
|
|
26
27
|
|
|
@@ -75,9 +76,7 @@ _* see the [Roadmap](#roadmap) section for more information_
|
|
|
75
76
|
|
|
76
77
|
## Installation
|
|
77
78
|
|
|
78
|
-
Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible
|
|
79
|
-
with
|
|
80
|
-
the Linux platform**.
|
|
79
|
+
Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible with the Linux platform**.
|
|
81
80
|
Keep track of this [issue](https://github.com/triton-lang/triton/issues/1640) for updates.
|
|
82
81
|
|
|
83
82
|
We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) using pip:
|
|
@@ -86,8 +85,8 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
|
|
|
86
85
|
|
|
87
86
|
### Dependencies
|
|
88
87
|
|
|
89
|
-
- [PyTorch](https://pytorch.org/) (built with v2.
|
|
90
|
-
- _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.
|
|
88
|
+
- [PyTorch](https://pytorch.org/) (built with v2.7.1)
|
|
89
|
+
- _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.3.1)_
|
|
91
90
|
- _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
|
|
92
91
|
|
|
93
92
|
## Changelog
|
|
@@ -103,7 +102,7 @@ We will continue to maintain the library and fix any issues that arise.
|
|
|
103
102
|
Should you find any bugs please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
|
|
104
103
|
We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
|
|
105
104
|
|
|
106
|
-
It might be that this changes with future projects, but as of
|
|
105
|
+
It might be that this changes with future projects, but as of June 2025, we are content with the current state of the
|
|
107
106
|
library.
|
|
108
107
|
|
|
109
108
|
## Known Limitations and Issues
|
|
@@ -112,9 +111,6 @@ library.
|
|
|
112
111
|
In order to work around this bug a manual conversion of some values is needed, (slightly) negatively impacting
|
|
113
112
|
performance.
|
|
114
113
|
Watch the [issue](https://github.com/triton-lang/triton/issues/6376) on Triton's issue tracker for more information.
|
|
115
|
-
- PyTorch's `wrap_triton()` currently does not support config pruning. It thus cannot be used for some of the kernels,
|
|
116
|
-
which could impact graph compilation.
|
|
117
|
-
- There seem to be some issues with autocasting, forcing some operations to manually cast.
|
|
118
114
|
- There will be some slight numerical differences between vanilla and blksprs operations.
|
|
119
115
|
These instabilities are due to Triton and thus cannot be fixed by this library alone.
|
|
120
116
|
However, for all intents and purposes, these very minor differences should not matter and can safely be ignored.
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
# blksprs
|
|
2
2
|
|
|
3
3
|
[](https://github.com/FelixSchoen/blksprs/releases)
|
|
4
|
-
[](https://www.python.org/downloads/release/python-3119/)
|
|
5
|
+
[](https://www.python.org/downloads/release/python-31210/)
|
|
5
6
|
|
|
6
7
|
## Overview
|
|
7
8
|
|
|
@@ -56,9 +57,7 @@ _* see the [Roadmap](#roadmap) section for more information_
|
|
|
56
57
|
|
|
57
58
|
## Installation
|
|
58
59
|
|
|
59
|
-
Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible
|
|
60
|
-
with
|
|
61
|
-
the Linux platform**.
|
|
60
|
+
Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible with the Linux platform**.
|
|
62
61
|
Keep track of this [issue](https://github.com/triton-lang/triton/issues/1640) for updates.
|
|
63
62
|
|
|
64
63
|
We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) using pip:
|
|
@@ -67,8 +66,8 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
|
|
|
67
66
|
|
|
68
67
|
### Dependencies
|
|
69
68
|
|
|
70
|
-
- [PyTorch](https://pytorch.org/) (built with v2.
|
|
71
|
-
- _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.
|
|
69
|
+
- [PyTorch](https://pytorch.org/) (built with v2.7.1)
|
|
70
|
+
- _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.3.1)_
|
|
72
71
|
- _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
|
|
73
72
|
|
|
74
73
|
## Changelog
|
|
@@ -84,7 +83,7 @@ We will continue to maintain the library and fix any issues that arise.
|
|
|
84
83
|
Should you find any bugs please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
|
|
85
84
|
We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
|
|
86
85
|
|
|
87
|
-
It might be that this changes with future projects, but as of
|
|
86
|
+
It might be that this changes with future projects, but as of June 2025, we are content with the current state of the
|
|
88
87
|
library.
|
|
89
88
|
|
|
90
89
|
## Known Limitations and Issues
|
|
@@ -93,9 +92,6 @@ library.
|
|
|
93
92
|
In order to work around this bug a manual conversion of some values is needed, (slightly) negatively impacting
|
|
94
93
|
performance.
|
|
95
94
|
Watch the [issue](https://github.com/triton-lang/triton/issues/6376) on Triton's issue tracker for more information.
|
|
96
|
-
- PyTorch's `wrap_triton()` currently does not support config pruning. It thus cannot be used for some of the kernels,
|
|
97
|
-
which could impact graph compilation.
|
|
98
|
-
- There seem to be some issues with autocasting, forcing some operations to manually cast.
|
|
99
95
|
- There will be some slight numerical differences between vanilla and blksprs operations.
|
|
100
96
|
These instabilities are due to Triton and thus cannot be fixed by this library alone.
|
|
101
97
|
However, for all intents and purposes, these very minor differences should not matter and can safely be ignored.
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
2
2
|
|
|
3
|
-
__version__ = "2.1.
|
|
3
|
+
__version__ = "2.1.5"
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
class ops:
|
|
@@ -27,9 +27,9 @@ class utils:
|
|
|
27
27
|
from blksprs.utils.processing import apply_torch_linear, apply_torch_normalisation, apply_torch_dropout, \
|
|
28
28
|
apply_function_applicable_row_wise
|
|
29
29
|
from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse
|
|
30
|
+
from blksprs.utils.validation import disable_contiguous, disable_validation
|
|
30
31
|
|
|
31
32
|
class validation:
|
|
32
|
-
from blksprs.utils.validation import disable_validation
|
|
33
33
|
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_dtype_float, \
|
|
34
34
|
validate_dtype_int, validate_device, validate_sparsity, validate_sparsity_dense, \
|
|
35
35
|
validate_sparsity_block_size
|
|
@@ -106,17 +106,13 @@ def to_sparse_kernel(x,
|
|
|
106
106
|
pid_col = tl.program_id(axis=2)
|
|
107
107
|
|
|
108
108
|
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
109
|
+
spa_val_idx = pid_blk * s_lut_r_s + tl.arange(0, 4) * s_lut_c_s
|
|
110
|
+
spa_val_msk = (tl.arange(0, 4) < 3)
|
|
111
|
+
spa_val = tl.load(s_lut + spa_val_idx, mask=spa_val_msk)
|
|
112
112
|
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
118
|
-
spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
|
|
119
|
-
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
113
|
+
spa_bat = tl.sum(spa_val * (tl.arange(0, 4) == 0))
|
|
114
|
+
spa_row = tl.sum(spa_val * (tl.arange(0, 4) == 1))
|
|
115
|
+
spa_col = tl.sum(spa_val * (tl.arange(0, 4) == 2))
|
|
120
116
|
|
|
121
117
|
# Load block from dense tensor
|
|
122
118
|
blk_d_idx = (spa_bat * x_b_s +
|
|
@@ -445,17 +441,13 @@ def adapt_layout_kernel(x,
|
|
|
445
441
|
pid_col = tl.program_id(axis=2)
|
|
446
442
|
|
|
447
443
|
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
|
|
453
|
-
spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
454
|
-
spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
|
|
444
|
+
spa_val_idx = pid_blk * s_lut_o_r_s + tl.arange(0, 4) * s_lut_o_c_s
|
|
445
|
+
spa_val_msk = (tl.arange(0, 4) < 3)
|
|
446
|
+
spa_val = tl.load(s_lut_o + spa_val_idx, mask=spa_val_msk)
|
|
455
447
|
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
spa_col_o = tl.
|
|
448
|
+
spa_bat_o = tl.sum(spa_val * (tl.arange(0, 4) == 0))
|
|
449
|
+
spa_row_o = tl.sum(spa_val * (tl.arange(0, 4) == 1))
|
|
450
|
+
spa_col_o = tl.sum(spa_val * (tl.arange(0, 4) == 2))
|
|
459
451
|
|
|
460
452
|
# Get equivalent sparsity block in from layout
|
|
461
453
|
spa_bat_x = spa_bat_o
|
|
@@ -125,17 +125,13 @@ def gather_kernel(x,
|
|
|
125
125
|
pid_col = tl.program_id(axis=2)
|
|
126
126
|
|
|
127
127
|
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
128
|
+
spa_val_idx = pid_blk * s_lut_o_r_s + tl.arange(0, 4) * s_lut_o_c_s
|
|
129
|
+
spa_val_msk = (tl.arange(0, 4) < 3)
|
|
130
|
+
spa_val = tl.load(s_lut_o + spa_val_idx, mask=spa_val_msk)
|
|
131
131
|
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
|
|
137
|
-
spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
138
|
-
spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
|
|
132
|
+
spa_bat_o = tl.sum(spa_val * (tl.arange(0, 4) == 0))
|
|
133
|
+
spa_row_o = tl.sum(spa_val * (tl.arange(0, 4) == 1))
|
|
134
|
+
spa_col_o = tl.sum(spa_val * (tl.arange(0, 4) == 2))
|
|
139
135
|
|
|
140
136
|
# Load index values
|
|
141
137
|
blk_i_idx = ((pid_blk * i_b_s) +
|
|
@@ -374,17 +370,13 @@ def scatter_reduce_kernel(x,
|
|
|
374
370
|
pid_col = tl.program_id(axis=2)
|
|
375
371
|
|
|
376
372
|
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
spa_row_x_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
|
|
382
|
-
spa_row_x_msk = (spa_row_x_idx >= 0 and spa_row_x_idx < s_lut_x_r * s_lut_x_r_s)
|
|
383
|
-
spa_row_x = tl.load(s_lut_x + spa_row_x_idx, mask=spa_row_x_msk)
|
|
373
|
+
spa_val_idx = pid_blk * s_lut_x_r_s + tl.arange(0, 4) * s_lut_x_c_s
|
|
374
|
+
spa_val_msk = (tl.arange(0, 4) < 3)
|
|
375
|
+
spa_val = tl.load(s_lut_x + spa_val_idx, mask=spa_val_msk)
|
|
384
376
|
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
spa_col_x = tl.
|
|
377
|
+
spa_bat_x = tl.sum(spa_val * (tl.arange(0, 4) == 0))
|
|
378
|
+
spa_row_x = tl.sum(spa_val * (tl.arange(0, 4) == 1))
|
|
379
|
+
spa_col_x = tl.sum(spa_val * (tl.arange(0, 4) == 2))
|
|
388
380
|
|
|
389
381
|
# Load x values
|
|
390
382
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
@@ -66,17 +66,13 @@ def flow_pull_kernel(x,
|
|
|
66
66
|
pid_col = tl.program_id(axis=2)
|
|
67
67
|
|
|
68
68
|
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
69
|
+
spa_val_idx = pid_blk * s_lut_r_s + tl.arange(0, 4) * s_lut_c_s
|
|
70
|
+
spa_val_msk = (tl.arange(0, 4) < 3)
|
|
71
|
+
spa_val = tl.load(s_lut + spa_val_idx, mask=spa_val_msk)
|
|
72
72
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
78
|
-
spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
|
|
79
|
-
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
73
|
+
spa_bat = tl.sum(spa_val * (tl.arange(0, 4) == 0))
|
|
74
|
+
spa_row = tl.sum(spa_val * (tl.arange(0, 4) == 1))
|
|
75
|
+
spa_col = tl.sum(spa_val * (tl.arange(0, 4) == 2))
|
|
80
76
|
|
|
81
77
|
# Load reverse sparsity index
|
|
82
78
|
rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
|
|
@@ -157,17 +153,13 @@ def flow_push_kernel(x,
|
|
|
157
153
|
pid_col = tl.program_id(axis=2)
|
|
158
154
|
|
|
159
155
|
# Get sparsity index of current input block consisting of its batch, row, and column index
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
165
|
-
spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
|
|
166
|
-
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
156
|
+
spa_val_idx = pid_blk * s_lut_r_s + tl.arange(0, 4) * s_lut_c_s
|
|
157
|
+
spa_val_msk = (tl.arange(0, 4) < 3)
|
|
158
|
+
spa_val = tl.load(s_lut + spa_val_idx, mask=spa_val_msk)
|
|
167
159
|
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
spa_col = tl.
|
|
160
|
+
spa_bat = tl.sum(spa_val * (tl.arange(0, 4) == 0))
|
|
161
|
+
spa_row = tl.sum(spa_val * (tl.arange(0, 4) == 1))
|
|
162
|
+
spa_col = tl.sum(spa_val * (tl.arange(0, 4) == 2))
|
|
171
163
|
|
|
172
164
|
# Get reverse sparsity index
|
|
173
165
|
rev_idx_spa_idx = (spa_bat * s_l_x_b_s +
|
|
@@ -145,17 +145,13 @@ def matmul_kernel(x,
|
|
|
145
145
|
pid_col = tl.program_id(axis=2)
|
|
146
146
|
|
|
147
147
|
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
148
|
+
spa_val_idx = pid_blk * s_lut_o_r_s + tl.arange(0, 4) * s_lut_o_c_s
|
|
149
|
+
spa_val_msk = (tl.arange(0, 4) < 3)
|
|
150
|
+
spa_val = tl.load(s_lut_o + spa_val_idx, mask=spa_val_msk)
|
|
151
151
|
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
|
|
157
|
-
spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
158
|
-
spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
|
|
152
|
+
spa_bat_o = tl.sum(spa_val * (tl.arange(0, 4) == 0))
|
|
153
|
+
spa_row_o = tl.sum(spa_val * (tl.arange(0, 4) == 1))
|
|
154
|
+
spa_col_o = tl.sum(spa_val * (tl.arange(0, 4) == 2))
|
|
159
155
|
|
|
160
156
|
# Setup buffer
|
|
161
157
|
buf = tl.zeros(shape=(TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), dtype=tl.float32)
|
|
@@ -110,17 +110,13 @@ def broadcast_add_kernel(x,
|
|
|
110
110
|
pid_col = tl.program_id(axis=2)
|
|
111
111
|
|
|
112
112
|
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
113
|
+
spa_val_idx = pid_blk * s_lut_o_r_s + tl.arange(0, 4) * s_lut_o_c_s
|
|
114
|
+
spa_val_msk = (tl.arange(0, 4) < 3)
|
|
115
|
+
spa_val = tl.load(s_lut_o + spa_val_idx, mask=spa_val_msk)
|
|
116
116
|
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
|
|
122
|
-
spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
123
|
-
spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
|
|
117
|
+
spa_bat_o = tl.sum(spa_val * (tl.arange(0, 4) == 0))
|
|
118
|
+
spa_row_o = tl.sum(spa_val * (tl.arange(0, 4) == 1))
|
|
119
|
+
spa_col_o = tl.sum(spa_val * (tl.arange(0, 4) == 2))
|
|
124
120
|
|
|
125
121
|
# Load x block
|
|
126
122
|
blk_x_idx = (spa_bat_o * x_b_s +
|
|
@@ -119,17 +119,17 @@ def row_wise_sum_kernel(x,
|
|
|
119
119
|
pid_col = tl.program_id(axis=2)
|
|
120
120
|
|
|
121
121
|
# Get position of current sparsity block consisting of its batch and row index
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
122
|
+
spa_val_idx = pid_blk * s_lut_x_r_s + tl.arange(0, 4) * s_lut_x_c_s
|
|
123
|
+
spa_val_msk = (tl.arange(0, 4) < 3)
|
|
124
|
+
spa_val = tl.load(s_lut_x + spa_val_idx, mask=spa_val_msk)
|
|
125
125
|
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
126
|
+
spa_bat_x = tl.sum(spa_val * (tl.arange(0, 4) == 0))
|
|
127
|
+
spa_row_x = tl.sum(spa_val * (tl.arange(0, 4) == 1))
|
|
128
|
+
spa_col_x = tl.sum(spa_val * (tl.arange(0, 4) == 2))
|
|
129
129
|
|
|
130
130
|
# Load reverse sparsity index for current block
|
|
131
|
-
rev_idx_spa_idx = (
|
|
132
|
-
|
|
131
|
+
rev_idx_spa_idx = (spa_bat_x * s_l_o_b_s +
|
|
132
|
+
spa_row_x * s_l_o_r_s)
|
|
133
133
|
rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
|
|
134
134
|
rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
135
135
|
|
|
@@ -263,17 +263,17 @@ def row_wise_max_kernel(x,
|
|
|
263
263
|
pid_col = tl.program_id(axis=2)
|
|
264
264
|
|
|
265
265
|
# Get position of current sparsity block consisting of its batch and row index
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
266
|
+
spa_val_idx = pid_blk * s_lut_x_r_s + tl.arange(0, 4) * s_lut_x_c_s
|
|
267
|
+
spa_val_msk = (tl.arange(0, 4) < 3)
|
|
268
|
+
spa_val = tl.load(s_lut_x + spa_val_idx, mask=spa_val_msk)
|
|
269
269
|
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
270
|
+
spa_bat_x = tl.sum(spa_val * (tl.arange(0, 4) == 0))
|
|
271
|
+
spa_row_x = tl.sum(spa_val * (tl.arange(0, 4) == 1))
|
|
272
|
+
spa_col_x = tl.sum(spa_val * (tl.arange(0, 4) == 2))
|
|
273
273
|
|
|
274
274
|
# Load reverse sparsity index for current block
|
|
275
|
-
rev_idx_spa_idx = (
|
|
276
|
-
|
|
275
|
+
rev_idx_spa_idx = (spa_bat_x * s_l_o_b_s +
|
|
276
|
+
spa_row_x * s_l_o_r_s)
|
|
277
277
|
rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
|
|
278
278
|
rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
279
279
|
|
|
@@ -361,7 +361,7 @@ def row_wise_add_forward(x: Tensor, sparsity_lut_x: Tensor,
|
|
|
361
361
|
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
362
362
|
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
363
363
|
|
|
364
|
-
(wrap_triton(
|
|
364
|
+
(wrap_triton(row_wise_add_kernel)[triton_grid]
|
|
365
365
|
(x,
|
|
366
366
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
367
367
|
sparsity_lut_x, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
@@ -383,33 +383,33 @@ def row_wise_add_forward(x: Tensor, sparsity_lut_x: Tensor,
|
|
|
383
383
|
reset_to_zero=["o"]
|
|
384
384
|
)
|
|
385
385
|
@triton.jit
|
|
386
|
-
def
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
386
|
+
def row_wise_add_kernel(x,
|
|
387
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
388
|
+
s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
389
|
+
y, y_b, y_b_s, y_r_s, y_c_s,
|
|
390
|
+
s_l_y_b, s_l_y_b_s, s_l_y_r_s,
|
|
391
|
+
r_lut_y,
|
|
392
|
+
o,
|
|
393
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
394
|
+
sparsity_block_size,
|
|
395
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
396
396
|
# Get triton block indices
|
|
397
397
|
pid_blk = tl.program_id(axis=0)
|
|
398
398
|
pid_row = tl.program_id(axis=1)
|
|
399
399
|
pid_col = tl.program_id(axis=2)
|
|
400
400
|
|
|
401
401
|
# Get position of current sparsity block consisting of its batch and row index
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
402
|
+
spa_val_idx = pid_blk * s_lut_x_r_s + tl.arange(0, 4) * s_lut_x_c_s
|
|
403
|
+
spa_val_msk = (tl.arange(0, 4) < 3)
|
|
404
|
+
spa_val = tl.load(s_lut_x + spa_val_idx, mask=spa_val_msk)
|
|
405
405
|
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
406
|
+
spa_bat_x = tl.sum(spa_val * (tl.arange(0, 4) == 0))
|
|
407
|
+
spa_row_x = tl.sum(spa_val * (tl.arange(0, 4) == 1))
|
|
408
|
+
spa_col_x = tl.sum(spa_val * (tl.arange(0, 4) == 2))
|
|
409
409
|
|
|
410
410
|
# Get reverse sparsity indices for s
|
|
411
|
-
rev_idx_spa_s_idx = (
|
|
412
|
-
|
|
411
|
+
rev_idx_spa_s_idx = (spa_bat_x * s_l_y_b_s +
|
|
412
|
+
spa_row_x * s_l_y_r_s)
|
|
413
413
|
rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_y_b * s_l_y_b_s)
|
|
414
414
|
rev_idx_spa_s = tl.load(r_lut_y + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
|
|
415
415
|
|
|
@@ -142,7 +142,7 @@ def repeat_build_lut(lut: dict, sparsity_layout_x: Tensor, repeats: tuple[int, i
|
|
|
142
142
|
n_sparse_blocks = torch.sum(lut["sparsity_layout_o"].to(torch.int)).item()
|
|
143
143
|
lut["n_sparse_blocks"] = n_sparse_blocks
|
|
144
144
|
|
|
145
|
-
validate_contiguous(sparsity_layout_o, lut["sparsity_lut"], lut["sparsity_reverse_lut"])
|
|
145
|
+
validate_contiguous(lut["sparsity_layout_o"], lut["sparsity_lut"], lut["sparsity_reverse_lut"])
|
|
146
146
|
|
|
147
147
|
return lut
|
|
148
148
|
|
|
@@ -178,7 +178,7 @@ def repeat_interleave_build_lut(lut: dict, sparsity_layout_x: Tensor, repeats: i
|
|
|
178
178
|
n_sparse_blocks = torch.sum(lut["sparsity_layout_o"].to(torch.int)).item()
|
|
179
179
|
lut["n_sparse_blocks"] = n_sparse_blocks
|
|
180
180
|
|
|
181
|
-
validate_contiguous(sparsity_layout_o, lut["sparsity_lut"], lut["sparsity_reverse_lut"])
|
|
181
|
+
validate_contiguous(lut["sparsity_layout_o"], lut["sparsity_lut"], lut["sparsity_reverse_lut"])
|
|
182
182
|
|
|
183
183
|
return lut
|
|
184
184
|
|
|
@@ -176,13 +176,12 @@ def softmax_kernel(x,
|
|
|
176
176
|
pid_col = tl.program_id(axis=2)
|
|
177
177
|
|
|
178
178
|
# Get position of current sparsity block consisting of its batch and row index
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
179
|
+
spa_val_idx = pid_blk * s_lut_r_s + tl.arange(0, 4) * s_lut_c_s
|
|
180
|
+
spa_val_msk = (tl.arange(0, 4) < 3)
|
|
181
|
+
spa_val = tl.load(s_lut + spa_val_idx, mask=spa_val_msk)
|
|
182
182
|
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
183
|
+
spa_bat = tl.sum(spa_val * (tl.arange(0, 4) == 0))
|
|
184
|
+
spa_row = tl.sum(spa_val * (tl.arange(0, 4) == 1))
|
|
186
185
|
|
|
187
186
|
# Get reverse sparsity indices for s
|
|
188
187
|
rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
|
|
@@ -241,13 +240,12 @@ def softmax_kernel_grad(g,
|
|
|
241
240
|
pid_col = tl.program_id(axis=2)
|
|
242
241
|
|
|
243
242
|
# Get position of current sparsity block consisting of its batch and row index
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
243
|
+
spa_val_idx = pid_blk * s_lut_r_s + tl.arange(0, 4) * s_lut_c_s
|
|
244
|
+
spa_val_msk = (tl.arange(0, 4) < 3)
|
|
245
|
+
spa_val = tl.load(s_lut + spa_val_idx, mask=spa_val_msk)
|
|
247
246
|
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
247
|
+
spa_bat = tl.sum(spa_val * (tl.arange(0, 4) == 0))
|
|
248
|
+
spa_row = tl.sum(spa_val * (tl.arange(0, 4) == 1))
|
|
251
249
|
|
|
252
250
|
rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
|
|
253
251
|
spa_row * s_l_s_r_s)
|
|
@@ -1,9 +1,17 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
from torch import Tensor
|
|
3
3
|
|
|
4
|
+
CONTIGUOUS = True
|
|
4
5
|
VALIDATION = True
|
|
5
6
|
|
|
6
7
|
|
|
8
|
+
def ensure_contiguous(*tensors: Tensor) -> tuple[Tensor, ...]:
|
|
9
|
+
if _check_skip_contiguous():
|
|
10
|
+
return tensors
|
|
11
|
+
|
|
12
|
+
return tuple(tensor.contiguous() for tensor in tensors)
|
|
13
|
+
|
|
14
|
+
|
|
7
15
|
def validate_dimensions(*tensors: Tensor, dims=3) -> None:
|
|
8
16
|
if _check_skip_validation():
|
|
9
17
|
return
|
|
@@ -124,6 +132,19 @@ def validate_sparsity_block_size(sparsity_block_size: int, *tensors):
|
|
|
124
132
|
raise ValueError("Tensor sizes must be divisible by sparsity block size")
|
|
125
133
|
|
|
126
134
|
|
|
135
|
+
def _check_skip_contiguous():
|
|
136
|
+
return not CONTIGUOUS
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def _set_skip_contiguous(skip_contiguous: bool):
|
|
140
|
+
global CONTIGUOUS
|
|
141
|
+
CONTIGUOUS = not skip_contiguous
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def disable_contiguous():
|
|
145
|
+
_set_skip_contiguous(True)
|
|
146
|
+
|
|
147
|
+
|
|
127
148
|
def _check_skip_validation():
|
|
128
149
|
return not VALIDATION
|
|
129
150
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 2.1.
|
|
3
|
+
Version: 2.1.5
|
|
4
4
|
Summary: A lightweight library for operations on block-sparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -20,7 +20,8 @@ Requires-Dist: matplotlib; extra == "test"
|
|
|
20
20
|
# blksprs
|
|
21
21
|
|
|
22
22
|
[](https://github.com/FelixSchoen/blksprs/releases)
|
|
23
|
-
[](https://www.python.org/downloads/release/python-3119/)
|
|
24
|
+
[](https://www.python.org/downloads/release/python-31210/)
|
|
24
25
|
|
|
25
26
|
## Overview
|
|
26
27
|
|
|
@@ -75,9 +76,7 @@ _* see the [Roadmap](#roadmap) section for more information_
|
|
|
75
76
|
|
|
76
77
|
## Installation
|
|
77
78
|
|
|
78
|
-
Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible
|
|
79
|
-
with
|
|
80
|
-
the Linux platform**.
|
|
79
|
+
Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible with the Linux platform**.
|
|
81
80
|
Keep track of this [issue](https://github.com/triton-lang/triton/issues/1640) for updates.
|
|
82
81
|
|
|
83
82
|
We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) using pip:
|
|
@@ -86,8 +85,8 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
|
|
|
86
85
|
|
|
87
86
|
### Dependencies
|
|
88
87
|
|
|
89
|
-
- [PyTorch](https://pytorch.org/) (built with v2.
|
|
90
|
-
- _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.
|
|
88
|
+
- [PyTorch](https://pytorch.org/) (built with v2.7.1)
|
|
89
|
+
- _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.3.1)_
|
|
91
90
|
- _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
|
|
92
91
|
|
|
93
92
|
## Changelog
|
|
@@ -103,7 +102,7 @@ We will continue to maintain the library and fix any issues that arise.
|
|
|
103
102
|
Should you find any bugs please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
|
|
104
103
|
We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
|
|
105
104
|
|
|
106
|
-
It might be that this changes with future projects, but as of
|
|
105
|
+
It might be that this changes with future projects, but as of June 2025, we are content with the current state of the
|
|
107
106
|
library.
|
|
108
107
|
|
|
109
108
|
## Known Limitations and Issues
|
|
@@ -112,9 +111,6 @@ library.
|
|
|
112
111
|
In order to work around this bug a manual conversion of some values is needed, (slightly) negatively impacting
|
|
113
112
|
performance.
|
|
114
113
|
Watch the [issue](https://github.com/triton-lang/triton/issues/6376) on Triton's issue tracker for more information.
|
|
115
|
-
- PyTorch's `wrap_triton()` currently does not support config pruning. It thus cannot be used for some of the kernels,
|
|
116
|
-
which could impact graph compilation.
|
|
117
|
-
- There seem to be some issues with autocasting, forcing some operations to manually cast.
|
|
118
114
|
- There will be some slight numerical differences between vanilla and blksprs operations.
|
|
119
115
|
These instabilities are due to Triton and thus cannot be fixed by this library alone.
|
|
120
116
|
However, for all intents and purposes, these very minor differences should not matter and can safely be ignored.
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|