blksprs 1.11__tar.gz → 2.0rc2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {blksprs-1.11 → blksprs-2.0rc2}/PKG-INFO +42 -36
  2. {blksprs-1.11 → blksprs-2.0rc2}/README.md +39 -32
  3. {blksprs-1.11 → blksprs-2.0rc2}/blksprs/__init__.py +2 -5
  4. {blksprs-1.11 → blksprs-2.0rc2}/blksprs/layouting/distribution_layout.py +32 -25
  5. {blksprs-1.11 → blksprs-2.0rc2}/blksprs/layouting/sparsity_layout.py +65 -52
  6. blksprs-2.0rc2/blksprs/ops/conversion.py +517 -0
  7. blksprs-2.0rc2/blksprs/ops/distribution.py +496 -0
  8. {blksprs-1.11 → blksprs-2.0rc2}/blksprs/ops/flow.py +125 -106
  9. blksprs-2.0rc2/blksprs/ops/matmul.py +264 -0
  10. {blksprs-1.11 → blksprs-2.0rc2}/blksprs/ops/misc/broadcast_ops.py +53 -35
  11. {blksprs-1.11 → blksprs-2.0rc2}/blksprs/ops/misc/row_wise.py +151 -91
  12. blksprs-2.0rc2/blksprs/ops/partitioning.py +217 -0
  13. blksprs-2.0rc2/blksprs/ops/repeat.py +191 -0
  14. blksprs-2.0rc2/blksprs/ops/softmax.py +306 -0
  15. blksprs-2.0rc2/blksprs/ops/transpose.py +98 -0
  16. {blksprs-1.11 → blksprs-2.0rc2}/blksprs/utils/benchmarking.py +3 -3
  17. blksprs-2.0rc2/blksprs/utils/tools.py +56 -0
  18. {blksprs-1.11 → blksprs-2.0rc2}/blksprs/utils/validation.py +4 -18
  19. {blksprs-1.11 → blksprs-2.0rc2}/blksprs.egg-info/PKG-INFO +42 -36
  20. {blksprs-1.11 → blksprs-2.0rc2}/blksprs.egg-info/SOURCES.txt +0 -1
  21. {blksprs-1.11 → blksprs-2.0rc2}/blksprs.egg-info/requires.txt +1 -3
  22. {blksprs-1.11 → blksprs-2.0rc2}/pyproject.toml +2 -4
  23. blksprs-1.11/blksprs/ops/conversion.py +0 -495
  24. blksprs-1.11/blksprs/ops/distribution.py +0 -458
  25. blksprs-1.11/blksprs/ops/matmul.py +0 -245
  26. blksprs-1.11/blksprs/ops/partitioning.py +0 -213
  27. blksprs-1.11/blksprs/ops/repeat.py +0 -196
  28. blksprs-1.11/blksprs/ops/softmax.py +0 -278
  29. blksprs-1.11/blksprs/ops/transpose.py +0 -97
  30. blksprs-1.11/blksprs/utils/layout_utils.py +0 -17
  31. blksprs-1.11/blksprs/utils/tools.py +0 -29
  32. {blksprs-1.11 → blksprs-2.0rc2}/blksprs/utils/blksprs_tensor.py +0 -0
  33. {blksprs-1.11 → blksprs-2.0rc2}/blksprs/utils/processing.py +0 -0
  34. {blksprs-1.11 → blksprs-2.0rc2}/blksprs.egg-info/dependency_links.txt +0 -0
  35. {blksprs-1.11 → blksprs-2.0rc2}/blksprs.egg-info/top_level.txt +0 -0
  36. {blksprs-1.11 → blksprs-2.0rc2}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: blksprs
3
- Version: 1.11
3
+ Version: 2.0rc2
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -14,9 +14,8 @@ Requires-Dist: pytest; extra == "test"
14
14
  Requires-Dist: pytest-xdist; extra == "test"
15
15
  Requires-Dist: pytest-cov; extra == "test"
16
16
  Requires-Dist: coverage; extra == "test"
17
+ Requires-Dist: build; extra == "test"
17
18
  Requires-Dist: matplotlib; extra == "test"
18
- Provides-Extra: build
19
- Requires-Dist: build; extra == "build"
20
19
 
21
20
  # blksprs
22
21
 
@@ -25,6 +24,13 @@ Requires-Dist: build; extra == "build"
25
24
 
26
25
  ## Overview
27
26
 
27
+ ### News
28
+
29
+ 🎉 ***Version 2.0 released***. blksprs now supports kernel auto-tuning, JIT compilation, specification of pre-calculated
30
+ LUTs, and makes use of `torch.library.triton_op()`!
31
+
32
+ ---
33
+
28
34
  A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
29
35
 
30
36
  Currently supported operations (includes gradient calculation):
@@ -52,23 +58,25 @@ These include, e.g.,
52
58
  Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to
53
59
  match.
54
60
 
55
- Further helpful operations (included in the ``bs.ops.misc`` module) that do **not** support gradient calculation include:
61
+ Further helpful operations (included in the ``bs.ops.misc`` module) that do **not** support gradient calculation
62
+ include:
56
63
 
57
64
  - Row-wise sum, max, addition, and subtraction
58
65
  - Broadcast addition and subtraction between slices
59
66
 
60
- Furthermore, the library provides a set of utility functions
67
+ Furthermore, the library provides a set of utility functions
61
68
 
62
69
  - for the creation of sparsity layouts based on existing
63
- dense tensors and for the scatter operation (module ``bs.layouting``),
70
+ dense tensors and for the scatter operation (module ``bs.layouting``),
64
71
  - for the application of ``nn.Linear``, ``nn.Dropout``, and ``nn.LayerNorm`` layers to block-sparse tensors,
65
72
  - as well as utility functions to ensure correct input dimensionality, and validate input (module ``bs.utils``).
66
73
 
67
- _* see the [Roadmap](#roadmap) section for more information_
74
+ _* see the [Roadmap](#roadmap) section for more information_
68
75
 
69
76
  ## Installation
70
77
 
71
- Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible with
78
+ Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible
79
+ with
72
80
  the Linux platform**.
73
81
  Keep track of this [issue](https://github.com/triton-lang/triton/issues/1640) for updates.
74
82
 
@@ -78,8 +86,8 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
78
86
 
79
87
  ### Dependencies
80
88
 
81
- - [PyTorch](https://pytorch.org/) (built with v2.5.1)
82
- - _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.2.0)_
89
+ - [PyTorch](https://pytorch.org/) (built with v2.6)
90
+ - _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.2.4)_
83
91
  - _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
84
92
 
85
93
  ## Changelog
@@ -89,12 +97,14 @@ See [`CHANGELOG.md`](https://github.com/FelixSchoen/blksprs/blob/main/CHANGELOG.
89
97
  ## Roadmap
90
98
 
91
99
  Note that since this library covers all our current needs it is in a **bugfix-only** state.
92
- This means that there are no plans to add new features, e.g., support for dimension specification of the ``split`` and ``merge`` operations.
100
+ This means that there are no plans to add new features, e.g., support for dimension specification of the ``split`` and
101
+ ``merge`` operations.
93
102
  We will continue to maintain the library and fix any issues that arise.
94
103
  Should you find any bugs please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
95
104
  We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
96
105
 
97
- It might be that this changes with future projects, but as of December 2024, we are content with the current state of the library.
106
+ It might be that this changes with future projects, but as of March 2025, we are content with the current state of the
107
+ library.
98
108
 
99
109
  ## Usage
100
110
 
@@ -120,10 +130,6 @@ def test_readme():
120
130
  # Must be a power of two, greater than or equal to 16 for matmul, and divide m, n, and k
121
131
  sparsity_block_size = 16
122
132
 
123
- # Must be a power of two and smaller than or equal to sparsity_block_size
124
- # If it is set to ``none`` a value will be chosen automatically
125
- triton_block_size = None
126
-
127
133
  # Initialise random (dense) tensors
128
134
  x = torch.randn(size=(b, h, m, k), device="cuda")
129
135
  y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()
@@ -133,53 +139,53 @@ def test_readme():
133
139
  y_dense, y_shape_original = bs.utils.do_shape_blocksparse(y)
134
140
 
135
141
  # Create sparsity layouts from existing tensors
136
- sparsity_layout_x = bs.layouting.build_sparsity_layout(x_dense, sparsity_block_size,
137
- triton_block_size=triton_block_size)
138
- sparsity_layout_y = bs.layouting.build_sparsity_layout(y_dense, sparsity_block_size,
139
- triton_block_size=triton_block_size)
142
+ sparsity_layout_x = bs.layouting.build_sparsity_layout(x_dense, sparsity_block_size)
143
+ sparsity_layout_y = bs.layouting.build_sparsity_layout(y_dense, sparsity_block_size)
140
144
 
141
145
  # Create random sparsity layout for output tensor
142
146
  sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)
143
147
 
144
148
  # Convert tensors to sparse tensors for matrix multiplication
145
- x_sparse = bs.to_sparse(x_dense, sparsity_layout_x, sparsity_block_size, triton_block_size=triton_block_size)
146
- y_sparse = bs.to_sparse(y_dense, sparsity_layout_y, sparsity_block_size, triton_block_size=triton_block_size)
149
+ x_sparse = bs.ops.to_sparse(x_dense, sparsity_layout_x, sparsity_block_size)
150
+ y_sparse = bs.ops.to_sparse(y_dense, sparsity_layout_y, sparsity_block_size)
151
+
152
+ # As of version 2.0, blksprs supports JIT compilation
153
+ matmul_compiled = torch.compile(bs.ops.matmul)
147
154
 
148
155
  # Perform matrix multiplication
149
- o_sparse = bs.matmul(x_sparse, sparsity_layout_x, y_sparse, sparsity_layout_y, sparsity_layout_o,
150
- sparsity_block_size,
151
- triton_block_size=triton_block_size)
156
+ o_sparse = matmul_compiled(x_sparse, sparsity_layout_x,
157
+ y_sparse, sparsity_layout_y,
158
+ sparsity_layout_o, sparsity_block_size)
152
159
 
153
160
  # Apply element-wise operation
154
161
  o_sparse = torch.add(o_sparse, 1)
155
162
 
156
- o_dense = bs.to_dense(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
163
+ o_dense = bs.ops.to_dense(o_sparse, sparsity_layout_o, sparsity_block_size)
157
164
 
158
165
  # Sanity check
159
166
  o_torch = torch.matmul(x_dense, y_dense)
160
167
  o_torch = torch.add(o_torch, 1)
161
168
 
162
169
  # Perform round trip to set sparse blocks to 0
163
- o_torch_round_trip = bs.to_dense(
164
- bs.to_sparse(o_torch, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size),
165
- sparsity_layout_o, sparsity_block_size, fill_value=0, triton_block_size=triton_block_size)
170
+ o_torch_round_trip = bs.ops.to_dense(
171
+ bs.ops.to_sparse(o_torch, sparsity_layout_o, sparsity_block_size),
172
+ sparsity_layout_o, sparsity_block_size, fill_value=0)
166
173
 
167
174
  # Assert that the output is correct
168
175
  assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2) # Note that small numerical differences are expected
169
176
 
170
177
  # Assert that the output has the correct sparsity layout
171
- actual_sparsity_layout_o = bs.layouting.build_sparsity_layout(o_dense, sparsity_block_size,
172
- triton_block_size=triton_block_size)
178
+ actual_sparsity_layout_o = bs.layouting.build_sparsity_layout(o_dense, sparsity_block_size)
173
179
  assert torch.allclose(actual_sparsity_layout_o.to(torch.int), sparsity_layout_o)
174
180
 
175
181
  # Convert output tensor back to original shape
176
182
  o = bs.utils.undo_shape_blocksparse(o_dense, x_shape_original)
177
183
 
178
184
  # Other available functions
179
- bs.transpose(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
180
- bs.softmax(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
181
- bs.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
182
- bs.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
185
+ bs.ops.transpose(o_sparse, sparsity_layout_o, sparsity_block_size)
186
+ bs.ops.softmax(o_sparse, sparsity_layout_o, sparsity_block_size)
187
+ bs.ops.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size)
188
+ bs.ops.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size)
183
189
 
184
190
 
185
191
  def _get_random_sparsity_layout(b, m, n, sparsity_block_size, sparsity_percentage):
@@ -5,6 +5,13 @@
5
5
 
6
6
  ## Overview
7
7
 
8
+ ### News
9
+
10
+ 🎉 ***Version 2.0 released***. blksprs now supports kernel auto-tuning, JIT compilation, specification of pre-calculated
11
+ LUTs, and makes use of `torch.library.triton_op()`!
12
+
13
+ ---
14
+
8
15
  A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
9
16
 
10
17
  Currently supported operations (includes gradient calculation):
@@ -32,23 +39,25 @@ These include, e.g.,
32
39
  Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to
33
40
  match.
34
41
 
35
- Further helpful operations (included in the ``bs.ops.misc`` module) that do **not** support gradient calculation include:
42
+ Further helpful operations (included in the ``bs.ops.misc`` module) that do **not** support gradient calculation
43
+ include:
36
44
 
37
45
  - Row-wise sum, max, addition, and subtraction
38
46
  - Broadcast addition and subtraction between slices
39
47
 
40
- Furthermore, the library provides a set of utility functions
48
+ Furthermore, the library provides a set of utility functions
41
49
 
42
50
  - for the creation of sparsity layouts based on existing
43
- dense tensors and for the scatter operation (module ``bs.layouting``),
51
+ dense tensors and for the scatter operation (module ``bs.layouting``),
44
52
  - for the application of ``nn.Linear``, ``nn.Dropout``, and ``nn.LayerNorm`` layers to block-sparse tensors,
45
53
  - as well as utility functions to ensure correct input dimensionality, and validate input (module ``bs.utils``).
46
54
 
47
- _* see the [Roadmap](#roadmap) section for more information_
55
+ _* see the [Roadmap](#roadmap) section for more information_
48
56
 
49
57
  ## Installation
50
58
 
51
- Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible with
59
+ Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible
60
+ with
52
61
  the Linux platform**.
53
62
  Keep track of this [issue](https://github.com/triton-lang/triton/issues/1640) for updates.
54
63
 
@@ -58,8 +67,8 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
58
67
 
59
68
  ### Dependencies
60
69
 
61
- - [PyTorch](https://pytorch.org/) (built with v2.5.1)
62
- - _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.2.0)_
70
+ - [PyTorch](https://pytorch.org/) (built with v2.6)
71
+ - _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.2.4)_
63
72
  - _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
64
73
 
65
74
  ## Changelog
@@ -69,12 +78,14 @@ See [`CHANGELOG.md`](https://github.com/FelixSchoen/blksprs/blob/main/CHANGELOG.
69
78
  ## Roadmap
70
79
 
71
80
  Note that since this library covers all our current needs it is in a **bugfix-only** state.
72
- This means that there are no plans to add new features, e.g., support for dimension specification of the ``split`` and ``merge`` operations.
81
+ This means that there are no plans to add new features, e.g., support for dimension specification of the ``split`` and
82
+ ``merge`` operations.
73
83
  We will continue to maintain the library and fix any issues that arise.
74
84
  Should you find any bugs please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
75
85
  We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
76
86
 
77
- It might be that this changes with future projects, but as of December 2024, we are content with the current state of the library.
87
+ It might be that this changes with future projects, but as of March 2025, we are content with the current state of the
88
+ library.
78
89
 
79
90
  ## Usage
80
91
 
@@ -100,10 +111,6 @@ def test_readme():
100
111
  # Must be a power of two, greater than or equal to 16 for matmul, and divide m, n, and k
101
112
  sparsity_block_size = 16
102
113
 
103
- # Must be a power of two and smaller than or equal to sparsity_block_size
104
- # If it is set to ``none`` a value will be chosen automatically
105
- triton_block_size = None
106
-
107
114
  # Initialise random (dense) tensors
108
115
  x = torch.randn(size=(b, h, m, k), device="cuda")
109
116
  y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()
@@ -113,53 +120,53 @@ def test_readme():
113
120
  y_dense, y_shape_original = bs.utils.do_shape_blocksparse(y)
114
121
 
115
122
  # Create sparsity layouts from existing tensors
116
- sparsity_layout_x = bs.layouting.build_sparsity_layout(x_dense, sparsity_block_size,
117
- triton_block_size=triton_block_size)
118
- sparsity_layout_y = bs.layouting.build_sparsity_layout(y_dense, sparsity_block_size,
119
- triton_block_size=triton_block_size)
123
+ sparsity_layout_x = bs.layouting.build_sparsity_layout(x_dense, sparsity_block_size)
124
+ sparsity_layout_y = bs.layouting.build_sparsity_layout(y_dense, sparsity_block_size)
120
125
 
121
126
  # Create random sparsity layout for output tensor
122
127
  sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)
123
128
 
124
129
  # Convert tensors to sparse tensors for matrix multiplication
125
- x_sparse = bs.to_sparse(x_dense, sparsity_layout_x, sparsity_block_size, triton_block_size=triton_block_size)
126
- y_sparse = bs.to_sparse(y_dense, sparsity_layout_y, sparsity_block_size, triton_block_size=triton_block_size)
130
+ x_sparse = bs.ops.to_sparse(x_dense, sparsity_layout_x, sparsity_block_size)
131
+ y_sparse = bs.ops.to_sparse(y_dense, sparsity_layout_y, sparsity_block_size)
132
+
133
+ # As of version 2.0, blksprs supports JIT compilation
134
+ matmul_compiled = torch.compile(bs.ops.matmul)
127
135
 
128
136
  # Perform matrix multiplication
129
- o_sparse = bs.matmul(x_sparse, sparsity_layout_x, y_sparse, sparsity_layout_y, sparsity_layout_o,
130
- sparsity_block_size,
131
- triton_block_size=triton_block_size)
137
+ o_sparse = matmul_compiled(x_sparse, sparsity_layout_x,
138
+ y_sparse, sparsity_layout_y,
139
+ sparsity_layout_o, sparsity_block_size)
132
140
 
133
141
  # Apply element-wise operation
134
142
  o_sparse = torch.add(o_sparse, 1)
135
143
 
136
- o_dense = bs.to_dense(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
144
+ o_dense = bs.ops.to_dense(o_sparse, sparsity_layout_o, sparsity_block_size)
137
145
 
138
146
  # Sanity check
139
147
  o_torch = torch.matmul(x_dense, y_dense)
140
148
  o_torch = torch.add(o_torch, 1)
141
149
 
142
150
  # Perform round trip to set sparse blocks to 0
143
- o_torch_round_trip = bs.to_dense(
144
- bs.to_sparse(o_torch, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size),
145
- sparsity_layout_o, sparsity_block_size, fill_value=0, triton_block_size=triton_block_size)
151
+ o_torch_round_trip = bs.ops.to_dense(
152
+ bs.ops.to_sparse(o_torch, sparsity_layout_o, sparsity_block_size),
153
+ sparsity_layout_o, sparsity_block_size, fill_value=0)
146
154
 
147
155
  # Assert that the output is correct
148
156
  assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2) # Note that small numerical differences are expected
149
157
 
150
158
  # Assert that the output has the correct sparsity layout
151
- actual_sparsity_layout_o = bs.layouting.build_sparsity_layout(o_dense, sparsity_block_size,
152
- triton_block_size=triton_block_size)
159
+ actual_sparsity_layout_o = bs.layouting.build_sparsity_layout(o_dense, sparsity_block_size)
153
160
  assert torch.allclose(actual_sparsity_layout_o.to(torch.int), sparsity_layout_o)
154
161
 
155
162
  # Convert output tensor back to original shape
156
163
  o = bs.utils.undo_shape_blocksparse(o_dense, x_shape_original)
157
164
 
158
165
  # Other available functions
159
- bs.transpose(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
160
- bs.softmax(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
161
- bs.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
162
- bs.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
166
+ bs.ops.transpose(o_sparse, sparsity_layout_o, sparsity_block_size)
167
+ bs.ops.softmax(o_sparse, sparsity_layout_o, sparsity_block_size)
168
+ bs.ops.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size)
169
+ bs.ops.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size)
163
170
 
164
171
 
165
172
  def _get_random_sparsity_layout(b, m, n, sparsity_block_size, sparsity_percentage):
@@ -18,19 +18,16 @@ class ops:
18
18
  class layouting:
19
19
  from blksprs.layouting.distribution_layout import build_distribution_layout
20
20
  from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption, \
21
- build_sparsity_layout_matmul, build_sparsity_layout_matmul_fast
22
- from blksprs.utils.layout_utils import build_full_sparsity_layout
21
+ build_sparsity_layout_matmul, build_sparsity_layout_matmul_fast, build_sparsity_layout_full
23
22
 
24
23
 
25
24
  class utils:
26
25
  from blksprs.utils.processing import apply_torch_linear, apply_torch_normalisation, apply_torch_dropout, \
27
26
  apply_function_applicable_row_wise
28
27
  from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse
29
- from blksprs.utils.validation import disable_validation
30
28
 
31
29
  class validation:
32
30
  from blksprs.utils.validation import disable_validation
33
31
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_dtype_float, \
34
32
  validate_dtype_int, validate_device, validate_sparsity, validate_sparsity_dense, \
35
- validate_sparsity_block_size, \
36
- validate_triton_block_size
33
+ validate_sparsity_block_size
@@ -4,14 +4,14 @@ from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
6
  from blksprs.utils.blksprs_tensor import BlksprsTensor
7
- from blksprs.utils.tools import get_triton_block_size, stride
8
- from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
7
+ from blksprs.utils.tools import stride, get_autotune_configs
8
+ from blksprs.utils.validation import validate_dimensions, validate_device, \
9
9
  validate_contiguous
10
10
 
11
11
 
12
12
  def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: Tensor,
13
13
  dim: int, size_target: torch.Size,
14
- sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
14
+ sparsity_block_size: int) -> Tensor:
15
15
  """Builds the sparsity layout of either the source of a gather or the target of a scatter operation.
16
16
 
17
17
  Args:
@@ -20,7 +20,6 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
20
20
  dim (int): The dimension along which the operation is conducted.
21
21
  size_target (torch.Size): The size of the block-sparse target tensor in regular form.
22
22
  sparsity_block_size (int): The size of the sparsity blocks.
23
- triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
24
23
 
25
24
  Returns:
26
25
  Tensor: The sparsity layout of the source or target tensor.
@@ -44,16 +43,11 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
44
43
  o_b, o_r, o_c = output.size()
45
44
  o_b_s, o_r_s, o_c_s = stride(output)
46
45
 
47
- if triton_block_size is None:
48
- triton_block_size = get_triton_block_size(sparsity_block_size)
49
-
50
- validate_triton_block_size(triton_block_size, sparsity_block_size)
51
-
52
46
  triton_grid = lambda meta: [i_b,
53
47
  triton.cdiv(i_r, meta["TRITON_BLOCK_SIZE"]),
54
48
  triton.cdiv(i_c, meta["TRITON_BLOCK_SIZE"])]
55
49
 
56
- (kernel_distribution_layout[triton_grid]
50
+ (build_distribution_layout_kernel[triton_grid]
57
51
  (indices,
58
52
  i_b, i_b_s, i_r_s, i_c_s,
59
53
  sparsity_lut_i,
@@ -61,27 +55,34 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
61
55
  adjusted_dim,
62
56
  output,
63
57
  o_b, o_b_s, o_r_s, o_c_s,
64
- sparsity_block_size,
65
- triton_block_size))
58
+ sparsity_block_size))
66
59
 
67
60
  return output
68
61
 
69
62
 
63
+ @triton.autotune(
64
+ configs=get_autotune_configs(),
65
+ key=[],
66
+ reset_to_zero=["o"]
67
+ )
70
68
  @triton.jit
71
- def kernel_distribution_layout(i,
72
- i_b, i_b_s, i_r_s, i_c_s,
73
- s_lut_i,
74
- s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
75
- dim,
76
- o,
77
- o_b, o_b_s, o_r_s, o_c_s,
78
- sparsity_block_size,
79
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
69
+ def build_distribution_layout_kernel(i,
70
+ i_b, i_b_s, i_r_s, i_c_s,
71
+ s_lut_i,
72
+ s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
73
+ dim,
74
+ o,
75
+ o_b, o_b_s, o_r_s, o_c_s,
76
+ sparsity_block_size,
77
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
80
78
  # Get triton block indices
81
79
  pid_blk = tl.program_id(axis=0)
82
80
  pid_row = tl.program_id(axis=1)
83
81
  pid_col = tl.program_id(axis=2)
84
82
 
83
+ # Get valid triton block size
84
+ val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
85
+
85
86
  # Get position of current sparsity block consisting of its batch, row, and column index
86
87
  spa_bat_i_idx = (pid_blk * s_lut_i_r_s + 0 * s_lut_i_c_s)
87
88
  spa_bat_i_msk = (spa_bat_i_idx >= 0 and spa_bat_i_idx < s_lut_i_r * s_lut_i_r_s)
@@ -96,9 +97,12 @@ def kernel_distribution_layout(i,
96
97
  spa_col_i = tl.load(s_lut_i + spa_col_i_idx, mask=spa_col_i_msk)
97
98
 
98
99
  blk_i_idx = (pid_blk * i_b_s +
99
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
100
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
101
- blk_i_msk = (blk_i_idx >= 0 and blk_i_idx < i_b * i_b_s)
100
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
101
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
102
+ blk_i_msk = ((blk_i_idx >= 0 and
103
+ blk_i_idx < i_b * i_b_s) and
104
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
105
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
102
106
  blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk)
103
107
 
104
108
  dst_bat_idx = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_i, dtype=tl.int32)
@@ -116,5 +120,8 @@ def kernel_distribution_layout(i,
116
120
  blk_o_idx = ((dst_bat_idx * o_b_s) +
117
121
  (dst_row_idx * o_r_s) +
118
122
  (dst_col_idx * o_c_s))
119
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
123
+ blk_o_msk = ((blk_o_idx >= 0 and
124
+ blk_o_idx < o_b * o_b_s) and
125
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
126
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
120
127
  tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)