blksprs 1.6.1__tar.gz → 1.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {blksprs-1.6.1 → blksprs-1.7}/PKG-INFO +3 -2
- {blksprs-1.6.1 → blksprs-1.7}/README.md +2 -1
- {blksprs-1.6.1 → blksprs-1.7}/blksprs/experimental/distribution_mdi.py +14 -14
- {blksprs-1.6.1 → blksprs-1.7}/blksprs/layouting/distribution_layout.py +4 -4
- {blksprs-1.6.1 → blksprs-1.7}/blksprs/layouting/sparsity_layout.py +6 -6
- {blksprs-1.6.1 → blksprs-1.7}/blksprs/misc/broadcast_ops.py +5 -5
- {blksprs-1.6.1 → blksprs-1.7}/blksprs/misc/repeat_interleave.py +5 -5
- {blksprs-1.6.1 → blksprs-1.7}/blksprs/misc/row_wise.py +16 -15
- {blksprs-1.6.1 → blksprs-1.7}/blksprs/ops/conversion.py +11 -11
- {blksprs-1.6.1 → blksprs-1.7}/blksprs/ops/distribution.py +11 -11
- {blksprs-1.6.1 → blksprs-1.7}/blksprs/ops/exp.py +3 -3
- {blksprs-1.6.1 → blksprs-1.7}/blksprs/ops/matmul.py +7 -7
- {blksprs-1.6.1 → blksprs-1.7}/blksprs/ops/partitioning.py +9 -98
- blksprs-1.7/blksprs/ops/repeat.py +241 -0
- {blksprs-1.6.1 → blksprs-1.7}/blksprs/ops/softmax.py +11 -10
- {blksprs-1.6.1 → blksprs-1.7}/blksprs/ops/transpose.py +7 -6
- {blksprs-1.6.1 → blksprs-1.7}/blksprs/utils/tools.py +3 -0
- {blksprs-1.6.1 → blksprs-1.7}/blksprs.egg-info/PKG-INFO +3 -2
- {blksprs-1.6.1 → blksprs-1.7}/blksprs.egg-info/SOURCES.txt +1 -0
- {blksprs-1.6.1 → blksprs-1.7}/pyproject.toml +1 -1
- {blksprs-1.6.1 → blksprs-1.7}/blksprs/__init__.py +0 -0
- {blksprs-1.6.1 → blksprs-1.7}/blksprs/utils/benchmarking.py +0 -0
- {blksprs-1.6.1 → blksprs-1.7}/blksprs/utils/validation.py +0 -0
- {blksprs-1.6.1 → blksprs-1.7}/blksprs.egg-info/dependency_links.txt +0 -0
- {blksprs-1.6.1 → blksprs-1.7}/blksprs.egg-info/requires.txt +0 -0
- {blksprs-1.6.1 → blksprs-1.7}/blksprs.egg-info/top_level.txt +0 -0
- {blksprs-1.6.1 → blksprs-1.7}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.7
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -34,6 +34,7 @@ Currently supported operations (includes gradient calculation):
|
|
|
34
34
|
- Transpose
|
|
35
35
|
- Gather
|
|
36
36
|
- Scatter (_supports either no reduction or summation, gradients are only available for summation_)
|
|
37
|
+
- Repeat (_supports target sparsity layout_)
|
|
37
38
|
- Splitting and merging of matrices along the last dimension
|
|
38
39
|
- Conversion to and from sparse form
|
|
39
40
|
- Conversion to different sparsity layouts and different sparsity block sizes
|
|
@@ -64,7 +65,7 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
|
|
|
64
65
|
|
|
65
66
|
### Dependencies
|
|
66
67
|
|
|
67
|
-
- [PyTorch](https://pytorch.org/) (built with v2.
|
|
68
|
+
- [PyTorch](https://pytorch.org/) (built with v2.5.0)
|
|
68
69
|
- _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
|
|
69
70
|
|
|
70
71
|
## Changelog
|
|
@@ -15,6 +15,7 @@ Currently supported operations (includes gradient calculation):
|
|
|
15
15
|
- Transpose
|
|
16
16
|
- Gather
|
|
17
17
|
- Scatter (_supports either no reduction or summation, gradients are only available for summation_)
|
|
18
|
+
- Repeat (_supports target sparsity layout_)
|
|
18
19
|
- Splitting and merging of matrices along the last dimension
|
|
19
20
|
- Conversion to and from sparse form
|
|
20
21
|
- Conversion to different sparsity layouts and different sparsity block sizes
|
|
@@ -45,7 +46,7 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
|
|
|
45
46
|
|
|
46
47
|
### Dependencies
|
|
47
48
|
|
|
48
|
-
- [PyTorch](https://pytorch.org/) (built with v2.
|
|
49
|
+
- [PyTorch](https://pytorch.org/) (built with v2.5.0)
|
|
49
50
|
- _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
|
|
50
51
|
|
|
51
52
|
## Changelog
|
|
@@ -3,7 +3,7 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
-
from blksprs.utils.tools import get_triton_block_size
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
7
7
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
8
8
|
validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
|
|
9
9
|
|
|
@@ -51,15 +51,15 @@ class _BlocksparseGatherMDI(torch.autograd.Function):
|
|
|
51
51
|
output = torch.empty_like(idx_col, dtype=x.dtype)
|
|
52
52
|
|
|
53
53
|
x_b, x_r, x_c = x.size()
|
|
54
|
-
x_b_s, x_r_s, x_c_s =
|
|
54
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
55
55
|
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
|
|
56
|
-
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s =
|
|
56
|
+
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
|
|
57
57
|
i_b, i_r, i_c = idx_col.size()
|
|
58
|
-
i_b_s, i_r_s, i_c_s =
|
|
58
|
+
i_b_s, i_r_s, i_c_s = stride(idx_col)
|
|
59
59
|
s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
|
|
60
|
-
s_lut_i_r_s, s_lut_i_c_s =
|
|
60
|
+
s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
|
|
61
61
|
o_b, o_r, o_c = output.size()
|
|
62
|
-
o_b_s, o_r_s, o_c_s =
|
|
62
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
63
63
|
|
|
64
64
|
if triton_block_size is None:
|
|
65
65
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -224,15 +224,15 @@ class _BlocksparseScatterReduceMDI(torch.autograd.Function):
|
|
|
224
224
|
dtype=x.dtype, device=x.device)
|
|
225
225
|
|
|
226
226
|
x_b, x_r, x_c = x.size()
|
|
227
|
-
x_b_s, x_r_s, x_c_s =
|
|
227
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
228
228
|
s_lut_x_r, s_lut_x_c = sparsity_lut_x.size()
|
|
229
|
-
s_lut_x_r_s, s_lut_x_c_s =
|
|
229
|
+
s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut_x)
|
|
230
230
|
i_b, i_r, i_c = idx_col.size()
|
|
231
|
-
i_b_s, i_r_s, i_c_s =
|
|
231
|
+
i_b_s, i_r_s, i_c_s = stride(idx_col)
|
|
232
232
|
o_b, o_r, o_c = output.size()
|
|
233
|
-
o_b_s, o_r_s, o_c_s =
|
|
233
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
234
234
|
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
|
|
235
|
-
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s =
|
|
235
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
|
|
236
236
|
|
|
237
237
|
if triton_block_size is None:
|
|
238
238
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -366,11 +366,11 @@ def build_distribution_layout_mdi(idx_bat: Tensor, idx_row: Tensor, idx_col: Ten
|
|
|
366
366
|
dtype=torch.bool, device=idx_col.device)
|
|
367
367
|
|
|
368
368
|
i_b, i_r, i_c = idx_col.size()
|
|
369
|
-
i_b_s, i_r_s, i_c_s =
|
|
369
|
+
i_b_s, i_r_s, i_c_s = stride(idx_col)
|
|
370
370
|
s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
|
|
371
|
-
s_lut_i_r_s, s_lut_i_c_s =
|
|
371
|
+
s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
|
|
372
372
|
o_b, o_r, o_c = output.size()
|
|
373
|
-
o_b_s, o_r_s, o_c_s =
|
|
373
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
374
374
|
|
|
375
375
|
if triton_block_size is None:
|
|
376
376
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -3,7 +3,7 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
-
from blksprs.utils.tools import get_triton_block_size
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
7
7
|
from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
|
|
8
8
|
validate_contiguous
|
|
9
9
|
|
|
@@ -34,11 +34,11 @@ def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
|
|
|
34
34
|
dtype=torch.bool, device=indices.device)
|
|
35
35
|
|
|
36
36
|
i_b, i_r, i_c = indices.size()
|
|
37
|
-
i_b_s, i_r_s, i_c_s =
|
|
37
|
+
i_b_s, i_r_s, i_c_s = stride(indices)
|
|
38
38
|
s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
|
|
39
|
-
s_lut_i_r_s, s_lut_i_c_s =
|
|
39
|
+
s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
|
|
40
40
|
o_b, o_r, o_c = output.size()
|
|
41
|
-
o_b_s, o_r_s, o_c_s =
|
|
41
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
42
42
|
|
|
43
43
|
if triton_block_size is None:
|
|
44
44
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -5,7 +5,7 @@ import triton
|
|
|
5
5
|
from torch import Tensor
|
|
6
6
|
from triton import language as tl
|
|
7
7
|
|
|
8
|
-
from blksprs.utils.tools import get_triton_block_size
|
|
8
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
9
9
|
from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
|
|
10
10
|
validate_contiguous, validate_sparsity, validate_sparsity_block_size
|
|
11
11
|
|
|
@@ -30,9 +30,9 @@ def build_sparsity_layout(x: Tensor, sparsity_block_size: int, triton_block_size
|
|
|
30
30
|
dtype=torch.bool, device=x.device)
|
|
31
31
|
|
|
32
32
|
x_b, x_r, x_c = x.size()
|
|
33
|
-
x_b_s, x_r_s, x_c_s =
|
|
33
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
34
34
|
o_b, o_r, o_c = output.size()
|
|
35
|
-
o_b_s, o_r_s, o_c_s =
|
|
35
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
36
36
|
|
|
37
37
|
if triton_block_size is None:
|
|
38
38
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -120,10 +120,10 @@ def build_sparsity_layout_adaption(x: Tensor, sparsity_layout_from: Tensor,
|
|
|
120
120
|
output = torch.zeros(o_b, o_r, o_c, dtype=torch.bool, device=x.device)
|
|
121
121
|
|
|
122
122
|
x_b, x_r, x_c = x.size()
|
|
123
|
-
x_b_s, x_r_s, x_c_s =
|
|
123
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
124
124
|
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
125
|
-
s_lut_r_s, s_lut_c_s =
|
|
126
|
-
o_b_s, o_r_s, o_c_s =
|
|
125
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
126
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
127
127
|
|
|
128
128
|
if triton_block_size is None:
|
|
129
129
|
triton_block_size = get_triton_block_size(sparsity_block_size_from)
|
|
@@ -3,7 +3,7 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
-
from blksprs.utils.tools import get_triton_block_size
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
7
7
|
from blksprs.utils.validation import validate_contiguous, validate_device, \
|
|
8
8
|
validate_sparsity_block_size, validate_triton_block_size
|
|
9
9
|
|
|
@@ -44,13 +44,13 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
|
44
44
|
output = torch.zeros(n_sparse_blocks, sparsity_block_size, sparsity_block_size, dtype=x.dtype, device=x.device)
|
|
45
45
|
|
|
46
46
|
x_b, x_c = x.size()
|
|
47
|
-
x_b_s, x_c_s =
|
|
47
|
+
x_b_s, x_c_s = stride(x)
|
|
48
48
|
y_b, y_c = y.size()
|
|
49
|
-
y_b_s, y_c_s =
|
|
49
|
+
y_b_s, y_c_s = stride(y)
|
|
50
50
|
o_b, o_r, o_c = output.size()
|
|
51
|
-
o_b_s, o_r_s, o_c_s =
|
|
51
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
52
52
|
s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
|
|
53
|
-
s_lut_o_r_s, s_lut_o_c_s =
|
|
53
|
+
s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_o)
|
|
54
54
|
|
|
55
55
|
if triton_block_size is None:
|
|
56
56
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -3,7 +3,7 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
-
from blksprs.utils.tools import get_triton_block_size
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
7
7
|
from blksprs.utils.validation import validate_contiguous, validate_device, \
|
|
8
8
|
validate_sparsity_block_size, validate_triton_block_size, validate_dimensions
|
|
9
9
|
|
|
@@ -52,13 +52,13 @@ def repeat_interleave(x: Tensor, sparsity_layout: Tensor, repeats: int,
|
|
|
52
52
|
dtype=x.dtype, device=x.device)
|
|
53
53
|
|
|
54
54
|
x_b, x_r, x_c = x.size()
|
|
55
|
-
x_b_s, x_r_s, x_c_s =
|
|
55
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
56
56
|
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
57
|
-
s_lut_r_s, s_lut_c_s =
|
|
57
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
58
58
|
o_b, o_r, o_c = output.size()
|
|
59
|
-
o_b_s, o_r_s, o_c_s =
|
|
59
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
60
60
|
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
|
|
61
|
-
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s =
|
|
61
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
|
|
62
62
|
|
|
63
63
|
if triton_block_size is None:
|
|
64
64
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -3,7 +3,7 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
-
from blksprs.utils.tools import get_triton_block_size
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
7
7
|
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, validate_sparsity, \
|
|
8
8
|
validate_sparsity_block_size, validate_triton_block_size
|
|
9
9
|
|
|
@@ -60,13 +60,13 @@ def row_wise_sum(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
|
60
60
|
device=x.device)
|
|
61
61
|
|
|
62
62
|
x_b, x_r, x_c = x.size()
|
|
63
|
-
x_b_s, x_r_s, x_c_s =
|
|
63
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
64
64
|
s_lut_x_r, s_lut_x_c = sparsity_lut.size()
|
|
65
|
-
s_lut_x_r_s, s_lut_x_c_s =
|
|
65
|
+
s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut)
|
|
66
66
|
o_b, o_r, o_c = output.size()
|
|
67
|
-
o_b_s, o_r_s, o_c_s =
|
|
67
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
68
68
|
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
|
|
69
|
-
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s =
|
|
69
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
|
|
70
70
|
|
|
71
71
|
if triton_block_size is None:
|
|
72
72
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -183,13 +183,13 @@ def row_wise_max(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
|
183
183
|
device=x.device)
|
|
184
184
|
|
|
185
185
|
x_b, x_r, x_c = x.size()
|
|
186
|
-
x_b_s, x_r_s, x_c_s =
|
|
186
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
187
187
|
s_lut_x_r, s_lut_x_c = sparsity_lut.size()
|
|
188
|
-
s_lut_x_r_s, s_lut_x_c_s =
|
|
188
|
+
s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut)
|
|
189
189
|
o_b, o_r, o_c = output.size()
|
|
190
|
-
o_b_s, o_r_s, o_c_s =
|
|
190
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
191
191
|
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
|
|
192
|
-
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s =
|
|
192
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
|
|
193
193
|
|
|
194
194
|
if triton_block_size is None:
|
|
195
195
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -290,15 +290,15 @@ def row_wise_add(x: Tensor, sparsity_layout_x: Tensor, y: Tensor,
|
|
|
290
290
|
output = torch.empty_like(x)
|
|
291
291
|
|
|
292
292
|
x_b, x_r, x_c = x.size()
|
|
293
|
-
x_b_s, x_r_s, x_c_s =
|
|
293
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
294
294
|
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
295
|
-
s_lut_r_s, s_lut_c_s =
|
|
295
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
296
296
|
y_b, y_r, y_c = y.size()
|
|
297
|
-
y_b_s, y_r_s, y_c_s =
|
|
297
|
+
y_b_s, y_r_s, y_c_s = stride(y)
|
|
298
298
|
s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_rwm.size()
|
|
299
|
-
s_l_y_b_s, s_l_y_r_s, s_l_y_c_s =
|
|
299
|
+
s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = stride(sparsity_layout_rwm)
|
|
300
300
|
o_b, o_r, o_c = output.size()
|
|
301
|
-
o_b_s, o_r_s, o_c_s =
|
|
301
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
302
302
|
|
|
303
303
|
if triton_block_size is None:
|
|
304
304
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -361,7 +361,8 @@ def kernel_blocksparse_row_wise_add(x,
|
|
|
361
361
|
rev_idx_spa_s = tl.load(r_lut_y + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
|
|
362
362
|
|
|
363
363
|
if rev_idx_spa_s == -1:
|
|
364
|
-
|
|
364
|
+
tl.device_assert(False)
|
|
365
|
+
return
|
|
365
366
|
|
|
366
367
|
# Load x block
|
|
367
368
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
@@ -6,7 +6,7 @@ from torch import Tensor
|
|
|
6
6
|
from triton import language as tl
|
|
7
7
|
|
|
8
8
|
from blksprs.layouting.sparsity_layout import build_sparsity_layout_adaption
|
|
9
|
-
from blksprs.utils.tools import get_triton_block_size
|
|
9
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
10
10
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
11
11
|
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
12
12
|
|
|
@@ -65,11 +65,11 @@ class _BlocksparseToDense(torch.autograd.Function):
|
|
|
65
65
|
dtype=x.dtype, device=x.device)
|
|
66
66
|
|
|
67
67
|
x_b, x_r, x_c = x.shape
|
|
68
|
-
x_b_s, x_r_s, x_c_s =
|
|
68
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
69
69
|
s_l_b, s_l_r, s_l_c = sparsity_layout.size()
|
|
70
|
-
s_l_b_s, s_l_r_s, s_l_c_s =
|
|
70
|
+
s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout)
|
|
71
71
|
o_b, o_r, o_c = output.size()
|
|
72
|
-
o_b_s, o_r_s, o_c_s =
|
|
72
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
73
73
|
|
|
74
74
|
if triton_block_size is None:
|
|
75
75
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -190,11 +190,11 @@ class _BlocksparseToSparse(torch.autograd.Function):
|
|
|
190
190
|
dtype=x.dtype, device=x.device)
|
|
191
191
|
|
|
192
192
|
x_b, x_r, x_c = x.size()
|
|
193
|
-
x_b_s, x_r_s, x_c_s =
|
|
193
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
194
194
|
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
195
|
-
s_lut_r_s, s_lut_c_s =
|
|
195
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
196
196
|
o_b, o_r, o_c = output.size()
|
|
197
|
-
o_b_s, o_r_s, o_c_s =
|
|
197
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
198
198
|
|
|
199
199
|
if triton_block_size is None:
|
|
200
200
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -347,13 +347,13 @@ class _BlocksparseAdaptLayout(torch.autograd.Function):
|
|
|
347
347
|
dtype=x.dtype, device=x.device)
|
|
348
348
|
|
|
349
349
|
x_b, x_r, x_c = x.size()
|
|
350
|
-
x_b_s, x_r_s, x_c_s =
|
|
350
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
351
351
|
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_from.size()
|
|
352
|
-
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s =
|
|
352
|
+
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_from)
|
|
353
353
|
o_b, o_r, o_c = output.size()
|
|
354
|
-
o_b_s, o_r_s, o_c_s =
|
|
354
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
355
355
|
s_lut_o_r, s_lut_o_c = sparsity_lut_to.size()
|
|
356
|
-
s_lut_o_r_s, s_lut_o_c_s =
|
|
356
|
+
s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_to)
|
|
357
357
|
|
|
358
358
|
if triton_block_size is None:
|
|
359
359
|
triton_block_size = get_triton_block_size(min_sparsity_block_size)
|
|
@@ -3,7 +3,7 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
-
from blksprs.utils.tools import get_triton_block_size
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
7
7
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
8
8
|
validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
|
|
9
9
|
|
|
@@ -59,15 +59,15 @@ class _BlocksparseGather(torch.autograd.Function):
|
|
|
59
59
|
output = torch.empty_like(i, dtype=x.dtype)
|
|
60
60
|
|
|
61
61
|
x_b, x_r, x_c = x.size()
|
|
62
|
-
x_b_s, x_r_s, x_c_s =
|
|
62
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
63
63
|
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
|
|
64
|
-
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s =
|
|
64
|
+
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
|
|
65
65
|
i_b, i_r, i_c = i.size()
|
|
66
|
-
i_b_s, i_r_s, i_c_s =
|
|
66
|
+
i_b_s, i_r_s, i_c_s = stride(i)
|
|
67
67
|
s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
|
|
68
|
-
s_lut_i_r_s, s_lut_i_c_s =
|
|
68
|
+
s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
|
|
69
69
|
o_b, o_r, o_c = output.size()
|
|
70
|
-
o_b_s, o_r_s, o_c_s =
|
|
70
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
71
71
|
|
|
72
72
|
if triton_block_size is None:
|
|
73
73
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -248,15 +248,15 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
|
|
|
248
248
|
dtype=x.dtype, device=x.device)
|
|
249
249
|
|
|
250
250
|
x_b, x_r, x_c = x.size()
|
|
251
|
-
x_b_s, x_r_s, x_c_s =
|
|
251
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
252
252
|
s_lut_x_r, s_lut_x_c = sparsity_lut_x.size()
|
|
253
|
-
s_lut_x_r_s, s_lut_x_c_s =
|
|
253
|
+
s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut_x)
|
|
254
254
|
i_b, i_r, i_c = i.size()
|
|
255
|
-
i_b_s, i_r_s, i_c_s =
|
|
255
|
+
i_b_s, i_r_s, i_c_s = stride(i)
|
|
256
256
|
o_b, o_r, o_c = output.size()
|
|
257
|
-
o_b_s, o_r_s, o_c_s =
|
|
257
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
258
258
|
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
|
|
259
|
-
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s =
|
|
259
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
|
|
260
260
|
|
|
261
261
|
if triton_block_size is None:
|
|
262
262
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -3,7 +3,7 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
-
from blksprs.utils.tools import get_triton_block_size
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
7
7
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
8
8
|
validate_sparsity_block_size, validate_triton_block_size
|
|
9
9
|
|
|
@@ -43,9 +43,9 @@ class _BlocksparseExp(torch.autograd.Function):
|
|
|
43
43
|
output = torch.empty_like(x)
|
|
44
44
|
|
|
45
45
|
x_b, x_r, x_c = x.shape
|
|
46
|
-
x_b_s, x_r_s, x_c_s =
|
|
46
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
47
47
|
o_b, o_r, o_c = output.shape
|
|
48
|
-
o_b_s, o_r_s, o_c_s =
|
|
48
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
49
49
|
|
|
50
50
|
if triton_block_size is None:
|
|
51
51
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -4,7 +4,7 @@ from torch import Tensor
|
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
6
|
from blksprs.ops.transpose import transpose
|
|
7
|
-
from blksprs.utils.tools import get_triton_block_size
|
|
7
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
8
8
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
9
9
|
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size, validate_dtype_float
|
|
10
10
|
|
|
@@ -82,17 +82,17 @@ class _BlocksparseMatmulSSS(torch.autograd.Function):
|
|
|
82
82
|
dtype=x.dtype, device=x.device)
|
|
83
83
|
|
|
84
84
|
x_b, x_r, x_c = x.size()
|
|
85
|
-
x_b_s, x_r_s, x_c_s =
|
|
85
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
86
86
|
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
|
|
87
|
-
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s =
|
|
87
|
+
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
|
|
88
88
|
y_b, y_r, y_c = y.size()
|
|
89
|
-
y_b_s, y_r_s, y_c_s =
|
|
89
|
+
y_b_s, y_r_s, y_c_s = stride(y)
|
|
90
90
|
s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_y.size()
|
|
91
|
-
s_l_y_b_s, s_l_y_r_s, s_l_y_c_s =
|
|
91
|
+
s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = stride(sparsity_layout_y)
|
|
92
92
|
o_b, o_r, o_c = output.size()
|
|
93
|
-
o_b_s, o_r_s, o_c_s =
|
|
93
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
94
94
|
s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
|
|
95
|
-
s_lut_o_r_s, s_lut_o_c_s =
|
|
95
|
+
s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_o)
|
|
96
96
|
|
|
97
97
|
if triton_block_size is None:
|
|
98
98
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -1,10 +1,7 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
import triton
|
|
3
|
-
from sympy.utilities.iterables import partitions
|
|
4
2
|
from torch import Tensor
|
|
5
|
-
from triton import language as tl
|
|
6
3
|
|
|
7
|
-
from blksprs.
|
|
4
|
+
from blksprs.ops.repeat import forward_flow
|
|
8
5
|
|
|
9
6
|
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
|
|
10
7
|
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
@@ -48,12 +45,11 @@ def split(x: Tensor, sparsity_layout: Tensor, partitions: int,
|
|
|
48
45
|
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
49
46
|
(sparsity_layout_flat == 1) -
|
|
50
47
|
(1 * (sparsity_layout_flat == 0)))
|
|
51
|
-
.reshape(sparsity_layout.size())
|
|
52
48
|
.reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
|
|
53
49
|
sparsity_layout.size(2) // partitions)
|
|
54
50
|
.permute(0, 2, 1, 3).reshape(-1).contiguous())
|
|
55
51
|
|
|
56
|
-
n_sparse_blocks = torch.sum(
|
|
52
|
+
n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
|
|
57
53
|
|
|
58
54
|
validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
|
|
59
55
|
|
|
@@ -66,10 +62,11 @@ class _BlocksparseSplit(torch.autograd.Function):
|
|
|
66
62
|
@staticmethod
|
|
67
63
|
def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
68
64
|
num_partitions: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
65
|
+
ctx.save_for_backward(sparsity_layout_o)
|
|
69
66
|
ctx.num_partitions = num_partitions
|
|
70
67
|
|
|
71
|
-
return
|
|
72
|
-
|
|
68
|
+
return forward_flow(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
69
|
+
n_sparse_blocks, triton_block_size)
|
|
73
70
|
|
|
74
71
|
@staticmethod
|
|
75
72
|
def backward(ctx, grad_output):
|
|
@@ -126,7 +123,7 @@ def merge(x: Tensor, sparsity_layout: Tensor, partitions: int,
|
|
|
126
123
|
sparsity_layout.size(1), sparsity_layout.size(2) * partitions)
|
|
127
124
|
.reshape(-1).contiguous())
|
|
128
125
|
|
|
129
|
-
n_sparse_blocks = torch.sum(
|
|
126
|
+
n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
|
|
130
127
|
|
|
131
128
|
validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
|
|
132
129
|
|
|
@@ -139,10 +136,11 @@ class _BlocksparseMerge(torch.autograd.Function):
|
|
|
139
136
|
@staticmethod
|
|
140
137
|
def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
141
138
|
num_partitions: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
139
|
+
ctx.save_for_backward(sparsity_layout_o)
|
|
142
140
|
ctx.num_partitions = num_partitions
|
|
143
141
|
|
|
144
|
-
return
|
|
145
|
-
|
|
142
|
+
return forward_flow(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
143
|
+
n_sparse_blocks, triton_block_size)
|
|
146
144
|
|
|
147
145
|
@staticmethod
|
|
148
146
|
def backward(ctx, grad_output):
|
|
@@ -155,90 +153,3 @@ class _BlocksparseMerge(torch.autograd.Function):
|
|
|
155
153
|
sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None
|
|
156
154
|
|
|
157
155
|
|
|
158
|
-
def forward_reorder(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
159
|
-
sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
160
|
-
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
161
|
-
dtype=x.dtype, device=x.device)
|
|
162
|
-
|
|
163
|
-
x_b, x_r, x_c = x.size()
|
|
164
|
-
x_b_s, x_r_s, x_c_s = x.stride()
|
|
165
|
-
s_l_b, s_l_r, s_l_c = sparsity_layout_o.size()
|
|
166
|
-
s_l_b_s, s_l_r_s, s_l_c_s = sparsity_layout_o.stride()
|
|
167
|
-
s_lut_r, s_lut_c = sparsity_lut.shape
|
|
168
|
-
s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
|
|
169
|
-
o_b, o_r, o_c = output.size()
|
|
170
|
-
o_b_s, o_r_s, o_c_s = output.stride()
|
|
171
|
-
|
|
172
|
-
if triton_block_size is None:
|
|
173
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
174
|
-
|
|
175
|
-
triton_grid = lambda meta: [o_b,
|
|
176
|
-
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
177
|
-
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
178
|
-
|
|
179
|
-
(kernel_blocksparse_reorder[triton_grid]
|
|
180
|
-
(x,
|
|
181
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
182
|
-
s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
|
|
183
|
-
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
184
|
-
sparsity_reverse_lut,
|
|
185
|
-
output,
|
|
186
|
-
o_b, o_b_s,
|
|
187
|
-
triton_block_size))
|
|
188
|
-
|
|
189
|
-
# Save for backward pass
|
|
190
|
-
ctx.save_for_backward(sparsity_layout_o)
|
|
191
|
-
ctx.sparsity_block_size = sparsity_block_size
|
|
192
|
-
ctx.triton_block_size = triton_block_size
|
|
193
|
-
|
|
194
|
-
return output
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
@triton.jit
|
|
198
|
-
def kernel_blocksparse_reorder(x,
|
|
199
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
200
|
-
s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
|
|
201
|
-
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
202
|
-
r_lut,
|
|
203
|
-
o,
|
|
204
|
-
o_b, o_b_s,
|
|
205
|
-
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
206
|
-
# Get triton block indices
|
|
207
|
-
pid_blk = tl.program_id(axis=0)
|
|
208
|
-
pid_row = tl.program_id(axis=1)
|
|
209
|
-
pid_col = tl.program_id(axis=2)
|
|
210
|
-
|
|
211
|
-
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
212
|
-
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
213
|
-
spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
214
|
-
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
215
|
-
|
|
216
|
-
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
217
|
-
spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
|
|
218
|
-
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
219
|
-
|
|
220
|
-
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
221
|
-
spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
|
|
222
|
-
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
223
|
-
|
|
224
|
-
# Get reverse sparsity index
|
|
225
|
-
rev_idx_spa_idx = (spa_bat * s_l_b_s +
|
|
226
|
-
spa_row * s_l_r_s +
|
|
227
|
-
spa_col * s_l_c_s)
|
|
228
|
-
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_b * s_l_b_s)
|
|
229
|
-
rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
230
|
-
|
|
231
|
-
if rev_idx_spa == -1:
|
|
232
|
-
assert False, "Invalid sparsity block"
|
|
233
|
-
|
|
234
|
-
blk_x_idx = (rev_idx_spa * x_b_s +
|
|
235
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
236
|
-
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
237
|
-
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
238
|
-
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
239
|
-
|
|
240
|
-
blk_o_idx = (pid_blk * o_b_s +
|
|
241
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
242
|
-
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
243
|
-
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
244
|
-
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
from triton import language as tl
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
7
|
+
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
|
|
8
|
+
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def repeat(x: Tensor, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
|
|
12
|
+
sparsity_block_size: int, sparsity_layout_output: Tensor = None, triton_block_size: int = None) -> (
|
|
13
|
+
Tensor, Tensor):
|
|
14
|
+
x = x.contiguous()
|
|
15
|
+
|
|
16
|
+
validate_dimensions(x)
|
|
17
|
+
validate_contiguous(x)
|
|
18
|
+
validate_device(x)
|
|
19
|
+
validate_sparsity(sparsity_block_size, (x, sparsity_layout_x))
|
|
20
|
+
validate_sparsity_block_size(sparsity_block_size, x)
|
|
21
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
22
|
+
|
|
23
|
+
sparsity_layout_o = sparsity_layout_x.repeat(repeats[0], repeats[1], repeats[2])
|
|
24
|
+
|
|
25
|
+
if sparsity_layout_output is not None:
|
|
26
|
+
sparsity_layout_o = torch.logical_and(sparsity_layout_o, sparsity_layout_output)
|
|
27
|
+
|
|
28
|
+
sparsity_lut = torch.nonzero(sparsity_layout_o).contiguous()
|
|
29
|
+
|
|
30
|
+
sparsity_layout_flat = sparsity_layout_x.reshape(-1)
|
|
31
|
+
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
32
|
+
(sparsity_layout_flat == 1) -
|
|
33
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
34
|
+
.reshape(sparsity_layout_x.size())
|
|
35
|
+
.repeat(repeats[0], repeats[1], repeats[2])
|
|
36
|
+
.reshape(-1).contiguous())
|
|
37
|
+
|
|
38
|
+
n_sparse_blocks = torch.sum(sparsity_layout_o.to(torch.int)).item()
|
|
39
|
+
|
|
40
|
+
validate_contiguous(sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
|
|
41
|
+
|
|
42
|
+
return _BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
|
|
43
|
+
sparsity_block_size, n_sparse_blocks, triton_block_size), sparsity_layout_o
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class _BlocksparseRepeat(torch.autograd.Function):
|
|
47
|
+
|
|
48
|
+
@staticmethod
|
|
49
|
+
def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor,
|
|
50
|
+
sparsity_reverse_lut: Tensor,
|
|
51
|
+
sparsity_block_size: int, n_sparse_blocks: int,
|
|
52
|
+
triton_block_size: int) -> Tensor:
|
|
53
|
+
ctx.save_for_backward(sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
|
|
54
|
+
ctx.x_size = x.size()
|
|
55
|
+
ctx.x_stride = stride(x)
|
|
56
|
+
|
|
57
|
+
return forward_flow(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
58
|
+
n_sparse_blocks, triton_block_size)
|
|
59
|
+
|
|
60
|
+
@staticmethod
|
|
61
|
+
def backward(ctx, grad_output):
|
|
62
|
+
sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut = ctx.saved_tensors
|
|
63
|
+
x_size = ctx.x_size
|
|
64
|
+
x_stride = ctx.x_stride
|
|
65
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
66
|
+
triton_block_size = ctx.triton_block_size
|
|
67
|
+
|
|
68
|
+
n_sparse_blocks = torch.sum(sparsity_layout_x.to(torch.int)).item()
|
|
69
|
+
|
|
70
|
+
output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
71
|
+
dtype=grad_output.dtype, device=grad_output.device)
|
|
72
|
+
|
|
73
|
+
x_b, x_r, x_c = grad_output.size()
|
|
74
|
+
x_b_s, x_r_s, x_c_s = stride(grad_output)
|
|
75
|
+
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_o.size()
|
|
76
|
+
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_o)
|
|
77
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
78
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
79
|
+
o_b, o_r, o_c = x_size
|
|
80
|
+
o_b_s, o_r_s, o_c_s = x_stride
|
|
81
|
+
|
|
82
|
+
if triton_block_size is None:
|
|
83
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
84
|
+
|
|
85
|
+
triton_grid = lambda meta: [x_b,
|
|
86
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
87
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
88
|
+
|
|
89
|
+
(kernel_blocksparse_flow_push[triton_grid]
|
|
90
|
+
(grad_output,
|
|
91
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
92
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
93
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
94
|
+
sparsity_reverse_lut,
|
|
95
|
+
output,
|
|
96
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
97
|
+
triton_block_size))
|
|
98
|
+
|
|
99
|
+
return output, None, None, None, None, None, None, None
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@triton.jit
|
|
103
|
+
def kernel_blocksparse_flow_pull(x,
|
|
104
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
105
|
+
o,
|
|
106
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
107
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
108
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
109
|
+
r_lut,
|
|
110
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
111
|
+
# Get triton block indices
|
|
112
|
+
pid_blk = tl.program_id(axis=0)
|
|
113
|
+
pid_row = tl.program_id(axis=1)
|
|
114
|
+
pid_col = tl.program_id(axis=2)
|
|
115
|
+
|
|
116
|
+
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
117
|
+
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
118
|
+
spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
119
|
+
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
120
|
+
|
|
121
|
+
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
122
|
+
spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
|
|
123
|
+
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
124
|
+
|
|
125
|
+
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
126
|
+
spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
|
|
127
|
+
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
128
|
+
|
|
129
|
+
# Get reverse sparsity index
|
|
130
|
+
rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
|
|
131
|
+
spa_row * s_l_o_r_s +
|
|
132
|
+
spa_col * s_l_o_c_s)
|
|
133
|
+
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
|
|
134
|
+
rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
135
|
+
|
|
136
|
+
if rev_idx_spa == -1:
|
|
137
|
+
tl.device_assert(False)
|
|
138
|
+
return
|
|
139
|
+
|
|
140
|
+
blk_x_idx = (rev_idx_spa * x_b_s +
|
|
141
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
142
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
143
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
144
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
145
|
+
|
|
146
|
+
blk_o_idx = (pid_blk * o_b_s +
|
|
147
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
148
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
149
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
150
|
+
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
@triton.jit
|
|
154
|
+
def kernel_blocksparse_flow_push(x,
|
|
155
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
156
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
157
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
158
|
+
r_lut,
|
|
159
|
+
o,
|
|
160
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
161
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
162
|
+
# Get triton block indices
|
|
163
|
+
pid_blk = tl.program_id(axis=0)
|
|
164
|
+
pid_row = tl.program_id(axis=1)
|
|
165
|
+
pid_col = tl.program_id(axis=2)
|
|
166
|
+
|
|
167
|
+
# Get sparsity index of current input block consisting of its batch, row, and column index
|
|
168
|
+
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
169
|
+
spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
170
|
+
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
171
|
+
|
|
172
|
+
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
173
|
+
spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
|
|
174
|
+
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
175
|
+
|
|
176
|
+
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
177
|
+
spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
|
|
178
|
+
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
179
|
+
|
|
180
|
+
# Get reverse sparsity index
|
|
181
|
+
rev_idx_spa_idx = (spa_bat * s_l_x_b_s +
|
|
182
|
+
spa_row * s_l_x_r_s +
|
|
183
|
+
spa_col * s_l_x_c_s)
|
|
184
|
+
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_x_b * s_l_x_b_s)
|
|
185
|
+
rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
186
|
+
|
|
187
|
+
if rev_idx_spa == -1:
|
|
188
|
+
tl.device_assert(False)
|
|
189
|
+
return
|
|
190
|
+
|
|
191
|
+
blk_x_idx = (pid_blk * x_b_s +
|
|
192
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
193
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
194
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
195
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
196
|
+
|
|
197
|
+
blk_o_idx = (rev_idx_spa * o_b_s +
|
|
198
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
199
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
200
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
201
|
+
tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def forward_flow(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
205
|
+
sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
206
|
+
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
207
|
+
dtype=x.dtype, device=x.device)
|
|
208
|
+
output = torch.zeros_like(output)
|
|
209
|
+
|
|
210
|
+
x_b, x_r, x_c = x.size()
|
|
211
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
212
|
+
o_b, o_r, o_c = output.size()
|
|
213
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
214
|
+
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
|
|
215
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
|
|
216
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
217
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
218
|
+
asdf = torch.tensor(sparsity_lut).stride()
|
|
219
|
+
|
|
220
|
+
if triton_block_size is None:
|
|
221
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
222
|
+
|
|
223
|
+
triton_grid = lambda meta: [o_b,
|
|
224
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
225
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
226
|
+
|
|
227
|
+
(kernel_blocksparse_flow_pull[triton_grid]
|
|
228
|
+
(x,
|
|
229
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
230
|
+
output,
|
|
231
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
232
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
233
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
234
|
+
sparsity_reverse_lut,
|
|
235
|
+
triton_block_size))
|
|
236
|
+
|
|
237
|
+
# Save for backward pass
|
|
238
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
239
|
+
ctx.triton_block_size = triton_block_size
|
|
240
|
+
|
|
241
|
+
return output
|
|
@@ -5,7 +5,7 @@ from triton import language as tl
|
|
|
5
5
|
|
|
6
6
|
from blksprs.ops.exp import exp
|
|
7
7
|
from blksprs.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
|
|
8
|
-
from blksprs.utils.tools import get_triton_block_size
|
|
8
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
9
9
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
10
10
|
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
11
11
|
|
|
@@ -61,9 +61,9 @@ class _BlocksparseSoftmax(torch.autograd.Function):
|
|
|
61
61
|
output = torch.empty_like(x)
|
|
62
62
|
|
|
63
63
|
x_b, x_r, x_c = x.size()
|
|
64
|
-
x_b_s, x_r_s, x_c_s =
|
|
64
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
65
65
|
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
66
|
-
s_lut_r_s, s_lut_c_s =
|
|
66
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
67
67
|
o_b, o_r, o_c = output.size()
|
|
68
68
|
|
|
69
69
|
x_row_wise_max, sparsity_layout_rwm = row_wise_max(x, sparsity_layout, sparsity_block_size,
|
|
@@ -76,9 +76,9 @@ class _BlocksparseSoftmax(torch.autograd.Function):
|
|
|
76
76
|
triton_block_size=triton_block_size)
|
|
77
77
|
|
|
78
78
|
s_b, s_r, s_c = x_exp_row_wise_sum.shape
|
|
79
|
-
s_b_s, s_r_s, s_c_s =
|
|
79
|
+
s_b_s, s_r_s, s_c_s = stride(x_exp_row_wise_sum)
|
|
80
80
|
s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_rws.shape
|
|
81
|
-
s_l_s_b_s, s_l_s_r_s, s_l_s_c_s =
|
|
81
|
+
s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_rws)
|
|
82
82
|
|
|
83
83
|
if triton_block_size is None:
|
|
84
84
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -119,13 +119,13 @@ class _BlocksparseSoftmax(torch.autograd.Function):
|
|
|
119
119
|
(1 * (sparsity_layout_s_flat == 0)))
|
|
120
120
|
|
|
121
121
|
o_b, o_r, o_c = o.size()
|
|
122
|
-
o_b_s, o_r_s, o_c_s =
|
|
122
|
+
o_b_s, o_r_s, o_c_s = stride(o)
|
|
123
123
|
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
124
|
-
s_lut_r_s, s_lut_c_s =
|
|
124
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
125
125
|
s_b, s_r, s_c = s.size()
|
|
126
|
-
s_b_s, s_r_s, s_c_s =
|
|
126
|
+
s_b_s, s_r_s, s_c_s = stride(s)
|
|
127
127
|
s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
|
|
128
|
-
s_l_s_b_s, s_l_s_r_s, s_l_s_c_s =
|
|
128
|
+
s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_s)
|
|
129
129
|
|
|
130
130
|
grad_x = torch.empty_like(o, dtype=torch.float)
|
|
131
131
|
|
|
@@ -181,7 +181,8 @@ class _BlocksparseSoftmax(torch.autograd.Function):
|
|
|
181
181
|
rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
|
|
182
182
|
|
|
183
183
|
if rev_idx_spa_s == -1:
|
|
184
|
-
|
|
184
|
+
tl.device_assert(False)
|
|
185
|
+
return
|
|
185
186
|
|
|
186
187
|
# Load x block
|
|
187
188
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
@@ -3,7 +3,7 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
-
from blksprs.utils.tools import get_triton_block_size
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
7
7
|
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
|
|
8
8
|
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
9
9
|
|
|
@@ -63,13 +63,13 @@ class _BlocksparseTranspose(torch.autograd.Function):
|
|
|
63
63
|
dtype=x.dtype, device=x.device)
|
|
64
64
|
|
|
65
65
|
x_b, x_r, x_c = x.size()
|
|
66
|
-
x_b_s, x_r_s, x_c_s =
|
|
66
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
67
67
|
s_l_b, s_l_r, s_l_c = sparsity_layout_o.size()
|
|
68
|
-
s_l_b_s, s_l_r_s, s_l_c_s =
|
|
68
|
+
s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout_o)
|
|
69
69
|
s_lut_r, s_lut_c = sparsity_lut.shape
|
|
70
|
-
s_lut_r_s, s_lut_c_s =
|
|
70
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
71
71
|
o_b, o_r, o_c = output.size()
|
|
72
|
-
o_b_s, o_r_s, o_c_s =
|
|
72
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
73
73
|
|
|
74
74
|
if triton_block_size is None:
|
|
75
75
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -140,7 +140,8 @@ class _BlocksparseTranspose(torch.autograd.Function):
|
|
|
140
140
|
rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
141
141
|
|
|
142
142
|
if rev_idx_spa == -1:
|
|
143
|
-
|
|
143
|
+
tl.device_assert(False)
|
|
144
|
+
return
|
|
144
145
|
|
|
145
146
|
blk_x_idx = (rev_idx_spa * x_b_s +
|
|
146
147
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.7
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -34,6 +34,7 @@ Currently supported operations (includes gradient calculation):
|
|
|
34
34
|
- Transpose
|
|
35
35
|
- Gather
|
|
36
36
|
- Scatter (_supports either no reduction or summation, gradients are only available for summation_)
|
|
37
|
+
- Repeat (_supports target sparsity layout_)
|
|
37
38
|
- Splitting and merging of matrices along the last dimension
|
|
38
39
|
- Conversion to and from sparse form
|
|
39
40
|
- Conversion to different sparsity layouts and different sparsity block sizes
|
|
@@ -64,7 +65,7 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
|
|
|
64
65
|
|
|
65
66
|
### Dependencies
|
|
66
67
|
|
|
67
|
-
- [PyTorch](https://pytorch.org/) (built with v2.
|
|
68
|
+
- [PyTorch](https://pytorch.org/) (built with v2.5.0)
|
|
68
69
|
- _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
|
|
69
70
|
|
|
70
71
|
## Changelog
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|