blksprs 1.0__py3-none-any.whl → 1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/layouting/distribution_layout.py +114 -0
- blksprs/layouting/sparsity_layout.py +17 -7
- blksprs/misc/broadcast_addition.py +132 -0
- blksprs/ops/conversion.py +40 -15
- blksprs/ops/distribution.py +362 -0
- blksprs/ops/exp.py +18 -8
- blksprs/ops/{matmul_sss.py → matmul.py} +28 -26
- blksprs/ops/row_wise_sum.py +21 -5
- blksprs/ops/softmax.py +23 -12
- blksprs/ops/transpose.py +19 -7
- blksprs/utils/tools.py +1 -28
- blksprs/utils/validation.py +53 -1
- {blksprs-1.0.dist-info → blksprs-1.1.dist-info}/METADATA +32 -13
- blksprs-1.1.dist-info/RECORD +17 -0
- {blksprs-1.0.dist-info → blksprs-1.1.dist-info}/WHEEL +1 -1
- blksprs-1.0.dist-info/RECORD +0 -14
- {blksprs-1.0.dist-info → blksprs-1.1.dist-info}/top_level.txt +0 -0
blksprs/utils/tools.py
CHANGED
|
@@ -15,33 +15,6 @@ def undo_shape_blocksparse(x: Tensor, shape: Size):
|
|
|
15
15
|
|
|
16
16
|
return x.reshape((*shape[:-2], *x.shape[-2:]))
|
|
17
17
|
|
|
18
|
+
|
|
18
19
|
def get_triton_block_size(sparsity_block_size: int, limit: int = 128):
|
|
19
20
|
return min(sparsity_block_size, limit)
|
|
20
|
-
|
|
21
|
-
#
|
|
22
|
-
|
|
23
|
-
def slow_to_dense(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int):
|
|
24
|
-
output = torch.zeros(size=(sparsity_layout.size(0), sparsity_layout.size(1) * sparsity_block_size,
|
|
25
|
-
sparsity_layout.size(2) * sparsity_block_size), device=x.device)
|
|
26
|
-
indices_sparse_blocks = sparsity_layout.nonzero(as_tuple=True)
|
|
27
|
-
|
|
28
|
-
for idx, (b, r, c) in enumerate(zip(*indices_sparse_blocks)):
|
|
29
|
-
t_r = r * sparsity_block_size
|
|
30
|
-
t_c = c * sparsity_block_size
|
|
31
|
-
to_insert = x[idx]
|
|
32
|
-
output[b, t_r:t_r + sparsity_block_size, t_c:t_c + sparsity_block_size] = to_insert
|
|
33
|
-
|
|
34
|
-
return output
|
|
35
|
-
|
|
36
|
-
def slow_to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int):
|
|
37
|
-
num_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
|
|
38
|
-
output = torch.zeros(size=(num_sparse_blocks, sparsity_block_size, sparsity_block_size), device=x.device)
|
|
39
|
-
indices_sparse_blocks = sparsity_layout.nonzero(as_tuple=True)
|
|
40
|
-
|
|
41
|
-
for idx, (b, r, c) in enumerate(zip(*indices_sparse_blocks)):
|
|
42
|
-
t_r = r * sparsity_block_size
|
|
43
|
-
t_c = c * sparsity_block_size
|
|
44
|
-
to_insert = x[b, t_r:t_r + sparsity_block_size, t_c:t_c + sparsity_block_size]
|
|
45
|
-
output[idx] = to_insert
|
|
46
|
-
|
|
47
|
-
return output
|
blksprs/utils/validation.py
CHANGED
|
@@ -3,24 +3,45 @@ from torch import Tensor
|
|
|
3
3
|
|
|
4
4
|
|
|
5
5
|
def validate_dimensions(*tensors: Tensor) -> None:
|
|
6
|
+
if _skip_validation():
|
|
7
|
+
return
|
|
8
|
+
|
|
6
9
|
for tensor in tensors:
|
|
7
10
|
if tensor.dim() != 3:
|
|
8
11
|
raise ValueError("Tensor must have 3 dimensions")
|
|
9
12
|
|
|
10
13
|
|
|
11
14
|
def validate_contiguous(*tensors: Tensor) -> None:
|
|
15
|
+
if _skip_validation():
|
|
16
|
+
return
|
|
17
|
+
|
|
12
18
|
for tensor in tensors:
|
|
13
19
|
if not tensor.is_contiguous():
|
|
14
20
|
raise ValueError("Tensor must be contiguous")
|
|
15
21
|
|
|
16
22
|
|
|
17
23
|
def validate_dtype_float(*tensors: Tensor) -> None:
|
|
24
|
+
if _skip_validation():
|
|
25
|
+
return
|
|
26
|
+
|
|
18
27
|
for tensor in tensors:
|
|
19
28
|
if tensor.dtype != torch.float32:
|
|
20
29
|
raise ValueError("Tensor must have float32 dtype")
|
|
21
30
|
|
|
22
31
|
|
|
32
|
+
def validate_dtype_int(*tensors: Tensor) -> None:
|
|
33
|
+
if _skip_validation():
|
|
34
|
+
return
|
|
35
|
+
|
|
36
|
+
for tensor in tensors:
|
|
37
|
+
if tensor.dtype != torch.int32 and tensor.dtype != torch.int64:
|
|
38
|
+
raise ValueError("Tensor must have int32 or int64 dtype")
|
|
39
|
+
|
|
40
|
+
|
|
23
41
|
def validate_device(*tensors: Tensor) -> None:
|
|
42
|
+
if _skip_validation():
|
|
43
|
+
return
|
|
44
|
+
|
|
24
45
|
device = None
|
|
25
46
|
|
|
26
47
|
for i, tensor in enumerate(tensors):
|
|
@@ -33,13 +54,44 @@ def validate_device(*tensors: Tensor) -> None:
|
|
|
33
54
|
if tensor.device != device:
|
|
34
55
|
raise ValueError("Tensors must be on same device")
|
|
35
56
|
|
|
57
|
+
|
|
36
58
|
def validate_sparsity(sparsity_block_size: int, *tensor_sparsity_layout_tuples: tuple[Tensor, Tensor]) -> None:
|
|
59
|
+
if _skip_validation():
|
|
60
|
+
return
|
|
61
|
+
|
|
37
62
|
for (tensor, sparsity_layout) in tensor_sparsity_layout_tuples:
|
|
63
|
+
_validate_sparsity_layout_values(sparsity_layout)
|
|
64
|
+
|
|
38
65
|
if not (tensor.size(-1) == tensor.size(-2) == sparsity_block_size):
|
|
39
66
|
raise ValueError("Blocks not conforming to sparsity block size")
|
|
40
67
|
if not tensor.size(0) == torch.sum(sparsity_layout.reshape(-1)):
|
|
41
68
|
raise ValueError("Mismatch between sparsity layout and blocks")
|
|
42
69
|
|
|
70
|
+
|
|
71
|
+
def _validate_sparsity_layout_values(sparsity_layout: Tensor):
|
|
72
|
+
if not torch.all(torch.logical_or(sparsity_layout == 0, sparsity_layout == 1)):
|
|
73
|
+
raise ValueError("Sparsity layout values must be either 0 or 1")
|
|
74
|
+
|
|
75
|
+
def validate_sparsity_block_size(sparsity_block_size: int, *tensors):
|
|
76
|
+
if _skip_validation():
|
|
77
|
+
return
|
|
78
|
+
|
|
79
|
+
if not (sparsity_block_size & (sparsity_block_size - 1)) == 0:
|
|
80
|
+
raise ValueError("Sparsity block size must be a power of 2")
|
|
81
|
+
|
|
82
|
+
for tensor in tensors:
|
|
83
|
+
if not (tensor.size(-1) % sparsity_block_size == 0 and tensor.size(-2) % sparsity_block_size == 0):
|
|
84
|
+
raise ValueError("Tensor sizes must be divisible by sparsity block size")
|
|
85
|
+
|
|
43
86
|
def validate_triton_block_size(triton_block_size: int, sparsity_block_size: int):
|
|
87
|
+
if _skip_validation():
|
|
88
|
+
return
|
|
89
|
+
|
|
90
|
+
if triton_block_size is None:
|
|
91
|
+
return
|
|
92
|
+
|
|
44
93
|
if triton_block_size > sparsity_block_size:
|
|
45
|
-
raise ValueError("Triton block size cannot be larger than sparsity block size")
|
|
94
|
+
raise ValueError("Triton block size cannot be larger than sparsity block size")
|
|
95
|
+
|
|
96
|
+
def _skip_validation():
|
|
97
|
+
return False
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.1
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -23,16 +23,20 @@ Requires-Dist: matplotlib; extra == "test"
|
|
|
23
23
|
|
|
24
24
|
## Overview
|
|
25
25
|
|
|
26
|
-
A lightweight library for operations on
|
|
26
|
+
A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
|
|
27
27
|
|
|
28
28
|
Currently supported operations (includes gradient calculation):
|
|
29
29
|
|
|
30
|
-
- Sparse matrix multiplication (_supports any combination of sparse and dense matrices due to support
|
|
30
|
+
- Sparse matrix multiplication (_supports any combination of sparse and dense matrices due to support
|
|
31
|
+
for `sparse = sparse @ sparse` matmul_)
|
|
31
32
|
- Softmax
|
|
32
33
|
- Transposition
|
|
34
|
+
- Gather
|
|
35
|
+
- Scatter (_supports either no reduction or summation, gradients are only available for summation_)
|
|
33
36
|
- Conversion from and to sparse form
|
|
34
37
|
|
|
35
|
-
As with this library sparse matrices are represented using a tuple of `(matrix, sparsity_layout, sparsity_block_size)`,
|
|
38
|
+
As with this library sparse matrices are represented using a tuple of `(matrix, sparsity_layout, sparsity_block_size)`,
|
|
39
|
+
any element-wise operations can be applied in regular torch-like fashion.
|
|
36
40
|
These include, e.g.,
|
|
37
41
|
|
|
38
42
|
- Element-wise addition and subtraction
|
|
@@ -40,8 +44,17 @@ These include, e.g.,
|
|
|
40
44
|
- Element-wise exponentiation
|
|
41
45
|
- ...
|
|
42
46
|
|
|
47
|
+
Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to
|
|
48
|
+
match.
|
|
49
|
+
|
|
50
|
+
Furthermore, the library provides a set of utility functions for the creation of sparsity layouts based on existing
|
|
51
|
+
dense tensors.
|
|
52
|
+
|
|
43
53
|
## Installation
|
|
44
54
|
|
|
55
|
+
Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is only compatible with
|
|
56
|
+
the Linux platform.
|
|
57
|
+
|
|
45
58
|
We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) using pip:
|
|
46
59
|
|
|
47
60
|
```pip install blksprs```
|
|
@@ -52,12 +65,19 @@ See [`CHANGELOG.md`](https://github.com/FelixSchoen/blksprs/blob/main/CHANGELOG.
|
|
|
52
65
|
|
|
53
66
|
## Usage
|
|
54
67
|
|
|
68
|
+
We provide an example below to demonstrate the usage of the library.
|
|
69
|
+
For more detailed examples, please refer to
|
|
70
|
+
the [test cases](https://github.com/FelixSchoen/blksprs/blob/main/test/cases/test_blocksparse.py) which cover all
|
|
71
|
+
implemented operations and functions.
|
|
72
|
+
The example below can also be found in
|
|
73
|
+
the [test cases](https://github.com/FelixSchoen/blksprs/blob/main/test/cases/test_readme.py).
|
|
74
|
+
|
|
55
75
|
```python
|
|
56
76
|
import torch
|
|
57
77
|
|
|
58
|
-
from blksprs.layouting.sparsity_layout import
|
|
78
|
+
from blksprs.layouting.sparsity_layout import build_sparsity_layout
|
|
59
79
|
from blksprs.ops.conversion import to_sparse, to_dense
|
|
60
|
-
from blksprs.ops.
|
|
80
|
+
from blksprs.ops.matmul import matmul
|
|
61
81
|
from blksprs.ops.row_wise_sum import row_wise_sum
|
|
62
82
|
from blksprs.ops.softmax import softmax
|
|
63
83
|
from blksprs.ops.transpose import transpose
|
|
@@ -65,7 +85,7 @@ from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse
|
|
|
65
85
|
|
|
66
86
|
|
|
67
87
|
def test_readme():
|
|
68
|
-
# Set up parameters
|
|
88
|
+
# Set up parameters (batch size, number of heads, dimensions for matrices (m, k) and (n, k))
|
|
69
89
|
b, h, m, n, k = 2, 4, 64, 64, 16
|
|
70
90
|
|
|
71
91
|
# Percentage of blocks that will be sparse in the output for demonstration purposes
|
|
@@ -78,7 +98,6 @@ def test_readme():
|
|
|
78
98
|
# If it is set to ``none`` a value will be chosen automatically
|
|
79
99
|
triton_block_size = None
|
|
80
100
|
|
|
81
|
-
|
|
82
101
|
# Initialise random (dense) tensors
|
|
83
102
|
x = torch.randn(size=(b, h, m, k), device="cuda")
|
|
84
103
|
y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()
|
|
@@ -88,8 +107,8 @@ def test_readme():
|
|
|
88
107
|
y_dense, y_shape_original = do_shape_blocksparse(y)
|
|
89
108
|
|
|
90
109
|
# Create sparsity layouts from existing tensors
|
|
91
|
-
sparsity_layout_x =
|
|
92
|
-
sparsity_layout_y =
|
|
110
|
+
sparsity_layout_x = build_sparsity_layout(x_dense, sparsity_block_size, triton_block_size=triton_block_size)
|
|
111
|
+
sparsity_layout_y = build_sparsity_layout(y_dense, sparsity_block_size, triton_block_size=triton_block_size)
|
|
93
112
|
|
|
94
113
|
# Create random sparsity layout for output tensor
|
|
95
114
|
sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)
|
|
@@ -99,8 +118,8 @@ def test_readme():
|
|
|
99
118
|
y_sparse = to_sparse(y_dense, sparsity_layout_y, sparsity_block_size, triton_block_size=triton_block_size)
|
|
100
119
|
|
|
101
120
|
# Perform matrix multiplication
|
|
102
|
-
o_sparse =
|
|
103
|
-
|
|
121
|
+
o_sparse = matmul(x_sparse, sparsity_layout_x, y_sparse, sparsity_layout_y, sparsity_layout_o, sparsity_block_size,
|
|
122
|
+
triton_block_size=triton_block_size)
|
|
104
123
|
o_dense = to_dense(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
|
|
105
124
|
|
|
106
125
|
# Sanity check
|
|
@@ -115,7 +134,7 @@ def test_readme():
|
|
|
115
134
|
assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2) # Note that small numerical differences are expected
|
|
116
135
|
|
|
117
136
|
# Assert that the output has the correct sparsity layout
|
|
118
|
-
actual_sparsity_layout_o =
|
|
137
|
+
actual_sparsity_layout_o = build_sparsity_layout(o_dense, sparsity_block_size, triton_block_size=triton_block_size)
|
|
119
138
|
assert torch.allclose(actual_sparsity_layout_o, sparsity_layout_o)
|
|
120
139
|
|
|
121
140
|
# Convert output tensor back to original shape
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
blksprs/layouting/distribution_layout.py,sha256=GQ-ZRXbeImiLcbaqnL2FuUZ6DoFwmB0naT_YrOpD84Q,4940
|
|
2
|
+
blksprs/layouting/sparsity_layout.py,sha256=Z9Ac88kQZVaPp27ymlLwyGN14ZfMyljXp6oM_gSsFMQ,2902
|
|
3
|
+
blksprs/misc/broadcast_addition.py,sha256=vf1Hdqz9Uyqykto3DCjmdyepMzpMXL238SpANQqRAwI,5297
|
|
4
|
+
blksprs/ops/conversion.py,sha256=COhHE5KvwhrtdUTLZX1wmxFe0kDNMY97iIhnkMmztBA,11362
|
|
5
|
+
blksprs/ops/distribution.py,sha256=_fQb6fWpLxocAh86D74ATahChi0EK0eBb4eUOUEBVps,16769
|
|
6
|
+
blksprs/ops/exp.py,sha256=qs8fVtCzxl4CKT4GepaqurjEL62jyi8VjMY12JFrFAU,3674
|
|
7
|
+
blksprs/ops/matmul.py,sha256=x3lrYg4g8fIf5PeMtZY_SEpi11kP9RFcRoemCIxcSDE,11086
|
|
8
|
+
blksprs/ops/row_wise_sum.py,sha256=ojuSejV37cLtRNS3lBfknA5KY3TEg8EHxOqVT6JZzoM,11387
|
|
9
|
+
blksprs/ops/softmax.py,sha256=ZyeAVqmG_VzJ72FArGrpUSFfoSM4GPxyubrmNKERVIA,11654
|
|
10
|
+
blksprs/ops/transpose.py,sha256=DVEXoxo2MoTNL3NZrjxsukMDrzk2vnEXL1uRnKFWkn0,6722
|
|
11
|
+
blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
|
|
12
|
+
blksprs/utils/tools.py,sha256=P2UALvccRjJJ7w05YGuaxB3qmNObgct4idfM0jlE2wg,465
|
|
13
|
+
blksprs/utils/validation.py,sha256=gJYZO5C48YUrXV3Fy_Z_lCaOpiFj951FT-Od7sKfprg,3007
|
|
14
|
+
blksprs-1.1.dist-info/METADATA,sha256=NIdEtqxj4SBUOP1eMlBz2RoOppwlQx9sJnRmDicWvp4,6982
|
|
15
|
+
blksprs-1.1.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
|
16
|
+
blksprs-1.1.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
|
|
17
|
+
blksprs-1.1.dist-info/RECORD,,
|
blksprs-1.0.dist-info/RECORD
DELETED
|
@@ -1,14 +0,0 @@
|
|
|
1
|
-
blksprs/layouting/sparsity_layout.py,sha256=fmzp0vasDDJuNwM-sPk3IBg1svCB2-ItELscppoSLPE,2553
|
|
2
|
-
blksprs/ops/conversion.py,sha256=AYceIsv_g7xswoBqP3TGR2vjOjiLjTpwWG0vZ7XKCa8,10062
|
|
3
|
-
blksprs/ops/exp.py,sha256=b0IuUVA_UoKNYDNT4Q3EFuXm7EEv_J2-DR7hfgCeT1Q,3222
|
|
4
|
-
blksprs/ops/matmul_sss.py,sha256=34JSkO_9OOnQXB4KZHraElGDjbCx8p0dr9J5JebVdhY,10639
|
|
5
|
-
blksprs/ops/row_wise_sum.py,sha256=ltoZpGVIApQBt_rbmknhsd-7MnibZRX6lkIzDnAC9k8,10462
|
|
6
|
-
blksprs/ops/softmax.py,sha256=fJGKFshFMIVxaYJ_pgPD7EC0ooG-31XH9-qzRQ-xY5A,11018
|
|
7
|
-
blksprs/ops/transpose.py,sha256=WEsXWRYDTWk2U36mt0aEQeReR3TG9TFrWAwVh3NNVYk,5985
|
|
8
|
-
blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
|
|
9
|
-
blksprs/utils/tools.py,sha256=bqJtXXUKmcUxQsQ4ZkrpDZ5P7gnCburSW5hVd9U2M3E,1708
|
|
10
|
-
blksprs/utils/validation.py,sha256=d4BFxzX-zVa5mUv_t3IW_bZbbP3vzSWan_KC1lyw7bs,1639
|
|
11
|
-
blksprs-1.0.dist-info/METADATA,sha256=hrQQ8iK3-F2b38ogT8gp0O-sxm7UFn3i6QNv6vsS2so,5991
|
|
12
|
-
blksprs-1.0.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
|
|
13
|
-
blksprs-1.0.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
|
|
14
|
-
blksprs-1.0.dist-info/RECORD,,
|
|
File without changes
|