blksprs 1.11__tar.gz → 2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {blksprs-1.11 → blksprs-2.0}/PKG-INFO +55 -36
- {blksprs-1.11 → blksprs-2.0}/README.md +52 -32
- {blksprs-1.11 → blksprs-2.0}/blksprs/__init__.py +4 -5
- {blksprs-1.11 → blksprs-2.0}/blksprs/layouting/distribution_layout.py +64 -48
- {blksprs-1.11 → blksprs-2.0}/blksprs/layouting/sparsity_layout.py +96 -72
- blksprs-2.0/blksprs/ops/conversion.py +506 -0
- blksprs-2.0/blksprs/ops/distribution.py +482 -0
- blksprs-2.0/blksprs/ops/flow.py +192 -0
- blksprs-2.0/blksprs/ops/matmul.py +260 -0
- {blksprs-1.11 → blksprs-2.0}/blksprs/ops/misc/broadcast_ops.py +68 -53
- blksprs-2.0/blksprs/ops/misc/row_wise.py +445 -0
- blksprs-2.0/blksprs/ops/partitioning.py +221 -0
- blksprs-2.0/blksprs/ops/repeat.py +194 -0
- blksprs-2.0/blksprs/ops/softmax.py +304 -0
- blksprs-2.0/blksprs/ops/transpose.py +100 -0
- blksprs-2.0/blksprs/utils/autotuning.py +78 -0
- {blksprs-1.11 → blksprs-2.0}/blksprs/utils/benchmarking.py +3 -3
- {blksprs-1.11 → blksprs-2.0}/blksprs/utils/processing.py +2 -1
- {blksprs-1.11 → blksprs-2.0}/blksprs/utils/tools.py +5 -6
- {blksprs-1.11 → blksprs-2.0}/blksprs/utils/validation.py +22 -16
- {blksprs-1.11 → blksprs-2.0}/blksprs.egg-info/PKG-INFO +55 -36
- {blksprs-1.11 → blksprs-2.0}/blksprs.egg-info/SOURCES.txt +1 -1
- {blksprs-1.11 → blksprs-2.0}/blksprs.egg-info/requires.txt +1 -3
- {blksprs-1.11 → blksprs-2.0}/pyproject.toml +2 -4
- blksprs-1.11/blksprs/ops/conversion.py +0 -495
- blksprs-1.11/blksprs/ops/distribution.py +0 -458
- blksprs-1.11/blksprs/ops/flow.py +0 -179
- blksprs-1.11/blksprs/ops/matmul.py +0 -245
- blksprs-1.11/blksprs/ops/misc/row_wise.py +0 -398
- blksprs-1.11/blksprs/ops/partitioning.py +0 -213
- blksprs-1.11/blksprs/ops/repeat.py +0 -196
- blksprs-1.11/blksprs/ops/softmax.py +0 -278
- blksprs-1.11/blksprs/ops/transpose.py +0 -97
- blksprs-1.11/blksprs/utils/layout_utils.py +0 -17
- {blksprs-1.11 → blksprs-2.0}/blksprs/utils/blksprs_tensor.py +0 -0
- {blksprs-1.11 → blksprs-2.0}/blksprs.egg-info/dependency_links.txt +0 -0
- {blksprs-1.11 → blksprs-2.0}/blksprs.egg-info/top_level.txt +0 -0
- {blksprs-1.11 → blksprs-2.0}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version:
|
|
3
|
+
Version: 2.0
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -14,9 +14,8 @@ Requires-Dist: pytest; extra == "test"
|
|
|
14
14
|
Requires-Dist: pytest-xdist; extra == "test"
|
|
15
15
|
Requires-Dist: pytest-cov; extra == "test"
|
|
16
16
|
Requires-Dist: coverage; extra == "test"
|
|
17
|
+
Requires-Dist: build; extra == "test"
|
|
17
18
|
Requires-Dist: matplotlib; extra == "test"
|
|
18
|
-
Provides-Extra: build
|
|
19
|
-
Requires-Dist: build; extra == "build"
|
|
20
19
|
|
|
21
20
|
# blksprs
|
|
22
21
|
|
|
@@ -25,6 +24,13 @@ Requires-Dist: build; extra == "build"
|
|
|
25
24
|
|
|
26
25
|
## Overview
|
|
27
26
|
|
|
27
|
+
### News
|
|
28
|
+
|
|
29
|
+
🎉 ***Version 2.0 released***. blksprs now supports kernel auto-tuning, JIT compilation, specification of pre-calculated
|
|
30
|
+
LUTs, autocasting, and makes use of `torch.library.triton_op()`!
|
|
31
|
+
|
|
32
|
+
---
|
|
33
|
+
|
|
28
34
|
A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
|
|
29
35
|
|
|
30
36
|
Currently supported operations (includes gradient calculation):
|
|
@@ -52,23 +58,25 @@ These include, e.g.,
|
|
|
52
58
|
Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to
|
|
53
59
|
match.
|
|
54
60
|
|
|
55
|
-
Further helpful operations (included in the ``bs.ops.misc`` module) that do **not** support gradient calculation
|
|
61
|
+
Further helpful operations (included in the ``bs.ops.misc`` module) that do **not** support gradient calculation
|
|
62
|
+
include:
|
|
56
63
|
|
|
57
64
|
- Row-wise sum, max, addition, and subtraction
|
|
58
65
|
- Broadcast addition and subtraction between slices
|
|
59
66
|
|
|
60
|
-
Furthermore, the library provides a set of utility functions
|
|
67
|
+
Furthermore, the library provides a set of utility functions
|
|
61
68
|
|
|
62
69
|
- for the creation of sparsity layouts based on existing
|
|
63
|
-
dense tensors and for the scatter operation (module ``bs.layouting``),
|
|
70
|
+
dense tensors and for the scatter operation (module ``bs.layouting``),
|
|
64
71
|
- for the application of ``nn.Linear``, ``nn.Dropout``, and ``nn.LayerNorm`` layers to block-sparse tensors,
|
|
65
72
|
- as well as utility functions to ensure correct input dimensionality, and validate input (module ``bs.utils``).
|
|
66
73
|
|
|
67
|
-
_* see the [Roadmap](#roadmap) section for more information_
|
|
74
|
+
_* see the [Roadmap](#roadmap) section for more information_
|
|
68
75
|
|
|
69
76
|
## Installation
|
|
70
77
|
|
|
71
|
-
Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible
|
|
78
|
+
Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible
|
|
79
|
+
with
|
|
72
80
|
the Linux platform**.
|
|
73
81
|
Keep track of this [issue](https://github.com/triton-lang/triton/issues/1640) for updates.
|
|
74
82
|
|
|
@@ -78,8 +86,8 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
|
|
|
78
86
|
|
|
79
87
|
### Dependencies
|
|
80
88
|
|
|
81
|
-
- [PyTorch](https://pytorch.org/) (built with v2.
|
|
82
|
-
- _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.2.
|
|
89
|
+
- [PyTorch](https://pytorch.org/) (built with v2.6)
|
|
90
|
+
- _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.2.4)_
|
|
83
91
|
- _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
|
|
84
92
|
|
|
85
93
|
## Changelog
|
|
@@ -89,12 +97,27 @@ See [`CHANGELOG.md`](https://github.com/FelixSchoen/blksprs/blob/main/CHANGELOG.
|
|
|
89
97
|
## Roadmap
|
|
90
98
|
|
|
91
99
|
Note that since this library covers all our current needs it is in a **bugfix-only** state.
|
|
92
|
-
This means that there are no plans to add new features, e.g., support for dimension specification of the ``split`` and
|
|
100
|
+
This means that there are no plans to add new features, e.g., support for dimension specification of the ``split`` and
|
|
101
|
+
``merge`` operations.
|
|
93
102
|
We will continue to maintain the library and fix any issues that arise.
|
|
94
103
|
Should you find any bugs please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
|
|
95
104
|
We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
|
|
96
105
|
|
|
97
|
-
It might be that this changes with future projects, but as of
|
|
106
|
+
It might be that this changes with future projects, but as of March 2025, we are content with the current state of the
|
|
107
|
+
library.
|
|
108
|
+
|
|
109
|
+
## Known Limitations and Issues
|
|
110
|
+
|
|
111
|
+
- Triton has a bug with `tl.atomix_max()` used for the row-wise max operation.
|
|
112
|
+
In order to work around this bug a manual conversion of some values is needed, (slightly) negatively impacting
|
|
113
|
+
performance.
|
|
114
|
+
Watch the [issue](https://github.com/triton-lang/triton/issues/6376) on Triton's issue tracker for more information.
|
|
115
|
+
- PyTorch's `wrap_triton()` currently does not support config pruning. It thus cannot be used for some of the kernels,
|
|
116
|
+
which could impact graph compilation.
|
|
117
|
+
- There seem to be some issues with autocasting, forcing some operations to manually cast.
|
|
118
|
+
- There will be some slight numerical differences between vanilla and blksprs operations.
|
|
119
|
+
These instabilities are due to Triton and thus cannot be fixed by this library alone.
|
|
120
|
+
However, for all intents and purposes, these very minor differences should not matter and can safely be ignored.
|
|
98
121
|
|
|
99
122
|
## Usage
|
|
100
123
|
|
|
@@ -120,10 +143,6 @@ def test_readme():
|
|
|
120
143
|
# Must be a power of two, greater than or equal to 16 for matmul, and divide m, n, and k
|
|
121
144
|
sparsity_block_size = 16
|
|
122
145
|
|
|
123
|
-
# Must be a power of two and smaller than or equal to sparsity_block_size
|
|
124
|
-
# If it is set to ``none`` a value will be chosen automatically
|
|
125
|
-
triton_block_size = None
|
|
126
|
-
|
|
127
146
|
# Initialise random (dense) tensors
|
|
128
147
|
x = torch.randn(size=(b, h, m, k), device="cuda")
|
|
129
148
|
y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()
|
|
@@ -133,53 +152,53 @@ def test_readme():
|
|
|
133
152
|
y_dense, y_shape_original = bs.utils.do_shape_blocksparse(y)
|
|
134
153
|
|
|
135
154
|
# Create sparsity layouts from existing tensors
|
|
136
|
-
sparsity_layout_x = bs.layouting.build_sparsity_layout(x_dense, sparsity_block_size
|
|
137
|
-
|
|
138
|
-
sparsity_layout_y = bs.layouting.build_sparsity_layout(y_dense, sparsity_block_size,
|
|
139
|
-
triton_block_size=triton_block_size)
|
|
155
|
+
sparsity_layout_x = bs.layouting.build_sparsity_layout(x_dense, sparsity_block_size)
|
|
156
|
+
sparsity_layout_y = bs.layouting.build_sparsity_layout(y_dense, sparsity_block_size)
|
|
140
157
|
|
|
141
158
|
# Create random sparsity layout for output tensor
|
|
142
159
|
sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)
|
|
143
160
|
|
|
144
161
|
# Convert tensors to sparse tensors for matrix multiplication
|
|
145
|
-
x_sparse = bs.to_sparse(x_dense, sparsity_layout_x, sparsity_block_size
|
|
146
|
-
y_sparse = bs.to_sparse(y_dense, sparsity_layout_y, sparsity_block_size
|
|
162
|
+
x_sparse = bs.ops.to_sparse(x_dense, sparsity_layout_x, sparsity_block_size)
|
|
163
|
+
y_sparse = bs.ops.to_sparse(y_dense, sparsity_layout_y, sparsity_block_size)
|
|
164
|
+
|
|
165
|
+
# As of version 2.0, blksprs supports JIT compilation
|
|
166
|
+
matmul_compiled = torch.compile(bs.ops.matmul)
|
|
147
167
|
|
|
148
168
|
# Perform matrix multiplication
|
|
149
|
-
o_sparse =
|
|
150
|
-
|
|
151
|
-
|
|
169
|
+
o_sparse = matmul_compiled(x_sparse, sparsity_layout_x,
|
|
170
|
+
y_sparse, sparsity_layout_y,
|
|
171
|
+
sparsity_layout_o, sparsity_block_size)
|
|
152
172
|
|
|
153
173
|
# Apply element-wise operation
|
|
154
174
|
o_sparse = torch.add(o_sparse, 1)
|
|
155
175
|
|
|
156
|
-
o_dense = bs.to_dense(o_sparse, sparsity_layout_o, sparsity_block_size
|
|
176
|
+
o_dense = bs.ops.to_dense(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
157
177
|
|
|
158
178
|
# Sanity check
|
|
159
179
|
o_torch = torch.matmul(x_dense, y_dense)
|
|
160
180
|
o_torch = torch.add(o_torch, 1)
|
|
161
181
|
|
|
162
182
|
# Perform round trip to set sparse blocks to 0
|
|
163
|
-
o_torch_round_trip = bs.to_dense(
|
|
164
|
-
bs.to_sparse(o_torch, sparsity_layout_o, sparsity_block_size
|
|
165
|
-
sparsity_layout_o, sparsity_block_size, fill_value=0
|
|
183
|
+
o_torch_round_trip = bs.ops.to_dense(
|
|
184
|
+
bs.ops.to_sparse(o_torch, sparsity_layout_o, sparsity_block_size),
|
|
185
|
+
sparsity_layout_o, sparsity_block_size, fill_value=0)
|
|
166
186
|
|
|
167
187
|
# Assert that the output is correct
|
|
168
188
|
assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2) # Note that small numerical differences are expected
|
|
169
189
|
|
|
170
190
|
# Assert that the output has the correct sparsity layout
|
|
171
|
-
actual_sparsity_layout_o = bs.layouting.build_sparsity_layout(o_dense, sparsity_block_size
|
|
172
|
-
triton_block_size=triton_block_size)
|
|
191
|
+
actual_sparsity_layout_o = bs.layouting.build_sparsity_layout(o_dense, sparsity_block_size)
|
|
173
192
|
assert torch.allclose(actual_sparsity_layout_o.to(torch.int), sparsity_layout_o)
|
|
174
193
|
|
|
175
194
|
# Convert output tensor back to original shape
|
|
176
195
|
o = bs.utils.undo_shape_blocksparse(o_dense, x_shape_original)
|
|
177
196
|
|
|
178
197
|
# Other available functions
|
|
179
|
-
bs.transpose(o_sparse, sparsity_layout_o, sparsity_block_size
|
|
180
|
-
bs.softmax(o_sparse, sparsity_layout_o, sparsity_block_size
|
|
181
|
-
bs.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size
|
|
182
|
-
bs.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size
|
|
198
|
+
bs.ops.transpose(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
199
|
+
bs.ops.softmax(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
200
|
+
bs.ops.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
201
|
+
bs.ops.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
183
202
|
|
|
184
203
|
|
|
185
204
|
def _get_random_sparsity_layout(b, m, n, sparsity_block_size, sparsity_percentage):
|
|
@@ -5,6 +5,13 @@
|
|
|
5
5
|
|
|
6
6
|
## Overview
|
|
7
7
|
|
|
8
|
+
### News
|
|
9
|
+
|
|
10
|
+
🎉 ***Version 2.0 released***. blksprs now supports kernel auto-tuning, JIT compilation, specification of pre-calculated
|
|
11
|
+
LUTs, autocasting, and makes use of `torch.library.triton_op()`!
|
|
12
|
+
|
|
13
|
+
---
|
|
14
|
+
|
|
8
15
|
A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
|
|
9
16
|
|
|
10
17
|
Currently supported operations (includes gradient calculation):
|
|
@@ -32,23 +39,25 @@ These include, e.g.,
|
|
|
32
39
|
Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to
|
|
33
40
|
match.
|
|
34
41
|
|
|
35
|
-
Further helpful operations (included in the ``bs.ops.misc`` module) that do **not** support gradient calculation
|
|
42
|
+
Further helpful operations (included in the ``bs.ops.misc`` module) that do **not** support gradient calculation
|
|
43
|
+
include:
|
|
36
44
|
|
|
37
45
|
- Row-wise sum, max, addition, and subtraction
|
|
38
46
|
- Broadcast addition and subtraction between slices
|
|
39
47
|
|
|
40
|
-
Furthermore, the library provides a set of utility functions
|
|
48
|
+
Furthermore, the library provides a set of utility functions
|
|
41
49
|
|
|
42
50
|
- for the creation of sparsity layouts based on existing
|
|
43
|
-
dense tensors and for the scatter operation (module ``bs.layouting``),
|
|
51
|
+
dense tensors and for the scatter operation (module ``bs.layouting``),
|
|
44
52
|
- for the application of ``nn.Linear``, ``nn.Dropout``, and ``nn.LayerNorm`` layers to block-sparse tensors,
|
|
45
53
|
- as well as utility functions to ensure correct input dimensionality, and validate input (module ``bs.utils``).
|
|
46
54
|
|
|
47
|
-
_* see the [Roadmap](#roadmap) section for more information_
|
|
55
|
+
_* see the [Roadmap](#roadmap) section for more information_
|
|
48
56
|
|
|
49
57
|
## Installation
|
|
50
58
|
|
|
51
|
-
Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible
|
|
59
|
+
Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible
|
|
60
|
+
with
|
|
52
61
|
the Linux platform**.
|
|
53
62
|
Keep track of this [issue](https://github.com/triton-lang/triton/issues/1640) for updates.
|
|
54
63
|
|
|
@@ -58,8 +67,8 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
|
|
|
58
67
|
|
|
59
68
|
### Dependencies
|
|
60
69
|
|
|
61
|
-
- [PyTorch](https://pytorch.org/) (built with v2.
|
|
62
|
-
- _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.2.
|
|
70
|
+
- [PyTorch](https://pytorch.org/) (built with v2.6)
|
|
71
|
+
- _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.2.4)_
|
|
63
72
|
- _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
|
|
64
73
|
|
|
65
74
|
## Changelog
|
|
@@ -69,12 +78,27 @@ See [`CHANGELOG.md`](https://github.com/FelixSchoen/blksprs/blob/main/CHANGELOG.
|
|
|
69
78
|
## Roadmap
|
|
70
79
|
|
|
71
80
|
Note that since this library covers all our current needs it is in a **bugfix-only** state.
|
|
72
|
-
This means that there are no plans to add new features, e.g., support for dimension specification of the ``split`` and
|
|
81
|
+
This means that there are no plans to add new features, e.g., support for dimension specification of the ``split`` and
|
|
82
|
+
``merge`` operations.
|
|
73
83
|
We will continue to maintain the library and fix any issues that arise.
|
|
74
84
|
Should you find any bugs please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
|
|
75
85
|
We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
|
|
76
86
|
|
|
77
|
-
It might be that this changes with future projects, but as of
|
|
87
|
+
It might be that this changes with future projects, but as of March 2025, we are content with the current state of the
|
|
88
|
+
library.
|
|
89
|
+
|
|
90
|
+
## Known Limitations and Issues
|
|
91
|
+
|
|
92
|
+
- Triton has a bug with `tl.atomix_max()` used for the row-wise max operation.
|
|
93
|
+
In order to work around this bug a manual conversion of some values is needed, (slightly) negatively impacting
|
|
94
|
+
performance.
|
|
95
|
+
Watch the [issue](https://github.com/triton-lang/triton/issues/6376) on Triton's issue tracker for more information.
|
|
96
|
+
- PyTorch's `wrap_triton()` currently does not support config pruning. It thus cannot be used for some of the kernels,
|
|
97
|
+
which could impact graph compilation.
|
|
98
|
+
- There seem to be some issues with autocasting, forcing some operations to manually cast.
|
|
99
|
+
- There will be some slight numerical differences between vanilla and blksprs operations.
|
|
100
|
+
These instabilities are due to Triton and thus cannot be fixed by this library alone.
|
|
101
|
+
However, for all intents and purposes, these very minor differences should not matter and can safely be ignored.
|
|
78
102
|
|
|
79
103
|
## Usage
|
|
80
104
|
|
|
@@ -100,10 +124,6 @@ def test_readme():
|
|
|
100
124
|
# Must be a power of two, greater than or equal to 16 for matmul, and divide m, n, and k
|
|
101
125
|
sparsity_block_size = 16
|
|
102
126
|
|
|
103
|
-
# Must be a power of two and smaller than or equal to sparsity_block_size
|
|
104
|
-
# If it is set to ``none`` a value will be chosen automatically
|
|
105
|
-
triton_block_size = None
|
|
106
|
-
|
|
107
127
|
# Initialise random (dense) tensors
|
|
108
128
|
x = torch.randn(size=(b, h, m, k), device="cuda")
|
|
109
129
|
y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()
|
|
@@ -113,53 +133,53 @@ def test_readme():
|
|
|
113
133
|
y_dense, y_shape_original = bs.utils.do_shape_blocksparse(y)
|
|
114
134
|
|
|
115
135
|
# Create sparsity layouts from existing tensors
|
|
116
|
-
sparsity_layout_x = bs.layouting.build_sparsity_layout(x_dense, sparsity_block_size
|
|
117
|
-
|
|
118
|
-
sparsity_layout_y = bs.layouting.build_sparsity_layout(y_dense, sparsity_block_size,
|
|
119
|
-
triton_block_size=triton_block_size)
|
|
136
|
+
sparsity_layout_x = bs.layouting.build_sparsity_layout(x_dense, sparsity_block_size)
|
|
137
|
+
sparsity_layout_y = bs.layouting.build_sparsity_layout(y_dense, sparsity_block_size)
|
|
120
138
|
|
|
121
139
|
# Create random sparsity layout for output tensor
|
|
122
140
|
sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)
|
|
123
141
|
|
|
124
142
|
# Convert tensors to sparse tensors for matrix multiplication
|
|
125
|
-
x_sparse = bs.to_sparse(x_dense, sparsity_layout_x, sparsity_block_size
|
|
126
|
-
y_sparse = bs.to_sparse(y_dense, sparsity_layout_y, sparsity_block_size
|
|
143
|
+
x_sparse = bs.ops.to_sparse(x_dense, sparsity_layout_x, sparsity_block_size)
|
|
144
|
+
y_sparse = bs.ops.to_sparse(y_dense, sparsity_layout_y, sparsity_block_size)
|
|
145
|
+
|
|
146
|
+
# As of version 2.0, blksprs supports JIT compilation
|
|
147
|
+
matmul_compiled = torch.compile(bs.ops.matmul)
|
|
127
148
|
|
|
128
149
|
# Perform matrix multiplication
|
|
129
|
-
o_sparse =
|
|
130
|
-
|
|
131
|
-
|
|
150
|
+
o_sparse = matmul_compiled(x_sparse, sparsity_layout_x,
|
|
151
|
+
y_sparse, sparsity_layout_y,
|
|
152
|
+
sparsity_layout_o, sparsity_block_size)
|
|
132
153
|
|
|
133
154
|
# Apply element-wise operation
|
|
134
155
|
o_sparse = torch.add(o_sparse, 1)
|
|
135
156
|
|
|
136
|
-
o_dense = bs.to_dense(o_sparse, sparsity_layout_o, sparsity_block_size
|
|
157
|
+
o_dense = bs.ops.to_dense(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
137
158
|
|
|
138
159
|
# Sanity check
|
|
139
160
|
o_torch = torch.matmul(x_dense, y_dense)
|
|
140
161
|
o_torch = torch.add(o_torch, 1)
|
|
141
162
|
|
|
142
163
|
# Perform round trip to set sparse blocks to 0
|
|
143
|
-
o_torch_round_trip = bs.to_dense(
|
|
144
|
-
bs.to_sparse(o_torch, sparsity_layout_o, sparsity_block_size
|
|
145
|
-
sparsity_layout_o, sparsity_block_size, fill_value=0
|
|
164
|
+
o_torch_round_trip = bs.ops.to_dense(
|
|
165
|
+
bs.ops.to_sparse(o_torch, sparsity_layout_o, sparsity_block_size),
|
|
166
|
+
sparsity_layout_o, sparsity_block_size, fill_value=0)
|
|
146
167
|
|
|
147
168
|
# Assert that the output is correct
|
|
148
169
|
assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2) # Note that small numerical differences are expected
|
|
149
170
|
|
|
150
171
|
# Assert that the output has the correct sparsity layout
|
|
151
|
-
actual_sparsity_layout_o = bs.layouting.build_sparsity_layout(o_dense, sparsity_block_size
|
|
152
|
-
triton_block_size=triton_block_size)
|
|
172
|
+
actual_sparsity_layout_o = bs.layouting.build_sparsity_layout(o_dense, sparsity_block_size)
|
|
153
173
|
assert torch.allclose(actual_sparsity_layout_o.to(torch.int), sparsity_layout_o)
|
|
154
174
|
|
|
155
175
|
# Convert output tensor back to original shape
|
|
156
176
|
o = bs.utils.undo_shape_blocksparse(o_dense, x_shape_original)
|
|
157
177
|
|
|
158
178
|
# Other available functions
|
|
159
|
-
bs.transpose(o_sparse, sparsity_layout_o, sparsity_block_size
|
|
160
|
-
bs.softmax(o_sparse, sparsity_layout_o, sparsity_block_size
|
|
161
|
-
bs.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size
|
|
162
|
-
bs.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size
|
|
179
|
+
bs.ops.transpose(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
180
|
+
bs.ops.softmax(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
181
|
+
bs.ops.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
182
|
+
bs.ops.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
163
183
|
|
|
164
184
|
|
|
165
185
|
def _get_random_sparsity_layout(b, m, n, sparsity_block_size, sparsity_percentage):
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
2
2
|
|
|
3
|
+
__version__ = "2.0"
|
|
4
|
+
|
|
3
5
|
|
|
4
6
|
class ops:
|
|
5
7
|
from blksprs.ops.conversion import to_dense, to_sparse, from_blksprs, to_blksprs, adapt_layout
|
|
@@ -18,19 +20,16 @@ class ops:
|
|
|
18
20
|
class layouting:
|
|
19
21
|
from blksprs.layouting.distribution_layout import build_distribution_layout
|
|
20
22
|
from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption, \
|
|
21
|
-
build_sparsity_layout_matmul, build_sparsity_layout_matmul_fast
|
|
22
|
-
from blksprs.utils.layout_utils import build_full_sparsity_layout
|
|
23
|
+
build_sparsity_layout_matmul, build_sparsity_layout_matmul_fast, build_sparsity_layout_full
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
class utils:
|
|
26
27
|
from blksprs.utils.processing import apply_torch_linear, apply_torch_normalisation, apply_torch_dropout, \
|
|
27
28
|
apply_function_applicable_row_wise
|
|
28
29
|
from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse
|
|
29
|
-
from blksprs.utils.validation import disable_validation
|
|
30
30
|
|
|
31
31
|
class validation:
|
|
32
32
|
from blksprs.utils.validation import disable_validation
|
|
33
33
|
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_dtype_float, \
|
|
34
34
|
validate_dtype_int, validate_device, validate_sparsity, validate_sparsity_dense, \
|
|
35
|
-
validate_sparsity_block_size
|
|
36
|
-
validate_triton_block_size
|
|
35
|
+
validate_sparsity_block_size
|
|
@@ -1,17 +1,23 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
|
|
1
3
|
import torch
|
|
2
4
|
import triton
|
|
3
5
|
from torch import Tensor
|
|
6
|
+
from torch._library import triton_op
|
|
7
|
+
from torch._library.triton import wrap_triton
|
|
4
8
|
from triton import language as tl
|
|
5
9
|
|
|
6
10
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
7
|
-
from blksprs.utils.tools import
|
|
8
|
-
from blksprs.utils.
|
|
11
|
+
from blksprs.utils.tools import stride
|
|
12
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
13
|
+
from blksprs.utils.validation import validate_dimensions, validate_device, \
|
|
9
14
|
validate_contiguous
|
|
10
15
|
|
|
11
16
|
|
|
17
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
12
18
|
def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: Tensor,
|
|
13
19
|
dim: int, size_target: torch.Size,
|
|
14
|
-
sparsity_block_size: int
|
|
20
|
+
sparsity_block_size: int) -> Tensor:
|
|
15
21
|
"""Builds the sparsity layout of either the source of a gather or the target of a scatter operation.
|
|
16
22
|
|
|
17
23
|
Args:
|
|
@@ -20,7 +26,6 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
|
|
|
20
26
|
dim (int): The dimension along which the operation is conducted.
|
|
21
27
|
size_target (torch.Size): The size of the block-sparse target tensor in regular form.
|
|
22
28
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
23
|
-
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
24
29
|
|
|
25
30
|
Returns:
|
|
26
31
|
Tensor: The sparsity layout of the source or target tensor.
|
|
@@ -34,49 +39,58 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
|
|
|
34
39
|
|
|
35
40
|
adjusted_dim = dim % 3
|
|
36
41
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
42
|
+
return build_distribution_layout_operation(indices, sparsity_lut_i, adjusted_dim, size_target, sparsity_block_size)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@triton_op("blksprs::build_distribution_layout", mutates_args={})
|
|
46
|
+
def build_distribution_layout_operation(indices: Tensor, sparsity_lut_i: Tensor,
|
|
47
|
+
adjusted_dim: int, size_target: typing.List[int],
|
|
48
|
+
sparsity_block_size: int) -> Tensor:
|
|
49
|
+
with torch.no_grad():
|
|
50
|
+
output = torch.zeros(size_target[0], size_target[1] // sparsity_block_size,
|
|
51
|
+
size_target[2] // sparsity_block_size,
|
|
52
|
+
dtype=torch.bool, device=indices.device)
|
|
53
|
+
|
|
54
|
+
i_b, i_r, i_c = indices.size()
|
|
55
|
+
i_b_s, i_r_s, i_c_s = stride(indices)
|
|
56
|
+
s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
|
|
57
|
+
s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
|
|
58
|
+
o_b, o_r, o_c = output.size()
|
|
59
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
60
|
+
|
|
61
|
+
triton_grid = lambda meta: [i_b,
|
|
62
|
+
triton.cdiv(i_r, meta["TRITON_BLOCK_SIZE"]),
|
|
63
|
+
triton.cdiv(i_c, meta["TRITON_BLOCK_SIZE"])]
|
|
64
|
+
|
|
65
|
+
(wrap_triton(build_distribution_layout_kernel)[triton_grid]
|
|
66
|
+
(indices,
|
|
67
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
68
|
+
sparsity_lut_i,
|
|
69
|
+
s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
|
|
70
|
+
adjusted_dim,
|
|
71
|
+
output,
|
|
72
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
73
|
+
sparsity_block_size))
|
|
74
|
+
|
|
75
|
+
return output
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@triton.autotune(
|
|
79
|
+
configs=get_autotune_configs(),
|
|
80
|
+
key=["sparsity_block_size"],
|
|
81
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
82
|
+
reset_to_zero=["o"]
|
|
83
|
+
)
|
|
70
84
|
@triton.jit
|
|
71
|
-
def
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
85
|
+
def build_distribution_layout_kernel(i,
|
|
86
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
87
|
+
s_lut_i,
|
|
88
|
+
s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
|
|
89
|
+
dim,
|
|
90
|
+
o,
|
|
91
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
92
|
+
sparsity_block_size,
|
|
93
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
80
94
|
# Get triton block indices
|
|
81
95
|
pid_blk = tl.program_id(axis=0)
|
|
82
96
|
pid_row = tl.program_id(axis=1)
|
|
@@ -98,7 +112,8 @@ def kernel_distribution_layout(i,
|
|
|
98
112
|
blk_i_idx = (pid_blk * i_b_s +
|
|
99
113
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
100
114
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
101
|
-
blk_i_msk = (blk_i_idx >= 0 and
|
|
115
|
+
blk_i_msk = (blk_i_idx >= 0 and
|
|
116
|
+
blk_i_idx < i_b * i_b_s)
|
|
102
117
|
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk)
|
|
103
118
|
|
|
104
119
|
dst_bat_idx = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_i, dtype=tl.int32)
|
|
@@ -116,5 +131,6 @@ def kernel_distribution_layout(i,
|
|
|
116
131
|
blk_o_idx = ((dst_bat_idx * o_b_s) +
|
|
117
132
|
(dst_row_idx * o_r_s) +
|
|
118
133
|
(dst_col_idx * o_c_s))
|
|
119
|
-
blk_o_msk = (blk_o_idx >= 0 and
|
|
134
|
+
blk_o_msk = (blk_o_idx >= 0 and
|
|
135
|
+
blk_o_idx < o_b * o_b_s)
|
|
120
136
|
tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
|