blksprs 1.5__tar.gz → 1.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {blksprs-1.5 → blksprs-1.7}/PKG-INFO +5 -3
- {blksprs-1.5 → blksprs-1.7}/README.md +4 -2
- {blksprs-1.5 → blksprs-1.7}/blksprs/__init__.py +3 -2
- {blksprs-1.5 → blksprs-1.7}/blksprs/experimental/distribution_mdi.py +14 -14
- {blksprs-1.5 → blksprs-1.7}/blksprs/layouting/distribution_layout.py +4 -4
- {blksprs-1.5 → blksprs-1.7}/blksprs/layouting/sparsity_layout.py +42 -6
- {blksprs-1.5 → blksprs-1.7}/blksprs/misc/broadcast_ops.py +5 -5
- {blksprs-1.5 → blksprs-1.7}/blksprs/misc/repeat_interleave.py +6 -6
- {blksprs-1.5 → blksprs-1.7}/blksprs/misc/row_wise.py +16 -15
- {blksprs-1.5 → blksprs-1.7}/blksprs/ops/conversion.py +11 -11
- {blksprs-1.5 → blksprs-1.7}/blksprs/ops/distribution.py +11 -11
- {blksprs-1.5 → blksprs-1.7}/blksprs/ops/exp.py +3 -3
- {blksprs-1.5 → blksprs-1.7}/blksprs/ops/matmul.py +7 -7
- blksprs-1.7/blksprs/ops/partitioning.py +155 -0
- blksprs-1.7/blksprs/ops/repeat.py +241 -0
- {blksprs-1.5 → blksprs-1.7}/blksprs/ops/softmax.py +11 -10
- {blksprs-1.5 → blksprs-1.7}/blksprs/ops/transpose.py +12 -12
- {blksprs-1.5 → blksprs-1.7}/blksprs/utils/tools.py +3 -0
- {blksprs-1.5 → blksprs-1.7}/blksprs/utils/validation.py +2 -0
- {blksprs-1.5 → blksprs-1.7}/blksprs.egg-info/PKG-INFO +5 -3
- {blksprs-1.5 → blksprs-1.7}/blksprs.egg-info/SOURCES.txt +2 -0
- {blksprs-1.5 → blksprs-1.7}/pyproject.toml +1 -1
- {blksprs-1.5 → blksprs-1.7}/blksprs/utils/benchmarking.py +0 -0
- {blksprs-1.5 → blksprs-1.7}/blksprs.egg-info/dependency_links.txt +0 -0
- {blksprs-1.5 → blksprs-1.7}/blksprs.egg-info/requires.txt +0 -0
- {blksprs-1.5 → blksprs-1.7}/blksprs.egg-info/top_level.txt +0 -0
- {blksprs-1.5 → blksprs-1.7}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.7
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -31,9 +31,11 @@ Currently supported operations (includes gradient calculation):
|
|
|
31
31
|
- Sparse matrix multiplication (_supports any combination of sparse and dense matrices due to support
|
|
32
32
|
for `sparse = sparse @ sparse` matmul_)
|
|
33
33
|
- Softmax
|
|
34
|
-
-
|
|
34
|
+
- Transpose
|
|
35
35
|
- Gather
|
|
36
36
|
- Scatter (_supports either no reduction or summation, gradients are only available for summation_)
|
|
37
|
+
- Repeat (_supports target sparsity layout_)
|
|
38
|
+
- Splitting and merging of matrices along the last dimension
|
|
37
39
|
- Conversion to and from sparse form
|
|
38
40
|
- Conversion to different sparsity layouts and different sparsity block sizes
|
|
39
41
|
|
|
@@ -63,7 +65,7 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
|
|
|
63
65
|
|
|
64
66
|
### Dependencies
|
|
65
67
|
|
|
66
|
-
- [PyTorch](https://pytorch.org/) (built with v2.
|
|
68
|
+
- [PyTorch](https://pytorch.org/) (built with v2.5.0)
|
|
67
69
|
- _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
|
|
68
70
|
|
|
69
71
|
## Changelog
|
|
@@ -12,9 +12,11 @@ Currently supported operations (includes gradient calculation):
|
|
|
12
12
|
- Sparse matrix multiplication (_supports any combination of sparse and dense matrices due to support
|
|
13
13
|
for `sparse = sparse @ sparse` matmul_)
|
|
14
14
|
- Softmax
|
|
15
|
-
-
|
|
15
|
+
- Transpose
|
|
16
16
|
- Gather
|
|
17
17
|
- Scatter (_supports either no reduction or summation, gradients are only available for summation_)
|
|
18
|
+
- Repeat (_supports target sparsity layout_)
|
|
19
|
+
- Splitting and merging of matrices along the last dimension
|
|
18
20
|
- Conversion to and from sparse form
|
|
19
21
|
- Conversion to different sparsity layouts and different sparsity block sizes
|
|
20
22
|
|
|
@@ -44,7 +46,7 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
|
|
|
44
46
|
|
|
45
47
|
### Dependencies
|
|
46
48
|
|
|
47
|
-
- [PyTorch](https://pytorch.org/) (built with v2.
|
|
49
|
+
- [PyTorch](https://pytorch.org/) (built with v2.5.0)
|
|
48
50
|
- _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
|
|
49
51
|
|
|
50
52
|
## Changelog
|
|
@@ -4,10 +4,11 @@ from blksprs.ops.exp import exp
|
|
|
4
4
|
from blksprs.ops.matmul import matmul
|
|
5
5
|
from blksprs.ops.softmax import softmax
|
|
6
6
|
from blksprs.ops.transpose import transpose
|
|
7
|
+
from blksprs.ops.partitioning import split, merge
|
|
7
8
|
|
|
8
9
|
class layout:
|
|
9
10
|
from blksprs.layouting.distribution_layout import build_distribution_layout
|
|
10
|
-
from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption
|
|
11
|
+
from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption, build_sparsity_layout_matmul
|
|
11
12
|
|
|
12
13
|
class misc:
|
|
13
14
|
from blksprs.misc.broadcast_ops import broadcast_add, broadcast_sub
|
|
@@ -18,4 +19,4 @@ class util:
|
|
|
18
19
|
from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse, disable_validation
|
|
19
20
|
|
|
20
21
|
class experimental:
|
|
21
|
-
from blksprs.experimental.distribution_mdi import gather_mdi
|
|
22
|
+
from blksprs.experimental.distribution_mdi import gather_mdi, scatter_reduce_mdi
|
|
@@ -3,7 +3,7 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
-
from blksprs.utils.tools import get_triton_block_size
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
7
7
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
8
8
|
validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
|
|
9
9
|
|
|
@@ -51,15 +51,15 @@ class _BlocksparseGatherMDI(torch.autograd.Function):
|
|
|
51
51
|
output = torch.empty_like(idx_col, dtype=x.dtype)
|
|
52
52
|
|
|
53
53
|
x_b, x_r, x_c = x.size()
|
|
54
|
-
x_b_s, x_r_s, x_c_s =
|
|
54
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
55
55
|
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
|
|
56
|
-
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s =
|
|
56
|
+
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
|
|
57
57
|
i_b, i_r, i_c = idx_col.size()
|
|
58
|
-
i_b_s, i_r_s, i_c_s =
|
|
58
|
+
i_b_s, i_r_s, i_c_s = stride(idx_col)
|
|
59
59
|
s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
|
|
60
|
-
s_lut_i_r_s, s_lut_i_c_s =
|
|
60
|
+
s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
|
|
61
61
|
o_b, o_r, o_c = output.size()
|
|
62
|
-
o_b_s, o_r_s, o_c_s =
|
|
62
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
63
63
|
|
|
64
64
|
if triton_block_size is None:
|
|
65
65
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -224,15 +224,15 @@ class _BlocksparseScatterReduceMDI(torch.autograd.Function):
|
|
|
224
224
|
dtype=x.dtype, device=x.device)
|
|
225
225
|
|
|
226
226
|
x_b, x_r, x_c = x.size()
|
|
227
|
-
x_b_s, x_r_s, x_c_s =
|
|
227
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
228
228
|
s_lut_x_r, s_lut_x_c = sparsity_lut_x.size()
|
|
229
|
-
s_lut_x_r_s, s_lut_x_c_s =
|
|
229
|
+
s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut_x)
|
|
230
230
|
i_b, i_r, i_c = idx_col.size()
|
|
231
|
-
i_b_s, i_r_s, i_c_s =
|
|
231
|
+
i_b_s, i_r_s, i_c_s = stride(idx_col)
|
|
232
232
|
o_b, o_r, o_c = output.size()
|
|
233
|
-
o_b_s, o_r_s, o_c_s =
|
|
233
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
234
234
|
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
|
|
235
|
-
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s =
|
|
235
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
|
|
236
236
|
|
|
237
237
|
if triton_block_size is None:
|
|
238
238
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -366,11 +366,11 @@ def build_distribution_layout_mdi(idx_bat: Tensor, idx_row: Tensor, idx_col: Ten
|
|
|
366
366
|
dtype=torch.bool, device=idx_col.device)
|
|
367
367
|
|
|
368
368
|
i_b, i_r, i_c = idx_col.size()
|
|
369
|
-
i_b_s, i_r_s, i_c_s =
|
|
369
|
+
i_b_s, i_r_s, i_c_s = stride(idx_col)
|
|
370
370
|
s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
|
|
371
|
-
s_lut_i_r_s, s_lut_i_c_s =
|
|
371
|
+
s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
|
|
372
372
|
o_b, o_r, o_c = output.size()
|
|
373
|
-
o_b_s, o_r_s, o_c_s =
|
|
373
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
374
374
|
|
|
375
375
|
if triton_block_size is None:
|
|
376
376
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -3,7 +3,7 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
-
from blksprs.utils.tools import get_triton_block_size
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
7
7
|
from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
|
|
8
8
|
validate_contiguous
|
|
9
9
|
|
|
@@ -34,11 +34,11 @@ def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
|
|
|
34
34
|
dtype=torch.bool, device=indices.device)
|
|
35
35
|
|
|
36
36
|
i_b, i_r, i_c = indices.size()
|
|
37
|
-
i_b_s, i_r_s, i_c_s =
|
|
37
|
+
i_b_s, i_r_s, i_c_s = stride(indices)
|
|
38
38
|
s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
|
|
39
|
-
s_lut_i_r_s, s_lut_i_c_s =
|
|
39
|
+
s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
|
|
40
40
|
o_b, o_r, o_c = output.size()
|
|
41
|
-
o_b_s, o_r_s, o_c_s =
|
|
41
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
42
42
|
|
|
43
43
|
if triton_block_size is None:
|
|
44
44
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -5,7 +5,7 @@ import triton
|
|
|
5
5
|
from torch import Tensor
|
|
6
6
|
from triton import language as tl
|
|
7
7
|
|
|
8
|
-
from blksprs.utils.tools import get_triton_block_size
|
|
8
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
9
9
|
from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
|
|
10
10
|
validate_contiguous, validate_sparsity, validate_sparsity_block_size
|
|
11
11
|
|
|
@@ -30,9 +30,9 @@ def build_sparsity_layout(x: Tensor, sparsity_block_size: int, triton_block_size
|
|
|
30
30
|
dtype=torch.bool, device=x.device)
|
|
31
31
|
|
|
32
32
|
x_b, x_r, x_c = x.size()
|
|
33
|
-
x_b_s, x_r_s, x_c_s =
|
|
33
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
34
34
|
o_b, o_r, o_c = output.size()
|
|
35
|
-
o_b_s, o_r_s, o_c_s =
|
|
35
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
36
36
|
|
|
37
37
|
if triton_block_size is None:
|
|
38
38
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -120,10 +120,10 @@ def build_sparsity_layout_adaption(x: Tensor, sparsity_layout_from: Tensor,
|
|
|
120
120
|
output = torch.zeros(o_b, o_r, o_c, dtype=torch.bool, device=x.device)
|
|
121
121
|
|
|
122
122
|
x_b, x_r, x_c = x.size()
|
|
123
|
-
x_b_s, x_r_s, x_c_s =
|
|
123
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
124
124
|
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
125
|
-
s_lut_r_s, s_lut_c_s =
|
|
126
|
-
o_b_s, o_r_s, o_c_s =
|
|
125
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
126
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
127
127
|
|
|
128
128
|
if triton_block_size is None:
|
|
129
129
|
triton_block_size = get_triton_block_size(sparsity_block_size_from)
|
|
@@ -188,3 +188,39 @@ def kernel_sparsity_layout_adaption(x,
|
|
|
188
188
|
// sparsity_block_size_to) * o_c_s))
|
|
189
189
|
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
190
190
|
tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def build_sparsity_layout_matmul(sparsity_layout_x: Tensor, sparsity_layout_y: Tensor):
|
|
194
|
+
"""Builds the precise sparsity layout of the result of a matrix multiplication between the two input tensors.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
sparsity_layout_x (Tensor): The sparsity layout of the first block-sparse tensor.
|
|
198
|
+
sparsity_layout_y (Tensor): The sparsity layout of the second block-sparse tensor.
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
Tensor: The precise sparsity layout of the result of a matrix multiplication between the two input tensors.
|
|
202
|
+
|
|
203
|
+
"""
|
|
204
|
+
return torch.matmul(sparsity_layout_x.to(torch.float), sparsity_layout_y.to(torch.float)).to(torch.bool)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def build_sparsity_layout_matmul_fast(sparsity_layout_x: Tensor, sparsity_layout_y: Tensor):
|
|
208
|
+
"""Builds the approximate sparsity layout of the result of a matrix multiplication between the two input tensors.
|
|
209
|
+
|
|
210
|
+
Note:
|
|
211
|
+
This function is faster than the ``build_sparsity_layout_matmul`` function due to the fact that it only checks
|
|
212
|
+
whether at least one of the blocks in either of the vectors participating in the matmul is non-sparse. The
|
|
213
|
+
resulting sparsity layout may thus overestimate the actual sparsity of the result.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
sparsity_layout_x (Tensor): The sparsity layout of the first block-sparse tensor.
|
|
217
|
+
sparsity_layout_y (Tensor): The sparsity layout of the second block-sparse tensor.
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
Tensor: The approximate sparsity layout of the result of a matrix multiplication between the two input tensors.
|
|
221
|
+
|
|
222
|
+
"""
|
|
223
|
+
sparsity_layout_x_slice = torch.max(sparsity_layout_x, dim=-1).values.unsqueeze(-1)
|
|
224
|
+
sparsity_layout_y_slice = torch.max(sparsity_layout_y, dim=-2).values.unsqueeze(1)
|
|
225
|
+
|
|
226
|
+
return torch.logical_or(sparsity_layout_x_slice, sparsity_layout_y_slice)
|
|
@@ -3,7 +3,7 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
-
from blksprs.utils.tools import get_triton_block_size
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
7
7
|
from blksprs.utils.validation import validate_contiguous, validate_device, \
|
|
8
8
|
validate_sparsity_block_size, validate_triton_block_size
|
|
9
9
|
|
|
@@ -44,13 +44,13 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
|
44
44
|
output = torch.zeros(n_sparse_blocks, sparsity_block_size, sparsity_block_size, dtype=x.dtype, device=x.device)
|
|
45
45
|
|
|
46
46
|
x_b, x_c = x.size()
|
|
47
|
-
x_b_s, x_c_s =
|
|
47
|
+
x_b_s, x_c_s = stride(x)
|
|
48
48
|
y_b, y_c = y.size()
|
|
49
|
-
y_b_s, y_c_s =
|
|
49
|
+
y_b_s, y_c_s = stride(y)
|
|
50
50
|
o_b, o_r, o_c = output.size()
|
|
51
|
-
o_b_s, o_r_s, o_c_s =
|
|
51
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
52
52
|
s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
|
|
53
|
-
s_lut_o_r_s, s_lut_o_c_s =
|
|
53
|
+
s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_o)
|
|
54
54
|
|
|
55
55
|
if triton_block_size is None:
|
|
56
56
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -3,7 +3,7 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
-
from blksprs.utils.tools import get_triton_block_size
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
7
7
|
from blksprs.utils.validation import validate_contiguous, validate_device, \
|
|
8
8
|
validate_sparsity_block_size, validate_triton_block_size, validate_dimensions
|
|
9
9
|
|
|
@@ -35,7 +35,7 @@ def repeat_interleave(x: Tensor, sparsity_layout: Tensor, repeats: int,
|
|
|
35
35
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
36
36
|
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
37
37
|
|
|
38
|
-
sparsity_layout_output = torch.repeat_interleave(sparsity_layout,
|
|
38
|
+
sparsity_layout_output = torch.repeat_interleave(sparsity_layout, repeats, dim=0).contiguous()
|
|
39
39
|
|
|
40
40
|
sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
|
|
41
41
|
|
|
@@ -52,13 +52,13 @@ def repeat_interleave(x: Tensor, sparsity_layout: Tensor, repeats: int,
|
|
|
52
52
|
dtype=x.dtype, device=x.device)
|
|
53
53
|
|
|
54
54
|
x_b, x_r, x_c = x.size()
|
|
55
|
-
x_b_s, x_r_s, x_c_s =
|
|
55
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
56
56
|
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
57
|
-
s_lut_r_s, s_lut_c_s =
|
|
57
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
58
58
|
o_b, o_r, o_c = output.size()
|
|
59
|
-
o_b_s, o_r_s, o_c_s =
|
|
59
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
60
60
|
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
|
|
61
|
-
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s =
|
|
61
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
|
|
62
62
|
|
|
63
63
|
if triton_block_size is None:
|
|
64
64
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -3,7 +3,7 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
-
from blksprs.utils.tools import get_triton_block_size
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
7
7
|
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, validate_sparsity, \
|
|
8
8
|
validate_sparsity_block_size, validate_triton_block_size
|
|
9
9
|
|
|
@@ -60,13 +60,13 @@ def row_wise_sum(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
|
60
60
|
device=x.device)
|
|
61
61
|
|
|
62
62
|
x_b, x_r, x_c = x.size()
|
|
63
|
-
x_b_s, x_r_s, x_c_s =
|
|
63
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
64
64
|
s_lut_x_r, s_lut_x_c = sparsity_lut.size()
|
|
65
|
-
s_lut_x_r_s, s_lut_x_c_s =
|
|
65
|
+
s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut)
|
|
66
66
|
o_b, o_r, o_c = output.size()
|
|
67
|
-
o_b_s, o_r_s, o_c_s =
|
|
67
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
68
68
|
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
|
|
69
|
-
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s =
|
|
69
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
|
|
70
70
|
|
|
71
71
|
if triton_block_size is None:
|
|
72
72
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -183,13 +183,13 @@ def row_wise_max(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
|
183
183
|
device=x.device)
|
|
184
184
|
|
|
185
185
|
x_b, x_r, x_c = x.size()
|
|
186
|
-
x_b_s, x_r_s, x_c_s =
|
|
186
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
187
187
|
s_lut_x_r, s_lut_x_c = sparsity_lut.size()
|
|
188
|
-
s_lut_x_r_s, s_lut_x_c_s =
|
|
188
|
+
s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut)
|
|
189
189
|
o_b, o_r, o_c = output.size()
|
|
190
|
-
o_b_s, o_r_s, o_c_s =
|
|
190
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
191
191
|
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
|
|
192
|
-
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s =
|
|
192
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
|
|
193
193
|
|
|
194
194
|
if triton_block_size is None:
|
|
195
195
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -290,15 +290,15 @@ def row_wise_add(x: Tensor, sparsity_layout_x: Tensor, y: Tensor,
|
|
|
290
290
|
output = torch.empty_like(x)
|
|
291
291
|
|
|
292
292
|
x_b, x_r, x_c = x.size()
|
|
293
|
-
x_b_s, x_r_s, x_c_s =
|
|
293
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
294
294
|
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
295
|
-
s_lut_r_s, s_lut_c_s =
|
|
295
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
296
296
|
y_b, y_r, y_c = y.size()
|
|
297
|
-
y_b_s, y_r_s, y_c_s =
|
|
297
|
+
y_b_s, y_r_s, y_c_s = stride(y)
|
|
298
298
|
s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_rwm.size()
|
|
299
|
-
s_l_y_b_s, s_l_y_r_s, s_l_y_c_s =
|
|
299
|
+
s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = stride(sparsity_layout_rwm)
|
|
300
300
|
o_b, o_r, o_c = output.size()
|
|
301
|
-
o_b_s, o_r_s, o_c_s =
|
|
301
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
302
302
|
|
|
303
303
|
if triton_block_size is None:
|
|
304
304
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -361,7 +361,8 @@ def kernel_blocksparse_row_wise_add(x,
|
|
|
361
361
|
rev_idx_spa_s = tl.load(r_lut_y + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
|
|
362
362
|
|
|
363
363
|
if rev_idx_spa_s == -1:
|
|
364
|
-
|
|
364
|
+
tl.device_assert(False)
|
|
365
|
+
return
|
|
365
366
|
|
|
366
367
|
# Load x block
|
|
367
368
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
@@ -6,7 +6,7 @@ from torch import Tensor
|
|
|
6
6
|
from triton import language as tl
|
|
7
7
|
|
|
8
8
|
from blksprs.layouting.sparsity_layout import build_sparsity_layout_adaption
|
|
9
|
-
from blksprs.utils.tools import get_triton_block_size
|
|
9
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
10
10
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
11
11
|
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
12
12
|
|
|
@@ -65,11 +65,11 @@ class _BlocksparseToDense(torch.autograd.Function):
|
|
|
65
65
|
dtype=x.dtype, device=x.device)
|
|
66
66
|
|
|
67
67
|
x_b, x_r, x_c = x.shape
|
|
68
|
-
x_b_s, x_r_s, x_c_s =
|
|
68
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
69
69
|
s_l_b, s_l_r, s_l_c = sparsity_layout.size()
|
|
70
|
-
s_l_b_s, s_l_r_s, s_l_c_s =
|
|
70
|
+
s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout)
|
|
71
71
|
o_b, o_r, o_c = output.size()
|
|
72
|
-
o_b_s, o_r_s, o_c_s =
|
|
72
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
73
73
|
|
|
74
74
|
if triton_block_size is None:
|
|
75
75
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -190,11 +190,11 @@ class _BlocksparseToSparse(torch.autograd.Function):
|
|
|
190
190
|
dtype=x.dtype, device=x.device)
|
|
191
191
|
|
|
192
192
|
x_b, x_r, x_c = x.size()
|
|
193
|
-
x_b_s, x_r_s, x_c_s =
|
|
193
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
194
194
|
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
195
|
-
s_lut_r_s, s_lut_c_s =
|
|
195
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
196
196
|
o_b, o_r, o_c = output.size()
|
|
197
|
-
o_b_s, o_r_s, o_c_s =
|
|
197
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
198
198
|
|
|
199
199
|
if triton_block_size is None:
|
|
200
200
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -347,13 +347,13 @@ class _BlocksparseAdaptLayout(torch.autograd.Function):
|
|
|
347
347
|
dtype=x.dtype, device=x.device)
|
|
348
348
|
|
|
349
349
|
x_b, x_r, x_c = x.size()
|
|
350
|
-
x_b_s, x_r_s, x_c_s =
|
|
350
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
351
351
|
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_from.size()
|
|
352
|
-
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s =
|
|
352
|
+
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_from)
|
|
353
353
|
o_b, o_r, o_c = output.size()
|
|
354
|
-
o_b_s, o_r_s, o_c_s =
|
|
354
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
355
355
|
s_lut_o_r, s_lut_o_c = sparsity_lut_to.size()
|
|
356
|
-
s_lut_o_r_s, s_lut_o_c_s =
|
|
356
|
+
s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_to)
|
|
357
357
|
|
|
358
358
|
if triton_block_size is None:
|
|
359
359
|
triton_block_size = get_triton_block_size(min_sparsity_block_size)
|
|
@@ -3,7 +3,7 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
-
from blksprs.utils.tools import get_triton_block_size
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
7
7
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
8
8
|
validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
|
|
9
9
|
|
|
@@ -59,15 +59,15 @@ class _BlocksparseGather(torch.autograd.Function):
|
|
|
59
59
|
output = torch.empty_like(i, dtype=x.dtype)
|
|
60
60
|
|
|
61
61
|
x_b, x_r, x_c = x.size()
|
|
62
|
-
x_b_s, x_r_s, x_c_s =
|
|
62
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
63
63
|
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
|
|
64
|
-
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s =
|
|
64
|
+
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
|
|
65
65
|
i_b, i_r, i_c = i.size()
|
|
66
|
-
i_b_s, i_r_s, i_c_s =
|
|
66
|
+
i_b_s, i_r_s, i_c_s = stride(i)
|
|
67
67
|
s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
|
|
68
|
-
s_lut_i_r_s, s_lut_i_c_s =
|
|
68
|
+
s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
|
|
69
69
|
o_b, o_r, o_c = output.size()
|
|
70
|
-
o_b_s, o_r_s, o_c_s =
|
|
70
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
71
71
|
|
|
72
72
|
if triton_block_size is None:
|
|
73
73
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -248,15 +248,15 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
|
|
|
248
248
|
dtype=x.dtype, device=x.device)
|
|
249
249
|
|
|
250
250
|
x_b, x_r, x_c = x.size()
|
|
251
|
-
x_b_s, x_r_s, x_c_s =
|
|
251
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
252
252
|
s_lut_x_r, s_lut_x_c = sparsity_lut_x.size()
|
|
253
|
-
s_lut_x_r_s, s_lut_x_c_s =
|
|
253
|
+
s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut_x)
|
|
254
254
|
i_b, i_r, i_c = i.size()
|
|
255
|
-
i_b_s, i_r_s, i_c_s =
|
|
255
|
+
i_b_s, i_r_s, i_c_s = stride(i)
|
|
256
256
|
o_b, o_r, o_c = output.size()
|
|
257
|
-
o_b_s, o_r_s, o_c_s =
|
|
257
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
258
258
|
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
|
|
259
|
-
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s =
|
|
259
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
|
|
260
260
|
|
|
261
261
|
if triton_block_size is None:
|
|
262
262
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -3,7 +3,7 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
-
from blksprs.utils.tools import get_triton_block_size
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
7
7
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
8
8
|
validate_sparsity_block_size, validate_triton_block_size
|
|
9
9
|
|
|
@@ -43,9 +43,9 @@ class _BlocksparseExp(torch.autograd.Function):
|
|
|
43
43
|
output = torch.empty_like(x)
|
|
44
44
|
|
|
45
45
|
x_b, x_r, x_c = x.shape
|
|
46
|
-
x_b_s, x_r_s, x_c_s =
|
|
46
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
47
47
|
o_b, o_r, o_c = output.shape
|
|
48
|
-
o_b_s, o_r_s, o_c_s =
|
|
48
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
49
49
|
|
|
50
50
|
if triton_block_size is None:
|
|
51
51
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -4,7 +4,7 @@ from torch import Tensor
|
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
6
|
from blksprs.ops.transpose import transpose
|
|
7
|
-
from blksprs.utils.tools import get_triton_block_size
|
|
7
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
8
8
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
9
9
|
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size, validate_dtype_float
|
|
10
10
|
|
|
@@ -82,17 +82,17 @@ class _BlocksparseMatmulSSS(torch.autograd.Function):
|
|
|
82
82
|
dtype=x.dtype, device=x.device)
|
|
83
83
|
|
|
84
84
|
x_b, x_r, x_c = x.size()
|
|
85
|
-
x_b_s, x_r_s, x_c_s =
|
|
85
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
86
86
|
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
|
|
87
|
-
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s =
|
|
87
|
+
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
|
|
88
88
|
y_b, y_r, y_c = y.size()
|
|
89
|
-
y_b_s, y_r_s, y_c_s =
|
|
89
|
+
y_b_s, y_r_s, y_c_s = stride(y)
|
|
90
90
|
s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_y.size()
|
|
91
|
-
s_l_y_b_s, s_l_y_r_s, s_l_y_c_s =
|
|
91
|
+
s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = stride(sparsity_layout_y)
|
|
92
92
|
o_b, o_r, o_c = output.size()
|
|
93
|
-
o_b_s, o_r_s, o_c_s =
|
|
93
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
94
94
|
s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
|
|
95
|
-
s_lut_o_r_s, s_lut_o_c_s =
|
|
95
|
+
s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_o)
|
|
96
96
|
|
|
97
97
|
if triton_block_size is None:
|
|
98
98
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import Tensor
|
|
3
|
+
|
|
4
|
+
from blksprs.ops.repeat import forward_flow
|
|
5
|
+
|
|
6
|
+
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
|
|
7
|
+
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def split(x: Tensor, sparsity_layout: Tensor, partitions: int,
|
|
11
|
+
sparsity_block_size: int, triton_block_size: int = None) -> (Tensor, Tensor):
|
|
12
|
+
"""Splits a block-sparse tensor in compressed form along the last dimension into partitions.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
x (Tensor): A block-sparse tensor in compressed form.
|
|
16
|
+
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
17
|
+
partitions (int): The number of partitions to split the block-sparse tensor into.
|
|
18
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
19
|
+
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
Tensor: The block-sparse tensor split into partitions in compressed form.
|
|
23
|
+
Tensor: The sparsity layout of the output tensor.
|
|
24
|
+
|
|
25
|
+
"""
|
|
26
|
+
x = x.contiguous()
|
|
27
|
+
|
|
28
|
+
validate_dimensions(x)
|
|
29
|
+
validate_contiguous(x)
|
|
30
|
+
validate_device(x)
|
|
31
|
+
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
32
|
+
validate_sparsity_block_size(sparsity_block_size, x)
|
|
33
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
34
|
+
|
|
35
|
+
sparsity_layout_output = (sparsity_layout
|
|
36
|
+
.reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
|
|
37
|
+
sparsity_layout.size(2) // partitions)
|
|
38
|
+
.permute(0, 2, 1, 3)
|
|
39
|
+
.reshape(sparsity_layout.size(0) * partitions, sparsity_layout.size(1),
|
|
40
|
+
sparsity_layout.size(2) // partitions).contiguous())
|
|
41
|
+
|
|
42
|
+
sparsity_lut = torch.nonzero(sparsity_layout_output).contiguous()
|
|
43
|
+
|
|
44
|
+
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
|
45
|
+
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
46
|
+
(sparsity_layout_flat == 1) -
|
|
47
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
48
|
+
.reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
|
|
49
|
+
sparsity_layout.size(2) // partitions)
|
|
50
|
+
.permute(0, 2, 1, 3).reshape(-1).contiguous())
|
|
51
|
+
|
|
52
|
+
n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
|
|
53
|
+
|
|
54
|
+
validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
|
|
55
|
+
|
|
56
|
+
return _BlocksparseSplit.apply(x, sparsity_layout_output, sparsity_lut, sparsity_reverse_lut, partitions,
|
|
57
|
+
sparsity_block_size, n_sparse_blocks, triton_block_size), sparsity_layout_output
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class _BlocksparseSplit(torch.autograd.Function):
|
|
61
|
+
|
|
62
|
+
@staticmethod
|
|
63
|
+
def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
64
|
+
num_partitions: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
65
|
+
ctx.save_for_backward(sparsity_layout_o)
|
|
66
|
+
ctx.num_partitions = num_partitions
|
|
67
|
+
|
|
68
|
+
return forward_flow(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
69
|
+
n_sparse_blocks, triton_block_size)
|
|
70
|
+
|
|
71
|
+
@staticmethod
|
|
72
|
+
def backward(ctx, grad_output):
|
|
73
|
+
sparsity_layout = ctx.saved_tensors[0]
|
|
74
|
+
num_partitions = ctx.num_partitions
|
|
75
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
76
|
+
triton_block_size = ctx.triton_block_size
|
|
77
|
+
|
|
78
|
+
return merge(grad_output, sparsity_layout, num_partitions,
|
|
79
|
+
sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def merge(x: Tensor, sparsity_layout: Tensor, partitions: int,
|
|
83
|
+
sparsity_block_size: int, triton_block_size: int = None) -> (Tensor, Tensor):
|
|
84
|
+
"""Merges the specified partitions of a block-sparse tensor in compressed form along the last dimension.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
x (Tensor): A block-sparse tensor in compressed form.
|
|
88
|
+
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
89
|
+
partitions (int): The number of partitions to be merged.
|
|
90
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
91
|
+
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
Tensor: The merged block-sparse tensor in compressed form.
|
|
95
|
+
Tensor: The sparsity layout of the output tensor.
|
|
96
|
+
|
|
97
|
+
"""
|
|
98
|
+
x = x.contiguous()
|
|
99
|
+
|
|
100
|
+
validate_dimensions(x)
|
|
101
|
+
validate_contiguous(x)
|
|
102
|
+
validate_device(x)
|
|
103
|
+
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
104
|
+
validate_sparsity_block_size(sparsity_block_size, x)
|
|
105
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
106
|
+
|
|
107
|
+
sparsity_layout_output = (sparsity_layout.reshape(sparsity_layout.size(0) // partitions, partitions,
|
|
108
|
+
sparsity_layout.size(1), sparsity_layout.size(2))
|
|
109
|
+
.permute(0, 2, 1, 3)
|
|
110
|
+
.reshape(sparsity_layout.size(0) // partitions,
|
|
111
|
+
sparsity_layout.size(1), sparsity_layout.size(2) * partitions).contiguous())
|
|
112
|
+
|
|
113
|
+
sparsity_lut = torch.nonzero(sparsity_layout_output).contiguous()
|
|
114
|
+
|
|
115
|
+
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
|
116
|
+
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
117
|
+
(sparsity_layout_flat == 1) -
|
|
118
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
119
|
+
.reshape(sparsity_layout.size(0) // partitions, partitions,
|
|
120
|
+
sparsity_layout.size(1), sparsity_layout.size(2))
|
|
121
|
+
.permute(0, 2, 1, 3)
|
|
122
|
+
.reshape(sparsity_layout.size(0) // partitions,
|
|
123
|
+
sparsity_layout.size(1), sparsity_layout.size(2) * partitions)
|
|
124
|
+
.reshape(-1).contiguous())
|
|
125
|
+
|
|
126
|
+
n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
|
|
127
|
+
|
|
128
|
+
validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
|
|
129
|
+
|
|
130
|
+
return _BlocksparseMerge.apply(x, sparsity_layout_output, sparsity_lut, sparsity_reverse_lut, partitions,
|
|
131
|
+
sparsity_block_size, n_sparse_blocks, triton_block_size), sparsity_layout_output
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class _BlocksparseMerge(torch.autograd.Function):
|
|
135
|
+
|
|
136
|
+
@staticmethod
|
|
137
|
+
def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
138
|
+
num_partitions: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
139
|
+
ctx.save_for_backward(sparsity_layout_o)
|
|
140
|
+
ctx.num_partitions = num_partitions
|
|
141
|
+
|
|
142
|
+
return forward_flow(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
143
|
+
n_sparse_blocks, triton_block_size)
|
|
144
|
+
|
|
145
|
+
@staticmethod
|
|
146
|
+
def backward(ctx, grad_output):
|
|
147
|
+
sparsity_layout = ctx.saved_tensors[0]
|
|
148
|
+
num_partitions = ctx.num_partitions
|
|
149
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
150
|
+
triton_block_size = ctx.triton_block_size
|
|
151
|
+
|
|
152
|
+
return split(grad_output, sparsity_layout, num_partitions,
|
|
153
|
+
sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None
|
|
154
|
+
|
|
155
|
+
|
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
from triton import language as tl
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
7
|
+
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
|
|
8
|
+
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def repeat(x: Tensor, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
|
|
12
|
+
sparsity_block_size: int, sparsity_layout_output: Tensor = None, triton_block_size: int = None) -> (
|
|
13
|
+
Tensor, Tensor):
|
|
14
|
+
x = x.contiguous()
|
|
15
|
+
|
|
16
|
+
validate_dimensions(x)
|
|
17
|
+
validate_contiguous(x)
|
|
18
|
+
validate_device(x)
|
|
19
|
+
validate_sparsity(sparsity_block_size, (x, sparsity_layout_x))
|
|
20
|
+
validate_sparsity_block_size(sparsity_block_size, x)
|
|
21
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
22
|
+
|
|
23
|
+
sparsity_layout_o = sparsity_layout_x.repeat(repeats[0], repeats[1], repeats[2])
|
|
24
|
+
|
|
25
|
+
if sparsity_layout_output is not None:
|
|
26
|
+
sparsity_layout_o = torch.logical_and(sparsity_layout_o, sparsity_layout_output)
|
|
27
|
+
|
|
28
|
+
sparsity_lut = torch.nonzero(sparsity_layout_o).contiguous()
|
|
29
|
+
|
|
30
|
+
sparsity_layout_flat = sparsity_layout_x.reshape(-1)
|
|
31
|
+
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
32
|
+
(sparsity_layout_flat == 1) -
|
|
33
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
34
|
+
.reshape(sparsity_layout_x.size())
|
|
35
|
+
.repeat(repeats[0], repeats[1], repeats[2])
|
|
36
|
+
.reshape(-1).contiguous())
|
|
37
|
+
|
|
38
|
+
n_sparse_blocks = torch.sum(sparsity_layout_o.to(torch.int)).item()
|
|
39
|
+
|
|
40
|
+
validate_contiguous(sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
|
|
41
|
+
|
|
42
|
+
return _BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
|
|
43
|
+
sparsity_block_size, n_sparse_blocks, triton_block_size), sparsity_layout_o
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class _BlocksparseRepeat(torch.autograd.Function):
|
|
47
|
+
|
|
48
|
+
@staticmethod
|
|
49
|
+
def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor,
|
|
50
|
+
sparsity_reverse_lut: Tensor,
|
|
51
|
+
sparsity_block_size: int, n_sparse_blocks: int,
|
|
52
|
+
triton_block_size: int) -> Tensor:
|
|
53
|
+
ctx.save_for_backward(sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
|
|
54
|
+
ctx.x_size = x.size()
|
|
55
|
+
ctx.x_stride = stride(x)
|
|
56
|
+
|
|
57
|
+
return forward_flow(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
58
|
+
n_sparse_blocks, triton_block_size)
|
|
59
|
+
|
|
60
|
+
@staticmethod
|
|
61
|
+
def backward(ctx, grad_output):
|
|
62
|
+
sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut = ctx.saved_tensors
|
|
63
|
+
x_size = ctx.x_size
|
|
64
|
+
x_stride = ctx.x_stride
|
|
65
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
66
|
+
triton_block_size = ctx.triton_block_size
|
|
67
|
+
|
|
68
|
+
n_sparse_blocks = torch.sum(sparsity_layout_x.to(torch.int)).item()
|
|
69
|
+
|
|
70
|
+
output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
71
|
+
dtype=grad_output.dtype, device=grad_output.device)
|
|
72
|
+
|
|
73
|
+
x_b, x_r, x_c = grad_output.size()
|
|
74
|
+
x_b_s, x_r_s, x_c_s = stride(grad_output)
|
|
75
|
+
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_o.size()
|
|
76
|
+
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_o)
|
|
77
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
78
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
79
|
+
o_b, o_r, o_c = x_size
|
|
80
|
+
o_b_s, o_r_s, o_c_s = x_stride
|
|
81
|
+
|
|
82
|
+
if triton_block_size is None:
|
|
83
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
84
|
+
|
|
85
|
+
triton_grid = lambda meta: [x_b,
|
|
86
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
87
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
88
|
+
|
|
89
|
+
(kernel_blocksparse_flow_push[triton_grid]
|
|
90
|
+
(grad_output,
|
|
91
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
92
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
93
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
94
|
+
sparsity_reverse_lut,
|
|
95
|
+
output,
|
|
96
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
97
|
+
triton_block_size))
|
|
98
|
+
|
|
99
|
+
return output, None, None, None, None, None, None, None
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@triton.jit
|
|
103
|
+
def kernel_blocksparse_flow_pull(x,
|
|
104
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
105
|
+
o,
|
|
106
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
107
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
108
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
109
|
+
r_lut,
|
|
110
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
111
|
+
# Get triton block indices
|
|
112
|
+
pid_blk = tl.program_id(axis=0)
|
|
113
|
+
pid_row = tl.program_id(axis=1)
|
|
114
|
+
pid_col = tl.program_id(axis=2)
|
|
115
|
+
|
|
116
|
+
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
117
|
+
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
118
|
+
spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
119
|
+
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
120
|
+
|
|
121
|
+
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
122
|
+
spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
|
|
123
|
+
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
124
|
+
|
|
125
|
+
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
126
|
+
spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
|
|
127
|
+
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
128
|
+
|
|
129
|
+
# Get reverse sparsity index
|
|
130
|
+
rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
|
|
131
|
+
spa_row * s_l_o_r_s +
|
|
132
|
+
spa_col * s_l_o_c_s)
|
|
133
|
+
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
|
|
134
|
+
rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
135
|
+
|
|
136
|
+
if rev_idx_spa == -1:
|
|
137
|
+
tl.device_assert(False)
|
|
138
|
+
return
|
|
139
|
+
|
|
140
|
+
blk_x_idx = (rev_idx_spa * x_b_s +
|
|
141
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
142
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
143
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
144
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
145
|
+
|
|
146
|
+
blk_o_idx = (pid_blk * o_b_s +
|
|
147
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
148
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
149
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
150
|
+
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
@triton.jit
|
|
154
|
+
def kernel_blocksparse_flow_push(x,
|
|
155
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
156
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
157
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
158
|
+
r_lut,
|
|
159
|
+
o,
|
|
160
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
161
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
162
|
+
# Get triton block indices
|
|
163
|
+
pid_blk = tl.program_id(axis=0)
|
|
164
|
+
pid_row = tl.program_id(axis=1)
|
|
165
|
+
pid_col = tl.program_id(axis=2)
|
|
166
|
+
|
|
167
|
+
# Get sparsity index of current input block consisting of its batch, row, and column index
|
|
168
|
+
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
169
|
+
spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
170
|
+
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
171
|
+
|
|
172
|
+
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
173
|
+
spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
|
|
174
|
+
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
175
|
+
|
|
176
|
+
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
177
|
+
spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
|
|
178
|
+
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
179
|
+
|
|
180
|
+
# Get reverse sparsity index
|
|
181
|
+
rev_idx_spa_idx = (spa_bat * s_l_x_b_s +
|
|
182
|
+
spa_row * s_l_x_r_s +
|
|
183
|
+
spa_col * s_l_x_c_s)
|
|
184
|
+
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_x_b * s_l_x_b_s)
|
|
185
|
+
rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
186
|
+
|
|
187
|
+
if rev_idx_spa == -1:
|
|
188
|
+
tl.device_assert(False)
|
|
189
|
+
return
|
|
190
|
+
|
|
191
|
+
blk_x_idx = (pid_blk * x_b_s +
|
|
192
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
193
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
194
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
195
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
196
|
+
|
|
197
|
+
blk_o_idx = (rev_idx_spa * o_b_s +
|
|
198
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
199
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
200
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
201
|
+
tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def forward_flow(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
205
|
+
sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
206
|
+
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
207
|
+
dtype=x.dtype, device=x.device)
|
|
208
|
+
output = torch.zeros_like(output)
|
|
209
|
+
|
|
210
|
+
x_b, x_r, x_c = x.size()
|
|
211
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
212
|
+
o_b, o_r, o_c = output.size()
|
|
213
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
214
|
+
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
|
|
215
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
|
|
216
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
217
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
218
|
+
asdf = torch.tensor(sparsity_lut).stride()
|
|
219
|
+
|
|
220
|
+
if triton_block_size is None:
|
|
221
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
222
|
+
|
|
223
|
+
triton_grid = lambda meta: [o_b,
|
|
224
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
225
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
226
|
+
|
|
227
|
+
(kernel_blocksparse_flow_pull[triton_grid]
|
|
228
|
+
(x,
|
|
229
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
230
|
+
output,
|
|
231
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
232
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
233
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
234
|
+
sparsity_reverse_lut,
|
|
235
|
+
triton_block_size))
|
|
236
|
+
|
|
237
|
+
# Save for backward pass
|
|
238
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
239
|
+
ctx.triton_block_size = triton_block_size
|
|
240
|
+
|
|
241
|
+
return output
|
|
@@ -5,7 +5,7 @@ from triton import language as tl
|
|
|
5
5
|
|
|
6
6
|
from blksprs.ops.exp import exp
|
|
7
7
|
from blksprs.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
|
|
8
|
-
from blksprs.utils.tools import get_triton_block_size
|
|
8
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
9
9
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
10
10
|
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
11
11
|
|
|
@@ -61,9 +61,9 @@ class _BlocksparseSoftmax(torch.autograd.Function):
|
|
|
61
61
|
output = torch.empty_like(x)
|
|
62
62
|
|
|
63
63
|
x_b, x_r, x_c = x.size()
|
|
64
|
-
x_b_s, x_r_s, x_c_s =
|
|
64
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
65
65
|
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
66
|
-
s_lut_r_s, s_lut_c_s =
|
|
66
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
67
67
|
o_b, o_r, o_c = output.size()
|
|
68
68
|
|
|
69
69
|
x_row_wise_max, sparsity_layout_rwm = row_wise_max(x, sparsity_layout, sparsity_block_size,
|
|
@@ -76,9 +76,9 @@ class _BlocksparseSoftmax(torch.autograd.Function):
|
|
|
76
76
|
triton_block_size=triton_block_size)
|
|
77
77
|
|
|
78
78
|
s_b, s_r, s_c = x_exp_row_wise_sum.shape
|
|
79
|
-
s_b_s, s_r_s, s_c_s =
|
|
79
|
+
s_b_s, s_r_s, s_c_s = stride(x_exp_row_wise_sum)
|
|
80
80
|
s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_rws.shape
|
|
81
|
-
s_l_s_b_s, s_l_s_r_s, s_l_s_c_s =
|
|
81
|
+
s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_rws)
|
|
82
82
|
|
|
83
83
|
if triton_block_size is None:
|
|
84
84
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -119,13 +119,13 @@ class _BlocksparseSoftmax(torch.autograd.Function):
|
|
|
119
119
|
(1 * (sparsity_layout_s_flat == 0)))
|
|
120
120
|
|
|
121
121
|
o_b, o_r, o_c = o.size()
|
|
122
|
-
o_b_s, o_r_s, o_c_s =
|
|
122
|
+
o_b_s, o_r_s, o_c_s = stride(o)
|
|
123
123
|
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
124
|
-
s_lut_r_s, s_lut_c_s =
|
|
124
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
125
125
|
s_b, s_r, s_c = s.size()
|
|
126
|
-
s_b_s, s_r_s, s_c_s =
|
|
126
|
+
s_b_s, s_r_s, s_c_s = stride(s)
|
|
127
127
|
s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
|
|
128
|
-
s_l_s_b_s, s_l_s_r_s, s_l_s_c_s =
|
|
128
|
+
s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_s)
|
|
129
129
|
|
|
130
130
|
grad_x = torch.empty_like(o, dtype=torch.float)
|
|
131
131
|
|
|
@@ -181,7 +181,8 @@ class _BlocksparseSoftmax(torch.autograd.Function):
|
|
|
181
181
|
rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
|
|
182
182
|
|
|
183
183
|
if rev_idx_spa_s == -1:
|
|
184
|
-
|
|
184
|
+
tl.device_assert(False)
|
|
185
|
+
return
|
|
185
186
|
|
|
186
187
|
# Load x block
|
|
187
188
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
@@ -3,7 +3,7 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
-
from blksprs.utils.tools import get_triton_block_size
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
7
7
|
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
|
|
8
8
|
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
9
9
|
|
|
@@ -56,20 +56,20 @@ def transpose(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, trit
|
|
|
56
56
|
class _BlocksparseTranspose(torch.autograd.Function):
|
|
57
57
|
|
|
58
58
|
@staticmethod
|
|
59
|
-
def forward(ctx, x: Tensor,
|
|
60
|
-
|
|
61
|
-
n_sparse_blocks: int, triton_block_size: int) ->
|
|
59
|
+
def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
60
|
+
sparsity_block_size: int,
|
|
61
|
+
n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
62
62
|
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
63
63
|
dtype=x.dtype, device=x.device)
|
|
64
64
|
|
|
65
65
|
x_b, x_r, x_c = x.size()
|
|
66
|
-
x_b_s, x_r_s, x_c_s =
|
|
67
|
-
s_l_b, s_l_r, s_l_c =
|
|
68
|
-
s_l_b_s, s_l_r_s, s_l_c_s =
|
|
66
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
67
|
+
s_l_b, s_l_r, s_l_c = sparsity_layout_o.size()
|
|
68
|
+
s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout_o)
|
|
69
69
|
s_lut_r, s_lut_c = sparsity_lut.shape
|
|
70
|
-
s_lut_r_s, s_lut_c_s =
|
|
70
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
71
71
|
o_b, o_r, o_c = output.size()
|
|
72
|
-
o_b_s, o_r_s, o_c_s =
|
|
72
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
73
73
|
|
|
74
74
|
if triton_block_size is None:
|
|
75
75
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -89,8 +89,7 @@ class _BlocksparseTranspose(torch.autograd.Function):
|
|
|
89
89
|
triton_block_size))
|
|
90
90
|
|
|
91
91
|
# Save for backward pass
|
|
92
|
-
ctx.save_for_backward(
|
|
93
|
-
ctx.sparsity_layout = sparsity_layout
|
|
92
|
+
ctx.save_for_backward(sparsity_layout_o)
|
|
94
93
|
ctx.sparsity_block_size = sparsity_block_size
|
|
95
94
|
ctx.triton_block_size = triton_block_size
|
|
96
95
|
|
|
@@ -141,7 +140,8 @@ class _BlocksparseTranspose(torch.autograd.Function):
|
|
|
141
140
|
rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
142
141
|
|
|
143
142
|
if rev_idx_spa == -1:
|
|
144
|
-
|
|
143
|
+
tl.device_assert(False)
|
|
144
|
+
return
|
|
145
145
|
|
|
146
146
|
blk_x_idx = (rev_idx_spa * x_b_s +
|
|
147
147
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
@@ -63,6 +63,8 @@ def validate_sparsity(sparsity_block_size: int, *tensor_sparsity_layout_tuples:
|
|
|
63
63
|
for (tensor, sparsity_layout) in tensor_sparsity_layout_tuples:
|
|
64
64
|
_validate_sparsity_layout_values(sparsity_layout)
|
|
65
65
|
|
|
66
|
+
if not sparsity_layout.dim() == 3:
|
|
67
|
+
raise ValueError("Sparsity layout must have exactly 3 dimensions")
|
|
66
68
|
if not (tensor.size(-1) == tensor.size(-2) == sparsity_block_size):
|
|
67
69
|
raise ValueError("Blocks not conforming to sparsity block size")
|
|
68
70
|
if not tensor.size(0) == torch.sum(sparsity_layout.reshape(-1)):
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.7
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -31,9 +31,11 @@ Currently supported operations (includes gradient calculation):
|
|
|
31
31
|
- Sparse matrix multiplication (_supports any combination of sparse and dense matrices due to support
|
|
32
32
|
for `sparse = sparse @ sparse` matmul_)
|
|
33
33
|
- Softmax
|
|
34
|
-
-
|
|
34
|
+
- Transpose
|
|
35
35
|
- Gather
|
|
36
36
|
- Scatter (_supports either no reduction or summation, gradients are only available for summation_)
|
|
37
|
+
- Repeat (_supports target sparsity layout_)
|
|
38
|
+
- Splitting and merging of matrices along the last dimension
|
|
37
39
|
- Conversion to and from sparse form
|
|
38
40
|
- Conversion to different sparsity layouts and different sparsity block sizes
|
|
39
41
|
|
|
@@ -63,7 +65,7 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
|
|
|
63
65
|
|
|
64
66
|
### Dependencies
|
|
65
67
|
|
|
66
|
-
- [PyTorch](https://pytorch.org/) (built with v2.
|
|
68
|
+
- [PyTorch](https://pytorch.org/) (built with v2.5.0)
|
|
67
69
|
- _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
|
|
68
70
|
|
|
69
71
|
## Changelog
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|