blksprs 1.10.2__tar.gz → 2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {blksprs-1.10.2 → blksprs-2.0}/PKG-INFO +55 -36
  2. {blksprs-1.10.2 → blksprs-2.0}/README.md +52 -32
  3. {blksprs-1.10.2 → blksprs-2.0}/blksprs/__init__.py +4 -6
  4. {blksprs-1.10.2 → blksprs-2.0}/blksprs/layouting/distribution_layout.py +64 -48
  5. {blksprs-1.10.2 → blksprs-2.0}/blksprs/layouting/sparsity_layout.py +96 -72
  6. blksprs-2.0/blksprs/ops/conversion.py +506 -0
  7. blksprs-2.0/blksprs/ops/distribution.py +482 -0
  8. blksprs-2.0/blksprs/ops/flow.py +192 -0
  9. blksprs-2.0/blksprs/ops/matmul.py +260 -0
  10. {blksprs-1.10.2 → blksprs-2.0}/blksprs/ops/misc/broadcast_ops.py +68 -53
  11. blksprs-2.0/blksprs/ops/misc/row_wise.py +445 -0
  12. blksprs-2.0/blksprs/ops/partitioning.py +221 -0
  13. blksprs-2.0/blksprs/ops/repeat.py +194 -0
  14. blksprs-2.0/blksprs/ops/softmax.py +304 -0
  15. blksprs-2.0/blksprs/ops/transpose.py +100 -0
  16. blksprs-2.0/blksprs/utils/autotuning.py +78 -0
  17. {blksprs-1.10.2 → blksprs-2.0}/blksprs/utils/benchmarking.py +3 -3
  18. {blksprs-1.10.2 → blksprs-2.0}/blksprs/utils/processing.py +2 -1
  19. {blksprs-1.10.2 → blksprs-2.0}/blksprs/utils/tools.py +5 -6
  20. {blksprs-1.10.2 → blksprs-2.0}/blksprs/utils/validation.py +22 -16
  21. {blksprs-1.10.2 → blksprs-2.0}/blksprs.egg-info/PKG-INFO +55 -36
  22. {blksprs-1.10.2 → blksprs-2.0}/blksprs.egg-info/SOURCES.txt +1 -2
  23. {blksprs-1.10.2 → blksprs-2.0}/blksprs.egg-info/requires.txt +1 -3
  24. {blksprs-1.10.2 → blksprs-2.0}/pyproject.toml +2 -4
  25. blksprs-1.10.2/blksprs/ops/conversion.py +0 -468
  26. blksprs-1.10.2/blksprs/ops/distribution.py +0 -428
  27. blksprs-1.10.2/blksprs/ops/flow.py +0 -146
  28. blksprs-1.10.2/blksprs/ops/matmul.py +0 -227
  29. blksprs-1.10.2/blksprs/ops/misc/exp.py +0 -104
  30. blksprs-1.10.2/blksprs/ops/misc/row_wise.py +0 -398
  31. blksprs-1.10.2/blksprs/ops/partitioning.py +0 -170
  32. blksprs-1.10.2/blksprs/ops/repeat.py +0 -184
  33. blksprs-1.10.2/blksprs/ops/softmax.py +0 -270
  34. blksprs-1.10.2/blksprs/ops/transpose.py +0 -160
  35. blksprs-1.10.2/blksprs/utils/layout_utils.py +0 -17
  36. {blksprs-1.10.2 → blksprs-2.0}/blksprs/utils/blksprs_tensor.py +0 -0
  37. {blksprs-1.10.2 → blksprs-2.0}/blksprs.egg-info/dependency_links.txt +0 -0
  38. {blksprs-1.10.2 → blksprs-2.0}/blksprs.egg-info/top_level.txt +0 -0
  39. {blksprs-1.10.2 → blksprs-2.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: blksprs
3
- Version: 1.10.2
3
+ Version: 2.0
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -14,9 +14,8 @@ Requires-Dist: pytest; extra == "test"
14
14
  Requires-Dist: pytest-xdist; extra == "test"
15
15
  Requires-Dist: pytest-cov; extra == "test"
16
16
  Requires-Dist: coverage; extra == "test"
17
+ Requires-Dist: build; extra == "test"
17
18
  Requires-Dist: matplotlib; extra == "test"
18
- Provides-Extra: build
19
- Requires-Dist: build; extra == "build"
20
19
 
21
20
  # blksprs
22
21
 
@@ -25,6 +24,13 @@ Requires-Dist: build; extra == "build"
25
24
 
26
25
  ## Overview
27
26
 
27
+ ### News
28
+
29
+ 🎉 ***Version 2.0 released***. blksprs now supports kernel auto-tuning, JIT compilation, specification of pre-calculated
30
+ LUTs, autocasting, and makes use of `torch.library.triton_op()`!
31
+
32
+ ---
33
+
28
34
  A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
29
35
 
30
36
  Currently supported operations (includes gradient calculation):
@@ -52,23 +58,25 @@ These include, e.g.,
52
58
  Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to
53
59
  match.
54
60
 
55
- Further helpful operations (included in the ``bs.ops.misc`` module) that do **not** support gradient calculation include:
61
+ Further helpful operations (included in the ``bs.ops.misc`` module) that do **not** support gradient calculation
62
+ include:
56
63
 
57
64
  - Row-wise sum, max, addition, and subtraction
58
65
  - Broadcast addition and subtraction between slices
59
66
 
60
- Furthermore, the library provides a set of utility functions
67
+ Furthermore, the library provides a set of utility functions
61
68
 
62
69
  - for the creation of sparsity layouts based on existing
63
- dense tensors and for the scatter operation (module ``bs.layouting``),
70
+ dense tensors and for the scatter operation (module ``bs.layouting``),
64
71
  - for the application of ``nn.Linear``, ``nn.Dropout``, and ``nn.LayerNorm`` layers to block-sparse tensors,
65
72
  - as well as utility functions to ensure correct input dimensionality, and validate input (module ``bs.utils``).
66
73
 
67
- _* see the [Roadmap](#roadmap) section for more information_
74
+ _* see the [Roadmap](#roadmap) section for more information_
68
75
 
69
76
  ## Installation
70
77
 
71
- Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible with
78
+ Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible
79
+ with
72
80
  the Linux platform**.
73
81
  Keep track of this [issue](https://github.com/triton-lang/triton/issues/1640) for updates.
74
82
 
@@ -78,8 +86,8 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
78
86
 
79
87
  ### Dependencies
80
88
 
81
- - [PyTorch](https://pytorch.org/) (built with v2.5.1)
82
- - _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.2.0)_
89
+ - [PyTorch](https://pytorch.org/) (built with v2.6)
90
+ - _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.2.4)_
83
91
  - _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
84
92
 
85
93
  ## Changelog
@@ -89,12 +97,27 @@ See [`CHANGELOG.md`](https://github.com/FelixSchoen/blksprs/blob/main/CHANGELOG.
89
97
  ## Roadmap
90
98
 
91
99
  Note that since this library covers all our current needs it is in a **bugfix-only** state.
92
- This means that there are no plans to add new features, e.g., support for dimension specification of the ``split`` and ``merge`` operations.
100
+ This means that there are no plans to add new features, e.g., support for dimension specification of the ``split`` and
101
+ ``merge`` operations.
93
102
  We will continue to maintain the library and fix any issues that arise.
94
103
  Should you find any bugs please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
95
104
  We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
96
105
 
97
- It might be that this changes with future projects, but as of December 2024, we are content with the current state of the library.
106
+ It might be that this changes with future projects, but as of March 2025, we are content with the current state of the
107
+ library.
108
+
109
+ ## Known Limitations and Issues
110
+
111
+ - Triton has a bug with `tl.atomix_max()` used for the row-wise max operation.
112
+ In order to work around this bug a manual conversion of some values is needed, (slightly) negatively impacting
113
+ performance.
114
+ Watch the [issue](https://github.com/triton-lang/triton/issues/6376) on Triton's issue tracker for more information.
115
+ - PyTorch's `wrap_triton()` currently does not support config pruning. It thus cannot be used for some of the kernels,
116
+ which could impact graph compilation.
117
+ - There seem to be some issues with autocasting, forcing some operations to manually cast.
118
+ - There will be some slight numerical differences between vanilla and blksprs operations.
119
+ These instabilities are due to Triton and thus cannot be fixed by this library alone.
120
+ However, for all intents and purposes, these very minor differences should not matter and can safely be ignored.
98
121
 
99
122
  ## Usage
100
123
 
@@ -120,10 +143,6 @@ def test_readme():
120
143
  # Must be a power of two, greater than or equal to 16 for matmul, and divide m, n, and k
121
144
  sparsity_block_size = 16
122
145
 
123
- # Must be a power of two and smaller than or equal to sparsity_block_size
124
- # If it is set to ``none`` a value will be chosen automatically
125
- triton_block_size = None
126
-
127
146
  # Initialise random (dense) tensors
128
147
  x = torch.randn(size=(b, h, m, k), device="cuda")
129
148
  y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()
@@ -133,53 +152,53 @@ def test_readme():
133
152
  y_dense, y_shape_original = bs.utils.do_shape_blocksparse(y)
134
153
 
135
154
  # Create sparsity layouts from existing tensors
136
- sparsity_layout_x = bs.layouting.build_sparsity_layout(x_dense, sparsity_block_size,
137
- triton_block_size=triton_block_size)
138
- sparsity_layout_y = bs.layouting.build_sparsity_layout(y_dense, sparsity_block_size,
139
- triton_block_size=triton_block_size)
155
+ sparsity_layout_x = bs.layouting.build_sparsity_layout(x_dense, sparsity_block_size)
156
+ sparsity_layout_y = bs.layouting.build_sparsity_layout(y_dense, sparsity_block_size)
140
157
 
141
158
  # Create random sparsity layout for output tensor
142
159
  sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)
143
160
 
144
161
  # Convert tensors to sparse tensors for matrix multiplication
145
- x_sparse = bs.to_sparse(x_dense, sparsity_layout_x, sparsity_block_size, triton_block_size=triton_block_size)
146
- y_sparse = bs.to_sparse(y_dense, sparsity_layout_y, sparsity_block_size, triton_block_size=triton_block_size)
162
+ x_sparse = bs.ops.to_sparse(x_dense, sparsity_layout_x, sparsity_block_size)
163
+ y_sparse = bs.ops.to_sparse(y_dense, sparsity_layout_y, sparsity_block_size)
164
+
165
+ # As of version 2.0, blksprs supports JIT compilation
166
+ matmul_compiled = torch.compile(bs.ops.matmul)
147
167
 
148
168
  # Perform matrix multiplication
149
- o_sparse = bs.matmul(x_sparse, sparsity_layout_x, y_sparse, sparsity_layout_y, sparsity_layout_o,
150
- sparsity_block_size,
151
- triton_block_size=triton_block_size)
169
+ o_sparse = matmul_compiled(x_sparse, sparsity_layout_x,
170
+ y_sparse, sparsity_layout_y,
171
+ sparsity_layout_o, sparsity_block_size)
152
172
 
153
173
  # Apply element-wise operation
154
174
  o_sparse = torch.add(o_sparse, 1)
155
175
 
156
- o_dense = bs.to_dense(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
176
+ o_dense = bs.ops.to_dense(o_sparse, sparsity_layout_o, sparsity_block_size)
157
177
 
158
178
  # Sanity check
159
179
  o_torch = torch.matmul(x_dense, y_dense)
160
180
  o_torch = torch.add(o_torch, 1)
161
181
 
162
182
  # Perform round trip to set sparse blocks to 0
163
- o_torch_round_trip = bs.to_dense(
164
- bs.to_sparse(o_torch, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size),
165
- sparsity_layout_o, sparsity_block_size, fill_value=0, triton_block_size=triton_block_size)
183
+ o_torch_round_trip = bs.ops.to_dense(
184
+ bs.ops.to_sparse(o_torch, sparsity_layout_o, sparsity_block_size),
185
+ sparsity_layout_o, sparsity_block_size, fill_value=0)
166
186
 
167
187
  # Assert that the output is correct
168
188
  assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2) # Note that small numerical differences are expected
169
189
 
170
190
  # Assert that the output has the correct sparsity layout
171
- actual_sparsity_layout_o = bs.layouting.build_sparsity_layout(o_dense, sparsity_block_size,
172
- triton_block_size=triton_block_size)
191
+ actual_sparsity_layout_o = bs.layouting.build_sparsity_layout(o_dense, sparsity_block_size)
173
192
  assert torch.allclose(actual_sparsity_layout_o.to(torch.int), sparsity_layout_o)
174
193
 
175
194
  # Convert output tensor back to original shape
176
195
  o = bs.utils.undo_shape_blocksparse(o_dense, x_shape_original)
177
196
 
178
197
  # Other available functions
179
- bs.transpose(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
180
- bs.softmax(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
181
- bs.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
182
- bs.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
198
+ bs.ops.transpose(o_sparse, sparsity_layout_o, sparsity_block_size)
199
+ bs.ops.softmax(o_sparse, sparsity_layout_o, sparsity_block_size)
200
+ bs.ops.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size)
201
+ bs.ops.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size)
183
202
 
184
203
 
185
204
  def _get_random_sparsity_layout(b, m, n, sparsity_block_size, sparsity_percentage):
@@ -5,6 +5,13 @@
5
5
 
6
6
  ## Overview
7
7
 
8
+ ### News
9
+
10
+ 🎉 ***Version 2.0 released***. blksprs now supports kernel auto-tuning, JIT compilation, specification of pre-calculated
11
+ LUTs, autocasting, and makes use of `torch.library.triton_op()`!
12
+
13
+ ---
14
+
8
15
  A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
9
16
 
10
17
  Currently supported operations (includes gradient calculation):
@@ -32,23 +39,25 @@ These include, e.g.,
32
39
  Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to
33
40
  match.
34
41
 
35
- Further helpful operations (included in the ``bs.ops.misc`` module) that do **not** support gradient calculation include:
42
+ Further helpful operations (included in the ``bs.ops.misc`` module) that do **not** support gradient calculation
43
+ include:
36
44
 
37
45
  - Row-wise sum, max, addition, and subtraction
38
46
  - Broadcast addition and subtraction between slices
39
47
 
40
- Furthermore, the library provides a set of utility functions
48
+ Furthermore, the library provides a set of utility functions
41
49
 
42
50
  - for the creation of sparsity layouts based on existing
43
- dense tensors and for the scatter operation (module ``bs.layouting``),
51
+ dense tensors and for the scatter operation (module ``bs.layouting``),
44
52
  - for the application of ``nn.Linear``, ``nn.Dropout``, and ``nn.LayerNorm`` layers to block-sparse tensors,
45
53
  - as well as utility functions to ensure correct input dimensionality, and validate input (module ``bs.utils``).
46
54
 
47
- _* see the [Roadmap](#roadmap) section for more information_
55
+ _* see the [Roadmap](#roadmap) section for more information_
48
56
 
49
57
  ## Installation
50
58
 
51
- Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible with
59
+ Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible
60
+ with
52
61
  the Linux platform**.
53
62
  Keep track of this [issue](https://github.com/triton-lang/triton/issues/1640) for updates.
54
63
 
@@ -58,8 +67,8 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
58
67
 
59
68
  ### Dependencies
60
69
 
61
- - [PyTorch](https://pytorch.org/) (built with v2.5.1)
62
- - _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.2.0)_
70
+ - [PyTorch](https://pytorch.org/) (built with v2.6)
71
+ - _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.2.4)_
63
72
  - _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
64
73
 
65
74
  ## Changelog
@@ -69,12 +78,27 @@ See [`CHANGELOG.md`](https://github.com/FelixSchoen/blksprs/blob/main/CHANGELOG.
69
78
  ## Roadmap
70
79
 
71
80
  Note that since this library covers all our current needs it is in a **bugfix-only** state.
72
- This means that there are no plans to add new features, e.g., support for dimension specification of the ``split`` and ``merge`` operations.
81
+ This means that there are no plans to add new features, e.g., support for dimension specification of the ``split`` and
82
+ ``merge`` operations.
73
83
  We will continue to maintain the library and fix any issues that arise.
74
84
  Should you find any bugs please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
75
85
  We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
76
86
 
77
- It might be that this changes with future projects, but as of December 2024, we are content with the current state of the library.
87
+ It might be that this changes with future projects, but as of March 2025, we are content with the current state of the
88
+ library.
89
+
90
+ ## Known Limitations and Issues
91
+
92
+ - Triton has a bug with `tl.atomix_max()` used for the row-wise max operation.
93
+ In order to work around this bug a manual conversion of some values is needed, (slightly) negatively impacting
94
+ performance.
95
+ Watch the [issue](https://github.com/triton-lang/triton/issues/6376) on Triton's issue tracker for more information.
96
+ - PyTorch's `wrap_triton()` currently does not support config pruning. It thus cannot be used for some of the kernels,
97
+ which could impact graph compilation.
98
+ - There seem to be some issues with autocasting, forcing some operations to manually cast.
99
+ - There will be some slight numerical differences between vanilla and blksprs operations.
100
+ These instabilities are due to Triton and thus cannot be fixed by this library alone.
101
+ However, for all intents and purposes, these very minor differences should not matter and can safely be ignored.
78
102
 
79
103
  ## Usage
80
104
 
@@ -100,10 +124,6 @@ def test_readme():
100
124
  # Must be a power of two, greater than or equal to 16 for matmul, and divide m, n, and k
101
125
  sparsity_block_size = 16
102
126
 
103
- # Must be a power of two and smaller than or equal to sparsity_block_size
104
- # If it is set to ``none`` a value will be chosen automatically
105
- triton_block_size = None
106
-
107
127
  # Initialise random (dense) tensors
108
128
  x = torch.randn(size=(b, h, m, k), device="cuda")
109
129
  y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()
@@ -113,53 +133,53 @@ def test_readme():
113
133
  y_dense, y_shape_original = bs.utils.do_shape_blocksparse(y)
114
134
 
115
135
  # Create sparsity layouts from existing tensors
116
- sparsity_layout_x = bs.layouting.build_sparsity_layout(x_dense, sparsity_block_size,
117
- triton_block_size=triton_block_size)
118
- sparsity_layout_y = bs.layouting.build_sparsity_layout(y_dense, sparsity_block_size,
119
- triton_block_size=triton_block_size)
136
+ sparsity_layout_x = bs.layouting.build_sparsity_layout(x_dense, sparsity_block_size)
137
+ sparsity_layout_y = bs.layouting.build_sparsity_layout(y_dense, sparsity_block_size)
120
138
 
121
139
  # Create random sparsity layout for output tensor
122
140
  sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)
123
141
 
124
142
  # Convert tensors to sparse tensors for matrix multiplication
125
- x_sparse = bs.to_sparse(x_dense, sparsity_layout_x, sparsity_block_size, triton_block_size=triton_block_size)
126
- y_sparse = bs.to_sparse(y_dense, sparsity_layout_y, sparsity_block_size, triton_block_size=triton_block_size)
143
+ x_sparse = bs.ops.to_sparse(x_dense, sparsity_layout_x, sparsity_block_size)
144
+ y_sparse = bs.ops.to_sparse(y_dense, sparsity_layout_y, sparsity_block_size)
145
+
146
+ # As of version 2.0, blksprs supports JIT compilation
147
+ matmul_compiled = torch.compile(bs.ops.matmul)
127
148
 
128
149
  # Perform matrix multiplication
129
- o_sparse = bs.matmul(x_sparse, sparsity_layout_x, y_sparse, sparsity_layout_y, sparsity_layout_o,
130
- sparsity_block_size,
131
- triton_block_size=triton_block_size)
150
+ o_sparse = matmul_compiled(x_sparse, sparsity_layout_x,
151
+ y_sparse, sparsity_layout_y,
152
+ sparsity_layout_o, sparsity_block_size)
132
153
 
133
154
  # Apply element-wise operation
134
155
  o_sparse = torch.add(o_sparse, 1)
135
156
 
136
- o_dense = bs.to_dense(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
157
+ o_dense = bs.ops.to_dense(o_sparse, sparsity_layout_o, sparsity_block_size)
137
158
 
138
159
  # Sanity check
139
160
  o_torch = torch.matmul(x_dense, y_dense)
140
161
  o_torch = torch.add(o_torch, 1)
141
162
 
142
163
  # Perform round trip to set sparse blocks to 0
143
- o_torch_round_trip = bs.to_dense(
144
- bs.to_sparse(o_torch, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size),
145
- sparsity_layout_o, sparsity_block_size, fill_value=0, triton_block_size=triton_block_size)
164
+ o_torch_round_trip = bs.ops.to_dense(
165
+ bs.ops.to_sparse(o_torch, sparsity_layout_o, sparsity_block_size),
166
+ sparsity_layout_o, sparsity_block_size, fill_value=0)
146
167
 
147
168
  # Assert that the output is correct
148
169
  assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2) # Note that small numerical differences are expected
149
170
 
150
171
  # Assert that the output has the correct sparsity layout
151
- actual_sparsity_layout_o = bs.layouting.build_sparsity_layout(o_dense, sparsity_block_size,
152
- triton_block_size=triton_block_size)
172
+ actual_sparsity_layout_o = bs.layouting.build_sparsity_layout(o_dense, sparsity_block_size)
153
173
  assert torch.allclose(actual_sparsity_layout_o.to(torch.int), sparsity_layout_o)
154
174
 
155
175
  # Convert output tensor back to original shape
156
176
  o = bs.utils.undo_shape_blocksparse(o_dense, x_shape_original)
157
177
 
158
178
  # Other available functions
159
- bs.transpose(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
160
- bs.softmax(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
161
- bs.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
162
- bs.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
179
+ bs.ops.transpose(o_sparse, sparsity_layout_o, sparsity_block_size)
180
+ bs.ops.softmax(o_sparse, sparsity_layout_o, sparsity_block_size)
181
+ bs.ops.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size)
182
+ bs.ops.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size)
163
183
 
164
184
 
165
185
  def _get_random_sparsity_layout(b, m, n, sparsity_block_size, sparsity_percentage):
@@ -1,5 +1,7 @@
1
1
  from blksprs.utils.blksprs_tensor import BlksprsTensor
2
2
 
3
+ __version__ = "2.0"
4
+
3
5
 
4
6
  class ops:
5
7
  from blksprs.ops.conversion import to_dense, to_sparse, from_blksprs, to_blksprs, adapt_layout
@@ -13,25 +15,21 @@ class ops:
13
15
  class misc:
14
16
  from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_add, row_wise_sub
15
17
  from blksprs.ops.misc.broadcast_ops import broadcast_add, broadcast_sub
16
- from blksprs.ops.misc.exp import exp
17
18
 
18
19
 
19
20
  class layouting:
20
21
  from blksprs.layouting.distribution_layout import build_distribution_layout
21
22
  from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption, \
22
- build_sparsity_layout_matmul, build_sparsity_layout_matmul_fast
23
- from blksprs.utils.layout_utils import build_full_sparsity_layout
23
+ build_sparsity_layout_matmul, build_sparsity_layout_matmul_fast, build_sparsity_layout_full
24
24
 
25
25
 
26
26
  class utils:
27
27
  from blksprs.utils.processing import apply_torch_linear, apply_torch_normalisation, apply_torch_dropout, \
28
28
  apply_function_applicable_row_wise
29
29
  from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse
30
- from blksprs.utils.validation import disable_validation
31
30
 
32
31
  class validation:
33
32
  from blksprs.utils.validation import disable_validation
34
33
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_dtype_float, \
35
34
  validate_dtype_int, validate_device, validate_sparsity, validate_sparsity_dense, \
36
- validate_sparsity_block_size, \
37
- validate_triton_block_size
35
+ validate_sparsity_block_size
@@ -1,17 +1,23 @@
1
+ import typing
2
+
1
3
  import torch
2
4
  import triton
3
5
  from torch import Tensor
6
+ from torch._library import triton_op
7
+ from torch._library.triton import wrap_triton
4
8
  from triton import language as tl
5
9
 
6
10
  from blksprs.utils.blksprs_tensor import BlksprsTensor
7
- from blksprs.utils.tools import get_triton_block_size, stride
8
- from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
11
+ from blksprs.utils.tools import stride
12
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
13
+ from blksprs.utils.validation import validate_dimensions, validate_device, \
9
14
  validate_contiguous
10
15
 
11
16
 
17
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
12
18
  def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: Tensor,
13
19
  dim: int, size_target: torch.Size,
14
- sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
20
+ sparsity_block_size: int) -> Tensor:
15
21
  """Builds the sparsity layout of either the source of a gather or the target of a scatter operation.
16
22
 
17
23
  Args:
@@ -20,7 +26,6 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
20
26
  dim (int): The dimension along which the operation is conducted.
21
27
  size_target (torch.Size): The size of the block-sparse target tensor in regular form.
22
28
  sparsity_block_size (int): The size of the sparsity blocks.
23
- triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
24
29
 
25
30
  Returns:
26
31
  Tensor: The sparsity layout of the source or target tensor.
@@ -34,49 +39,58 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
34
39
 
35
40
  adjusted_dim = dim % 3
36
41
 
37
- output = torch.zeros(size_target[0], size_target[1] // sparsity_block_size, size_target[2] // sparsity_block_size,
38
- dtype=torch.bool, device=indices.device)
39
-
40
- i_b, i_r, i_c = indices.size()
41
- i_b_s, i_r_s, i_c_s = stride(indices)
42
- s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
43
- s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
44
- o_b, o_r, o_c = output.size()
45
- o_b_s, o_r_s, o_c_s = stride(output)
46
-
47
- if triton_block_size is None:
48
- triton_block_size = get_triton_block_size(sparsity_block_size)
49
-
50
- validate_triton_block_size(triton_block_size, sparsity_block_size)
51
-
52
- triton_grid = lambda meta: [i_b,
53
- triton.cdiv(i_r, meta["TRITON_BLOCK_SIZE"]),
54
- triton.cdiv(i_c, meta["TRITON_BLOCK_SIZE"])]
55
-
56
- (kernel_distribution_layout[triton_grid]
57
- (indices,
58
- i_b, i_b_s, i_r_s, i_c_s,
59
- sparsity_lut_i,
60
- s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
61
- adjusted_dim,
62
- output,
63
- o_b, o_b_s, o_r_s, o_c_s,
64
- sparsity_block_size,
65
- triton_block_size))
66
-
67
- return output
68
-
69
-
42
+ return build_distribution_layout_operation(indices, sparsity_lut_i, adjusted_dim, size_target, sparsity_block_size)
43
+
44
+
45
+ @triton_op("blksprs::build_distribution_layout", mutates_args={})
46
+ def build_distribution_layout_operation(indices: Tensor, sparsity_lut_i: Tensor,
47
+ adjusted_dim: int, size_target: typing.List[int],
48
+ sparsity_block_size: int) -> Tensor:
49
+ with torch.no_grad():
50
+ output = torch.zeros(size_target[0], size_target[1] // sparsity_block_size,
51
+ size_target[2] // sparsity_block_size,
52
+ dtype=torch.bool, device=indices.device)
53
+
54
+ i_b, i_r, i_c = indices.size()
55
+ i_b_s, i_r_s, i_c_s = stride(indices)
56
+ s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
57
+ s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
58
+ o_b, o_r, o_c = output.size()
59
+ o_b_s, o_r_s, o_c_s = stride(output)
60
+
61
+ triton_grid = lambda meta: [i_b,
62
+ triton.cdiv(i_r, meta["TRITON_BLOCK_SIZE"]),
63
+ triton.cdiv(i_c, meta["TRITON_BLOCK_SIZE"])]
64
+
65
+ (wrap_triton(build_distribution_layout_kernel)[triton_grid]
66
+ (indices,
67
+ i_b, i_b_s, i_r_s, i_c_s,
68
+ sparsity_lut_i,
69
+ s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
70
+ adjusted_dim,
71
+ output,
72
+ o_b, o_b_s, o_r_s, o_c_s,
73
+ sparsity_block_size))
74
+
75
+ return output
76
+
77
+
78
+ @triton.autotune(
79
+ configs=get_autotune_configs(),
80
+ key=["sparsity_block_size"],
81
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
82
+ reset_to_zero=["o"]
83
+ )
70
84
  @triton.jit
71
- def kernel_distribution_layout(i,
72
- i_b, i_b_s, i_r_s, i_c_s,
73
- s_lut_i,
74
- s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
75
- dim,
76
- o,
77
- o_b, o_b_s, o_r_s, o_c_s,
78
- sparsity_block_size,
79
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
85
+ def build_distribution_layout_kernel(i,
86
+ i_b, i_b_s, i_r_s, i_c_s,
87
+ s_lut_i,
88
+ s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
89
+ dim,
90
+ o,
91
+ o_b, o_b_s, o_r_s, o_c_s,
92
+ sparsity_block_size,
93
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
80
94
  # Get triton block indices
81
95
  pid_blk = tl.program_id(axis=0)
82
96
  pid_row = tl.program_id(axis=1)
@@ -98,7 +112,8 @@ def kernel_distribution_layout(i,
98
112
  blk_i_idx = (pid_blk * i_b_s +
99
113
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
100
114
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
101
- blk_i_msk = (blk_i_idx >= 0 and blk_i_idx < i_b * i_b_s)
115
+ blk_i_msk = (blk_i_idx >= 0 and
116
+ blk_i_idx < i_b * i_b_s)
102
117
  blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk)
103
118
 
104
119
  dst_bat_idx = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_i, dtype=tl.int32)
@@ -116,5 +131,6 @@ def kernel_distribution_layout(i,
116
131
  blk_o_idx = ((dst_bat_idx * o_b_s) +
117
132
  (dst_row_idx * o_r_s) +
118
133
  (dst_col_idx * o_c_s))
119
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
134
+ blk_o_msk = (blk_o_idx >= 0 and
135
+ blk_o_idx < o_b * o_b_s)
120
136
  tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)