blksprs 2.1.7__tar.gz → 2.1.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {blksprs-2.1.7 → blksprs-2.1.9}/PKG-INFO +3 -3
- {blksprs-2.1.7 → blksprs-2.1.9}/README.md +1 -1
- {blksprs-2.1.7 → blksprs-2.1.9}/blksprs/__init__.py +10 -2
- {blksprs-2.1.7 → blksprs-2.1.9}/blksprs/layouting/distribution_layout.py +11 -8
- {blksprs-2.1.7 → blksprs-2.1.9}/blksprs/layouting/sparsity_layout.py +15 -10
- {blksprs-2.1.7 → blksprs-2.1.9}/blksprs/ops/conversion.py +28 -25
- {blksprs-2.1.7 → blksprs-2.1.9}/blksprs/ops/distribution.py +28 -28
- {blksprs-2.1.7 → blksprs-2.1.9}/blksprs/ops/flow.py +13 -11
- {blksprs-2.1.7 → blksprs-2.1.9}/blksprs/ops/matmul.py +16 -14
- {blksprs-2.1.7 → blksprs-2.1.9}/blksprs/ops/misc/broadcast_ops.py +8 -8
- {blksprs-2.1.7 → blksprs-2.1.9}/blksprs/ops/misc/row_wise.py +24 -23
- {blksprs-2.1.7 → blksprs-2.1.9}/blksprs/ops/partitioning.py +2 -2
- {blksprs-2.1.7 → blksprs-2.1.9}/blksprs/ops/repeat.py +2 -2
- {blksprs-2.1.7 → blksprs-2.1.9}/blksprs/ops/softmax.py +38 -33
- {blksprs-2.1.7 → blksprs-2.1.9}/blksprs/ops/transpose.py +3 -3
- {blksprs-2.1.7 → blksprs-2.1.9}/blksprs/utils/autotuning.py +1 -1
- {blksprs-2.1.7 → blksprs-2.1.9}/blksprs/utils/blksprs_tensor.py +10 -1
- {blksprs-2.1.7 → blksprs-2.1.9}/blksprs/utils/processing.py +2 -1
- {blksprs-2.1.7 → blksprs-2.1.9}/blksprs/utils/tools.py +2 -5
- {blksprs-2.1.7 → blksprs-2.1.9}/blksprs.egg-info/PKG-INFO +3 -3
- {blksprs-2.1.7 → blksprs-2.1.9}/blksprs.egg-info/requires.txt +1 -1
- {blksprs-2.1.7 → blksprs-2.1.9}/pyproject.toml +2 -2
- {blksprs-2.1.7 → blksprs-2.1.9}/blksprs/utils/benchmarking.py +0 -0
- {blksprs-2.1.7 → blksprs-2.1.9}/blksprs/utils/validation.py +0 -0
- {blksprs-2.1.7 → blksprs-2.1.9}/blksprs.egg-info/SOURCES.txt +0 -0
- {blksprs-2.1.7 → blksprs-2.1.9}/blksprs.egg-info/dependency_links.txt +0 -0
- {blksprs-2.1.7 → blksprs-2.1.9}/blksprs.egg-info/top_level.txt +0 -0
- {blksprs-2.1.7 → blksprs-2.1.9}/setup.cfg +0 -0
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 2.1.
|
|
3
|
+
Version: 2.1.9
|
|
4
4
|
Summary: A lightweight library for operations on block-sparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
7
7
|
Project-URL: Bugtracker, https://github.com/FelixSchoen/blksprs/issues
|
|
8
8
|
Requires-Python: >=3.11
|
|
9
9
|
Description-Content-Type: text/markdown
|
|
10
|
-
Requires-Dist: torch
|
|
10
|
+
Requires-Dist: torch>=2.8.0
|
|
11
11
|
Requires-Dist: numpy
|
|
12
12
|
Provides-Extra: test
|
|
13
13
|
Requires-Dist: pytest; extra == "test"
|
|
@@ -102,7 +102,7 @@ We will continue to maintain the library and fix any issues that arise.
|
|
|
102
102
|
Should you find any bugs please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
|
|
103
103
|
We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
|
|
104
104
|
|
|
105
|
-
It might be that this changes with future projects, but as of
|
|
105
|
+
It might be that this changes with future projects, but as of August 2025, we are content with the current state of the
|
|
106
106
|
library.
|
|
107
107
|
|
|
108
108
|
## Known Limitations and Issues
|
|
@@ -83,7 +83,7 @@ We will continue to maintain the library and fix any issues that arise.
|
|
|
83
83
|
Should you find any bugs please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
|
|
84
84
|
We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
|
|
85
85
|
|
|
86
|
-
It might be that this changes with future projects, but as of
|
|
86
|
+
It might be that this changes with future projects, but as of August 2025, we are content with the current state of the
|
|
87
87
|
library.
|
|
88
88
|
|
|
89
89
|
## Known Limitations and Issues
|
|
@@ -1,6 +1,14 @@
|
|
|
1
|
-
|
|
1
|
+
# Settings
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
# Capture scalar outputs for JIT compilation
|
|
5
|
+
torch._dynamo.config.capture_scalar_outputs = True
|
|
6
|
+
# Set version
|
|
7
|
+
__version__ = "2.1.9"
|
|
2
8
|
|
|
3
|
-
|
|
9
|
+
# Imports
|
|
10
|
+
|
|
11
|
+
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
4
12
|
|
|
5
13
|
|
|
6
14
|
class ops:
|
|
@@ -7,9 +7,9 @@ from torch._library import triton_op
|
|
|
7
7
|
from torch._library.triton import wrap_triton
|
|
8
8
|
from triton import language as tl
|
|
9
9
|
|
|
10
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
10
11
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
11
12
|
from blksprs.utils.tools import stride
|
|
12
|
-
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
13
13
|
from blksprs.utils.validation import validate_dimensions, validate_device, \
|
|
14
14
|
validate_contiguous
|
|
15
15
|
|
|
@@ -98,22 +98,25 @@ def build_distribution_layout_kernel(i,
|
|
|
98
98
|
|
|
99
99
|
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
100
100
|
spa_bat_i_idx = (pid_blk * s_lut_i_r_s + 0 * s_lut_i_c_s)
|
|
101
|
-
spa_bat_i_msk = (spa_bat_i_idx >= 0
|
|
101
|
+
spa_bat_i_msk = ((spa_bat_i_idx >= 0) &
|
|
102
|
+
(spa_bat_i_idx < s_lut_i_r * s_lut_i_r_s))
|
|
102
103
|
spa_bat_i = tl.load(s_lut_i + spa_bat_i_idx, mask=spa_bat_i_msk)
|
|
103
104
|
|
|
104
105
|
spa_row_i_idx = (pid_blk * s_lut_i_r_s + 1 * s_lut_i_c_s)
|
|
105
|
-
spa_row_i_msk = (spa_row_i_idx >= 0
|
|
106
|
+
spa_row_i_msk = ((spa_row_i_idx >= 0) &
|
|
107
|
+
(spa_row_i_idx < s_lut_i_r * s_lut_i_r_s))
|
|
106
108
|
spa_row_i = tl.load(s_lut_i + spa_row_i_idx, mask=spa_row_i_msk)
|
|
107
109
|
|
|
108
110
|
spa_col_i_idx = (pid_blk * s_lut_i_r_s + 2 * s_lut_i_c_s)
|
|
109
|
-
spa_col_i_msk = (spa_col_i_idx >= 0
|
|
111
|
+
spa_col_i_msk = ((spa_col_i_idx >= 0) &
|
|
112
|
+
(spa_col_i_idx < s_lut_i_r * s_lut_i_r_s))
|
|
110
113
|
spa_col_i = tl.load(s_lut_i + spa_col_i_idx, mask=spa_col_i_msk)
|
|
111
114
|
|
|
112
115
|
blk_i_idx = (pid_blk * i_b_s +
|
|
113
116
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
114
117
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
115
|
-
blk_i_msk = (blk_i_idx >= 0
|
|
116
|
-
blk_i_idx < i_b * i_b_s)
|
|
118
|
+
blk_i_msk = ((blk_i_idx >= 0) &
|
|
119
|
+
(blk_i_idx < i_b * i_b_s))
|
|
117
120
|
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk)
|
|
118
121
|
|
|
119
122
|
dst_bat_idx = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_i, dtype=tl.int32)
|
|
@@ -131,6 +134,6 @@ def build_distribution_layout_kernel(i,
|
|
|
131
134
|
blk_o_idx = ((dst_bat_idx * o_b_s) +
|
|
132
135
|
(dst_row_idx * o_r_s) +
|
|
133
136
|
(dst_col_idx * o_c_s))
|
|
134
|
-
blk_o_msk = (blk_o_idx >= 0
|
|
135
|
-
blk_o_idx < o_b * o_b_s)
|
|
137
|
+
blk_o_msk = ((blk_o_idx >= 0) &
|
|
138
|
+
(blk_o_idx < o_b * o_b_s))
|
|
136
139
|
tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
|
|
@@ -6,9 +6,9 @@ from torch import Tensor
|
|
|
6
6
|
from torch._library.triton import wrap_triton, triton_op
|
|
7
7
|
from triton import language as tl
|
|
8
8
|
|
|
9
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs, prune_autotune_configs_conversion
|
|
9
10
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
10
11
|
from blksprs.utils.tools import stride
|
|
11
|
-
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs, prune_autotune_configs_conversion
|
|
12
12
|
from blksprs.utils.validation import validate_dimensions, validate_device, \
|
|
13
13
|
validate_contiguous, validate_sparsity, validate_sparsity_block_size
|
|
14
14
|
|
|
@@ -79,8 +79,8 @@ def build_sparsity_layout_kernel(x,
|
|
|
79
79
|
blk_x_idx = (pid_bat * x_b_s +
|
|
80
80
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
81
81
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
82
|
-
blk_x_msk = (blk_x_idx >= 0
|
|
83
|
-
blk_x_idx < x_b * x_b_s)
|
|
82
|
+
blk_x_msk = ((blk_x_idx >= 0) &
|
|
83
|
+
(blk_x_idx < x_b * x_b_s))
|
|
84
84
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
85
85
|
|
|
86
86
|
# Store sparsity layout value
|
|
@@ -88,7 +88,8 @@ def build_sparsity_layout_kernel(x,
|
|
|
88
88
|
blk_o_idx = (pid_bat * o_b_s +
|
|
89
89
|
(((pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_r_s +
|
|
90
90
|
((pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_c_s))
|
|
91
|
-
blk_o_msk = (blk_o_idx >= 0
|
|
91
|
+
blk_o_msk = ((blk_o_idx >= 0) &
|
|
92
|
+
(blk_o_idx < o_b * o_b_s))
|
|
92
93
|
tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
|
|
93
94
|
|
|
94
95
|
|
|
@@ -178,23 +179,26 @@ def build_sparsity_layout_adaption_kernel(x,
|
|
|
178
179
|
|
|
179
180
|
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
180
181
|
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
181
|
-
spa_bat_msk = (spa_bat_idx >= 0
|
|
182
|
+
spa_bat_msk = ((spa_bat_idx >= 0) &
|
|
183
|
+
(spa_bat_idx < s_lut_r * s_lut_r_s))
|
|
182
184
|
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
183
185
|
|
|
184
186
|
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
185
|
-
spa_row_msk = (spa_row_idx >= 0
|
|
187
|
+
spa_row_msk = ((spa_row_idx >= 0) &
|
|
188
|
+
(spa_row_idx < s_lut_r * s_lut_r_s))
|
|
186
189
|
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
187
190
|
|
|
188
191
|
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
189
|
-
spa_col_msk = (spa_col_idx >= 0
|
|
192
|
+
spa_col_msk = ((spa_col_idx >= 0) &
|
|
193
|
+
(spa_col_idx < s_lut_r * s_lut_r_s))
|
|
190
194
|
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
191
195
|
|
|
192
196
|
# Load x values
|
|
193
197
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
194
198
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
195
199
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
196
|
-
blk_x_msk = (blk_x_idx >= 0
|
|
197
|
-
blk_x_idx < x_b * x_b_s)
|
|
200
|
+
blk_x_msk = ((blk_x_idx >= 0) &
|
|
201
|
+
(blk_x_idx < x_b * x_b_s))
|
|
198
202
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
199
203
|
|
|
200
204
|
# Store sparsity layout value
|
|
@@ -204,7 +208,8 @@ def build_sparsity_layout_adaption_kernel(x,
|
|
|
204
208
|
// sparsity_block_size_to) * o_r_s) +
|
|
205
209
|
(((pid_col * TRITON_BLOCK_SIZE + spa_col * sparsity_block_size_from)
|
|
206
210
|
// sparsity_block_size_to) * o_c_s))
|
|
207
|
-
blk_o_msk = (blk_o_idx >= 0
|
|
211
|
+
blk_o_msk = ((blk_o_idx >= 0) &
|
|
212
|
+
(blk_o_idx < o_b * o_b_s))
|
|
208
213
|
tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
|
|
209
214
|
|
|
210
215
|
|
|
@@ -5,9 +5,9 @@ from torch._library.triton import wrap_triton, triton_op
|
|
|
5
5
|
from triton import language as tl
|
|
6
6
|
|
|
7
7
|
from blksprs.layouting.sparsity_layout import build_sparsity_layout_adaption
|
|
8
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs, prune_autotune_configs_conversion
|
|
8
9
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
9
10
|
from blksprs.utils.tools import stride
|
|
10
|
-
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs, prune_autotune_configs_conversion
|
|
11
11
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
12
12
|
validate_sparsity, validate_sparsity_block_size, validate_sparsity_dense, ensure_contiguous
|
|
13
13
|
|
|
@@ -46,10 +46,10 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor,
|
|
|
46
46
|
lut = to_sparse_build_lut(lut, sparsity_layout)
|
|
47
47
|
|
|
48
48
|
if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
|
|
49
|
-
return BlksprsTensor(x)
|
|
49
|
+
return BlksprsTensor.wrap(x)
|
|
50
50
|
|
|
51
|
-
return BlksprsTensor(to_sparse_forward(x, sparsity_layout,
|
|
52
|
-
|
|
51
|
+
return BlksprsTensor.wrap(to_sparse_forward(x, sparsity_layout,
|
|
52
|
+
lut["sparsity_lut"], sparsity_block_size, lut["n_sparse_blocks"]))
|
|
53
53
|
|
|
54
54
|
|
|
55
55
|
@triton_op("blksprs::to_sparse_forward", mutates_args={})
|
|
@@ -120,16 +120,16 @@ def to_sparse_kernel(x,
|
|
|
120
120
|
tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
121
121
|
((pid_col * TRITON_BLOCK_SIZE + spa_col * sparsity_block_size +
|
|
122
122
|
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
123
|
-
blk_d_msk = (blk_d_idx >= 0
|
|
124
|
-
blk_d_idx < x_b * x_b_s)
|
|
123
|
+
blk_d_msk = ((blk_d_idx >= 0) &
|
|
124
|
+
(blk_d_idx < x_b * x_b_s))
|
|
125
125
|
blk_d = tl.load(x + blk_d_idx, mask=blk_d_msk)
|
|
126
126
|
|
|
127
127
|
# Store block in sparse tensor
|
|
128
128
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
129
129
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
130
130
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
|
|
131
|
-
blk_o_msk = (blk_o_idx >= 0
|
|
132
|
-
blk_o_idx < (pid_blk + 1) * o_b_s)
|
|
131
|
+
blk_o_msk = ((blk_o_idx >= 0) &
|
|
132
|
+
(blk_o_idx < (pid_blk + 1) * o_b_s))
|
|
133
133
|
tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
|
|
134
134
|
|
|
135
135
|
|
|
@@ -201,7 +201,7 @@ def to_dense(x: BlksprsTensor, sparsity_layout: Tensor,
|
|
|
201
201
|
return x
|
|
202
202
|
|
|
203
203
|
return Tensor(to_dense_forward(x, sparsity_layout,
|
|
204
|
-
|
|
204
|
+
lut["sparsity_reverse_lut"], sparsity_block_size, fill_value))
|
|
205
205
|
|
|
206
206
|
|
|
207
207
|
@triton_op("blksprs::to_dense_forward", mutates_args={})
|
|
@@ -269,7 +269,8 @@ def to_dense_kernel(x,
|
|
|
269
269
|
|
|
270
270
|
# Get reverse sparsity index for current block
|
|
271
271
|
rev_idx_spa_idx = (pid_blk * s_l_b_s + spa_row * s_l_r_s + spa_col * s_l_c_s)
|
|
272
|
-
rev_idx_spa_msk = (rev_idx_spa_idx >= 0
|
|
272
|
+
rev_idx_spa_msk = ((rev_idx_spa_idx >= 0) &
|
|
273
|
+
(rev_idx_spa_idx < s_l_b * s_l_b_s))
|
|
273
274
|
rev_idx_spa = tl.load(sparsity_reverse_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
274
275
|
|
|
275
276
|
# If block is present commence operations
|
|
@@ -279,14 +280,15 @@ def to_dense_kernel(x,
|
|
|
279
280
|
tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
280
281
|
(((pid_col % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
|
|
281
282
|
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
282
|
-
blk_msk = (blk_idx >= 0
|
|
283
|
-
blk_idx < x_b * x_b_s)
|
|
283
|
+
blk_msk = ((blk_idx >= 0) &
|
|
284
|
+
(blk_idx < x_b * x_b_s))
|
|
284
285
|
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
285
286
|
|
|
286
287
|
o_idx = (pid_blk * o_b_s +
|
|
287
288
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
288
289
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
289
|
-
o_msk = (o_idx >= 0
|
|
290
|
+
o_msk = ((o_idx >= 0) &
|
|
291
|
+
(o_idx < o_b * o_b_s))
|
|
290
292
|
tl.store(o + o_idx, blk, o_msk)
|
|
291
293
|
|
|
292
294
|
|
|
@@ -360,14 +362,14 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
|
|
|
360
362
|
validate_contiguous(sparsity_reverse_lut_from, sparsity_layout_to, sparsity_lut_to)
|
|
361
363
|
|
|
362
364
|
if (sparsity_block_size_from == sparsity_block_size_to) and torch.equal(sparsity_layout_from, sparsity_layout_to):
|
|
363
|
-
return BlksprsTensor(x), sparsity_layout_to
|
|
365
|
+
return BlksprsTensor.wrap(x), sparsity_layout_to
|
|
364
366
|
|
|
365
|
-
return BlksprsTensor(adapt_layout_forward(x,
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
367
|
+
return BlksprsTensor.wrap(adapt_layout_forward(x,
|
|
368
|
+
sparsity_layout_from, sparsity_reverse_lut_from,
|
|
369
|
+
sparsity_block_size_from,
|
|
370
|
+
sparsity_layout_to, sparsity_lut_to,
|
|
371
|
+
sparsity_block_size_to,
|
|
372
|
+
n_sparse_blocks_to)), sparsity_layout_to
|
|
371
373
|
|
|
372
374
|
|
|
373
375
|
@triton_op("blksprs::adapt_layout_forward", mutates_args={})
|
|
@@ -458,7 +460,8 @@ def adapt_layout_kernel(x,
|
|
|
458
460
|
rev_idx_spa_x_idx = (spa_bat_x * s_l_x_b_s +
|
|
459
461
|
spa_row_x * s_l_x_r_s +
|
|
460
462
|
spa_col_x * s_l_x_c_s)
|
|
461
|
-
rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0
|
|
463
|
+
rev_idx_spa_x_msk = ((rev_idx_spa_x_idx >= 0) &
|
|
464
|
+
(rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s))
|
|
462
465
|
rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
|
|
463
466
|
|
|
464
467
|
# If block is present commence operations
|
|
@@ -473,16 +476,16 @@ def adapt_layout_kernel(x,
|
|
|
473
476
|
blk_x_idx = ((rev_idx_spa_x * x_b_s) +
|
|
474
477
|
((shift_row_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
475
478
|
((shift_col_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
476
|
-
blk_x_msk = (blk_x_idx >= 0
|
|
477
|
-
blk_x_idx < x_b * x_b_s)
|
|
479
|
+
blk_x_msk = ((blk_x_idx >= 0) &
|
|
480
|
+
(blk_x_idx < x_b * x_b_s))
|
|
478
481
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
479
482
|
|
|
480
483
|
# Store output
|
|
481
484
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
482
485
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
483
486
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
484
|
-
blk_o_msk = (blk_o_idx >= 0
|
|
485
|
-
blk_o_idx < o_b * o_b_s)
|
|
487
|
+
blk_o_msk = ((blk_o_idx >= 0) &
|
|
488
|
+
(blk_o_idx < o_b * o_b_s))
|
|
486
489
|
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
487
490
|
|
|
488
491
|
|
|
@@ -5,9 +5,9 @@ from torch._library import triton_op
|
|
|
5
5
|
from torch._library.triton import wrap_triton
|
|
6
6
|
from triton import language as tl
|
|
7
7
|
|
|
8
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
8
9
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
9
10
|
from blksprs.utils.tools import stride
|
|
10
|
-
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
11
11
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
12
12
|
validate_sparsity, validate_dtype_int, validate_sparsity_block_size, ensure_contiguous
|
|
13
13
|
|
|
@@ -45,9 +45,9 @@ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
45
45
|
|
|
46
46
|
lut = gather_build_lut(lut, sparsity_layout_src, sparsity_layout_idx)
|
|
47
47
|
|
|
48
|
-
return BlksprsTensor(gather_forward(src, sparsity_layout_src, lut["sparsity_reverse_lut_x"],
|
|
49
|
-
|
|
50
|
-
|
|
48
|
+
return BlksprsTensor.wrap(gather_forward(src, sparsity_layout_src, lut["sparsity_reverse_lut_x"],
|
|
49
|
+
adjusted_dim, idx, sparsity_layout_idx, lut["sparsity_lut_i"],
|
|
50
|
+
sparsity_block_size))
|
|
51
51
|
|
|
52
52
|
|
|
53
53
|
@triton_op("blksprs::gather_forward", mutates_args={})
|
|
@@ -136,8 +136,8 @@ def gather_kernel(x,
|
|
|
136
136
|
blk_i_idx = ((pid_blk * i_b_s) +
|
|
137
137
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
138
138
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
139
|
-
blk_i_msk = (blk_i_idx >= 0
|
|
140
|
-
blk_i_idx < i_b * i_b_s)
|
|
139
|
+
blk_i_msk = ((blk_i_idx >= 0) &
|
|
140
|
+
(blk_i_idx < i_b * i_b_s))
|
|
141
141
|
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
|
|
142
142
|
|
|
143
143
|
# Get indices of sparsity blocks and positions within the blocks
|
|
@@ -164,26 +164,26 @@ def gather_kernel(x,
|
|
|
164
164
|
rev_idx_spa_x_idx = ((rev_dst_bat_x * s_l_x_b_s) +
|
|
165
165
|
(rev_dst_row_x * s_l_x_r_s) +
|
|
166
166
|
(rev_dst_col_x * s_l_x_c_s))
|
|
167
|
-
rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0
|
|
168
|
-
rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
|
|
167
|
+
rev_idx_spa_x_msk = ((rev_idx_spa_x_idx >= 0) &
|
|
168
|
+
(rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s))
|
|
169
169
|
rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
|
|
170
170
|
|
|
171
171
|
# Load x values
|
|
172
172
|
blk_x_idx = ((rev_idx_spa_x * x_b_s) +
|
|
173
173
|
dst_row_x +
|
|
174
174
|
dst_col_x)
|
|
175
|
-
blk_x_msk = ((blk_x_idx >= 0
|
|
176
|
-
blk_x_idx < x_b * x_b_s)
|
|
177
|
-
rev_idx_spa_x_msk != -1)
|
|
175
|
+
blk_x_msk = (((blk_x_idx >= 0) &
|
|
176
|
+
(blk_x_idx < x_b * x_b_s)) &
|
|
177
|
+
(rev_idx_spa_x_msk != -1))
|
|
178
178
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
179
179
|
|
|
180
180
|
# Store output
|
|
181
181
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
182
182
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
183
183
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
184
|
-
blk_o_msk = ((blk_o_idx >= 0
|
|
185
|
-
blk_o_idx < o_b * o_b_s)
|
|
186
|
-
rev_idx_spa_x_msk != -1)
|
|
184
|
+
blk_o_msk = (((blk_o_idx >= 0) &
|
|
185
|
+
(blk_o_idx < o_b * o_b_s)) &
|
|
186
|
+
(rev_idx_spa_x_msk != -1))
|
|
187
187
|
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
188
188
|
|
|
189
189
|
|
|
@@ -276,11 +276,11 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
276
276
|
|
|
277
277
|
lut = scatter_reduce_build_lut(lut, sparsity_layout_src, sparsity_layout_tgt)
|
|
278
278
|
|
|
279
|
-
return BlksprsTensor(scatter_reduce_forward(src, sparsity_layout_src, lut["sparsity_lut_x"],
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
279
|
+
return BlksprsTensor.wrap(scatter_reduce_forward(src, sparsity_layout_src, lut["sparsity_lut_x"],
|
|
280
|
+
adjusted_dim, idx,
|
|
281
|
+
sparsity_layout_tgt, lut["sparsity_reverse_lut_o"],
|
|
282
|
+
sparsity_block_size, lut["n_sparse_blocks"],
|
|
283
|
+
reduce_op))
|
|
284
284
|
|
|
285
285
|
|
|
286
286
|
@triton_op("blksprs::scatter_reduce_forward", mutates_args={})
|
|
@@ -380,16 +380,16 @@ def scatter_reduce_kernel(x,
|
|
|
380
380
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
381
381
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
382
382
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
383
|
-
blk_x_msk = (blk_x_idx >= 0
|
|
384
|
-
blk_x_idx < x_b * x_b_s)
|
|
383
|
+
blk_x_msk = ((blk_x_idx >= 0) &
|
|
384
|
+
(blk_x_idx < x_b * x_b_s))
|
|
385
385
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
386
386
|
|
|
387
387
|
# Load index values
|
|
388
388
|
blk_i_idx = ((pid_blk * i_b_s) +
|
|
389
389
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
390
390
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
391
|
-
blk_i_msk = (blk_i_idx >= 0
|
|
392
|
-
blk_i_idx < i_b * i_b_s)
|
|
391
|
+
blk_i_msk = ((blk_i_idx >= 0) &
|
|
392
|
+
(blk_i_idx < i_b * i_b_s))
|
|
393
393
|
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
|
|
394
394
|
|
|
395
395
|
# Get indices of sparsity blocks and positions within the blocks
|
|
@@ -416,17 +416,17 @@ def scatter_reduce_kernel(x,
|
|
|
416
416
|
rev_idx_spa_o_idx = ((rev_dst_bat_o * s_l_o_b_s) +
|
|
417
417
|
(rev_dst_row_o * s_l_o_r_s) +
|
|
418
418
|
(rev_dst_col_o * s_l_o_c_s))
|
|
419
|
-
rev_idx_spa_o_msk = (rev_idx_spa_o_idx >= 0
|
|
420
|
-
rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
|
|
419
|
+
rev_idx_spa_o_msk = ((rev_idx_spa_o_idx >= 0) &
|
|
420
|
+
(rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s))
|
|
421
421
|
rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
|
|
422
422
|
|
|
423
423
|
# Store output
|
|
424
424
|
blk_o_idx = ((rev_idx_spa_o * o_b_s) +
|
|
425
425
|
dst_row_o +
|
|
426
426
|
dst_col_o)
|
|
427
|
-
blk_o_msk = ((blk_o_idx >= 0
|
|
428
|
-
blk_o_idx < o_b * o_b_s)
|
|
429
|
-
rev_idx_spa_o_msk != -1)
|
|
427
|
+
blk_o_msk = (((blk_o_idx >= 0) &
|
|
428
|
+
(blk_o_idx < o_b * o_b_s)) &
|
|
429
|
+
(rev_idx_spa_o_msk != -1))
|
|
430
430
|
|
|
431
431
|
if reduce_op_ind == 0:
|
|
432
432
|
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
@@ -5,8 +5,8 @@ from torch._library import triton_op
|
|
|
5
5
|
from torch._library.triton import wrap_triton
|
|
6
6
|
from triton import language as tl
|
|
7
7
|
|
|
8
|
-
from blksprs.utils.tools import stride
|
|
9
8
|
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
9
|
+
from blksprs.utils.tools import stride
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
@triton_op("blksprs::flow_pull_forward", mutates_args={})
|
|
@@ -78,22 +78,23 @@ def flow_pull_kernel(x,
|
|
|
78
78
|
rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
|
|
79
79
|
spa_row * s_l_o_r_s +
|
|
80
80
|
spa_col * s_l_o_c_s)
|
|
81
|
-
rev_idx_spa_msk = (rev_idx_spa_idx >= 0
|
|
81
|
+
rev_idx_spa_msk = ((rev_idx_spa_idx >= 0) &
|
|
82
|
+
(rev_idx_spa_idx < s_l_o_b * s_l_o_b_s))
|
|
82
83
|
rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
83
84
|
|
|
84
85
|
if rev_idx_spa >= 0:
|
|
85
86
|
blk_x_idx = (rev_idx_spa * x_b_s +
|
|
86
87
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
87
88
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
88
|
-
blk_x_msk = (blk_x_idx >= 0
|
|
89
|
-
blk_x_idx < x_b * x_b_s)
|
|
89
|
+
blk_x_msk = ((blk_x_idx >= 0) &
|
|
90
|
+
(blk_x_idx < x_b * x_b_s))
|
|
90
91
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
91
92
|
|
|
92
93
|
blk_o_idx = (pid_blk * o_b_s +
|
|
93
94
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
94
95
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
95
|
-
blk_o_msk = (blk_o_idx >= 0
|
|
96
|
-
blk_o_idx < o_b * o_b_s)
|
|
96
|
+
blk_o_msk = ((blk_o_idx >= 0) &
|
|
97
|
+
(blk_o_idx < o_b * o_b_s))
|
|
97
98
|
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
98
99
|
|
|
99
100
|
|
|
@@ -165,20 +166,21 @@ def flow_push_kernel(x,
|
|
|
165
166
|
rev_idx_spa_idx = (spa_bat * s_l_x_b_s +
|
|
166
167
|
spa_row * s_l_x_r_s +
|
|
167
168
|
spa_col * s_l_x_c_s)
|
|
168
|
-
rev_idx_spa_msk = (rev_idx_spa_idx >= 0
|
|
169
|
+
rev_idx_spa_msk = ((rev_idx_spa_idx >= 0) &
|
|
170
|
+
(rev_idx_spa_idx < s_l_x_b * s_l_x_b_s))
|
|
169
171
|
rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
170
172
|
|
|
171
173
|
if rev_idx_spa >= 0:
|
|
172
174
|
blk_x_idx = (pid_blk * x_b_s +
|
|
173
175
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
174
176
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
175
|
-
blk_x_msk = (blk_x_idx >= 0
|
|
176
|
-
blk_x_idx < x_b * x_b_s)
|
|
177
|
+
blk_x_msk = ((blk_x_idx >= 0) &
|
|
178
|
+
(blk_x_idx < x_b * x_b_s))
|
|
177
179
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
178
180
|
|
|
179
181
|
blk_o_idx = (rev_idx_spa * o_b_s +
|
|
180
182
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
181
183
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
182
|
-
blk_o_msk = (blk_o_idx >= 0
|
|
183
|
-
blk_o_idx < o_b * o_b_s)
|
|
184
|
+
blk_o_msk = ((blk_o_idx >= 0) &
|
|
185
|
+
(blk_o_idx < o_b * o_b_s))
|
|
184
186
|
tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
@@ -5,9 +5,9 @@ from torch.library import triton_op, wrap_triton
|
|
|
5
5
|
from triton import language as tl
|
|
6
6
|
|
|
7
7
|
from blksprs.ops.transpose import transpose
|
|
8
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
8
9
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
9
10
|
from blksprs.utils.tools import stride
|
|
10
|
-
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
11
11
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
12
12
|
validate_sparsity, validate_sparsity_block_size, validate_dtype_float, ensure_contiguous
|
|
13
13
|
|
|
@@ -47,11 +47,11 @@ def matmul(x: BlksprsTensor, sparsity_layout_x: Tensor,
|
|
|
47
47
|
|
|
48
48
|
lut = matmul_build_lut(lut, sparsity_layout_x, sparsity_layout_y, sparsity_layout_output)
|
|
49
49
|
|
|
50
|
-
return BlksprsTensor(matmul_forward(x, y,
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
50
|
+
return BlksprsTensor.wrap(matmul_forward(x, y,
|
|
51
|
+
sparsity_layout_x, lut["sparsity_reverse_lut_x"],
|
|
52
|
+
sparsity_layout_y, lut["sparsity_reverse_lut_y"],
|
|
53
|
+
sparsity_layout_output, lut["sparsity_lut_o"],
|
|
54
|
+
sparsity_block_size, lut["n_sparse_blocks"]))
|
|
55
55
|
|
|
56
56
|
|
|
57
57
|
@triton_op("blksprs::matmul_forward", mutates_args={})
|
|
@@ -169,12 +169,14 @@ def matmul_kernel(x,
|
|
|
169
169
|
rev_idx_spa_x_idx = (spa_bat_o * s_l_x_b_s +
|
|
170
170
|
spa_row_o * s_l_x_r_s +
|
|
171
171
|
i_seg_spa * s_l_x_c_s)
|
|
172
|
-
rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0
|
|
172
|
+
rev_idx_spa_x_msk = ((rev_idx_spa_x_idx >= 0) &
|
|
173
|
+
(rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s))
|
|
173
174
|
rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
|
|
174
175
|
|
|
175
176
|
# Get reverse sparsity indices for y
|
|
176
177
|
rev_idx_spa_y_idx = (spa_bat_o * s_l_y_b_s + i_seg_spa * s_l_y_r_s + spa_col_o * s_l_y_c_s)
|
|
177
|
-
rev_idx_spa_y_msk = (rev_idx_spa_y_idx >= 0
|
|
178
|
+
rev_idx_spa_y_msk = ((rev_idx_spa_y_idx >= 0) &
|
|
179
|
+
(rev_idx_spa_y_idx < s_l_y_b * s_l_y_b_s))
|
|
178
180
|
rev_idx_spa_y = tl.load(r_lut_y + rev_idx_spa_y_idx, mask=rev_idx_spa_y_msk).to(tl.int32)
|
|
179
181
|
|
|
180
182
|
# If both blocks are present commence calculation
|
|
@@ -183,16 +185,16 @@ def matmul_kernel(x,
|
|
|
183
185
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
184
186
|
((i_seg_tri_mod * TRITON_BLOCK_SIZE +
|
|
185
187
|
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
186
|
-
blk_x_msk = (blk_x_idx >= 0
|
|
187
|
-
blk_x_idx < x_b * x_b_s)
|
|
188
|
+
blk_x_msk = ((blk_x_idx >= 0) &
|
|
189
|
+
(blk_x_idx < x_b * x_b_s))
|
|
188
190
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
189
191
|
|
|
190
192
|
blk_y_idx = ((rev_idx_spa_y * y_b_s) +
|
|
191
193
|
((i_seg_tri_mod * TRITON_BLOCK_SIZE +
|
|
192
194
|
tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
|
|
193
195
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
|
|
194
|
-
blk_y_msk = (blk_y_idx >= 0
|
|
195
|
-
blk_y_idx < y_b * y_b_s)
|
|
196
|
+
blk_y_msk = ((blk_y_idx >= 0) &
|
|
197
|
+
(blk_y_idx < y_b * y_b_s))
|
|
196
198
|
blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
|
|
197
199
|
|
|
198
200
|
# Perform matrix multiplication
|
|
@@ -205,8 +207,8 @@ def matmul_kernel(x,
|
|
|
205
207
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
206
208
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
207
209
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
208
|
-
blk_o_msk = (blk_o_idx >= 0
|
|
209
|
-
blk_o_idx < o_b * o_b_s)
|
|
210
|
+
blk_o_msk = ((blk_o_idx >= 0) &
|
|
211
|
+
(blk_o_idx < o_b * o_b_s))
|
|
210
212
|
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
|
211
213
|
|
|
212
214
|
|
|
@@ -5,9 +5,9 @@ from torch._library import triton_op
|
|
|
5
5
|
from torch._library.triton import wrap_triton
|
|
6
6
|
from triton import language as tl
|
|
7
7
|
|
|
8
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
8
9
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
9
10
|
from blksprs.utils.tools import stride
|
|
10
|
-
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
11
11
|
from blksprs.utils.validation import validate_contiguous, validate_device, \
|
|
12
12
|
validate_sparsity_block_size, ensure_contiguous
|
|
13
13
|
|
|
@@ -43,7 +43,7 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
|
43
43
|
|
|
44
44
|
validate_contiguous(sparsity_layout_output, sparsity_lut_o)
|
|
45
45
|
|
|
46
|
-
return BlksprsTensor(broadcast_add_forward(x, y, sparsity_lut_o, sparsity_block_size, n_sparse_blocks))
|
|
46
|
+
return BlksprsTensor.wrap(broadcast_add_forward(x, y, sparsity_lut_o, sparsity_block_size, n_sparse_blocks))
|
|
47
47
|
|
|
48
48
|
|
|
49
49
|
def broadcast_sub(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
@@ -121,16 +121,16 @@ def broadcast_add_kernel(x,
|
|
|
121
121
|
blk_x_idx = (spa_bat_o * x_b_s +
|
|
122
122
|
((pid_row * TRITON_BLOCK_SIZE + spa_row_o * sparsity_block_size +
|
|
123
123
|
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
124
|
-
blk_x_msk = (blk_x_idx >= 0
|
|
125
|
-
blk_x_idx < x_b * x_b_s)
|
|
124
|
+
blk_x_msk = ((blk_x_idx >= 0) &
|
|
125
|
+
(blk_x_idx < x_b * x_b_s))
|
|
126
126
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
127
127
|
|
|
128
128
|
# Load y block
|
|
129
129
|
blk_y_idx = (spa_bat_o * y_b_s +
|
|
130
130
|
((pid_col * TRITON_BLOCK_SIZE + spa_col_o * sparsity_block_size +
|
|
131
131
|
tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
|
|
132
|
-
blk_y_msk = (blk_y_idx >= 0
|
|
133
|
-
blk_y_idx < y_b * y_b_s)
|
|
132
|
+
blk_y_msk = ((blk_y_idx >= 0) &
|
|
133
|
+
(blk_y_idx < y_b * y_b_s))
|
|
134
134
|
blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
|
|
135
135
|
|
|
136
136
|
# Compute sum
|
|
@@ -141,6 +141,6 @@ def broadcast_add_kernel(x,
|
|
|
141
141
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
142
142
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
143
143
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
144
|
-
blk_o_msk = (blk_o_idx >= 0
|
|
145
|
-
blk_o_idx < o_b * o_b_s)
|
|
144
|
+
blk_o_msk = ((blk_o_idx >= 0) &
|
|
145
|
+
(blk_o_idx < o_b * o_b_s))
|
|
146
146
|
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
|
@@ -55,7 +55,7 @@ def row_wise_sum(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
|
|
|
55
55
|
validate_contiguous(sparsity_layout, sparsity_lut,
|
|
56
56
|
sparsity_layout_output, sparsity_reverse_lut_output)
|
|
57
57
|
|
|
58
|
-
return BlksprsTensor(row_wise_sum_forward(
|
|
58
|
+
return BlksprsTensor.wrap(row_wise_sum_forward(
|
|
59
59
|
x, sparsity_lut, sparsity_layout_output, sparsity_reverse_lut_output,
|
|
60
60
|
sparsity_block_size, n_sparse_blocks_output, flag_slice_only)), sparsity_layout_output
|
|
61
61
|
|
|
@@ -130,15 +130,16 @@ def row_wise_sum_kernel(x,
|
|
|
130
130
|
# Load reverse sparsity index for current block
|
|
131
131
|
rev_idx_spa_idx = (spa_bat_x * s_l_o_b_s +
|
|
132
132
|
spa_row_x * s_l_o_r_s)
|
|
133
|
-
rev_idx_spa_msk = (rev_idx_spa_idx >= 0
|
|
133
|
+
rev_idx_spa_msk = ((rev_idx_spa_idx >= 0) &
|
|
134
|
+
(rev_idx_spa_idx < s_l_o_b * s_l_o_b_s))
|
|
134
135
|
rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
135
136
|
|
|
136
137
|
if rev_idx_spa >= 0:
|
|
137
138
|
blk_idx = ((pid_blk * x_b_s) +
|
|
138
139
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
139
140
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
140
|
-
blk_msk = (blk_idx >= 0
|
|
141
|
-
blk_idx < x_b * x_b_s)
|
|
141
|
+
blk_msk = ((blk_idx >= 0) &
|
|
142
|
+
(blk_idx < x_b * x_b_s))
|
|
142
143
|
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
143
144
|
|
|
144
145
|
buf = tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
|
|
@@ -146,8 +147,8 @@ def row_wise_sum_kernel(x,
|
|
|
146
147
|
o_idx = (rev_idx_spa * o_b_s +
|
|
147
148
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
148
149
|
(tl.arange(0, 1))[None, :])
|
|
149
|
-
o_msk = (o_idx >= 0
|
|
150
|
-
o_idx < o_b * o_b_s)
|
|
150
|
+
o_msk = ((o_idx >= 0) &
|
|
151
|
+
(o_idx < o_b * o_b_s))
|
|
151
152
|
tl.atomic_add(o + o_idx, buf, o_msk)
|
|
152
153
|
|
|
153
154
|
|
|
@@ -174,8 +175,6 @@ def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
|
|
|
174
175
|
of the input and the sparsity layout of the output tensor.
|
|
175
176
|
|
|
176
177
|
"""
|
|
177
|
-
# TODO Fix for triton bug, see https://github.com/triton-lang/triton/issues/6376, should be fixed with the upcoming 3.4.0 release
|
|
178
|
-
x = torch.where(x == -0.0, torch.tensor(0.0), x)
|
|
179
178
|
x = ensure_contiguous(x)
|
|
180
179
|
|
|
181
180
|
validate_dimensions(x)
|
|
@@ -197,7 +196,7 @@ def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
|
|
|
197
196
|
validate_contiguous(sparsity_layout, sparsity_lut,
|
|
198
197
|
sparsity_layout_output, sparsity_reverse_lut_output)
|
|
199
198
|
|
|
200
|
-
return BlksprsTensor(
|
|
199
|
+
return BlksprsTensor.wrap(
|
|
201
200
|
row_wise_max_forward(x, sparsity_lut, sparsity_layout_output, sparsity_reverse_lut_output, sparsity_block_size,
|
|
202
201
|
n_sparse_blocks_output, flag_slice_only)), sparsity_layout_output
|
|
203
202
|
|
|
@@ -274,15 +273,16 @@ def row_wise_max_kernel(x,
|
|
|
274
273
|
# Load reverse sparsity index for current block
|
|
275
274
|
rev_idx_spa_idx = (spa_bat_x * s_l_o_b_s +
|
|
276
275
|
spa_row_x * s_l_o_r_s)
|
|
277
|
-
rev_idx_spa_msk = (rev_idx_spa_idx >= 0
|
|
276
|
+
rev_idx_spa_msk = ((rev_idx_spa_idx >= 0) &
|
|
277
|
+
(rev_idx_spa_idx < s_l_o_b * s_l_o_b_s))
|
|
278
278
|
rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
279
279
|
|
|
280
280
|
if rev_idx_spa >= 0:
|
|
281
281
|
blk_idx = ((pid_blk * x_b_s) +
|
|
282
282
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
283
283
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
284
|
-
blk_msk = (blk_idx >= 0
|
|
285
|
-
blk_idx < x_b * x_b_s)
|
|
284
|
+
blk_msk = ((blk_idx >= 0) &
|
|
285
|
+
(blk_idx < x_b * x_b_s))
|
|
286
286
|
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
287
287
|
|
|
288
288
|
buf = tl.reshape(tl.max(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
|
|
@@ -290,8 +290,8 @@ def row_wise_max_kernel(x,
|
|
|
290
290
|
o_idx = (rev_idx_spa * o_b_s +
|
|
291
291
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
292
292
|
(tl.arange(0, 1))[None, :])
|
|
293
|
-
o_msk = (o_idx >= 0
|
|
294
|
-
o_idx < o_b * o_b_s)
|
|
293
|
+
o_msk = ((o_idx >= 0) &
|
|
294
|
+
(o_idx < o_b * o_b_s))
|
|
295
295
|
tl.atomic_max(o + o_idx, buf, o_msk)
|
|
296
296
|
|
|
297
297
|
|
|
@@ -329,8 +329,8 @@ def row_wise_add(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
|
|
|
329
329
|
|
|
330
330
|
validate_contiguous(sparsity_layout_x, sparsity_lut_x, sparsity_reverse_lut_rwm)
|
|
331
331
|
|
|
332
|
-
return BlksprsTensor(row_wise_add_forward(x, sparsity_lut_x, sparsity_layout_rwm,
|
|
333
|
-
|
|
332
|
+
return BlksprsTensor.wrap(row_wise_add_forward(x, sparsity_lut_x, sparsity_layout_rwm,
|
|
333
|
+
sparsity_reverse_lut_rwm, y, sparsity_block_size))
|
|
334
334
|
|
|
335
335
|
|
|
336
336
|
def row_wise_sub(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
|
|
@@ -412,7 +412,8 @@ def row_wise_add_kernel(x,
|
|
|
412
412
|
# Get reverse sparsity indices for s
|
|
413
413
|
rev_idx_spa_s_idx = (spa_bat_x * s_l_y_b_s +
|
|
414
414
|
spa_row_x * s_l_y_r_s)
|
|
415
|
-
rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0
|
|
415
|
+
rev_idx_spa_s_msk = ((rev_idx_spa_s_idx >= 0) &
|
|
416
|
+
(rev_idx_spa_s_idx < s_l_y_b * s_l_y_b_s))
|
|
416
417
|
rev_idx_spa_s = tl.load(r_lut_y + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
|
|
417
418
|
|
|
418
419
|
if rev_idx_spa_s == -1:
|
|
@@ -423,16 +424,16 @@ def row_wise_add_kernel(x,
|
|
|
423
424
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
424
425
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
425
426
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
426
|
-
blk_x_msk = (blk_x_idx >= 0
|
|
427
|
-
blk_x_idx < x_b * x_b_s)
|
|
427
|
+
blk_x_msk = ((blk_x_idx >= 0) &
|
|
428
|
+
(blk_x_idx < x_b * x_b_s))
|
|
428
429
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
429
430
|
|
|
430
431
|
# Load sum block
|
|
431
432
|
blk_s_idx = (rev_idx_spa_s * y_b_s +
|
|
432
433
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
|
|
433
434
|
(tl.arange(0, 1) * y_c_s)[None, :])
|
|
434
|
-
blk_s_msk = (blk_s_idx >= 0
|
|
435
|
-
blk_s_idx < y_b * y_b_s)
|
|
435
|
+
blk_s_msk = ((blk_s_idx >= 0) &
|
|
436
|
+
(blk_s_idx < y_b * y_b_s))
|
|
436
437
|
blk_s = tl.load(y + blk_s_idx, mask=blk_s_msk)
|
|
437
438
|
|
|
438
439
|
# Compute exp
|
|
@@ -442,6 +443,6 @@ def row_wise_add_kernel(x,
|
|
|
442
443
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
443
444
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
444
445
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
445
|
-
blk_o_msk = (blk_o_idx >= 0
|
|
446
|
-
blk_o_idx < o_b * o_b_s)
|
|
446
|
+
blk_o_msk = ((blk_o_idx >= 0) &
|
|
447
|
+
(blk_o_idx < o_b * o_b_s))
|
|
447
448
|
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
|
@@ -41,7 +41,7 @@ def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
|
41
41
|
|
|
42
42
|
lut = split_build_lut(lut, sparsity_layout, partitions)
|
|
43
43
|
|
|
44
|
-
return BlksprsTensor(split_forward(
|
|
44
|
+
return BlksprsTensor.wrap(split_forward(
|
|
45
45
|
x, lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"],
|
|
46
46
|
partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_output"]
|
|
47
47
|
|
|
@@ -146,7 +146,7 @@ def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
|
146
146
|
|
|
147
147
|
lut = merge_build_lut(lut, sparsity_layout, partitions)
|
|
148
148
|
|
|
149
|
-
return BlksprsTensor(merge_forward(
|
|
149
|
+
return BlksprsTensor.wrap(merge_forward(
|
|
150
150
|
x, lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"],
|
|
151
151
|
partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_output"]
|
|
152
152
|
|
|
@@ -46,7 +46,7 @@ def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int,
|
|
|
46
46
|
|
|
47
47
|
lut = repeat_build_lut(lut, sparsity_layout_x, repeats, sparsity_layout_output)
|
|
48
48
|
|
|
49
|
-
return BlksprsTensor(repeat_forward(
|
|
49
|
+
return BlksprsTensor.wrap(repeat_forward(
|
|
50
50
|
x, sparsity_layout_x, lut["sparsity_layout_o"], lut["sparsity_lut"],
|
|
51
51
|
lut["sparsity_reverse_lut"], sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_o"]
|
|
52
52
|
|
|
@@ -87,7 +87,7 @@ def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
|
|
|
87
87
|
|
|
88
88
|
lut = repeat_interleave_build_lut(lut, sparsity_layout_x, repeats, sparsity_layout_output)
|
|
89
89
|
|
|
90
|
-
return BlksprsTensor(repeat_forward(
|
|
90
|
+
return BlksprsTensor.wrap(repeat_forward(
|
|
91
91
|
x, sparsity_layout_x, lut["sparsity_layout_o"], lut["sparsity_lut"],
|
|
92
92
|
lut["sparsity_reverse_lut"], sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_o"]
|
|
93
93
|
|
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
import pdb
|
|
2
|
-
|
|
3
1
|
import torch
|
|
4
2
|
import triton
|
|
5
3
|
from torch import Tensor
|
|
@@ -8,9 +6,9 @@ from torch._library.triton import wrap_triton
|
|
|
8
6
|
from triton import language as tl
|
|
9
7
|
|
|
10
8
|
from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
|
|
9
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
11
10
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
12
11
|
from blksprs.utils.tools import stride, ceil_pow2
|
|
13
|
-
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
14
12
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
15
13
|
validate_sparsity, validate_sparsity_block_size, validate_dtype_float_32, ensure_contiguous
|
|
16
14
|
|
|
@@ -55,10 +53,10 @@ def softmax_regular(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_si
|
|
|
55
53
|
|
|
56
54
|
lut = softmax_build_lut(lut, sparsity_layout)
|
|
57
55
|
|
|
58
|
-
return BlksprsTensor(softmax_forward(x, sparsity_layout,
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
56
|
+
return BlksprsTensor.wrap(softmax_forward(x, sparsity_layout,
|
|
57
|
+
lut["sparsity_lut"],
|
|
58
|
+
lut["sparsity_reverse_lut_rws"],
|
|
59
|
+
sparsity_block_size))
|
|
62
60
|
|
|
63
61
|
|
|
64
62
|
@triton_op("blksprs::softmax_forward", mutates_args={})
|
|
@@ -186,7 +184,8 @@ def softmax_kernel(x,
|
|
|
186
184
|
# Get reverse sparsity indices for s
|
|
187
185
|
rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
|
|
188
186
|
spa_row * s_l_s_r_s)
|
|
189
|
-
rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0
|
|
187
|
+
rev_idx_spa_s_msk = ((rev_idx_spa_s_idx >= 0) &
|
|
188
|
+
(rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s))
|
|
190
189
|
rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
|
|
191
190
|
|
|
192
191
|
if rev_idx_spa_s >= 0:
|
|
@@ -194,16 +193,16 @@ def softmax_kernel(x,
|
|
|
194
193
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
195
194
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
196
195
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
197
|
-
blk_x_msk = (blk_x_idx >= 0
|
|
198
|
-
blk_x_idx < x_b * x_b_s)
|
|
196
|
+
blk_x_msk = ((blk_x_idx >= 0) &
|
|
197
|
+
(blk_x_idx < x_b * x_b_s))
|
|
199
198
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
200
199
|
|
|
201
200
|
# Load sum block
|
|
202
201
|
blk_s_idx = (rev_idx_spa_s * s_b_s +
|
|
203
202
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
|
|
204
203
|
(tl.arange(0, 1) * s_c_s)[None, :])
|
|
205
|
-
blk_s_msk = (blk_s_idx >= 0
|
|
206
|
-
blk_s_idx < s_b * s_b_s)
|
|
204
|
+
blk_s_msk = ((blk_s_idx >= 0) &
|
|
205
|
+
(blk_s_idx < s_b * s_b_s))
|
|
207
206
|
blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
|
|
208
207
|
|
|
209
208
|
# Compute softmax
|
|
@@ -249,29 +248,30 @@ def softmax_kernel_grad(g,
|
|
|
249
248
|
|
|
250
249
|
rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
|
|
251
250
|
spa_row * s_l_s_r_s)
|
|
252
|
-
rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0
|
|
251
|
+
rev_idx_spa_s_msk = ((rev_idx_spa_s_idx >= 0) &
|
|
252
|
+
(rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s))
|
|
253
253
|
rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
|
|
254
254
|
|
|
255
255
|
if rev_idx_spa_s >= 0:
|
|
256
256
|
blk_s_idx = (rev_idx_spa_s * s_b_s +
|
|
257
257
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
|
|
258
258
|
(tl.arange(0, 1) * s_c_s)[None, :])
|
|
259
|
-
blk_s_msk = (blk_s_idx >= 0
|
|
260
|
-
blk_s_idx < s_b * s_b_s)
|
|
259
|
+
blk_s_msk = ((blk_s_idx >= 0) &
|
|
260
|
+
(blk_s_idx < s_b * s_b_s))
|
|
261
261
|
blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
|
|
262
262
|
|
|
263
263
|
blk_g_idx = ((pid_blk * g_b_s) +
|
|
264
264
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_r_s)[:, None] +
|
|
265
265
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_c_s)[None, :])
|
|
266
|
-
blk_g_msk = (blk_g_idx >= 0
|
|
267
|
-
blk_g_idx < g_b * g_b_s)
|
|
266
|
+
blk_g_msk = ((blk_g_idx >= 0) &
|
|
267
|
+
(blk_g_idx < g_b * g_b_s))
|
|
268
268
|
blk_g = tl.load(g + blk_g_idx, mask=blk_g_msk)
|
|
269
269
|
|
|
270
270
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
271
271
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
272
272
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
273
|
-
blk_x_msk = (blk_x_idx >= 0
|
|
274
|
-
blk_x_idx < x_b * x_b_s)
|
|
273
|
+
blk_x_msk = ((blk_x_idx >= 0) &
|
|
274
|
+
(blk_x_idx < x_b * x_b_s))
|
|
275
275
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
276
276
|
|
|
277
277
|
buf = blk_x * (blk_g - blk_s)
|
|
@@ -279,8 +279,8 @@ def softmax_kernel_grad(g,
|
|
|
279
279
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
280
280
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
281
281
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
282
|
-
blk_o_msk = (blk_o_idx >= 0
|
|
283
|
-
blk_o_idx < o_b * o_b_s)
|
|
282
|
+
blk_o_msk = ((blk_o_idx >= 0) &
|
|
283
|
+
(blk_o_idx < o_b * o_b_s))
|
|
284
284
|
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
|
285
285
|
|
|
286
286
|
|
|
@@ -346,10 +346,10 @@ def softmax_fused(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size
|
|
|
346
346
|
|
|
347
347
|
lut = softmax_fused_build_lut(lut, sparsity_layout)
|
|
348
348
|
|
|
349
|
-
return BlksprsTensor(softmax_fused_forward(x, sparsity_layout,
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
349
|
+
return BlksprsTensor.wrap(softmax_fused_forward(x, sparsity_layout,
|
|
350
|
+
lut["sparsity_reverse_lut_sorted"],
|
|
351
|
+
lut["max_blocks_line"],
|
|
352
|
+
sparsity_block_size))
|
|
353
353
|
|
|
354
354
|
|
|
355
355
|
@triton_op("blksprs::softmax_fused_forward", mutates_args={})
|
|
@@ -449,7 +449,8 @@ def softmax_fused_kernel(x,
|
|
|
449
449
|
blk_rev_idx = (pid_bat * s_l_b_s +
|
|
450
450
|
pid_row * s_l_r_s +
|
|
451
451
|
(tl.arange(0, mbs) * s_l_c_s))
|
|
452
|
-
blk_rev_msk = ((blk_rev_idx >= 0
|
|
452
|
+
blk_rev_msk = (((blk_rev_idx >= 0) &
|
|
453
|
+
(blk_rev_idx < s_l_b * s_l_b_s)) &
|
|
453
454
|
(tl.arange(0, mbs) < s_l_c))
|
|
454
455
|
blk_rev = tl.load(r_lut_s + blk_rev_idx, mask=blk_rev_msk, other=-1).to(tl.int32)
|
|
455
456
|
|
|
@@ -464,8 +465,9 @@ def softmax_fused_kernel(x,
|
|
|
464
465
|
blk_x_idx = (blk_rev_ext * x_b_s +
|
|
465
466
|
pid_lin * x_r_s +
|
|
466
467
|
(tl.arange(0, mbs * sparsity_block_size) % sparsity_block_size) * x_c_s)
|
|
467
|
-
blk_x_mask = ((blk_x_idx >= 0
|
|
468
|
-
|
|
468
|
+
blk_x_mask = (((blk_x_idx >= 0) &
|
|
469
|
+
(blk_x_idx < x_b * x_b_s)) &
|
|
470
|
+
(blk_rev_ext != -1))
|
|
469
471
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_mask, other=float("-inf"))
|
|
470
472
|
|
|
471
473
|
# Compute softmax
|
|
@@ -502,7 +504,8 @@ def softmax_fused_kernel_grad(g,
|
|
|
502
504
|
blk_rev_idx = (pid_bat * s_l_b_s +
|
|
503
505
|
pid_row * s_l_r_s +
|
|
504
506
|
(tl.arange(0, mbs) * s_l_c_s))
|
|
505
|
-
blk_rev_msk = ((blk_rev_idx >= 0
|
|
507
|
+
blk_rev_msk = (((blk_rev_idx >= 0) &
|
|
508
|
+
(blk_rev_idx < s_l_b * s_l_b_s)) &
|
|
506
509
|
(tl.arange(0, mbs) < s_l_c))
|
|
507
510
|
blk_rev = tl.load(r_lut_s + blk_rev_idx, mask=blk_rev_msk, other=-1).to(tl.int32)
|
|
508
511
|
|
|
@@ -517,16 +520,18 @@ def softmax_fused_kernel_grad(g,
|
|
|
517
520
|
blk_g_idx = (blk_rev_ext * g_b_s +
|
|
518
521
|
pid_lin * g_r_s +
|
|
519
522
|
(tl.arange(0, mbs * sparsity_block_size) % sparsity_block_size) * g_c_s)
|
|
520
|
-
blk_g_mask = ((blk_g_idx >= 0
|
|
521
|
-
|
|
523
|
+
blk_g_mask = (((blk_g_idx >= 0) &
|
|
524
|
+
(blk_g_idx < g_b * g_b_s)) &
|
|
525
|
+
(blk_rev_ext != -1))
|
|
522
526
|
blk_g = tl.load(g + blk_g_idx, mask=blk_g_mask)
|
|
523
527
|
|
|
524
528
|
# Load line of x
|
|
525
529
|
blk_x_idx = (blk_rev_ext * x_b_s +
|
|
526
530
|
pid_lin * x_r_s +
|
|
527
531
|
(tl.arange(0, mbs * sparsity_block_size) % sparsity_block_size) * x_c_s)
|
|
528
|
-
blk_x_mask = ((blk_x_idx >= 0
|
|
529
|
-
|
|
532
|
+
blk_x_mask = (((blk_x_idx >= 0) &
|
|
533
|
+
(blk_x_idx < x_b * x_b_s)) &
|
|
534
|
+
(blk_rev_ext != -1))
|
|
530
535
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_mask)
|
|
531
536
|
|
|
532
537
|
# Compute gradients
|
|
@@ -37,9 +37,9 @@ def transpose(x: BlksprsTensor, sparsity_layout: Tensor,
|
|
|
37
37
|
|
|
38
38
|
lut = transpose_build_lut(lut, sparsity_layout)
|
|
39
39
|
|
|
40
|
-
return BlksprsTensor(transpose_forward(x, lut["sparsity_layout_t"],
|
|
41
|
-
|
|
42
|
-
|
|
40
|
+
return BlksprsTensor.wrap(transpose_forward(x, lut["sparsity_layout_t"],
|
|
41
|
+
lut["sparsity_lut"], lut["sparsity_reverse_lut"],
|
|
42
|
+
sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_t"]
|
|
43
43
|
|
|
44
44
|
|
|
45
45
|
@triton_op("blksprs::transpose_forward", mutates_args={})
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import Union
|
|
2
|
+
|
|
1
3
|
import torch
|
|
2
4
|
from torch import Tensor
|
|
3
5
|
|
|
@@ -7,4 +9,11 @@ class BlksprsTensor(Tensor):
|
|
|
7
9
|
"""
|
|
8
10
|
|
|
9
11
|
def __repr__(self):
|
|
10
|
-
return f"BlksprsTensor({torch.Tensor(self).__repr__()})"
|
|
12
|
+
return f"BlksprsTensor({torch.Tensor(self).__repr__()})"
|
|
13
|
+
|
|
14
|
+
@staticmethod
|
|
15
|
+
def wrap(tensor: Tensor) -> Union[Tensor, "BlksprsTensor"]:
|
|
16
|
+
if torch._dynamo.is_compiling():
|
|
17
|
+
return tensor
|
|
18
|
+
else:
|
|
19
|
+
return BlksprsTensor(tensor)
|
|
@@ -26,7 +26,8 @@ def apply_torch_linear(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block
|
|
|
26
26
|
|
|
27
27
|
# Apply weights
|
|
28
28
|
sparsity_layout_xw = build_sparsity_layout_matmul_fast(sparsity_layout, sparsity_layout_w_t)
|
|
29
|
-
xw = matmul(x, sparsity_layout, BlksprsTensor(w_t_bs.to(x.dtype)), sparsity_layout_w_t, sparsity_layout_xw,
|
|
29
|
+
xw = matmul(x, sparsity_layout, BlksprsTensor.wrap(w_t_bs.to(x.dtype)), sparsity_layout_w_t, sparsity_layout_xw,
|
|
30
|
+
sparsity_block_size)
|
|
30
31
|
interim = xw
|
|
31
32
|
|
|
32
33
|
# Apply bias
|
|
@@ -1,9 +1,5 @@
|
|
|
1
|
-
import torch
|
|
2
1
|
from torch import Tensor, Size
|
|
3
2
|
|
|
4
|
-
# Capture scalar outputs for JIT compilation
|
|
5
|
-
torch._dynamo.config.capture_scalar_outputs = True
|
|
6
|
-
|
|
7
3
|
|
|
8
4
|
def do_shape_blocksparse(x: Tensor) -> tuple[Tensor, Size]:
|
|
9
5
|
if x.dim() == 3:
|
|
@@ -27,7 +23,8 @@ def stride(x: Tensor):
|
|
|
27
23
|
else:
|
|
28
24
|
raise NotImplementedError
|
|
29
25
|
|
|
26
|
+
|
|
30
27
|
def ceil_pow2(x: int) -> int:
|
|
31
28
|
if x <= 0:
|
|
32
29
|
raise ValueError("Input must be a positive integer.")
|
|
33
|
-
return 1 << (x - 1).bit_length()
|
|
30
|
+
return 1 << (x - 1).bit_length()
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 2.1.
|
|
3
|
+
Version: 2.1.9
|
|
4
4
|
Summary: A lightweight library for operations on block-sparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
7
7
|
Project-URL: Bugtracker, https://github.com/FelixSchoen/blksprs/issues
|
|
8
8
|
Requires-Python: >=3.11
|
|
9
9
|
Description-Content-Type: text/markdown
|
|
10
|
-
Requires-Dist: torch
|
|
10
|
+
Requires-Dist: torch>=2.8.0
|
|
11
11
|
Requires-Dist: numpy
|
|
12
12
|
Provides-Extra: test
|
|
13
13
|
Requires-Dist: pytest; extra == "test"
|
|
@@ -102,7 +102,7 @@ We will continue to maintain the library and fix any issues that arise.
|
|
|
102
102
|
Should you find any bugs please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
|
|
103
103
|
We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
|
|
104
104
|
|
|
105
|
-
It might be that this changes with future projects, but as of
|
|
105
|
+
It might be that this changes with future projects, but as of August 2025, we are content with the current state of the
|
|
106
106
|
library.
|
|
107
107
|
|
|
108
108
|
## Known Limitations and Issues
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "blksprs"
|
|
3
|
-
version = "2.1.
|
|
3
|
+
version = "2.1.9"
|
|
4
4
|
authors = [{ name = "Felix Schön", email = "schoen@kr.tuwien.ac.at" }]
|
|
5
5
|
description = "A lightweight library for operations on block-sparse matrices in PyTorch."
|
|
6
6
|
readme = "README.md"
|
|
7
7
|
requires-python = ">=3.11"
|
|
8
8
|
license = { file = "LICENSE.md" }
|
|
9
9
|
dependencies = [
|
|
10
|
-
"torch",
|
|
10
|
+
"torch >= 2.8.0",
|
|
11
11
|
"numpy"
|
|
12
12
|
]
|
|
13
13
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|