blksprs 1.0__tar.gz → 1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {blksprs-1.0 → blksprs-1.2}/PKG-INFO +39 -14
- {blksprs-1.0 → blksprs-1.2}/README.md +38 -13
- blksprs-1.2/blksprs/layouting/distribution_layout.py +114 -0
- blksprs-1.2/blksprs/layouting/sparsity_layout.py +190 -0
- blksprs-1.2/blksprs/misc/broadcast_addition.py +132 -0
- blksprs-1.2/blksprs/ops/conversion.py +451 -0
- blksprs-1.2/blksprs/ops/distribution.py +362 -0
- {blksprs-1.0 → blksprs-1.2}/blksprs/ops/exp.py +18 -8
- blksprs-1.0/blksprs/ops/matmul_sss.py → blksprs-1.2/blksprs/ops/matmul.py +28 -26
- {blksprs-1.0 → blksprs-1.2}/blksprs/ops/row_wise_sum.py +21 -5
- {blksprs-1.0 → blksprs-1.2}/blksprs/ops/softmax.py +23 -12
- {blksprs-1.0 → blksprs-1.2}/blksprs/ops/transpose.py +19 -7
- blksprs-1.2/blksprs/utils/tools.py +20 -0
- {blksprs-1.0 → blksprs-1.2}/blksprs/utils/validation.py +53 -1
- {blksprs-1.0 → blksprs-1.2}/blksprs.egg-info/PKG-INFO +39 -14
- {blksprs-1.0 → blksprs-1.2}/blksprs.egg-info/SOURCES.txt +4 -1
- {blksprs-1.0 → blksprs-1.2}/pyproject.toml +1 -1
- blksprs-1.0/blksprs/layouting/sparsity_layout.py +0 -68
- blksprs-1.0/blksprs/ops/conversion.py +0 -231
- blksprs-1.0/blksprs/utils/tools.py +0 -47
- {blksprs-1.0 → blksprs-1.2}/blksprs/utils/benchmarking.py +0 -0
- {blksprs-1.0 → blksprs-1.2}/blksprs.egg-info/dependency_links.txt +0 -0
- {blksprs-1.0 → blksprs-1.2}/blksprs.egg-info/requires.txt +0 -0
- {blksprs-1.0 → blksprs-1.2}/blksprs.egg-info/top_level.txt +0 -0
- {blksprs-1.0 → blksprs-1.2}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.2
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -23,16 +23,21 @@ Requires-Dist: pdoc3; extra == "deploy"
|
|
|
23
23
|
|
|
24
24
|
## Overview
|
|
25
25
|
|
|
26
|
-
A lightweight library for operations on
|
|
26
|
+
A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
|
|
27
27
|
|
|
28
28
|
Currently supported operations (includes gradient calculation):
|
|
29
29
|
|
|
30
|
-
- Sparse matrix multiplication (_supports any combination of sparse and dense matrices due to support
|
|
30
|
+
- Sparse matrix multiplication (_supports any combination of sparse and dense matrices due to support
|
|
31
|
+
for `sparse = sparse @ sparse` matmul_)
|
|
31
32
|
- Softmax
|
|
32
33
|
- Transposition
|
|
33
|
-
-
|
|
34
|
+
- Gather
|
|
35
|
+
- Scatter (_supports either no reduction or summation, gradients are only available for summation_)
|
|
36
|
+
- Conversion to and from sparse form
|
|
37
|
+
- Conversion to different sparsity layouts and different sparsity block sizes
|
|
34
38
|
|
|
35
|
-
As with this library sparse matrices are represented using a tuple of `(matrix, sparsity_layout, sparsity_block_size)`,
|
|
39
|
+
As with this library sparse matrices are represented using a tuple of `(matrix, sparsity_layout, sparsity_block_size)`,
|
|
40
|
+
any element-wise operations can be applied in regular torch-like fashion.
|
|
36
41
|
These include, e.g.,
|
|
37
42
|
|
|
38
43
|
- Element-wise addition and subtraction
|
|
@@ -40,24 +45,45 @@ These include, e.g.,
|
|
|
40
45
|
- Element-wise exponentiation
|
|
41
46
|
- ...
|
|
42
47
|
|
|
48
|
+
Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to
|
|
49
|
+
match.
|
|
50
|
+
|
|
51
|
+
Furthermore, the library provides a set of utility functions for the creation of sparsity layouts based on existing
|
|
52
|
+
dense tensors.
|
|
53
|
+
|
|
43
54
|
## Installation
|
|
44
55
|
|
|
56
|
+
Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is only compatible with
|
|
57
|
+
the Linux platform.
|
|
58
|
+
|
|
45
59
|
We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) using pip:
|
|
46
60
|
|
|
47
61
|
```pip install blksprs```
|
|
48
62
|
|
|
63
|
+
### Dependencies
|
|
64
|
+
|
|
65
|
+
- [PyTorch](https://pytorch.org/) (built with v2.4.0)
|
|
66
|
+
- _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
|
|
67
|
+
|
|
49
68
|
## Changelog
|
|
50
69
|
|
|
51
70
|
See [`CHANGELOG.md`](https://github.com/FelixSchoen/blksprs/blob/main/CHANGELOG.md) for a detailed changelog.
|
|
52
71
|
|
|
53
72
|
## Usage
|
|
54
73
|
|
|
74
|
+
We provide an example below to demonstrate the usage of the library.
|
|
75
|
+
For more detailed examples, please refer to
|
|
76
|
+
the [test cases](https://github.com/FelixSchoen/blksprs/blob/main/test/cases/test_blocksparse.py) which cover all
|
|
77
|
+
implemented operations and functions.
|
|
78
|
+
The example below can also be found in
|
|
79
|
+
the [test cases](https://github.com/FelixSchoen/blksprs/blob/main/test/cases/test_readme.py).
|
|
80
|
+
|
|
55
81
|
```python
|
|
56
82
|
import torch
|
|
57
83
|
|
|
58
|
-
from blksprs.layouting.sparsity_layout import
|
|
84
|
+
from blksprs.layouting.sparsity_layout import build_sparsity_layout
|
|
59
85
|
from blksprs.ops.conversion import to_sparse, to_dense
|
|
60
|
-
from blksprs.ops.
|
|
86
|
+
from blksprs.ops.matmul import matmul
|
|
61
87
|
from blksprs.ops.row_wise_sum import row_wise_sum
|
|
62
88
|
from blksprs.ops.softmax import softmax
|
|
63
89
|
from blksprs.ops.transpose import transpose
|
|
@@ -65,7 +91,7 @@ from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse
|
|
|
65
91
|
|
|
66
92
|
|
|
67
93
|
def test_readme():
|
|
68
|
-
# Set up parameters
|
|
94
|
+
# Set up parameters (batch size, number of heads, dimensions for matrices (m, k) and (n, k))
|
|
69
95
|
b, h, m, n, k = 2, 4, 64, 64, 16
|
|
70
96
|
|
|
71
97
|
# Percentage of blocks that will be sparse in the output for demonstration purposes
|
|
@@ -78,7 +104,6 @@ def test_readme():
|
|
|
78
104
|
# If it is set to ``none`` a value will be chosen automatically
|
|
79
105
|
triton_block_size = None
|
|
80
106
|
|
|
81
|
-
|
|
82
107
|
# Initialise random (dense) tensors
|
|
83
108
|
x = torch.randn(size=(b, h, m, k), device="cuda")
|
|
84
109
|
y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()
|
|
@@ -88,8 +113,8 @@ def test_readme():
|
|
|
88
113
|
y_dense, y_shape_original = do_shape_blocksparse(y)
|
|
89
114
|
|
|
90
115
|
# Create sparsity layouts from existing tensors
|
|
91
|
-
sparsity_layout_x =
|
|
92
|
-
sparsity_layout_y =
|
|
116
|
+
sparsity_layout_x = build_sparsity_layout(x_dense, sparsity_block_size, triton_block_size=triton_block_size)
|
|
117
|
+
sparsity_layout_y = build_sparsity_layout(y_dense, sparsity_block_size, triton_block_size=triton_block_size)
|
|
93
118
|
|
|
94
119
|
# Create random sparsity layout for output tensor
|
|
95
120
|
sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)
|
|
@@ -99,8 +124,8 @@ def test_readme():
|
|
|
99
124
|
y_sparse = to_sparse(y_dense, sparsity_layout_y, sparsity_block_size, triton_block_size=triton_block_size)
|
|
100
125
|
|
|
101
126
|
# Perform matrix multiplication
|
|
102
|
-
o_sparse =
|
|
103
|
-
|
|
127
|
+
o_sparse = matmul(x_sparse, sparsity_layout_x, y_sparse, sparsity_layout_y, sparsity_layout_o, sparsity_block_size,
|
|
128
|
+
triton_block_size=triton_block_size)
|
|
104
129
|
o_dense = to_dense(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
|
|
105
130
|
|
|
106
131
|
# Sanity check
|
|
@@ -115,7 +140,7 @@ def test_readme():
|
|
|
115
140
|
assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2) # Note that small numerical differences are expected
|
|
116
141
|
|
|
117
142
|
# Assert that the output has the correct sparsity layout
|
|
118
|
-
actual_sparsity_layout_o =
|
|
143
|
+
actual_sparsity_layout_o = build_sparsity_layout(o_dense, sparsity_block_size, triton_block_size=triton_block_size)
|
|
119
144
|
assert torch.allclose(actual_sparsity_layout_o, sparsity_layout_o)
|
|
120
145
|
|
|
121
146
|
# Convert output tensor back to original shape
|
|
@@ -2,16 +2,21 @@
|
|
|
2
2
|
|
|
3
3
|
## Overview
|
|
4
4
|
|
|
5
|
-
A lightweight library for operations on
|
|
5
|
+
A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
|
|
6
6
|
|
|
7
7
|
Currently supported operations (includes gradient calculation):
|
|
8
8
|
|
|
9
|
-
- Sparse matrix multiplication (_supports any combination of sparse and dense matrices due to support
|
|
9
|
+
- Sparse matrix multiplication (_supports any combination of sparse and dense matrices due to support
|
|
10
|
+
for `sparse = sparse @ sparse` matmul_)
|
|
10
11
|
- Softmax
|
|
11
12
|
- Transposition
|
|
12
|
-
-
|
|
13
|
+
- Gather
|
|
14
|
+
- Scatter (_supports either no reduction or summation, gradients are only available for summation_)
|
|
15
|
+
- Conversion to and from sparse form
|
|
16
|
+
- Conversion to different sparsity layouts and different sparsity block sizes
|
|
13
17
|
|
|
14
|
-
As with this library sparse matrices are represented using a tuple of `(matrix, sparsity_layout, sparsity_block_size)`,
|
|
18
|
+
As with this library sparse matrices are represented using a tuple of `(matrix, sparsity_layout, sparsity_block_size)`,
|
|
19
|
+
any element-wise operations can be applied in regular torch-like fashion.
|
|
15
20
|
These include, e.g.,
|
|
16
21
|
|
|
17
22
|
- Element-wise addition and subtraction
|
|
@@ -19,24 +24,45 @@ These include, e.g.,
|
|
|
19
24
|
- Element-wise exponentiation
|
|
20
25
|
- ...
|
|
21
26
|
|
|
27
|
+
Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to
|
|
28
|
+
match.
|
|
29
|
+
|
|
30
|
+
Furthermore, the library provides a set of utility functions for the creation of sparsity layouts based on existing
|
|
31
|
+
dense tensors.
|
|
32
|
+
|
|
22
33
|
## Installation
|
|
23
34
|
|
|
35
|
+
Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is only compatible with
|
|
36
|
+
the Linux platform.
|
|
37
|
+
|
|
24
38
|
We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) using pip:
|
|
25
39
|
|
|
26
40
|
```pip install blksprs```
|
|
27
41
|
|
|
42
|
+
### Dependencies
|
|
43
|
+
|
|
44
|
+
- [PyTorch](https://pytorch.org/) (built with v2.4.0)
|
|
45
|
+
- _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
|
|
46
|
+
|
|
28
47
|
## Changelog
|
|
29
48
|
|
|
30
49
|
See [`CHANGELOG.md`](https://github.com/FelixSchoen/blksprs/blob/main/CHANGELOG.md) for a detailed changelog.
|
|
31
50
|
|
|
32
51
|
## Usage
|
|
33
52
|
|
|
53
|
+
We provide an example below to demonstrate the usage of the library.
|
|
54
|
+
For more detailed examples, please refer to
|
|
55
|
+
the [test cases](https://github.com/FelixSchoen/blksprs/blob/main/test/cases/test_blocksparse.py) which cover all
|
|
56
|
+
implemented operations and functions.
|
|
57
|
+
The example below can also be found in
|
|
58
|
+
the [test cases](https://github.com/FelixSchoen/blksprs/blob/main/test/cases/test_readme.py).
|
|
59
|
+
|
|
34
60
|
```python
|
|
35
61
|
import torch
|
|
36
62
|
|
|
37
|
-
from blksprs.layouting.sparsity_layout import
|
|
63
|
+
from blksprs.layouting.sparsity_layout import build_sparsity_layout
|
|
38
64
|
from blksprs.ops.conversion import to_sparse, to_dense
|
|
39
|
-
from blksprs.ops.
|
|
65
|
+
from blksprs.ops.matmul import matmul
|
|
40
66
|
from blksprs.ops.row_wise_sum import row_wise_sum
|
|
41
67
|
from blksprs.ops.softmax import softmax
|
|
42
68
|
from blksprs.ops.transpose import transpose
|
|
@@ -44,7 +70,7 @@ from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse
|
|
|
44
70
|
|
|
45
71
|
|
|
46
72
|
def test_readme():
|
|
47
|
-
# Set up parameters
|
|
73
|
+
# Set up parameters (batch size, number of heads, dimensions for matrices (m, k) and (n, k))
|
|
48
74
|
b, h, m, n, k = 2, 4, 64, 64, 16
|
|
49
75
|
|
|
50
76
|
# Percentage of blocks that will be sparse in the output for demonstration purposes
|
|
@@ -57,7 +83,6 @@ def test_readme():
|
|
|
57
83
|
# If it is set to ``none`` a value will be chosen automatically
|
|
58
84
|
triton_block_size = None
|
|
59
85
|
|
|
60
|
-
|
|
61
86
|
# Initialise random (dense) tensors
|
|
62
87
|
x = torch.randn(size=(b, h, m, k), device="cuda")
|
|
63
88
|
y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()
|
|
@@ -67,8 +92,8 @@ def test_readme():
|
|
|
67
92
|
y_dense, y_shape_original = do_shape_blocksparse(y)
|
|
68
93
|
|
|
69
94
|
# Create sparsity layouts from existing tensors
|
|
70
|
-
sparsity_layout_x =
|
|
71
|
-
sparsity_layout_y =
|
|
95
|
+
sparsity_layout_x = build_sparsity_layout(x_dense, sparsity_block_size, triton_block_size=triton_block_size)
|
|
96
|
+
sparsity_layout_y = build_sparsity_layout(y_dense, sparsity_block_size, triton_block_size=triton_block_size)
|
|
72
97
|
|
|
73
98
|
# Create random sparsity layout for output tensor
|
|
74
99
|
sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)
|
|
@@ -78,8 +103,8 @@ def test_readme():
|
|
|
78
103
|
y_sparse = to_sparse(y_dense, sparsity_layout_y, sparsity_block_size, triton_block_size=triton_block_size)
|
|
79
104
|
|
|
80
105
|
# Perform matrix multiplication
|
|
81
|
-
o_sparse =
|
|
82
|
-
|
|
106
|
+
o_sparse = matmul(x_sparse, sparsity_layout_x, y_sparse, sparsity_layout_y, sparsity_layout_o, sparsity_block_size,
|
|
107
|
+
triton_block_size=triton_block_size)
|
|
83
108
|
o_dense = to_dense(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
|
|
84
109
|
|
|
85
110
|
# Sanity check
|
|
@@ -94,7 +119,7 @@ def test_readme():
|
|
|
94
119
|
assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2) # Note that small numerical differences are expected
|
|
95
120
|
|
|
96
121
|
# Assert that the output has the correct sparsity layout
|
|
97
|
-
actual_sparsity_layout_o =
|
|
122
|
+
actual_sparsity_layout_o = build_sparsity_layout(o_dense, sparsity_block_size, triton_block_size=triton_block_size)
|
|
98
123
|
assert torch.allclose(actual_sparsity_layout_o, sparsity_layout_o)
|
|
99
124
|
|
|
100
125
|
# Convert output tensor back to original shape
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
from triton import language as tl
|
|
5
|
+
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size
|
|
7
|
+
from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
|
|
8
|
+
validate_contiguous
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
|
|
12
|
+
size_target: torch.Size,
|
|
13
|
+
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
14
|
+
"""Builds the sparsity layout of either the source of a gather or the target of a scatter operation.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
indices (Tensor): The block-sparse indices tensor in compressed form used for the gather or scatter operation.
|
|
18
|
+
sparsity_layout_indices (Tensor): The sparsity layout of the indices block-sparse tensor.
|
|
19
|
+
size_target (torch.Size): The size of the block-sparse target tensor in regular form.
|
|
20
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
21
|
+
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
Tensor: The sparsity layout of the source or target tensor.
|
|
25
|
+
|
|
26
|
+
"""
|
|
27
|
+
validate_dimensions(indices)
|
|
28
|
+
validate_contiguous(indices)
|
|
29
|
+
validate_device(indices)
|
|
30
|
+
|
|
31
|
+
sparsity_lut_i = torch.nonzero(sparsity_layout_indices).contiguous()
|
|
32
|
+
|
|
33
|
+
output = torch.zeros(size_target[0], size_target[1] // sparsity_block_size, size_target[2] // sparsity_block_size,
|
|
34
|
+
device=indices.device, dtype=torch.int32)
|
|
35
|
+
|
|
36
|
+
i_b, i_r, i_c = indices.size()
|
|
37
|
+
i_b_s, i_r_s, i_c_s = indices.stride()
|
|
38
|
+
s_l_i_b, s_l_i_r, s_l_i_c = sparsity_layout_indices.size()
|
|
39
|
+
s_l_i_b_s, s_l_i_r_s, s_l_i_c_s = sparsity_layout_indices.stride()
|
|
40
|
+
s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
|
|
41
|
+
s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
|
|
42
|
+
o_b, o_r, o_c = output.size()
|
|
43
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
44
|
+
|
|
45
|
+
if triton_block_size is None:
|
|
46
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
47
|
+
|
|
48
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
49
|
+
|
|
50
|
+
triton_grid = lambda meta: [i_b,
|
|
51
|
+
triton.cdiv(i_r, meta["TRITON_BLOCK_SIZE"]),
|
|
52
|
+
triton.cdiv(i_c, meta["TRITON_BLOCK_SIZE"])]
|
|
53
|
+
|
|
54
|
+
(kernel_distribution_layout[triton_grid]
|
|
55
|
+
(indices,
|
|
56
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
57
|
+
sparsity_layout_indices,
|
|
58
|
+
s_l_i_b, s_l_i_b_s, s_l_i_r, s_l_i_r_s, s_l_i_c, s_l_i_c_s,
|
|
59
|
+
sparsity_lut_i,
|
|
60
|
+
s_lut_i_r, s_lut_i_r_s, s_lut_i_c, s_lut_i_c_s,
|
|
61
|
+
output,
|
|
62
|
+
o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,
|
|
63
|
+
sparsity_block_size,
|
|
64
|
+
triton_block_size))
|
|
65
|
+
|
|
66
|
+
return output
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@triton.jit
|
|
70
|
+
def kernel_distribution_layout(i,
|
|
71
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
72
|
+
s_l_i,
|
|
73
|
+
s_l_i_b, s_l_i_b_s, s_l_i_r, s_l_i_r_s, s_l_i_c, s_l_i_c_s,
|
|
74
|
+
s_lut_i,
|
|
75
|
+
s_lut_i_r, s_lut_i_r_s, s_lut_i_c, s_lut_i_c_s,
|
|
76
|
+
o,
|
|
77
|
+
o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,
|
|
78
|
+
sparsity_block_size,
|
|
79
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
80
|
+
# Get triton block indices
|
|
81
|
+
pid_blk = tl.program_id(axis=0)
|
|
82
|
+
pid_row = tl.program_id(axis=1)
|
|
83
|
+
pid_col = tl.program_id(axis=2)
|
|
84
|
+
|
|
85
|
+
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
86
|
+
spa_bat_i_idx = (pid_blk * s_lut_i_r_s + 0 * s_lut_i_c_s)
|
|
87
|
+
spa_bat_i_msk = (spa_bat_i_idx < s_lut_i_r * s_lut_i_r_s)
|
|
88
|
+
spa_bat_i = tl.load(s_lut_i + spa_bat_i_idx, mask=spa_bat_i_msk)
|
|
89
|
+
|
|
90
|
+
spa_row_i_idx = (pid_blk * s_lut_i_r_s + 1 * s_lut_i_c_s)
|
|
91
|
+
spa_row_i_msk = (spa_row_i_idx < s_lut_i_r * s_lut_i_r_s)
|
|
92
|
+
spa_row_i = tl.load(s_lut_i + spa_row_i_idx, mask=spa_row_i_msk)
|
|
93
|
+
|
|
94
|
+
blk_i_idx = (pid_blk * i_b_s +
|
|
95
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
96
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
97
|
+
blk_i_msk = (blk_i_idx < i_b * i_b_s)
|
|
98
|
+
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk)
|
|
99
|
+
|
|
100
|
+
blk_i = blk_i // sparsity_block_size
|
|
101
|
+
blk_v = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), 1, dtype=tl.int32)
|
|
102
|
+
|
|
103
|
+
blk_o_idx = ((spa_bat_i * o_b_s) +
|
|
104
|
+
(spa_row_i * o_r_s) +
|
|
105
|
+
(blk_i * o_c_s))
|
|
106
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
107
|
+
tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
|
|
108
|
+
|
|
109
|
+
# if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
|
|
110
|
+
# blk_o_idx = (pid_bat * o_b_s +
|
|
111
|
+
# (((pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_r_s +
|
|
112
|
+
# ((pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_c_s))
|
|
113
|
+
# blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
114
|
+
# tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
|
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import triton
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
from triton import language as tl
|
|
7
|
+
|
|
8
|
+
from blksprs.utils.tools import get_triton_block_size
|
|
9
|
+
from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
|
|
10
|
+
validate_contiguous, validate_sparsity, validate_sparsity_block_size
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def build_sparsity_layout(x: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
14
|
+
"""Builds the sparsity layout of a dense tensor in regular form covering its sparse blocks.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
x (Tensor): A block-sparse (or dense) tensor in regular form.
|
|
18
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
19
|
+
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
Tensor: The sparsity layout of the input block-sparse (or dense) tensor.
|
|
23
|
+
|
|
24
|
+
"""
|
|
25
|
+
validate_dimensions(x)
|
|
26
|
+
validate_contiguous(x)
|
|
27
|
+
validate_device(x)
|
|
28
|
+
|
|
29
|
+
output = torch.zeros(x.size(0), x.size(1) // sparsity_block_size, x.size(2) // sparsity_block_size,
|
|
30
|
+
device=x.device, dtype=torch.int32)
|
|
31
|
+
|
|
32
|
+
x_b, x_r, x_c = x.size()
|
|
33
|
+
x_b_s, x_r_s, x_c_s = x.stride()
|
|
34
|
+
o_b, o_r, o_c = output.size()
|
|
35
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
36
|
+
|
|
37
|
+
if triton_block_size is None:
|
|
38
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
39
|
+
|
|
40
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
41
|
+
|
|
42
|
+
triton_grid = lambda meta: [x_b,
|
|
43
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
44
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
45
|
+
|
|
46
|
+
(kernel_sparsity_layout[triton_grid]
|
|
47
|
+
(x,
|
|
48
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
49
|
+
output,
|
|
50
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
51
|
+
sparsity_block_size,
|
|
52
|
+
triton_block_size))
|
|
53
|
+
|
|
54
|
+
return output
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@triton.jit
|
|
58
|
+
def kernel_sparsity_layout(x,
|
|
59
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
60
|
+
o,
|
|
61
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
62
|
+
sparsity_block_size,
|
|
63
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
64
|
+
# Get triton block indices
|
|
65
|
+
pid_bat = tl.program_id(axis=0)
|
|
66
|
+
pid_row = tl.program_id(axis=1)
|
|
67
|
+
pid_col = tl.program_id(axis=2)
|
|
68
|
+
|
|
69
|
+
# Load x values
|
|
70
|
+
blk_x_idx = (pid_bat * x_b_s +
|
|
71
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
72
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
73
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
74
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
75
|
+
|
|
76
|
+
# Store sparsity layout value
|
|
77
|
+
if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
|
|
78
|
+
blk_o_idx = (pid_bat * o_b_s +
|
|
79
|
+
(((pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_r_s +
|
|
80
|
+
((pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_c_s))
|
|
81
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
82
|
+
tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def build_sparsity_layout_adaption(x: Tensor, sparsity_layout_from: Tensor,
|
|
86
|
+
sparsity_block_size_from: int, sparsity_block_size_to: int,
|
|
87
|
+
triton_block_size: int = None) -> Tensor:
|
|
88
|
+
"""Builds the sparsity layout of a block-sparse tensor in compressed form if a different sparsity block size were
|
|
89
|
+
used.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
x (Tensor): A block-sparse tensor in compressed form.
|
|
93
|
+
sparsity_layout_from (Tensor): The sparsity layout of the input block-sparse tensor.
|
|
94
|
+
sparsity_block_size_from (int): The size of the sparsity blocks of the input tensor.
|
|
95
|
+
sparsity_block_size_to (int): The desired size of the sparsity blocks for the resulting layout.
|
|
96
|
+
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
Tensor: The sparsity layout in regular form using the new sparsity block size of the input block-sparse tensor
|
|
100
|
+
in compressed form.
|
|
101
|
+
|
|
102
|
+
"""
|
|
103
|
+
validate_dimensions(x)
|
|
104
|
+
validate_contiguous(x, sparsity_layout_from)
|
|
105
|
+
validate_device(x)
|
|
106
|
+
validate_sparsity(sparsity_block_size_from, (x, sparsity_layout_from))
|
|
107
|
+
validate_sparsity_block_size(sparsity_block_size_from, x)
|
|
108
|
+
validate_sparsity_block_size(sparsity_block_size_to)
|
|
109
|
+
min_sparsity_block_size = min(sparsity_block_size_from, sparsity_block_size_to)
|
|
110
|
+
validate_triton_block_size(triton_block_size, min_sparsity_block_size)
|
|
111
|
+
|
|
112
|
+
sparsity_lut = torch.nonzero(sparsity_layout_from).contiguous()
|
|
113
|
+
|
|
114
|
+
validate_contiguous(sparsity_layout_from, sparsity_lut)
|
|
115
|
+
|
|
116
|
+
o_b = sparsity_layout_from.size(0)
|
|
117
|
+
o_r = math.ceil(sparsity_layout_from.size(1) * sparsity_block_size_from // sparsity_block_size_to)
|
|
118
|
+
o_c = math.ceil(sparsity_layout_from.size(2) * sparsity_block_size_from // sparsity_block_size_to)
|
|
119
|
+
|
|
120
|
+
output = torch.zeros(o_b, o_r, o_c, device=x.device, dtype=torch.int32)
|
|
121
|
+
|
|
122
|
+
x_b, x_r, x_c = x.size()
|
|
123
|
+
x_b_s, x_r_s, x_c_s = x.stride()
|
|
124
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
125
|
+
s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
|
|
126
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
127
|
+
|
|
128
|
+
if triton_block_size is None:
|
|
129
|
+
triton_block_size = get_triton_block_size(sparsity_block_size_from)
|
|
130
|
+
|
|
131
|
+
triton_grid = lambda meta: [x_b,
|
|
132
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
133
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
134
|
+
|
|
135
|
+
(kernel_sparsity_layout_adaption[triton_grid]
|
|
136
|
+
(x,
|
|
137
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
138
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
139
|
+
output,
|
|
140
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
141
|
+
sparsity_block_size_from,
|
|
142
|
+
sparsity_block_size_to,
|
|
143
|
+
triton_block_size))
|
|
144
|
+
|
|
145
|
+
return output
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
@triton.jit
|
|
149
|
+
def kernel_sparsity_layout_adaption(x,
|
|
150
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
151
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
152
|
+
o,
|
|
153
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
154
|
+
sparsity_block_size_from,
|
|
155
|
+
sparsity_block_size_to,
|
|
156
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
157
|
+
# Get triton block indices
|
|
158
|
+
pid_blk = tl.program_id(axis=0)
|
|
159
|
+
pid_row = tl.program_id(axis=1)
|
|
160
|
+
pid_col = tl.program_id(axis=2)
|
|
161
|
+
|
|
162
|
+
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
163
|
+
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
164
|
+
spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
165
|
+
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
166
|
+
|
|
167
|
+
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
168
|
+
spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
|
|
169
|
+
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
170
|
+
|
|
171
|
+
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
172
|
+
spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
|
|
173
|
+
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
174
|
+
|
|
175
|
+
# Load x values
|
|
176
|
+
blk_x_idx = ((pid_blk * x_b_s) +
|
|
177
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
178
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
179
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
180
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
181
|
+
|
|
182
|
+
# Store sparsity layout value
|
|
183
|
+
if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
|
|
184
|
+
blk_o_idx = ((spa_bat * o_b_s) +
|
|
185
|
+
(((spa_row * sparsity_block_size_from + pid_row * TRITON_BLOCK_SIZE)
|
|
186
|
+
// sparsity_block_size_to) * o_r_s) +
|
|
187
|
+
(((spa_col * sparsity_block_size_from + pid_col * TRITON_BLOCK_SIZE)
|
|
188
|
+
// sparsity_block_size_to) * o_c_s))
|
|
189
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
190
|
+
tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
from triton import language as tl
|
|
5
|
+
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size
|
|
7
|
+
from blksprs.utils.validation import validate_contiguous, validate_device, \
|
|
8
|
+
validate_sparsity_block_size, validate_triton_block_size
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def broadcast_addition(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
12
|
+
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
13
|
+
"""Performs a broadcast and subsequent addition of two dense tensors x and y. Returns a block-sparse tensor in
|
|
14
|
+
compressed form.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
x (Tensor): A dense input tensor.
|
|
18
|
+
y (Tensor): A dense input tensor.
|
|
19
|
+
sparsity_layout_output (Tensor): The sparsity layout of the output tensor.
|
|
20
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
21
|
+
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
Tensor: The result of the operation as a block-sparse tensor in compressed form. Each element o(i, j) of the
|
|
25
|
+
output tensor corresponds to x(i) + y(j).
|
|
26
|
+
|
|
27
|
+
"""
|
|
28
|
+
validate_device(x, y)
|
|
29
|
+
validate_contiguous(x, y)
|
|
30
|
+
if x.size(-1) != y.size(-1):
|
|
31
|
+
raise ValueError("Dimensions of tensors must match")
|
|
32
|
+
validate_sparsity_block_size(sparsity_block_size)
|
|
33
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
34
|
+
|
|
35
|
+
sparsity_lut_o = torch.nonzero(sparsity_layout_output).contiguous()
|
|
36
|
+
|
|
37
|
+
n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
|
|
38
|
+
|
|
39
|
+
validate_contiguous(sparsity_layout_output, sparsity_lut_o)
|
|
40
|
+
|
|
41
|
+
output = torch.zeros(n_sparse_blocks, sparsity_block_size, sparsity_block_size, device=x.device)
|
|
42
|
+
|
|
43
|
+
x_b, x_c = x.size()
|
|
44
|
+
x_b_s, x_c_s = x.stride()
|
|
45
|
+
y_b, y_c = y.size()
|
|
46
|
+
y_b_s, y_c_s = y.stride()
|
|
47
|
+
o_b, o_r, o_c = output.size()
|
|
48
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
49
|
+
s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
|
|
50
|
+
s_lut_o_r_s, s_lut_o_c_s = sparsity_lut_o.stride()
|
|
51
|
+
|
|
52
|
+
if triton_block_size is None:
|
|
53
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
54
|
+
|
|
55
|
+
triton_grid = lambda meta: [o_b,
|
|
56
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
57
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
58
|
+
|
|
59
|
+
(kernel_broadcast_addition[triton_grid]
|
|
60
|
+
(x,
|
|
61
|
+
x_b, x_b_s, x_c_s,
|
|
62
|
+
y,
|
|
63
|
+
y_b, y_b_s, y_c_s,
|
|
64
|
+
output,
|
|
65
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
66
|
+
sparsity_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
67
|
+
sparsity_block_size,
|
|
68
|
+
triton_block_size))
|
|
69
|
+
|
|
70
|
+
return output
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def broadcast_subtraction(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
74
|
+
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
75
|
+
"""Wrapper for ``broadcast_addition`` with negated y.
|
|
76
|
+
|
|
77
|
+
"""
|
|
78
|
+
return broadcast_addition(x, torch.neg(y), sparsity_layout_output, sparsity_block_size, triton_block_size)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@triton.jit
|
|
82
|
+
def kernel_broadcast_addition(x,
|
|
83
|
+
x_b, x_b_s, x_c_s,
|
|
84
|
+
y,
|
|
85
|
+
y_b, y_b_s, y_c_s,
|
|
86
|
+
o,
|
|
87
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
88
|
+
s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
89
|
+
sparsity_block_size,
|
|
90
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
91
|
+
# Get triton block indices
|
|
92
|
+
pid_blk = tl.program_id(axis=0)
|
|
93
|
+
pid_row = tl.program_id(axis=1)
|
|
94
|
+
pid_col = tl.program_id(axis=2)
|
|
95
|
+
|
|
96
|
+
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
97
|
+
spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
|
|
98
|
+
spa_bat_o_msk = (spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
99
|
+
spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
|
|
100
|
+
|
|
101
|
+
spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
|
|
102
|
+
spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
103
|
+
spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
|
|
104
|
+
|
|
105
|
+
spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
|
|
106
|
+
spa_col_o_msk = (spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
107
|
+
spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
|
|
108
|
+
|
|
109
|
+
# Load x block
|
|
110
|
+
blk_x_idx = (spa_bat_o * x_b_s +
|
|
111
|
+
((spa_row_o * sparsity_block_size + pid_row * TRITON_BLOCK_SIZE +
|
|
112
|
+
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
113
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
114
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
115
|
+
|
|
116
|
+
# Load y block
|
|
117
|
+
blk_y_idx = (spa_bat_o * y_b_s +
|
|
118
|
+
((spa_col_o * sparsity_block_size + pid_col * TRITON_BLOCK_SIZE +
|
|
119
|
+
tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
|
|
120
|
+
blk_y_msk = (blk_y_idx < y_b * y_b_s)
|
|
121
|
+
blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
|
|
122
|
+
|
|
123
|
+
# Compute sum
|
|
124
|
+
blk_x, blk_y = tl.broadcast(tl.trans(blk_x), blk_y)
|
|
125
|
+
buf = blk_x + blk_y
|
|
126
|
+
|
|
127
|
+
# Store result
|
|
128
|
+
blk_o_idx = ((pid_blk * o_b_s) +
|
|
129
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
130
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
131
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
132
|
+
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|