blksprs 1.11__tar.gz → 2.0rc2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {blksprs-1.11 → blksprs-2.0rc2}/PKG-INFO +42 -36
- {blksprs-1.11 → blksprs-2.0rc2}/README.md +39 -32
- {blksprs-1.11 → blksprs-2.0rc2}/blksprs/__init__.py +2 -5
- {blksprs-1.11 → blksprs-2.0rc2}/blksprs/layouting/distribution_layout.py +32 -25
- {blksprs-1.11 → blksprs-2.0rc2}/blksprs/layouting/sparsity_layout.py +65 -52
- blksprs-2.0rc2/blksprs/ops/conversion.py +517 -0
- blksprs-2.0rc2/blksprs/ops/distribution.py +496 -0
- {blksprs-1.11 → blksprs-2.0rc2}/blksprs/ops/flow.py +125 -106
- blksprs-2.0rc2/blksprs/ops/matmul.py +264 -0
- {blksprs-1.11 → blksprs-2.0rc2}/blksprs/ops/misc/broadcast_ops.py +53 -35
- {blksprs-1.11 → blksprs-2.0rc2}/blksprs/ops/misc/row_wise.py +151 -91
- blksprs-2.0rc2/blksprs/ops/partitioning.py +217 -0
- blksprs-2.0rc2/blksprs/ops/repeat.py +191 -0
- blksprs-2.0rc2/blksprs/ops/softmax.py +306 -0
- blksprs-2.0rc2/blksprs/ops/transpose.py +98 -0
- {blksprs-1.11 → blksprs-2.0rc2}/blksprs/utils/benchmarking.py +3 -3
- blksprs-2.0rc2/blksprs/utils/tools.py +56 -0
- {blksprs-1.11 → blksprs-2.0rc2}/blksprs/utils/validation.py +4 -18
- {blksprs-1.11 → blksprs-2.0rc2}/blksprs.egg-info/PKG-INFO +42 -36
- {blksprs-1.11 → blksprs-2.0rc2}/blksprs.egg-info/SOURCES.txt +0 -1
- {blksprs-1.11 → blksprs-2.0rc2}/blksprs.egg-info/requires.txt +1 -3
- {blksprs-1.11 → blksprs-2.0rc2}/pyproject.toml +2 -4
- blksprs-1.11/blksprs/ops/conversion.py +0 -495
- blksprs-1.11/blksprs/ops/distribution.py +0 -458
- blksprs-1.11/blksprs/ops/matmul.py +0 -245
- blksprs-1.11/blksprs/ops/partitioning.py +0 -213
- blksprs-1.11/blksprs/ops/repeat.py +0 -196
- blksprs-1.11/blksprs/ops/softmax.py +0 -278
- blksprs-1.11/blksprs/ops/transpose.py +0 -97
- blksprs-1.11/blksprs/utils/layout_utils.py +0 -17
- blksprs-1.11/blksprs/utils/tools.py +0 -29
- {blksprs-1.11 → blksprs-2.0rc2}/blksprs/utils/blksprs_tensor.py +0 -0
- {blksprs-1.11 → blksprs-2.0rc2}/blksprs/utils/processing.py +0 -0
- {blksprs-1.11 → blksprs-2.0rc2}/blksprs.egg-info/dependency_links.txt +0 -0
- {blksprs-1.11 → blksprs-2.0rc2}/blksprs.egg-info/top_level.txt +0 -0
- {blksprs-1.11 → blksprs-2.0rc2}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version:
|
|
3
|
+
Version: 2.0rc2
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -14,9 +14,8 @@ Requires-Dist: pytest; extra == "test"
|
|
|
14
14
|
Requires-Dist: pytest-xdist; extra == "test"
|
|
15
15
|
Requires-Dist: pytest-cov; extra == "test"
|
|
16
16
|
Requires-Dist: coverage; extra == "test"
|
|
17
|
+
Requires-Dist: build; extra == "test"
|
|
17
18
|
Requires-Dist: matplotlib; extra == "test"
|
|
18
|
-
Provides-Extra: build
|
|
19
|
-
Requires-Dist: build; extra == "build"
|
|
20
19
|
|
|
21
20
|
# blksprs
|
|
22
21
|
|
|
@@ -25,6 +24,13 @@ Requires-Dist: build; extra == "build"
|
|
|
25
24
|
|
|
26
25
|
## Overview
|
|
27
26
|
|
|
27
|
+
### News
|
|
28
|
+
|
|
29
|
+
🎉 ***Version 2.0 released***. blksprs now supports kernel auto-tuning, JIT compilation, specification of pre-calculated
|
|
30
|
+
LUTs, and makes use of `torch.library.triton_op()`!
|
|
31
|
+
|
|
32
|
+
---
|
|
33
|
+
|
|
28
34
|
A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
|
|
29
35
|
|
|
30
36
|
Currently supported operations (includes gradient calculation):
|
|
@@ -52,23 +58,25 @@ These include, e.g.,
|
|
|
52
58
|
Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to
|
|
53
59
|
match.
|
|
54
60
|
|
|
55
|
-
Further helpful operations (included in the ``bs.ops.misc`` module) that do **not** support gradient calculation
|
|
61
|
+
Further helpful operations (included in the ``bs.ops.misc`` module) that do **not** support gradient calculation
|
|
62
|
+
include:
|
|
56
63
|
|
|
57
64
|
- Row-wise sum, max, addition, and subtraction
|
|
58
65
|
- Broadcast addition and subtraction between slices
|
|
59
66
|
|
|
60
|
-
Furthermore, the library provides a set of utility functions
|
|
67
|
+
Furthermore, the library provides a set of utility functions
|
|
61
68
|
|
|
62
69
|
- for the creation of sparsity layouts based on existing
|
|
63
|
-
dense tensors and for the scatter operation (module ``bs.layouting``),
|
|
70
|
+
dense tensors and for the scatter operation (module ``bs.layouting``),
|
|
64
71
|
- for the application of ``nn.Linear``, ``nn.Dropout``, and ``nn.LayerNorm`` layers to block-sparse tensors,
|
|
65
72
|
- as well as utility functions to ensure correct input dimensionality, and validate input (module ``bs.utils``).
|
|
66
73
|
|
|
67
|
-
_* see the [Roadmap](#roadmap) section for more information_
|
|
74
|
+
_* see the [Roadmap](#roadmap) section for more information_
|
|
68
75
|
|
|
69
76
|
## Installation
|
|
70
77
|
|
|
71
|
-
Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible
|
|
78
|
+
Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible
|
|
79
|
+
with
|
|
72
80
|
the Linux platform**.
|
|
73
81
|
Keep track of this [issue](https://github.com/triton-lang/triton/issues/1640) for updates.
|
|
74
82
|
|
|
@@ -78,8 +86,8 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
|
|
|
78
86
|
|
|
79
87
|
### Dependencies
|
|
80
88
|
|
|
81
|
-
- [PyTorch](https://pytorch.org/) (built with v2.
|
|
82
|
-
- _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.2.
|
|
89
|
+
- [PyTorch](https://pytorch.org/) (built with v2.6)
|
|
90
|
+
- _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.2.4)_
|
|
83
91
|
- _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
|
|
84
92
|
|
|
85
93
|
## Changelog
|
|
@@ -89,12 +97,14 @@ See [`CHANGELOG.md`](https://github.com/FelixSchoen/blksprs/blob/main/CHANGELOG.
|
|
|
89
97
|
## Roadmap
|
|
90
98
|
|
|
91
99
|
Note that since this library covers all our current needs it is in a **bugfix-only** state.
|
|
92
|
-
This means that there are no plans to add new features, e.g., support for dimension specification of the ``split`` and
|
|
100
|
+
This means that there are no plans to add new features, e.g., support for dimension specification of the ``split`` and
|
|
101
|
+
``merge`` operations.
|
|
93
102
|
We will continue to maintain the library and fix any issues that arise.
|
|
94
103
|
Should you find any bugs please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
|
|
95
104
|
We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
|
|
96
105
|
|
|
97
|
-
It might be that this changes with future projects, but as of
|
|
106
|
+
It might be that this changes with future projects, but as of March 2025, we are content with the current state of the
|
|
107
|
+
library.
|
|
98
108
|
|
|
99
109
|
## Usage
|
|
100
110
|
|
|
@@ -120,10 +130,6 @@ def test_readme():
|
|
|
120
130
|
# Must be a power of two, greater than or equal to 16 for matmul, and divide m, n, and k
|
|
121
131
|
sparsity_block_size = 16
|
|
122
132
|
|
|
123
|
-
# Must be a power of two and smaller than or equal to sparsity_block_size
|
|
124
|
-
# If it is set to ``none`` a value will be chosen automatically
|
|
125
|
-
triton_block_size = None
|
|
126
|
-
|
|
127
133
|
# Initialise random (dense) tensors
|
|
128
134
|
x = torch.randn(size=(b, h, m, k), device="cuda")
|
|
129
135
|
y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()
|
|
@@ -133,53 +139,53 @@ def test_readme():
|
|
|
133
139
|
y_dense, y_shape_original = bs.utils.do_shape_blocksparse(y)
|
|
134
140
|
|
|
135
141
|
# Create sparsity layouts from existing tensors
|
|
136
|
-
sparsity_layout_x = bs.layouting.build_sparsity_layout(x_dense, sparsity_block_size
|
|
137
|
-
|
|
138
|
-
sparsity_layout_y = bs.layouting.build_sparsity_layout(y_dense, sparsity_block_size,
|
|
139
|
-
triton_block_size=triton_block_size)
|
|
142
|
+
sparsity_layout_x = bs.layouting.build_sparsity_layout(x_dense, sparsity_block_size)
|
|
143
|
+
sparsity_layout_y = bs.layouting.build_sparsity_layout(y_dense, sparsity_block_size)
|
|
140
144
|
|
|
141
145
|
# Create random sparsity layout for output tensor
|
|
142
146
|
sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)
|
|
143
147
|
|
|
144
148
|
# Convert tensors to sparse tensors for matrix multiplication
|
|
145
|
-
x_sparse = bs.to_sparse(x_dense, sparsity_layout_x, sparsity_block_size
|
|
146
|
-
y_sparse = bs.to_sparse(y_dense, sparsity_layout_y, sparsity_block_size
|
|
149
|
+
x_sparse = bs.ops.to_sparse(x_dense, sparsity_layout_x, sparsity_block_size)
|
|
150
|
+
y_sparse = bs.ops.to_sparse(y_dense, sparsity_layout_y, sparsity_block_size)
|
|
151
|
+
|
|
152
|
+
# As of version 2.0, blksprs supports JIT compilation
|
|
153
|
+
matmul_compiled = torch.compile(bs.ops.matmul)
|
|
147
154
|
|
|
148
155
|
# Perform matrix multiplication
|
|
149
|
-
o_sparse =
|
|
150
|
-
|
|
151
|
-
|
|
156
|
+
o_sparse = matmul_compiled(x_sparse, sparsity_layout_x,
|
|
157
|
+
y_sparse, sparsity_layout_y,
|
|
158
|
+
sparsity_layout_o, sparsity_block_size)
|
|
152
159
|
|
|
153
160
|
# Apply element-wise operation
|
|
154
161
|
o_sparse = torch.add(o_sparse, 1)
|
|
155
162
|
|
|
156
|
-
o_dense = bs.to_dense(o_sparse, sparsity_layout_o, sparsity_block_size
|
|
163
|
+
o_dense = bs.ops.to_dense(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
157
164
|
|
|
158
165
|
# Sanity check
|
|
159
166
|
o_torch = torch.matmul(x_dense, y_dense)
|
|
160
167
|
o_torch = torch.add(o_torch, 1)
|
|
161
168
|
|
|
162
169
|
# Perform round trip to set sparse blocks to 0
|
|
163
|
-
o_torch_round_trip = bs.to_dense(
|
|
164
|
-
bs.to_sparse(o_torch, sparsity_layout_o, sparsity_block_size
|
|
165
|
-
sparsity_layout_o, sparsity_block_size, fill_value=0
|
|
170
|
+
o_torch_round_trip = bs.ops.to_dense(
|
|
171
|
+
bs.ops.to_sparse(o_torch, sparsity_layout_o, sparsity_block_size),
|
|
172
|
+
sparsity_layout_o, sparsity_block_size, fill_value=0)
|
|
166
173
|
|
|
167
174
|
# Assert that the output is correct
|
|
168
175
|
assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2) # Note that small numerical differences are expected
|
|
169
176
|
|
|
170
177
|
# Assert that the output has the correct sparsity layout
|
|
171
|
-
actual_sparsity_layout_o = bs.layouting.build_sparsity_layout(o_dense, sparsity_block_size
|
|
172
|
-
triton_block_size=triton_block_size)
|
|
178
|
+
actual_sparsity_layout_o = bs.layouting.build_sparsity_layout(o_dense, sparsity_block_size)
|
|
173
179
|
assert torch.allclose(actual_sparsity_layout_o.to(torch.int), sparsity_layout_o)
|
|
174
180
|
|
|
175
181
|
# Convert output tensor back to original shape
|
|
176
182
|
o = bs.utils.undo_shape_blocksparse(o_dense, x_shape_original)
|
|
177
183
|
|
|
178
184
|
# Other available functions
|
|
179
|
-
bs.transpose(o_sparse, sparsity_layout_o, sparsity_block_size
|
|
180
|
-
bs.softmax(o_sparse, sparsity_layout_o, sparsity_block_size
|
|
181
|
-
bs.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size
|
|
182
|
-
bs.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size
|
|
185
|
+
bs.ops.transpose(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
186
|
+
bs.ops.softmax(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
187
|
+
bs.ops.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
188
|
+
bs.ops.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
183
189
|
|
|
184
190
|
|
|
185
191
|
def _get_random_sparsity_layout(b, m, n, sparsity_block_size, sparsity_percentage):
|
|
@@ -5,6 +5,13 @@
|
|
|
5
5
|
|
|
6
6
|
## Overview
|
|
7
7
|
|
|
8
|
+
### News
|
|
9
|
+
|
|
10
|
+
🎉 ***Version 2.0 released***. blksprs now supports kernel auto-tuning, JIT compilation, specification of pre-calculated
|
|
11
|
+
LUTs, and makes use of `torch.library.triton_op()`!
|
|
12
|
+
|
|
13
|
+
---
|
|
14
|
+
|
|
8
15
|
A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
|
|
9
16
|
|
|
10
17
|
Currently supported operations (includes gradient calculation):
|
|
@@ -32,23 +39,25 @@ These include, e.g.,
|
|
|
32
39
|
Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to
|
|
33
40
|
match.
|
|
34
41
|
|
|
35
|
-
Further helpful operations (included in the ``bs.ops.misc`` module) that do **not** support gradient calculation
|
|
42
|
+
Further helpful operations (included in the ``bs.ops.misc`` module) that do **not** support gradient calculation
|
|
43
|
+
include:
|
|
36
44
|
|
|
37
45
|
- Row-wise sum, max, addition, and subtraction
|
|
38
46
|
- Broadcast addition and subtraction between slices
|
|
39
47
|
|
|
40
|
-
Furthermore, the library provides a set of utility functions
|
|
48
|
+
Furthermore, the library provides a set of utility functions
|
|
41
49
|
|
|
42
50
|
- for the creation of sparsity layouts based on existing
|
|
43
|
-
dense tensors and for the scatter operation (module ``bs.layouting``),
|
|
51
|
+
dense tensors and for the scatter operation (module ``bs.layouting``),
|
|
44
52
|
- for the application of ``nn.Linear``, ``nn.Dropout``, and ``nn.LayerNorm`` layers to block-sparse tensors,
|
|
45
53
|
- as well as utility functions to ensure correct input dimensionality, and validate input (module ``bs.utils``).
|
|
46
54
|
|
|
47
|
-
_* see the [Roadmap](#roadmap) section for more information_
|
|
55
|
+
_* see the [Roadmap](#roadmap) section for more information_
|
|
48
56
|
|
|
49
57
|
## Installation
|
|
50
58
|
|
|
51
|
-
Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible
|
|
59
|
+
Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible
|
|
60
|
+
with
|
|
52
61
|
the Linux platform**.
|
|
53
62
|
Keep track of this [issue](https://github.com/triton-lang/triton/issues/1640) for updates.
|
|
54
63
|
|
|
@@ -58,8 +67,8 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
|
|
|
58
67
|
|
|
59
68
|
### Dependencies
|
|
60
69
|
|
|
61
|
-
- [PyTorch](https://pytorch.org/) (built with v2.
|
|
62
|
-
- _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.2.
|
|
70
|
+
- [PyTorch](https://pytorch.org/) (built with v2.6)
|
|
71
|
+
- _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.2.4)_
|
|
63
72
|
- _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
|
|
64
73
|
|
|
65
74
|
## Changelog
|
|
@@ -69,12 +78,14 @@ See [`CHANGELOG.md`](https://github.com/FelixSchoen/blksprs/blob/main/CHANGELOG.
|
|
|
69
78
|
## Roadmap
|
|
70
79
|
|
|
71
80
|
Note that since this library covers all our current needs it is in a **bugfix-only** state.
|
|
72
|
-
This means that there are no plans to add new features, e.g., support for dimension specification of the ``split`` and
|
|
81
|
+
This means that there are no plans to add new features, e.g., support for dimension specification of the ``split`` and
|
|
82
|
+
``merge`` operations.
|
|
73
83
|
We will continue to maintain the library and fix any issues that arise.
|
|
74
84
|
Should you find any bugs please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
|
|
75
85
|
We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
|
|
76
86
|
|
|
77
|
-
It might be that this changes with future projects, but as of
|
|
87
|
+
It might be that this changes with future projects, but as of March 2025, we are content with the current state of the
|
|
88
|
+
library.
|
|
78
89
|
|
|
79
90
|
## Usage
|
|
80
91
|
|
|
@@ -100,10 +111,6 @@ def test_readme():
|
|
|
100
111
|
# Must be a power of two, greater than or equal to 16 for matmul, and divide m, n, and k
|
|
101
112
|
sparsity_block_size = 16
|
|
102
113
|
|
|
103
|
-
# Must be a power of two and smaller than or equal to sparsity_block_size
|
|
104
|
-
# If it is set to ``none`` a value will be chosen automatically
|
|
105
|
-
triton_block_size = None
|
|
106
|
-
|
|
107
114
|
# Initialise random (dense) tensors
|
|
108
115
|
x = torch.randn(size=(b, h, m, k), device="cuda")
|
|
109
116
|
y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()
|
|
@@ -113,53 +120,53 @@ def test_readme():
|
|
|
113
120
|
y_dense, y_shape_original = bs.utils.do_shape_blocksparse(y)
|
|
114
121
|
|
|
115
122
|
# Create sparsity layouts from existing tensors
|
|
116
|
-
sparsity_layout_x = bs.layouting.build_sparsity_layout(x_dense, sparsity_block_size
|
|
117
|
-
|
|
118
|
-
sparsity_layout_y = bs.layouting.build_sparsity_layout(y_dense, sparsity_block_size,
|
|
119
|
-
triton_block_size=triton_block_size)
|
|
123
|
+
sparsity_layout_x = bs.layouting.build_sparsity_layout(x_dense, sparsity_block_size)
|
|
124
|
+
sparsity_layout_y = bs.layouting.build_sparsity_layout(y_dense, sparsity_block_size)
|
|
120
125
|
|
|
121
126
|
# Create random sparsity layout for output tensor
|
|
122
127
|
sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)
|
|
123
128
|
|
|
124
129
|
# Convert tensors to sparse tensors for matrix multiplication
|
|
125
|
-
x_sparse = bs.to_sparse(x_dense, sparsity_layout_x, sparsity_block_size
|
|
126
|
-
y_sparse = bs.to_sparse(y_dense, sparsity_layout_y, sparsity_block_size
|
|
130
|
+
x_sparse = bs.ops.to_sparse(x_dense, sparsity_layout_x, sparsity_block_size)
|
|
131
|
+
y_sparse = bs.ops.to_sparse(y_dense, sparsity_layout_y, sparsity_block_size)
|
|
132
|
+
|
|
133
|
+
# As of version 2.0, blksprs supports JIT compilation
|
|
134
|
+
matmul_compiled = torch.compile(bs.ops.matmul)
|
|
127
135
|
|
|
128
136
|
# Perform matrix multiplication
|
|
129
|
-
o_sparse =
|
|
130
|
-
|
|
131
|
-
|
|
137
|
+
o_sparse = matmul_compiled(x_sparse, sparsity_layout_x,
|
|
138
|
+
y_sparse, sparsity_layout_y,
|
|
139
|
+
sparsity_layout_o, sparsity_block_size)
|
|
132
140
|
|
|
133
141
|
# Apply element-wise operation
|
|
134
142
|
o_sparse = torch.add(o_sparse, 1)
|
|
135
143
|
|
|
136
|
-
o_dense = bs.to_dense(o_sparse, sparsity_layout_o, sparsity_block_size
|
|
144
|
+
o_dense = bs.ops.to_dense(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
137
145
|
|
|
138
146
|
# Sanity check
|
|
139
147
|
o_torch = torch.matmul(x_dense, y_dense)
|
|
140
148
|
o_torch = torch.add(o_torch, 1)
|
|
141
149
|
|
|
142
150
|
# Perform round trip to set sparse blocks to 0
|
|
143
|
-
o_torch_round_trip = bs.to_dense(
|
|
144
|
-
bs.to_sparse(o_torch, sparsity_layout_o, sparsity_block_size
|
|
145
|
-
sparsity_layout_o, sparsity_block_size, fill_value=0
|
|
151
|
+
o_torch_round_trip = bs.ops.to_dense(
|
|
152
|
+
bs.ops.to_sparse(o_torch, sparsity_layout_o, sparsity_block_size),
|
|
153
|
+
sparsity_layout_o, sparsity_block_size, fill_value=0)
|
|
146
154
|
|
|
147
155
|
# Assert that the output is correct
|
|
148
156
|
assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2) # Note that small numerical differences are expected
|
|
149
157
|
|
|
150
158
|
# Assert that the output has the correct sparsity layout
|
|
151
|
-
actual_sparsity_layout_o = bs.layouting.build_sparsity_layout(o_dense, sparsity_block_size
|
|
152
|
-
triton_block_size=triton_block_size)
|
|
159
|
+
actual_sparsity_layout_o = bs.layouting.build_sparsity_layout(o_dense, sparsity_block_size)
|
|
153
160
|
assert torch.allclose(actual_sparsity_layout_o.to(torch.int), sparsity_layout_o)
|
|
154
161
|
|
|
155
162
|
# Convert output tensor back to original shape
|
|
156
163
|
o = bs.utils.undo_shape_blocksparse(o_dense, x_shape_original)
|
|
157
164
|
|
|
158
165
|
# Other available functions
|
|
159
|
-
bs.transpose(o_sparse, sparsity_layout_o, sparsity_block_size
|
|
160
|
-
bs.softmax(o_sparse, sparsity_layout_o, sparsity_block_size
|
|
161
|
-
bs.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size
|
|
162
|
-
bs.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size
|
|
166
|
+
bs.ops.transpose(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
167
|
+
bs.ops.softmax(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
168
|
+
bs.ops.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
169
|
+
bs.ops.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
163
170
|
|
|
164
171
|
|
|
165
172
|
def _get_random_sparsity_layout(b, m, n, sparsity_block_size, sparsity_percentage):
|
|
@@ -18,19 +18,16 @@ class ops:
|
|
|
18
18
|
class layouting:
|
|
19
19
|
from blksprs.layouting.distribution_layout import build_distribution_layout
|
|
20
20
|
from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption, \
|
|
21
|
-
build_sparsity_layout_matmul, build_sparsity_layout_matmul_fast
|
|
22
|
-
from blksprs.utils.layout_utils import build_full_sparsity_layout
|
|
21
|
+
build_sparsity_layout_matmul, build_sparsity_layout_matmul_fast, build_sparsity_layout_full
|
|
23
22
|
|
|
24
23
|
|
|
25
24
|
class utils:
|
|
26
25
|
from blksprs.utils.processing import apply_torch_linear, apply_torch_normalisation, apply_torch_dropout, \
|
|
27
26
|
apply_function_applicable_row_wise
|
|
28
27
|
from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse
|
|
29
|
-
from blksprs.utils.validation import disable_validation
|
|
30
28
|
|
|
31
29
|
class validation:
|
|
32
30
|
from blksprs.utils.validation import disable_validation
|
|
33
31
|
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_dtype_float, \
|
|
34
32
|
validate_dtype_int, validate_device, validate_sparsity, validate_sparsity_dense, \
|
|
35
|
-
validate_sparsity_block_size
|
|
36
|
-
validate_triton_block_size
|
|
33
|
+
validate_sparsity_block_size
|
|
@@ -4,14 +4,14 @@ from torch import Tensor
|
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
6
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
7
|
-
from blksprs.utils.tools import
|
|
8
|
-
from blksprs.utils.validation import
|
|
7
|
+
from blksprs.utils.tools import stride, get_autotune_configs
|
|
8
|
+
from blksprs.utils.validation import validate_dimensions, validate_device, \
|
|
9
9
|
validate_contiguous
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: Tensor,
|
|
13
13
|
dim: int, size_target: torch.Size,
|
|
14
|
-
sparsity_block_size: int
|
|
14
|
+
sparsity_block_size: int) -> Tensor:
|
|
15
15
|
"""Builds the sparsity layout of either the source of a gather or the target of a scatter operation.
|
|
16
16
|
|
|
17
17
|
Args:
|
|
@@ -20,7 +20,6 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
|
|
|
20
20
|
dim (int): The dimension along which the operation is conducted.
|
|
21
21
|
size_target (torch.Size): The size of the block-sparse target tensor in regular form.
|
|
22
22
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
23
|
-
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
24
23
|
|
|
25
24
|
Returns:
|
|
26
25
|
Tensor: The sparsity layout of the source or target tensor.
|
|
@@ -44,16 +43,11 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
|
|
|
44
43
|
o_b, o_r, o_c = output.size()
|
|
45
44
|
o_b_s, o_r_s, o_c_s = stride(output)
|
|
46
45
|
|
|
47
|
-
if triton_block_size is None:
|
|
48
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
49
|
-
|
|
50
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
51
|
-
|
|
52
46
|
triton_grid = lambda meta: [i_b,
|
|
53
47
|
triton.cdiv(i_r, meta["TRITON_BLOCK_SIZE"]),
|
|
54
48
|
triton.cdiv(i_c, meta["TRITON_BLOCK_SIZE"])]
|
|
55
49
|
|
|
56
|
-
(
|
|
50
|
+
(build_distribution_layout_kernel[triton_grid]
|
|
57
51
|
(indices,
|
|
58
52
|
i_b, i_b_s, i_r_s, i_c_s,
|
|
59
53
|
sparsity_lut_i,
|
|
@@ -61,27 +55,34 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
|
|
|
61
55
|
adjusted_dim,
|
|
62
56
|
output,
|
|
63
57
|
o_b, o_b_s, o_r_s, o_c_s,
|
|
64
|
-
sparsity_block_size
|
|
65
|
-
triton_block_size))
|
|
58
|
+
sparsity_block_size))
|
|
66
59
|
|
|
67
60
|
return output
|
|
68
61
|
|
|
69
62
|
|
|
63
|
+
@triton.autotune(
|
|
64
|
+
configs=get_autotune_configs(),
|
|
65
|
+
key=[],
|
|
66
|
+
reset_to_zero=["o"]
|
|
67
|
+
)
|
|
70
68
|
@triton.jit
|
|
71
|
-
def
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
69
|
+
def build_distribution_layout_kernel(i,
|
|
70
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
71
|
+
s_lut_i,
|
|
72
|
+
s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
|
|
73
|
+
dim,
|
|
74
|
+
o,
|
|
75
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
76
|
+
sparsity_block_size,
|
|
77
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
80
78
|
# Get triton block indices
|
|
81
79
|
pid_blk = tl.program_id(axis=0)
|
|
82
80
|
pid_row = tl.program_id(axis=1)
|
|
83
81
|
pid_col = tl.program_id(axis=2)
|
|
84
82
|
|
|
83
|
+
# Get valid triton block size
|
|
84
|
+
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
85
|
+
|
|
85
86
|
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
86
87
|
spa_bat_i_idx = (pid_blk * s_lut_i_r_s + 0 * s_lut_i_c_s)
|
|
87
88
|
spa_bat_i_msk = (spa_bat_i_idx >= 0 and spa_bat_i_idx < s_lut_i_r * s_lut_i_r_s)
|
|
@@ -96,9 +97,12 @@ def kernel_distribution_layout(i,
|
|
|
96
97
|
spa_col_i = tl.load(s_lut_i + spa_col_i_idx, mask=spa_col_i_msk)
|
|
97
98
|
|
|
98
99
|
blk_i_idx = (pid_blk * i_b_s +
|
|
99
|
-
((pid_row *
|
|
100
|
-
((pid_col *
|
|
101
|
-
blk_i_msk = (blk_i_idx >= 0 and
|
|
100
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
101
|
+
((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
102
|
+
blk_i_msk = ((blk_i_idx >= 0 and
|
|
103
|
+
blk_i_idx < i_b * i_b_s) and
|
|
104
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
105
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
102
106
|
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk)
|
|
103
107
|
|
|
104
108
|
dst_bat_idx = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_i, dtype=tl.int32)
|
|
@@ -116,5 +120,8 @@ def kernel_distribution_layout(i,
|
|
|
116
120
|
blk_o_idx = ((dst_bat_idx * o_b_s) +
|
|
117
121
|
(dst_row_idx * o_r_s) +
|
|
118
122
|
(dst_col_idx * o_c_s))
|
|
119
|
-
blk_o_msk = (blk_o_idx >= 0 and
|
|
123
|
+
blk_o_msk = ((blk_o_idx >= 0 and
|
|
124
|
+
blk_o_idx < o_b * o_b_s) and
|
|
125
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
126
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
120
127
|
tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
|