blksprs 1.8.1__tar.gz → 1.8.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {blksprs-1.8.1 → blksprs-1.8.3}/PKG-INFO +21 -13
- {blksprs-1.8.1 → blksprs-1.8.3}/README.md +20 -12
- blksprs-1.8.3/blksprs/__init__.py +40 -0
- {blksprs-1.8.1 → blksprs-1.8.3}/blksprs/layouting/distribution_layout.py +3 -2
- {blksprs-1.8.1 → blksprs-1.8.3}/blksprs/layouting/sparsity_layout.py +3 -2
- {blksprs-1.8.1 → blksprs-1.8.3}/blksprs/ops/conversion.py +35 -25
- {blksprs-1.8.1 → blksprs-1.8.3}/blksprs/ops/distribution.py +19 -18
- {blksprs-1.8.1/blksprs → blksprs-1.8.3/blksprs/ops}/experimental/distribution_mdi.py +22 -21
- {blksprs-1.8.1 → blksprs-1.8.3}/blksprs/ops/matmul.py +14 -13
- {blksprs-1.8.1/blksprs → blksprs-1.8.3/blksprs/ops}/misc/broadcast_ops.py +5 -4
- {blksprs-1.8.1/blksprs → blksprs-1.8.3/blksprs/ops}/misc/exp.py +5 -4
- {blksprs-1.8.1/blksprs → blksprs-1.8.3/blksprs/ops}/misc/row_wise.py +19 -18
- {blksprs-1.8.1/blksprs/misc → blksprs-1.8.3/blksprs/ops}/partitioning.py +13 -12
- {blksprs-1.8.1 → blksprs-1.8.3}/blksprs/ops/repeat.py +13 -12
- {blksprs-1.8.1 → blksprs-1.8.3}/blksprs/ops/softmax.py +8 -7
- {blksprs-1.8.1 → blksprs-1.8.3}/blksprs/ops/transpose.py +7 -6
- blksprs-1.8.3/blksprs/utils/blksprs_tensor.py +8 -0
- blksprs-1.8.3/blksprs/utils/processing.py +41 -0
- {blksprs-1.8.1 → blksprs-1.8.3}/blksprs/utils/tools.py +1 -6
- {blksprs-1.8.1 → blksprs-1.8.3}/blksprs/utils/validation.py +4 -0
- {blksprs-1.8.1 → blksprs-1.8.3}/blksprs.egg-info/PKG-INFO +21 -13
- {blksprs-1.8.1 → blksprs-1.8.3}/blksprs.egg-info/SOURCES.txt +7 -5
- {blksprs-1.8.1 → blksprs-1.8.3}/pyproject.toml +1 -1
- blksprs-1.8.1/blksprs/__init__.py +0 -27
- {blksprs-1.8.1 → blksprs-1.8.3}/blksprs/utils/benchmarking.py +0 -0
- {blksprs-1.8.1 → blksprs-1.8.3}/blksprs.egg-info/dependency_links.txt +0 -0
- {blksprs-1.8.1 → blksprs-1.8.3}/blksprs.egg-info/requires.txt +0 -0
- {blksprs-1.8.1 → blksprs-1.8.3}/blksprs.egg-info/top_level.txt +0 -0
- {blksprs-1.8.1 → blksprs-1.8.3}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 1.8.
|
|
3
|
+
Version: 1.8.3
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -22,6 +22,14 @@ Requires-Dist: build; extra == "build"
|
|
|
22
22
|
[](https://github.com/FelixSchoen/blksprs/releases)
|
|
23
23
|
[](https://www.python.org/downloads/release/python-3119/)
|
|
24
24
|
|
|
25
|
+
## Important Notice
|
|
26
|
+
|
|
27
|
+
🚨 **Non-Final API** 🚨
|
|
28
|
+
|
|
29
|
+
Although it already supports a wide variety of functions, this library is still under active development and the API is
|
|
30
|
+
subject to change. For feature requests or bug reports, please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
|
|
31
|
+
We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
|
|
32
|
+
|
|
25
33
|
## Overview
|
|
26
34
|
|
|
27
35
|
A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
|
|
@@ -51,14 +59,14 @@ These include, e.g.,
|
|
|
51
59
|
Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to
|
|
52
60
|
match.
|
|
53
61
|
|
|
54
|
-
Further helpful operations (included in the ``bs.misc`` module) that do **not** support gradient calculation include:
|
|
62
|
+
Further helpful operations (included in the ``bs.ops.misc`` module) that do **not** support gradient calculation include:
|
|
55
63
|
|
|
56
64
|
- Row-wise sum, max, addition, and subtraction
|
|
57
65
|
- Broadcast addition and subtraction between slices
|
|
58
66
|
|
|
59
67
|
Furthermore, the library provides a set of utility functions for the creation of sparsity layouts based on existing
|
|
60
|
-
dense tensors and for the scatter operation (module ``bs.
|
|
61
|
-
dimensionality (module ``bs.
|
|
68
|
+
dense tensors and for the scatter operation (module ``bs.layouting``), as well as utility functions to apply linear layers,
|
|
69
|
+
ensure correct input dimensionality, and validate input (module ``bs.utils``).
|
|
62
70
|
|
|
63
71
|
## Installation
|
|
64
72
|
|
|
@@ -111,14 +119,14 @@ def test_readme():
|
|
|
111
119
|
y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()
|
|
112
120
|
|
|
113
121
|
# Convert tensors to three-dimensional (dense) tensors since Triton can only handle tensors of exactly three dimensions
|
|
114
|
-
x_dense, x_shape_original = bs.
|
|
115
|
-
y_dense, y_shape_original = bs.
|
|
122
|
+
x_dense, x_shape_original = bs.utils.do_shape_blocksparse(x)
|
|
123
|
+
y_dense, y_shape_original = bs.utils.do_shape_blocksparse(y)
|
|
116
124
|
|
|
117
125
|
# Create sparsity layouts from existing tensors
|
|
118
|
-
sparsity_layout_x = bs.
|
|
119
|
-
|
|
120
|
-
sparsity_layout_y = bs.
|
|
121
|
-
|
|
126
|
+
sparsity_layout_x = bs.layouting.build_sparsity_layout(x_dense, sparsity_block_size,
|
|
127
|
+
triton_block_size=triton_block_size)
|
|
128
|
+
sparsity_layout_y = bs.layouting.build_sparsity_layout(y_dense, sparsity_block_size,
|
|
129
|
+
triton_block_size=triton_block_size)
|
|
122
130
|
|
|
123
131
|
# Create random sparsity layout for output tensor
|
|
124
132
|
sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)
|
|
@@ -150,12 +158,12 @@ def test_readme():
|
|
|
150
158
|
assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2) # Note that small numerical differences are expected
|
|
151
159
|
|
|
152
160
|
# Assert that the output has the correct sparsity layout
|
|
153
|
-
actual_sparsity_layout_o = bs.
|
|
154
|
-
|
|
161
|
+
actual_sparsity_layout_o = bs.layouting.build_sparsity_layout(o_dense, sparsity_block_size,
|
|
162
|
+
triton_block_size=triton_block_size)
|
|
155
163
|
assert torch.allclose(actual_sparsity_layout_o.to(torch.int), sparsity_layout_o)
|
|
156
164
|
|
|
157
165
|
# Convert output tensor back to original shape
|
|
158
|
-
o = bs.
|
|
166
|
+
o = bs.utils.undo_shape_blocksparse(o_dense, x_shape_original)
|
|
159
167
|
|
|
160
168
|
# Other available functions
|
|
161
169
|
bs.transpose(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
|
|
@@ -3,6 +3,14 @@
|
|
|
3
3
|
[](https://github.com/FelixSchoen/blksprs/releases)
|
|
4
4
|
[](https://www.python.org/downloads/release/python-3119/)
|
|
5
5
|
|
|
6
|
+
## Important Notice
|
|
7
|
+
|
|
8
|
+
🚨 **Non-Final API** 🚨
|
|
9
|
+
|
|
10
|
+
Although it already supports a wide variety of functions, this library is still under active development and the API is
|
|
11
|
+
subject to change. For feature requests or bug reports, please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
|
|
12
|
+
We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
|
|
13
|
+
|
|
6
14
|
## Overview
|
|
7
15
|
|
|
8
16
|
A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
|
|
@@ -32,14 +40,14 @@ These include, e.g.,
|
|
|
32
40
|
Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to
|
|
33
41
|
match.
|
|
34
42
|
|
|
35
|
-
Further helpful operations (included in the ``bs.misc`` module) that do **not** support gradient calculation include:
|
|
43
|
+
Further helpful operations (included in the ``bs.ops.misc`` module) that do **not** support gradient calculation include:
|
|
36
44
|
|
|
37
45
|
- Row-wise sum, max, addition, and subtraction
|
|
38
46
|
- Broadcast addition and subtraction between slices
|
|
39
47
|
|
|
40
48
|
Furthermore, the library provides a set of utility functions for the creation of sparsity layouts based on existing
|
|
41
|
-
dense tensors and for the scatter operation (module ``bs.
|
|
42
|
-
dimensionality (module ``bs.
|
|
49
|
+
dense tensors and for the scatter operation (module ``bs.layouting``), as well as utility functions to apply linear layers,
|
|
50
|
+
ensure correct input dimensionality, and validate input (module ``bs.utils``).
|
|
43
51
|
|
|
44
52
|
## Installation
|
|
45
53
|
|
|
@@ -92,14 +100,14 @@ def test_readme():
|
|
|
92
100
|
y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()
|
|
93
101
|
|
|
94
102
|
# Convert tensors to three-dimensional (dense) tensors since Triton can only handle tensors of exactly three dimensions
|
|
95
|
-
x_dense, x_shape_original = bs.
|
|
96
|
-
y_dense, y_shape_original = bs.
|
|
103
|
+
x_dense, x_shape_original = bs.utils.do_shape_blocksparse(x)
|
|
104
|
+
y_dense, y_shape_original = bs.utils.do_shape_blocksparse(y)
|
|
97
105
|
|
|
98
106
|
# Create sparsity layouts from existing tensors
|
|
99
|
-
sparsity_layout_x = bs.
|
|
100
|
-
|
|
101
|
-
sparsity_layout_y = bs.
|
|
102
|
-
|
|
107
|
+
sparsity_layout_x = bs.layouting.build_sparsity_layout(x_dense, sparsity_block_size,
|
|
108
|
+
triton_block_size=triton_block_size)
|
|
109
|
+
sparsity_layout_y = bs.layouting.build_sparsity_layout(y_dense, sparsity_block_size,
|
|
110
|
+
triton_block_size=triton_block_size)
|
|
103
111
|
|
|
104
112
|
# Create random sparsity layout for output tensor
|
|
105
113
|
sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)
|
|
@@ -131,12 +139,12 @@ def test_readme():
|
|
|
131
139
|
assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2) # Note that small numerical differences are expected
|
|
132
140
|
|
|
133
141
|
# Assert that the output has the correct sparsity layout
|
|
134
|
-
actual_sparsity_layout_o = bs.
|
|
135
|
-
|
|
142
|
+
actual_sparsity_layout_o = bs.layouting.build_sparsity_layout(o_dense, sparsity_block_size,
|
|
143
|
+
triton_block_size=triton_block_size)
|
|
136
144
|
assert torch.allclose(actual_sparsity_layout_o.to(torch.int), sparsity_layout_o)
|
|
137
145
|
|
|
138
146
|
# Convert output tensor back to original shape
|
|
139
|
-
o = bs.
|
|
147
|
+
o = bs.utils.undo_shape_blocksparse(o_dense, x_shape_original)
|
|
140
148
|
|
|
141
149
|
# Other available functions
|
|
142
150
|
bs.transpose(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
2
|
+
|
|
3
|
+
class ops:
|
|
4
|
+
from blksprs.ops.conversion import to_dense, to_sparse, from_blksprs, to_blksprs, adapt_layout
|
|
5
|
+
from blksprs.ops.distribution import gather, scatter, scatter_reduce
|
|
6
|
+
from blksprs.ops.matmul import matmul
|
|
7
|
+
from blksprs.ops.softmax import softmax
|
|
8
|
+
from blksprs.ops.transpose import transpose
|
|
9
|
+
from blksprs.ops.repeat import repeat, repeat_interleave
|
|
10
|
+
from blksprs.ops.partitioning import split, merge
|
|
11
|
+
|
|
12
|
+
class misc:
|
|
13
|
+
from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_add, row_wise_sub
|
|
14
|
+
from blksprs.ops.misc.broadcast_ops import broadcast_add, broadcast_sub
|
|
15
|
+
from blksprs.ops.misc.exp import exp
|
|
16
|
+
|
|
17
|
+
class experimental:
|
|
18
|
+
from blksprs.ops.experimental.distribution_mdi import gather_mdi, scatter_reduce_mdi
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class layouting:
|
|
22
|
+
from blksprs.layouting.distribution_layout import build_distribution_layout
|
|
23
|
+
from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption, \
|
|
24
|
+
build_sparsity_layout_matmul, build_sparsity_layout_matmul_fast
|
|
25
|
+
|
|
26
|
+
class experimental:
|
|
27
|
+
from blksprs.ops.experimental.distribution_mdi import build_distribution_layout_mdi
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class utils:
|
|
31
|
+
from blksprs.utils.processing import apply_torch_linear
|
|
32
|
+
from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse
|
|
33
|
+
from blksprs.utils.validation import disable_validation
|
|
34
|
+
|
|
35
|
+
class validation:
|
|
36
|
+
from blksprs.utils.validation import disable_validation
|
|
37
|
+
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_dtype_float, \
|
|
38
|
+
validate_dtype_int, validate_device, validate_sparsity, validate_sparsity_dense, \
|
|
39
|
+
validate_sparsity_block_size, \
|
|
40
|
+
validate_triton_block_size
|
|
@@ -3,18 +3,19 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
+
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
6
7
|
from blksprs.utils.tools import get_triton_block_size, stride
|
|
7
8
|
from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
|
|
8
9
|
validate_contiguous
|
|
9
10
|
|
|
10
11
|
|
|
11
|
-
def build_distribution_layout(indices:
|
|
12
|
+
def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: Tensor,
|
|
12
13
|
size_target: torch.Size,
|
|
13
14
|
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
14
15
|
"""Builds the sparsity layout of either the source of a gather or the target of a scatter operation.
|
|
15
16
|
|
|
16
17
|
Args:
|
|
17
|
-
indices (
|
|
18
|
+
indices (BlksprsTensor): The block-sparse indices tensor in compressed form used for the gather or scatter operation.
|
|
18
19
|
sparsity_layout_indices (Tensor): The sparsity layout of the indices block-sparse tensor.
|
|
19
20
|
size_target (torch.Size): The size of the block-sparse target tensor in regular form.
|
|
20
21
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
@@ -5,6 +5,7 @@ import triton
|
|
|
5
5
|
from torch import Tensor
|
|
6
6
|
from triton import language as tl
|
|
7
7
|
|
|
8
|
+
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
8
9
|
from blksprs.utils.tools import get_triton_block_size, stride
|
|
9
10
|
from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
|
|
10
11
|
validate_contiguous, validate_sparsity, validate_sparsity_block_size
|
|
@@ -82,14 +83,14 @@ def kernel_sparsity_layout(x,
|
|
|
82
83
|
tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
|
|
83
84
|
|
|
84
85
|
|
|
85
|
-
def build_sparsity_layout_adaption(x:
|
|
86
|
+
def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tensor,
|
|
86
87
|
sparsity_block_size_from: int, sparsity_block_size_to: int,
|
|
87
88
|
triton_block_size: int = None) -> Tensor:
|
|
88
89
|
"""Builds the sparsity layout of a block-sparse tensor in compressed form if a different sparsity block size were
|
|
89
90
|
used.
|
|
90
91
|
|
|
91
92
|
Args:
|
|
92
|
-
x (
|
|
93
|
+
x (BlksprsTensor): A block-sparse tensor in compressed form.
|
|
93
94
|
sparsity_layout_from (Tensor): The sparsity layout of the input block-sparse tensor.
|
|
94
95
|
sparsity_block_size_from (int): The size of the sparsity blocks of the input tensor.
|
|
95
96
|
sparsity_block_size_to (int): The desired size of the sparsity blocks for the resulting layout.
|
|
@@ -6,23 +6,27 @@ from torch import Tensor
|
|
|
6
6
|
from triton import language as tl
|
|
7
7
|
|
|
8
8
|
from blksprs.layouting.sparsity_layout import build_sparsity_layout_adaption
|
|
9
|
+
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
9
10
|
from blksprs.utils.tools import get_triton_block_size, stride
|
|
10
11
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
11
12
|
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size, validate_sparsity_dense
|
|
12
13
|
|
|
13
14
|
|
|
14
|
-
def from_blksprs(x:
|
|
15
|
+
def from_blksprs(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
|
|
15
16
|
triton_block_size: int = None) -> Tensor:
|
|
17
|
+
"""Wrapper for ``to_dense``.
|
|
18
|
+
|
|
19
|
+
"""
|
|
16
20
|
return to_dense(x, sparsity_layout, sparsity_block_size, fill_value, triton_block_size)
|
|
17
21
|
|
|
18
22
|
|
|
19
|
-
def to_dense(x:
|
|
23
|
+
def to_dense(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
|
|
20
24
|
triton_block_size: int = None) -> Tensor:
|
|
21
25
|
"""Converts a block-sparse tensor in compressed form to a block-sparse tensor in regular form based on the given
|
|
22
26
|
sparsity layout.
|
|
23
27
|
|
|
24
28
|
Args:
|
|
25
|
-
x (
|
|
29
|
+
x (BlksprsTensor): A block-sparse tensor in compressed form.
|
|
26
30
|
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
27
31
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
28
32
|
fill_value (float): The value to fill the resulting dense tensor with where the block-sparse tensor is not
|
|
@@ -50,12 +54,12 @@ def to_dense(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_
|
|
|
50
54
|
validate_contiguous(sparsity_reverse_lut)
|
|
51
55
|
|
|
52
56
|
if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
|
|
53
|
-
return x
|
|
57
|
+
return BlksprsTensor(x)
|
|
54
58
|
|
|
55
|
-
return _BlocksparseToDense.apply(x,
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
+
return BlksprsTensor(_BlocksparseToDense.apply(x,
|
|
60
|
+
sparsity_layout, sparsity_reverse_lut,
|
|
61
|
+
sparsity_block_size, fill_value,
|
|
62
|
+
triton_block_size))
|
|
59
63
|
|
|
60
64
|
|
|
61
65
|
class _BlocksparseToDense(torch.autograd.Function):
|
|
@@ -150,11 +154,15 @@ class _BlocksparseToDense(torch.autograd.Function):
|
|
|
150
154
|
|
|
151
155
|
|
|
152
156
|
def to_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
153
|
-
triton_block_size: int = None) ->
|
|
157
|
+
triton_block_size: int = None) -> BlksprsTensor:
|
|
158
|
+
"""Wrapper for ``to_sparse``.
|
|
159
|
+
|
|
160
|
+
"""
|
|
154
161
|
return to_sparse(x, sparsity_layout, sparsity_block_size, triton_block_size)
|
|
155
162
|
|
|
156
163
|
|
|
157
|
-
def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
164
|
+
def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
165
|
+
triton_block_size: int = None) -> BlksprsTensor:
|
|
158
166
|
"""Converts a block-sparse tensor in regular form to a block-sparse tensor in compressed form based on the given
|
|
159
167
|
sparsity layout.
|
|
160
168
|
|
|
@@ -165,7 +173,7 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, trit
|
|
|
165
173
|
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
166
174
|
|
|
167
175
|
Returns:
|
|
168
|
-
|
|
176
|
+
BlksprsTensor: The block-sparse tensor converted to compressed form.
|
|
169
177
|
|
|
170
178
|
"""
|
|
171
179
|
x = x.contiguous()
|
|
@@ -183,12 +191,12 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, trit
|
|
|
183
191
|
validate_contiguous(sparsity_layout, sparsity_lut)
|
|
184
192
|
|
|
185
193
|
if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
|
|
186
|
-
return x
|
|
194
|
+
return BlksprsTensor(x)
|
|
187
195
|
|
|
188
|
-
return _BlocksparseToSparse.apply(x,
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
196
|
+
return BlksprsTensor(_BlocksparseToSparse.apply(x,
|
|
197
|
+
sparsity_layout, sparsity_lut,
|
|
198
|
+
sparsity_block_size, n_sparse_blocks,
|
|
199
|
+
triton_block_size))
|
|
192
200
|
|
|
193
201
|
|
|
194
202
|
class _BlocksparseToSparse(torch.autograd.Function):
|
|
@@ -280,13 +288,14 @@ class _BlocksparseToSparse(torch.autograd.Function):
|
|
|
280
288
|
tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
|
|
281
289
|
|
|
282
290
|
|
|
283
|
-
def adapt_layout(x:
|
|
284
|
-
|
|
291
|
+
def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_size_from: int,
|
|
292
|
+
sparsity_block_size_to: int,
|
|
293
|
+
preprocess_data: dict = None, triton_block_size: int = None) -> BlksprsTensor:
|
|
285
294
|
"""Adapts the sparsity layout of a block-sparse tensor, resulting in a new block-sparse tensor in compressed form
|
|
286
295
|
conforming to the new sparsity layout (and sparsity block size) definition.
|
|
287
296
|
|
|
288
297
|
Args:
|
|
289
|
-
x (
|
|
298
|
+
x (BlksprsTensor): A block-sparse tensor in compressed form.
|
|
290
299
|
sparsity_layout_from (Tensor): The sparsity layout of the input block-sparse tensor.
|
|
291
300
|
sparsity_block_size_from (int): The size of the sparsity blocks of the input sparsity layout.
|
|
292
301
|
sparsity_block_size_to (int): The size of the sparsity blocks of the output sparsity layout.
|
|
@@ -294,7 +303,7 @@ def adapt_layout(x: Tensor, sparsity_layout_from: Tensor, sparsity_block_size_fr
|
|
|
294
303
|
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
295
304
|
|
|
296
305
|
Returns:
|
|
297
|
-
|
|
306
|
+
BlksprsTensor: The block-sparse tensor in compressed form with the adapted sparsity layout and sparsity block size.
|
|
298
307
|
|
|
299
308
|
"""
|
|
300
309
|
x = x.contiguous()
|
|
@@ -339,12 +348,13 @@ def adapt_layout(x: Tensor, sparsity_layout_from: Tensor, sparsity_block_size_fr
|
|
|
339
348
|
validate_contiguous(sparsity_layout_to, sparsity_reverse_lut_from, sparsity_lut_to)
|
|
340
349
|
|
|
341
350
|
if (sparsity_block_size_from == sparsity_block_size_to) and torch.equal(sparsity_layout_from, sparsity_layout_to):
|
|
342
|
-
return x
|
|
351
|
+
return BlksprsTensor(x)
|
|
343
352
|
|
|
344
|
-
return _BlocksparseAdaptLayout.apply(x,
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
353
|
+
return BlksprsTensor(_BlocksparseAdaptLayout.apply(x,
|
|
354
|
+
sparsity_layout_from, sparsity_reverse_lut_from,
|
|
355
|
+
sparsity_block_size_from,
|
|
356
|
+
sparsity_layout_to, sparsity_lut_to, sparsity_block_size_to,
|
|
357
|
+
n_sparse_blocks_to, min_sparsity_block_size, triton_block_size))
|
|
348
358
|
|
|
349
359
|
|
|
350
360
|
class _BlocksparseAdaptLayout(torch.autograd.Function):
|
|
@@ -3,25 +3,26 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
+
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
6
7
|
from blksprs.utils.tools import get_triton_block_size, stride
|
|
7
8
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
8
9
|
validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
|
|
9
10
|
|
|
10
11
|
|
|
11
|
-
def gather(src:
|
|
12
|
-
sparsity_block_size: int, triton_block_size: int = None) ->
|
|
12
|
+
def gather(src: BlksprsTensor, sparsity_layout_src: Tensor, idx: BlksprsTensor, sparsity_layout_idx: Tensor,
|
|
13
|
+
sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
|
|
13
14
|
"""Applies a gather operation on a block-sparse tensor in compressed form.
|
|
14
15
|
|
|
15
16
|
Args:
|
|
16
|
-
src (
|
|
17
|
+
src (BlksprsTensor): The source block-sparse tensor in compressed form to gather from.
|
|
17
18
|
sparsity_layout_src (Tensor): The sparsity layout of the source block-sparse tensor.
|
|
18
|
-
idx (
|
|
19
|
+
idx (BlksprsTensor): The block-sparse indices tensor in compressed form specifying how to gather from the source tensor.
|
|
19
20
|
sparsity_layout_idx (Tensor): The sparsity layout of the indices block-sparse tensor.
|
|
20
21
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
21
22
|
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
22
23
|
|
|
23
24
|
Returns:
|
|
24
|
-
|
|
25
|
+
BlksprsTensor: The result of the gather operation as a block-sparse tensor in compressed form.
|
|
25
26
|
|
|
26
27
|
"""
|
|
27
28
|
src = src.contiguous()
|
|
@@ -45,9 +46,9 @@ def gather(src: Tensor, sparsity_layout_src: Tensor, idx: Tensor, sparsity_layou
|
|
|
45
46
|
validate_contiguous(sparsity_layout_src, sparsity_reverse_lut_x,
|
|
46
47
|
sparsity_layout_idx, sparsity_lut_i)
|
|
47
48
|
|
|
48
|
-
return _BlocksparseGather.apply(src, sparsity_layout_src, sparsity_reverse_lut_x,
|
|
49
|
+
return BlksprsTensor(_BlocksparseGather.apply(src, sparsity_layout_src, sparsity_reverse_lut_x,
|
|
49
50
|
idx, sparsity_layout_idx, sparsity_lut_i,
|
|
50
|
-
sparsity_block_size, triton_block_size)
|
|
51
|
+
sparsity_block_size, triton_block_size))
|
|
51
52
|
|
|
52
53
|
|
|
53
54
|
class _BlocksparseGather(torch.autograd.Function):
|
|
@@ -168,10 +169,10 @@ class _BlocksparseGather(torch.autograd.Function):
|
|
|
168
169
|
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
169
170
|
|
|
170
171
|
|
|
171
|
-
def scatter(src:
|
|
172
|
-
idx:
|
|
172
|
+
def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
173
|
+
idx: BlksprsTensor,
|
|
173
174
|
sparsity_layout_tgt: Tensor,
|
|
174
|
-
sparsity_block_size: int, triton_block_size: int = None) ->
|
|
175
|
+
sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
|
|
175
176
|
"""Wrapper for ``scatter_reduce`` with ``reduce_op="none"``.
|
|
176
177
|
|
|
177
178
|
"""
|
|
@@ -182,17 +183,17 @@ def scatter(src: Tensor, sparsity_layout_src: Tensor,
|
|
|
182
183
|
reduce_op="none", triton_block_size=triton_block_size)
|
|
183
184
|
|
|
184
185
|
|
|
185
|
-
def scatter_reduce(src:
|
|
186
|
-
idx:
|
|
186
|
+
def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
187
|
+
idx: BlksprsTensor,
|
|
187
188
|
sparsity_layout_tgt: Tensor,
|
|
188
189
|
sparsity_block_size: int,
|
|
189
|
-
reduce_op: str = "sum", triton_block_size: int = None) ->
|
|
190
|
+
reduce_op: str = "sum", triton_block_size: int = None) -> BlksprsTensor:
|
|
190
191
|
"""Applies a scatter operation on a block-sparse tensor in compressed form.
|
|
191
192
|
|
|
192
193
|
Args:
|
|
193
|
-
src (
|
|
194
|
+
src (BlksprsTensor): The source block-sparse tensor in compressed form to scatter from.
|
|
194
195
|
sparsity_layout_src (Tensor): The sparsity layout of the source block-sparse tensor.
|
|
195
|
-
idx (
|
|
196
|
+
idx (BlksprsTensor): The block-sparse indices tensor in compressed form specifying how to scatter to the target tensor.
|
|
196
197
|
sparsity_layout_tgt (Tensor): The sparsity layout of the target block-sparse tensor.
|
|
197
198
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
198
199
|
reduce_op (str, optional): The reduction operation to apply during the scatter operation (default ``"sum"``).
|
|
@@ -200,7 +201,7 @@ def scatter_reduce(src: Tensor, sparsity_layout_src: Tensor,
|
|
|
200
201
|
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
201
202
|
|
|
202
203
|
Returns:
|
|
203
|
-
|
|
204
|
+
BlksprsTensor: The result of the scatter operation as a block-sparse tensor in compressed form.
|
|
204
205
|
|
|
205
206
|
"""
|
|
206
207
|
src = src.contiguous()
|
|
@@ -229,11 +230,11 @@ def scatter_reduce(src: Tensor, sparsity_layout_src: Tensor,
|
|
|
229
230
|
validate_contiguous(sparsity_layout_src, sparsity_lut_x,
|
|
230
231
|
sparsity_layout_tgt, sparsity_reverse_lut_o)
|
|
231
232
|
|
|
232
|
-
return _BlocksparseScatterReduce.apply(src, sparsity_layout_src, sparsity_lut_x,
|
|
233
|
+
return BlksprsTensor(_BlocksparseScatterReduce.apply(src, sparsity_layout_src, sparsity_lut_x,
|
|
233
234
|
idx,
|
|
234
235
|
sparsity_layout_tgt, sparsity_reverse_lut_o,
|
|
235
236
|
sparsity_block_size, n_sparse_blocks,
|
|
236
|
-
reduce_op, triton_block_size)
|
|
237
|
+
reduce_op, triton_block_size))
|
|
237
238
|
|
|
238
239
|
|
|
239
240
|
class _BlocksparseScatterReduce(torch.autograd.Function):
|
|
@@ -3,17 +3,18 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
+
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
6
7
|
from blksprs.utils.tools import get_triton_block_size, stride
|
|
7
8
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
8
9
|
validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
|
|
9
10
|
|
|
10
11
|
|
|
11
|
-
def gather_mdi(src:
|
|
12
|
-
idx_bat:
|
|
13
|
-
idx_row:
|
|
14
|
-
idx_col:
|
|
12
|
+
def gather_mdi(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
13
|
+
idx_bat: BlksprsTensor,
|
|
14
|
+
idx_row: BlksprsTensor,
|
|
15
|
+
idx_col: BlksprsTensor,
|
|
15
16
|
sparsity_layout_idx: Tensor,
|
|
16
|
-
sparsity_block_size: int, triton_block_size: int = None) ->
|
|
17
|
+
sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
|
|
17
18
|
src = src.contiguous()
|
|
18
19
|
idx_bat = idx_bat.contiguous()
|
|
19
20
|
idx_col = idx_col.contiguous()
|
|
@@ -37,9 +38,9 @@ def gather_mdi(src: Tensor, sparsity_layout_src: Tensor,
|
|
|
37
38
|
validate_contiguous(sparsity_layout_src, sparsity_reverse_lut_x,
|
|
38
39
|
sparsity_layout_idx, sparsity_lut_i)
|
|
39
40
|
|
|
40
|
-
return _BlocksparseGatherMDI.apply(src, sparsity_layout_src, sparsity_reverse_lut_x,
|
|
41
|
-
|
|
42
|
-
|
|
41
|
+
return BlksprsTensor(_BlocksparseGatherMDI.apply(src, sparsity_layout_src, sparsity_reverse_lut_x,
|
|
42
|
+
idx_bat, idx_col, sparsity_layout_idx, sparsity_lut_i,
|
|
43
|
+
sparsity_block_size, triton_block_size))
|
|
43
44
|
|
|
44
45
|
|
|
45
46
|
class _BlocksparseGatherMDI(torch.autograd.Function):
|
|
@@ -167,13 +168,13 @@ class _BlocksparseGatherMDI(torch.autograd.Function):
|
|
|
167
168
|
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
168
169
|
|
|
169
170
|
|
|
170
|
-
def scatter_reduce_mdi(src:
|
|
171
|
-
idx_bat:
|
|
172
|
-
idx_row:
|
|
173
|
-
idx_col:
|
|
171
|
+
def scatter_reduce_mdi(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
172
|
+
idx_bat: BlksprsTensor,
|
|
173
|
+
idx_row: BlksprsTensor,
|
|
174
|
+
idx_col: BlksprsTensor,
|
|
174
175
|
sparsity_layout_tgt: Tensor,
|
|
175
176
|
sparsity_block_size: int,
|
|
176
|
-
reduce_op: str = "sum", triton_block_size: int = None) ->
|
|
177
|
+
reduce_op: str = "sum", triton_block_size: int = None) -> BlksprsTensor:
|
|
177
178
|
src = src.contiguous()
|
|
178
179
|
idx_bat = idx_bat.contiguous()
|
|
179
180
|
idx_col = idx_col.contiguous()
|
|
@@ -203,12 +204,12 @@ def scatter_reduce_mdi(src: Tensor, sparsity_layout_src: Tensor,
|
|
|
203
204
|
validate_contiguous(sparsity_layout_src, sparsity_lut_x,
|
|
204
205
|
sparsity_layout_tgt, sparsity_reverse_lut_o)
|
|
205
206
|
|
|
206
|
-
return _BlocksparseScatterReduceMDI.apply(src, sparsity_layout_src, sparsity_lut_x,
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
207
|
+
return BlksprsTensor(_BlocksparseScatterReduceMDI.apply(src, sparsity_layout_src, sparsity_lut_x,
|
|
208
|
+
idx_bat,
|
|
209
|
+
idx_col,
|
|
210
|
+
sparsity_layout_tgt, sparsity_reverse_lut_o,
|
|
211
|
+
sparsity_block_size, n_sparse_blocks,
|
|
212
|
+
reduce_op, triton_block_size))
|
|
212
213
|
|
|
213
214
|
|
|
214
215
|
class _BlocksparseScatterReduceMDI(torch.autograd.Function):
|
|
@@ -353,8 +354,8 @@ class _BlocksparseScatterReduceMDI(torch.autograd.Function):
|
|
|
353
354
|
tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
354
355
|
|
|
355
356
|
|
|
356
|
-
def build_distribution_layout_mdi(idx_bat:
|
|
357
|
-
size_target: torch.Size,
|
|
357
|
+
def build_distribution_layout_mdi(idx_bat: BlksprsTensor, idx_row: BlksprsTensor, idx_col: BlksprsTensor,
|
|
358
|
+
sparsity_layout_idx: Tensor, size_target: torch.Size,
|
|
358
359
|
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
359
360
|
validate_dimensions(idx_bat, idx_col)
|
|
360
361
|
validate_contiguous(idx_bat, idx_col)
|
|
@@ -4,22 +4,23 @@ from torch import Tensor
|
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
6
|
from blksprs.ops.transpose import transpose
|
|
7
|
+
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
7
8
|
from blksprs.utils.tools import get_triton_block_size, stride
|
|
8
9
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
9
10
|
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size, validate_dtype_float
|
|
10
11
|
|
|
11
12
|
|
|
12
|
-
def matmul(x:
|
|
13
|
-
y:
|
|
13
|
+
def matmul(x: BlksprsTensor, sparsity_layout_x: Tensor,
|
|
14
|
+
y: BlksprsTensor, sparsity_layout_y: Tensor,
|
|
14
15
|
sparsity_layout_output: Tensor,
|
|
15
|
-
sparsity_block_size: int, triton_block_size: int = None) ->
|
|
16
|
+
sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
|
|
16
17
|
"""Performs matrix multiplication between two block-sparse tensors.
|
|
17
18
|
|
|
18
19
|
The sparsity layout of the output tensor is used to only calculate blocks that will be present in the output.
|
|
19
20
|
|
|
20
21
|
Args:
|
|
21
|
-
x (
|
|
22
|
-
y (
|
|
22
|
+
x (BlksprsTensor): A block-sparse tensor in compressed form.
|
|
23
|
+
y (BlksprsTensor): A block-sparse tensor in compressed form.
|
|
23
24
|
sparsity_layout_x (Tensor): The sparsity layout of the first block-sparse tensor.
|
|
24
25
|
sparsity_layout_y (Tensor): The sparsity layout of the second block-sparse tensor.
|
|
25
26
|
sparsity_layout_output (Tensor): The sparsity layout of the output tensor.
|
|
@@ -27,7 +28,7 @@ def matmul(x: Tensor, sparsity_layout_x: Tensor,
|
|
|
27
28
|
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
28
29
|
|
|
29
30
|
Returns:
|
|
30
|
-
|
|
31
|
+
BlksprsTensor: The result of the matrix multiplication as a block-sparse tensor in compressed form.
|
|
31
32
|
|
|
32
33
|
"""
|
|
33
34
|
x = x.contiguous()
|
|
@@ -61,13 +62,13 @@ def matmul(x: Tensor, sparsity_layout_x: Tensor,
|
|
|
61
62
|
sparsity_layout_y, sparsity_reverse_lut_y,
|
|
62
63
|
sparsity_layout_output, sparsity_lut_o)
|
|
63
64
|
|
|
64
|
-
return _BlocksparseMatmulSSS.apply(x, y,
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
65
|
+
return BlksprsTensor(_BlocksparseMatmulSSS.apply(x, y,
|
|
66
|
+
sparsity_layout_x, sparsity_reverse_lut_x,
|
|
67
|
+
sparsity_layout_y, sparsity_reverse_lut_y,
|
|
68
|
+
sparsity_layout_output, sparsity_lut_o,
|
|
69
|
+
sparsity_block_size,
|
|
70
|
+
n_sparse_blocks,
|
|
71
|
+
triton_block_size))
|
|
71
72
|
|
|
72
73
|
|
|
73
74
|
class _BlocksparseMatmulSSS(torch.autograd.Function):
|