blksprs 1.7__tar.gz → 1.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {blksprs-1.7 → blksprs-1.8}/PKG-INFO +10 -4
- {blksprs-1.7 → blksprs-1.8}/README.md +9 -3
- {blksprs-1.7 → blksprs-1.8}/blksprs/__init__.py +11 -6
- {blksprs-1.7 → blksprs-1.8}/blksprs/ops/conversion.py +12 -1
- {blksprs-1.7 → blksprs-1.8}/blksprs/ops/repeat.py +82 -1
- {blksprs-1.7 → blksprs-1.8}/blksprs/ops/softmax.py +1 -1
- {blksprs-1.7 → blksprs-1.8}/blksprs/utils/validation.py +20 -1
- {blksprs-1.7 → blksprs-1.8}/blksprs.egg-info/PKG-INFO +10 -4
- {blksprs-1.7 → blksprs-1.8}/blksprs.egg-info/SOURCES.txt +2 -3
- {blksprs-1.7 → blksprs-1.8}/pyproject.toml +1 -1
- blksprs-1.7/blksprs/misc/repeat_interleave.py +0 -132
- {blksprs-1.7 → blksprs-1.8}/blksprs/experimental/distribution_mdi.py +0 -0
- {blksprs-1.7 → blksprs-1.8}/blksprs/layouting/distribution_layout.py +0 -0
- {blksprs-1.7 → blksprs-1.8}/blksprs/layouting/sparsity_layout.py +0 -0
- {blksprs-1.7 → blksprs-1.8}/blksprs/misc/broadcast_ops.py +0 -0
- {blksprs-1.7/blksprs/ops → blksprs-1.8/blksprs/misc}/exp.py +0 -0
- {blksprs-1.7/blksprs/ops → blksprs-1.8/blksprs/misc}/partitioning.py +0 -0
- {blksprs-1.7 → blksprs-1.8}/blksprs/misc/row_wise.py +0 -0
- {blksprs-1.7 → blksprs-1.8}/blksprs/ops/distribution.py +0 -0
- {blksprs-1.7 → blksprs-1.8}/blksprs/ops/matmul.py +0 -0
- {blksprs-1.7 → blksprs-1.8}/blksprs/ops/transpose.py +0 -0
- {blksprs-1.7 → blksprs-1.8}/blksprs/utils/benchmarking.py +0 -0
- {blksprs-1.7 → blksprs-1.8}/blksprs/utils/tools.py +0 -0
- {blksprs-1.7 → blksprs-1.8}/blksprs.egg-info/dependency_links.txt +0 -0
- {blksprs-1.7 → blksprs-1.8}/blksprs.egg-info/requires.txt +0 -0
- {blksprs-1.7 → blksprs-1.8}/blksprs.egg-info/top_level.txt +0 -0
- {blksprs-1.7 → blksprs-1.8}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.8
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -28,13 +28,13 @@ A lightweight and efficient library for operations on block-sparse matrices in P
|
|
|
28
28
|
|
|
29
29
|
Currently supported operations (includes gradient calculation):
|
|
30
30
|
|
|
31
|
-
-
|
|
32
|
-
for `sparse = sparse @ sparse` matmul_)
|
|
31
|
+
- Matrix multiplication
|
|
33
32
|
- Softmax
|
|
34
33
|
- Transpose
|
|
35
34
|
- Gather
|
|
36
35
|
- Scatter (_supports either no reduction or summation, gradients are only available for summation_)
|
|
37
36
|
- Repeat (_supports target sparsity layout_)
|
|
37
|
+
- Repeat Interleave (_supports target sparsity layout_)
|
|
38
38
|
- Splitting and merging of matrices along the last dimension
|
|
39
39
|
- Conversion to and from sparse form
|
|
40
40
|
- Conversion to different sparsity layouts and different sparsity block sizes
|
|
@@ -51,8 +51,14 @@ These include, e.g.,
|
|
|
51
51
|
Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to
|
|
52
52
|
match.
|
|
53
53
|
|
|
54
|
+
Further helpful operations (included in the ``bs.misc`` module) that do **not** support gradient calculation include:
|
|
55
|
+
|
|
56
|
+
- Row-wise sum, max, addition, and subtraction
|
|
57
|
+
- Broadcast addition and subtraction between slices
|
|
58
|
+
|
|
54
59
|
Furthermore, the library provides a set of utility functions for the creation of sparsity layouts based on existing
|
|
55
|
-
dense tensors.
|
|
60
|
+
dense tensors and for the scatter operation (module ``bs.layout``), as well as utility functions to ensure correct input
|
|
61
|
+
dimensionality (module ``bs.util``).
|
|
56
62
|
|
|
57
63
|
## Installation
|
|
58
64
|
|
|
@@ -9,13 +9,13 @@ A lightweight and efficient library for operations on block-sparse matrices in P
|
|
|
9
9
|
|
|
10
10
|
Currently supported operations (includes gradient calculation):
|
|
11
11
|
|
|
12
|
-
-
|
|
13
|
-
for `sparse = sparse @ sparse` matmul_)
|
|
12
|
+
- Matrix multiplication
|
|
14
13
|
- Softmax
|
|
15
14
|
- Transpose
|
|
16
15
|
- Gather
|
|
17
16
|
- Scatter (_supports either no reduction or summation, gradients are only available for summation_)
|
|
18
17
|
- Repeat (_supports target sparsity layout_)
|
|
18
|
+
- Repeat Interleave (_supports target sparsity layout_)
|
|
19
19
|
- Splitting and merging of matrices along the last dimension
|
|
20
20
|
- Conversion to and from sparse form
|
|
21
21
|
- Conversion to different sparsity layouts and different sparsity block sizes
|
|
@@ -32,8 +32,14 @@ These include, e.g.,
|
|
|
32
32
|
Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to
|
|
33
33
|
match.
|
|
34
34
|
|
|
35
|
+
Further helpful operations (included in the ``bs.misc`` module) that do **not** support gradient calculation include:
|
|
36
|
+
|
|
37
|
+
- Row-wise sum, max, addition, and subtraction
|
|
38
|
+
- Broadcast addition and subtraction between slices
|
|
39
|
+
|
|
35
40
|
Furthermore, the library provides a set of utility functions for the creation of sparsity layouts based on existing
|
|
36
|
-
dense tensors.
|
|
41
|
+
dense tensors and for the scatter operation (module ``bs.layout``), as well as utility functions to ensure correct input
|
|
42
|
+
dimensionality (module ``bs.util``).
|
|
37
43
|
|
|
38
44
|
## Installation
|
|
39
45
|
|
|
@@ -1,22 +1,27 @@
|
|
|
1
|
-
from blksprs.ops.conversion import to_dense, to_sparse
|
|
1
|
+
from blksprs.ops.conversion import to_dense, to_sparse, from_blksprs, to_blksprs
|
|
2
2
|
from blksprs.ops.distribution import gather, scatter, scatter_reduce
|
|
3
|
-
from blksprs.ops.exp import exp
|
|
4
3
|
from blksprs.ops.matmul import matmul
|
|
5
4
|
from blksprs.ops.softmax import softmax
|
|
6
5
|
from blksprs.ops.transpose import transpose
|
|
7
|
-
from blksprs.ops.
|
|
6
|
+
from blksprs.ops.repeat import repeat, repeat_interleave
|
|
7
|
+
from blksprs.misc.partitioning import split, merge
|
|
8
|
+
|
|
8
9
|
|
|
9
10
|
class layout:
|
|
10
11
|
from blksprs.layouting.distribution_layout import build_distribution_layout
|
|
11
|
-
from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption,
|
|
12
|
+
from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption, \
|
|
13
|
+
build_sparsity_layout_matmul
|
|
14
|
+
|
|
12
15
|
|
|
13
16
|
class misc:
|
|
14
17
|
from blksprs.misc.broadcast_ops import broadcast_add, broadcast_sub
|
|
15
|
-
from blksprs.misc.
|
|
18
|
+
from blksprs.misc.exp import exp
|
|
16
19
|
from blksprs.misc.row_wise import row_wise_sum, row_wise_max, row_wise_add, row_wise_sub
|
|
17
20
|
|
|
21
|
+
|
|
18
22
|
class util:
|
|
19
23
|
from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse, disable_validation
|
|
20
24
|
|
|
25
|
+
|
|
21
26
|
class experimental:
|
|
22
|
-
from blksprs.experimental.distribution_mdi import gather_mdi, scatter_reduce_mdi
|
|
27
|
+
from blksprs.experimental.distribution_mdi import gather_mdi, scatter_reduce_mdi
|
|
@@ -8,7 +8,12 @@ from triton import language as tl
|
|
|
8
8
|
from blksprs.layouting.sparsity_layout import build_sparsity_layout_adaption
|
|
9
9
|
from blksprs.utils.tools import get_triton_block_size, stride
|
|
10
10
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
11
|
-
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
11
|
+
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size, validate_sparsity_dense
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def from_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
|
|
15
|
+
triton_block_size: int = None) -> Tensor:
|
|
16
|
+
return to_dense(x, sparsity_layout, sparsity_block_size, fill_value, triton_block_size)
|
|
12
17
|
|
|
13
18
|
|
|
14
19
|
def to_dense(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
|
|
@@ -144,6 +149,11 @@ class _BlocksparseToDense(torch.autograd.Function):
|
|
|
144
149
|
tl.store(o + o_idx, blk, o_msk)
|
|
145
150
|
|
|
146
151
|
|
|
152
|
+
def to_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
153
|
+
triton_block_size: int = None) -> Tensor:
|
|
154
|
+
return to_sparse(x, sparsity_layout, sparsity_block_size, triton_block_size)
|
|
155
|
+
|
|
156
|
+
|
|
147
157
|
def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
148
158
|
"""Converts a block-sparse tensor in regular form to a block-sparse tensor in compressed form based on the given
|
|
149
159
|
sparsity layout.
|
|
@@ -163,6 +173,7 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, trit
|
|
|
163
173
|
validate_dimensions(x)
|
|
164
174
|
validate_contiguous(x)
|
|
165
175
|
validate_device(x)
|
|
176
|
+
validate_sparsity_dense(sparsity_block_size, (x, sparsity_layout))
|
|
166
177
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
167
178
|
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
168
179
|
|
|
@@ -11,6 +11,30 @@ from blksprs.utils.validation import validate_dimensions, validate_contiguous, v
|
|
|
11
11
|
def repeat(x: Tensor, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
|
|
12
12
|
sparsity_block_size: int, sparsity_layout_output: Tensor = None, triton_block_size: int = None) -> (
|
|
13
13
|
Tensor, Tensor):
|
|
14
|
+
"""Repeats a block-spare tensor in compressed form according to the given repeats.
|
|
15
|
+
|
|
16
|
+
Repeats is a 3-tuple of integers, where each integer represents the number of times the tensor should be repeated in
|
|
17
|
+
the first, second and third dimension respectively.
|
|
18
|
+
|
|
19
|
+
Note:
|
|
20
|
+
An output sparsity layout can be provided, in which case only the indicated blocks are filled. This may result
|
|
21
|
+
in blocks not being present in the output that were present in the input if the output sparsity layout indicates
|
|
22
|
+
them to be sparse.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
x (Tensor): A block-sparse tensor in compressed form.
|
|
26
|
+
sparsity_layout_x (Tensor): The sparsity layout of the block-sparse tensor.
|
|
27
|
+
repeats (tuple[int, int, int]): The number of times the tensor should be repeated in the first, second and
|
|
28
|
+
third dimension respectively.
|
|
29
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
30
|
+
sparsity_layout_output (Tensor): The desired sparsity layout of the output tensor (default ``None``).
|
|
31
|
+
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
Tensor: A block-sparse tensor in compressed form containing the repeated values.
|
|
35
|
+
Tensor: The sparsity layout of the resulting output tensor.
|
|
36
|
+
|
|
37
|
+
"""
|
|
14
38
|
x = x.contiguous()
|
|
15
39
|
|
|
16
40
|
validate_dimensions(x)
|
|
@@ -43,6 +67,64 @@ def repeat(x: Tensor, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
|
|
|
43
67
|
sparsity_block_size, n_sparse_blocks, triton_block_size), sparsity_layout_o
|
|
44
68
|
|
|
45
69
|
|
|
70
|
+
def repeat_interleave(x: Tensor, sparsity_layout_x: Tensor, repeats: int,
|
|
71
|
+
sparsity_block_size: int, sparsity_layout_output: Tensor = None,
|
|
72
|
+
triton_block_size: int = None) -> (
|
|
73
|
+
Tensor, Tensor):
|
|
74
|
+
"""Repeats and interleaves the block-sparse tensor in compressed form.
|
|
75
|
+
|
|
76
|
+
Repeats each matrix contained in the tensors by ``repeats`` amount and places them consecutively in the output
|
|
77
|
+
tensor.
|
|
78
|
+
|
|
79
|
+
Note:
|
|
80
|
+
In similar fashion to the regular ``repeat`` an output sparsity layout can be provided. In this case only
|
|
81
|
+
non-sparse blocks will be filled.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
x (Tensor): A block-sparse tensor in compressed form.
|
|
85
|
+
sparsity_layout_x (Tensor): The sparsity layout of the block-sparse tensor.
|
|
86
|
+
repeats (int): The number of times to repeat the matrices.
|
|
87
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
88
|
+
sparsity_layout_output (Tensor): The desired sparsity layout of the output tensor (default ``None``).
|
|
89
|
+
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
Tensor: A block-sparse tensor in compressed form containing the repeated and interleaved matrices.
|
|
93
|
+
Tensor: The sparsity layout of the resulting output tensor.
|
|
94
|
+
|
|
95
|
+
"""
|
|
96
|
+
x = x.contiguous()
|
|
97
|
+
|
|
98
|
+
validate_dimensions(x)
|
|
99
|
+
validate_contiguous(x)
|
|
100
|
+
validate_device(x)
|
|
101
|
+
validate_sparsity(sparsity_block_size, (x, sparsity_layout_x))
|
|
102
|
+
validate_sparsity_block_size(sparsity_block_size, x)
|
|
103
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
104
|
+
|
|
105
|
+
sparsity_layout_o = torch.repeat_interleave(sparsity_layout_x, repeats, dim=0).contiguous()
|
|
106
|
+
|
|
107
|
+
if sparsity_layout_output is not None:
|
|
108
|
+
sparsity_layout_o = torch.logical_and(sparsity_layout_o, sparsity_layout_output)
|
|
109
|
+
|
|
110
|
+
sparsity_lut = torch.nonzero(sparsity_layout_o).contiguous()
|
|
111
|
+
|
|
112
|
+
sparsity_layout_flat = sparsity_layout_x.reshape(-1)
|
|
113
|
+
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
114
|
+
(sparsity_layout_flat == 1) -
|
|
115
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
116
|
+
.reshape(sparsity_layout_x.size())
|
|
117
|
+
.repeat_interleave(repeats, dim=0)
|
|
118
|
+
.reshape(-1).contiguous())
|
|
119
|
+
|
|
120
|
+
n_sparse_blocks = torch.sum(sparsity_layout_o.to(torch.int)).item()
|
|
121
|
+
|
|
122
|
+
validate_contiguous(sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
|
|
123
|
+
|
|
124
|
+
return _BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
|
|
125
|
+
sparsity_block_size, n_sparse_blocks, triton_block_size), sparsity_layout_o
|
|
126
|
+
|
|
127
|
+
|
|
46
128
|
class _BlocksparseRepeat(torch.autograd.Function):
|
|
47
129
|
|
|
48
130
|
@staticmethod
|
|
@@ -215,7 +297,6 @@ def forward_flow(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor
|
|
|
215
297
|
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
|
|
216
298
|
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
217
299
|
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
218
|
-
asdf = torch.tensor(sparsity_lut).stride()
|
|
219
300
|
|
|
220
301
|
if triton_block_size is None:
|
|
221
302
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -3,7 +3,7 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
-
from blksprs.
|
|
6
|
+
from blksprs.misc.exp import exp
|
|
7
7
|
from blksprs.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
|
|
8
8
|
from blksprs.utils.tools import get_triton_block_size, stride
|
|
9
9
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
@@ -3,6 +3,7 @@ from torch import Tensor
|
|
|
3
3
|
|
|
4
4
|
VALIDATION = True
|
|
5
5
|
|
|
6
|
+
|
|
6
7
|
def validate_dimensions(*tensors: Tensor, dims=3) -> None:
|
|
7
8
|
if _check_skip_validation():
|
|
8
9
|
return
|
|
@@ -71,10 +72,25 @@ def validate_sparsity(sparsity_block_size: int, *tensor_sparsity_layout_tuples:
|
|
|
71
72
|
raise ValueError("Mismatch between sparsity layout and blocks")
|
|
72
73
|
|
|
73
74
|
|
|
75
|
+
def validate_sparsity_dense(sparsity_block_size: int, *tensor_sparsity_layout_tuples: tuple[Tensor, Tensor]) -> None:
|
|
76
|
+
if _check_skip_validation():
|
|
77
|
+
return
|
|
78
|
+
|
|
79
|
+
for (tensor, sparsity_layout) in tensor_sparsity_layout_tuples:
|
|
80
|
+
_validate_sparsity_layout_values(sparsity_layout)
|
|
81
|
+
|
|
82
|
+
if not sparsity_layout.dim() == 3:
|
|
83
|
+
raise ValueError("Sparsity layout must have exactly 3 dimensions")
|
|
84
|
+
if not (tensor.size(-1) // sparsity_block_size == sparsity_layout.size(-1) and
|
|
85
|
+
tensor.size(-2) // sparsity_block_size == sparsity_layout.size(-2)):
|
|
86
|
+
raise ValueError("Tensor not conforming to sparsity layout")
|
|
87
|
+
|
|
88
|
+
|
|
74
89
|
def _validate_sparsity_layout_values(sparsity_layout: Tensor):
|
|
75
90
|
if not torch.all(torch.logical_or(sparsity_layout == 0, sparsity_layout == 1)):
|
|
76
91
|
raise ValueError("Sparsity layout values must be either 0 or 1")
|
|
77
92
|
|
|
93
|
+
|
|
78
94
|
def validate_sparsity_block_size(sparsity_block_size: int, *tensors):
|
|
79
95
|
if _check_skip_validation():
|
|
80
96
|
return
|
|
@@ -86,6 +102,7 @@ def validate_sparsity_block_size(sparsity_block_size: int, *tensors):
|
|
|
86
102
|
if not (tensor.size(-1) % sparsity_block_size == 0 and tensor.size(-2) % sparsity_block_size == 0):
|
|
87
103
|
raise ValueError("Tensor sizes must be divisible by sparsity block size")
|
|
88
104
|
|
|
105
|
+
|
|
89
106
|
def validate_triton_block_size(triton_block_size: int, sparsity_block_size: int):
|
|
90
107
|
if _check_skip_validation():
|
|
91
108
|
return
|
|
@@ -99,9 +116,11 @@ def validate_triton_block_size(triton_block_size: int, sparsity_block_size: int)
|
|
|
99
116
|
if triton_block_size > sparsity_block_size:
|
|
100
117
|
raise ValueError("Triton block size cannot be larger than sparsity block size")
|
|
101
118
|
|
|
119
|
+
|
|
102
120
|
def _check_skip_validation():
|
|
103
121
|
return not VALIDATION
|
|
104
122
|
|
|
123
|
+
|
|
105
124
|
def _set_skip_validation(skip_validation: bool):
|
|
106
125
|
global VALIDATION
|
|
107
|
-
VALIDATION = not skip_validation
|
|
126
|
+
VALIDATION = not skip_validation
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.8
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -28,13 +28,13 @@ A lightweight and efficient library for operations on block-sparse matrices in P
|
|
|
28
28
|
|
|
29
29
|
Currently supported operations (includes gradient calculation):
|
|
30
30
|
|
|
31
|
-
-
|
|
32
|
-
for `sparse = sparse @ sparse` matmul_)
|
|
31
|
+
- Matrix multiplication
|
|
33
32
|
- Softmax
|
|
34
33
|
- Transpose
|
|
35
34
|
- Gather
|
|
36
35
|
- Scatter (_supports either no reduction or summation, gradients are only available for summation_)
|
|
37
36
|
- Repeat (_supports target sparsity layout_)
|
|
37
|
+
- Repeat Interleave (_supports target sparsity layout_)
|
|
38
38
|
- Splitting and merging of matrices along the last dimension
|
|
39
39
|
- Conversion to and from sparse form
|
|
40
40
|
- Conversion to different sparsity layouts and different sparsity block sizes
|
|
@@ -51,8 +51,14 @@ These include, e.g.,
|
|
|
51
51
|
Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to
|
|
52
52
|
match.
|
|
53
53
|
|
|
54
|
+
Further helpful operations (included in the ``bs.misc`` module) that do **not** support gradient calculation include:
|
|
55
|
+
|
|
56
|
+
- Row-wise sum, max, addition, and subtraction
|
|
57
|
+
- Broadcast addition and subtraction between slices
|
|
58
|
+
|
|
54
59
|
Furthermore, the library provides a set of utility functions for the creation of sparsity layouts based on existing
|
|
55
|
-
dense tensors.
|
|
60
|
+
dense tensors and for the scatter operation (module ``bs.layout``), as well as utility functions to ensure correct input
|
|
61
|
+
dimensionality (module ``bs.util``).
|
|
56
62
|
|
|
57
63
|
## Installation
|
|
58
64
|
|
|
@@ -10,13 +10,12 @@ blksprs/experimental/distribution_mdi.py
|
|
|
10
10
|
blksprs/layouting/distribution_layout.py
|
|
11
11
|
blksprs/layouting/sparsity_layout.py
|
|
12
12
|
blksprs/misc/broadcast_ops.py
|
|
13
|
-
blksprs/misc/
|
|
13
|
+
blksprs/misc/exp.py
|
|
14
|
+
blksprs/misc/partitioning.py
|
|
14
15
|
blksprs/misc/row_wise.py
|
|
15
16
|
blksprs/ops/conversion.py
|
|
16
17
|
blksprs/ops/distribution.py
|
|
17
|
-
blksprs/ops/exp.py
|
|
18
18
|
blksprs/ops/matmul.py
|
|
19
|
-
blksprs/ops/partitioning.py
|
|
20
19
|
blksprs/ops/repeat.py
|
|
21
20
|
blksprs/ops/softmax.py
|
|
22
21
|
blksprs/ops/transpose.py
|
|
@@ -1,132 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
import triton
|
|
3
|
-
from torch import Tensor
|
|
4
|
-
from triton import language as tl
|
|
5
|
-
|
|
6
|
-
from blksprs.utils.tools import get_triton_block_size, stride
|
|
7
|
-
from blksprs.utils.validation import validate_contiguous, validate_device, \
|
|
8
|
-
validate_sparsity_block_size, validate_triton_block_size, validate_dimensions
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def repeat_interleave(x: Tensor, sparsity_layout: Tensor, repeats: int,
|
|
12
|
-
sparsity_block_size: int, triton_block_size: int = None) -> tuple[Tensor, Tensor]:
|
|
13
|
-
"""Repeats and interleaves the block-sparse tensor in compressed form.
|
|
14
|
-
|
|
15
|
-
Repeats each matrix contained in the tensors by ``repeats`` amount and places them consecutively in the output
|
|
16
|
-
tensor.
|
|
17
|
-
|
|
18
|
-
Args:
|
|
19
|
-
x (Tensor): A block-sparse tensor in compressed form.
|
|
20
|
-
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
21
|
-
repeats (int): The number of times to repeat the matrices.
|
|
22
|
-
sparsity_block_size (int): The size of the sparsity blocks.
|
|
23
|
-
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
24
|
-
|
|
25
|
-
Returns:
|
|
26
|
-
Tensor: A block-sparse tensor in compressed form containing the repeated and interleaved matrices.
|
|
27
|
-
Tensor: The sparsity layout of the resulting output tensor.
|
|
28
|
-
|
|
29
|
-
"""
|
|
30
|
-
x = x.contiguous()
|
|
31
|
-
|
|
32
|
-
validate_dimensions(x)
|
|
33
|
-
validate_contiguous(x)
|
|
34
|
-
validate_device(x)
|
|
35
|
-
validate_sparsity_block_size(sparsity_block_size, x)
|
|
36
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
37
|
-
|
|
38
|
-
sparsity_layout_output = torch.repeat_interleave(sparsity_layout, repeats, dim=0).contiguous()
|
|
39
|
-
|
|
40
|
-
sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
|
|
41
|
-
|
|
42
|
-
sparsity_layout_output_flat = sparsity_layout_output.reshape(-1)
|
|
43
|
-
sparsity_output_reverse_lut = ((torch.cumsum(sparsity_layout_output_flat, dim=-1) - 1) *
|
|
44
|
-
(sparsity_layout_output_flat == 1) -
|
|
45
|
-
(1 * (sparsity_layout_output_flat == 0)))
|
|
46
|
-
|
|
47
|
-
n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
|
|
48
|
-
|
|
49
|
-
validate_contiguous(sparsity_layout, sparsity_lut, sparsity_layout_output, sparsity_output_reverse_lut)
|
|
50
|
-
|
|
51
|
-
output = torch.empty(n_sparse_blocks * repeats, sparsity_block_size, sparsity_block_size,
|
|
52
|
-
dtype=x.dtype, device=x.device)
|
|
53
|
-
|
|
54
|
-
x_b, x_r, x_c = x.size()
|
|
55
|
-
x_b_s, x_r_s, x_c_s = stride(x)
|
|
56
|
-
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
57
|
-
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
58
|
-
o_b, o_r, o_c = output.size()
|
|
59
|
-
o_b_s, o_r_s, o_c_s = stride(output)
|
|
60
|
-
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
|
|
61
|
-
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
|
|
62
|
-
|
|
63
|
-
if triton_block_size is None:
|
|
64
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
65
|
-
|
|
66
|
-
triton_grid = lambda meta: [x_b,
|
|
67
|
-
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
68
|
-
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
69
|
-
|
|
70
|
-
(kernel_repeat_interleave[triton_grid]
|
|
71
|
-
(x,
|
|
72
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
73
|
-
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
74
|
-
output,
|
|
75
|
-
o_b, o_b_s, o_r_s, o_c_s,
|
|
76
|
-
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
77
|
-
sparsity_output_reverse_lut,
|
|
78
|
-
repeats,
|
|
79
|
-
triton_block_size))
|
|
80
|
-
|
|
81
|
-
return output, sparsity_layout_output
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
@triton.jit
|
|
85
|
-
def kernel_repeat_interleave(x,
|
|
86
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
87
|
-
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
88
|
-
o,
|
|
89
|
-
o_b, o_b_s, o_r_s, o_c_s,
|
|
90
|
-
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
91
|
-
r_lut_o,
|
|
92
|
-
repeats,
|
|
93
|
-
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
94
|
-
# Get triton block indices
|
|
95
|
-
pid_blk = tl.program_id(axis=0)
|
|
96
|
-
pid_row = tl.program_id(axis=1)
|
|
97
|
-
pid_col = tl.program_id(axis=2)
|
|
98
|
-
|
|
99
|
-
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
100
|
-
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
101
|
-
spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
102
|
-
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
103
|
-
|
|
104
|
-
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
105
|
-
spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
|
|
106
|
-
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
107
|
-
|
|
108
|
-
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
109
|
-
spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
|
|
110
|
-
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
111
|
-
|
|
112
|
-
# Load block
|
|
113
|
-
blk_x_idx = ((pid_blk * x_b_s) +
|
|
114
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
115
|
-
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
116
|
-
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
117
|
-
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
118
|
-
|
|
119
|
-
for repeat in range(repeats):
|
|
120
|
-
# Get reverse sparsity index
|
|
121
|
-
rev_idx_spa_idx = ((spa_bat * repeats + repeat) * s_l_o_b_s +
|
|
122
|
-
spa_row * s_l_o_r_s +
|
|
123
|
-
spa_col * s_l_o_c_s)
|
|
124
|
-
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
|
|
125
|
-
rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
126
|
-
|
|
127
|
-
# Store block
|
|
128
|
-
blk_o_idx = ((rev_idx_spa * o_b_s) +
|
|
129
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
130
|
-
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
131
|
-
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
132
|
-
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|