blksprs 1.5__tar.gz → 1.6.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {blksprs-1.5 → blksprs-1.6.1}/PKG-INFO +3 -2
- {blksprs-1.5 → blksprs-1.6.1}/README.md +2 -1
- {blksprs-1.5 → blksprs-1.6.1}/blksprs/__init__.py +3 -2
- {blksprs-1.5 → blksprs-1.6.1}/blksprs/layouting/sparsity_layout.py +36 -0
- {blksprs-1.5 → blksprs-1.6.1}/blksprs/misc/repeat_interleave.py +1 -1
- blksprs-1.6.1/blksprs/ops/partitioning.py +244 -0
- {blksprs-1.5 → blksprs-1.6.1}/blksprs/ops/transpose.py +6 -7
- {blksprs-1.5 → blksprs-1.6.1}/blksprs/utils/validation.py +2 -0
- {blksprs-1.5 → blksprs-1.6.1}/blksprs.egg-info/PKG-INFO +3 -2
- {blksprs-1.5 → blksprs-1.6.1}/blksprs.egg-info/SOURCES.txt +1 -0
- {blksprs-1.5 → blksprs-1.6.1}/pyproject.toml +1 -1
- {blksprs-1.5 → blksprs-1.6.1}/blksprs/experimental/distribution_mdi.py +0 -0
- {blksprs-1.5 → blksprs-1.6.1}/blksprs/layouting/distribution_layout.py +0 -0
- {blksprs-1.5 → blksprs-1.6.1}/blksprs/misc/broadcast_ops.py +0 -0
- {blksprs-1.5 → blksprs-1.6.1}/blksprs/misc/row_wise.py +0 -0
- {blksprs-1.5 → blksprs-1.6.1}/blksprs/ops/conversion.py +0 -0
- {blksprs-1.5 → blksprs-1.6.1}/blksprs/ops/distribution.py +0 -0
- {blksprs-1.5 → blksprs-1.6.1}/blksprs/ops/exp.py +0 -0
- {blksprs-1.5 → blksprs-1.6.1}/blksprs/ops/matmul.py +0 -0
- {blksprs-1.5 → blksprs-1.6.1}/blksprs/ops/softmax.py +0 -0
- {blksprs-1.5 → blksprs-1.6.1}/blksprs/utils/benchmarking.py +0 -0
- {blksprs-1.5 → blksprs-1.6.1}/blksprs/utils/tools.py +0 -0
- {blksprs-1.5 → blksprs-1.6.1}/blksprs.egg-info/dependency_links.txt +0 -0
- {blksprs-1.5 → blksprs-1.6.1}/blksprs.egg-info/requires.txt +0 -0
- {blksprs-1.5 → blksprs-1.6.1}/blksprs.egg-info/top_level.txt +0 -0
- {blksprs-1.5 → blksprs-1.6.1}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.6.1
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -31,9 +31,10 @@ Currently supported operations (includes gradient calculation):
|
|
|
31
31
|
- Sparse matrix multiplication (_supports any combination of sparse and dense matrices due to support
|
|
32
32
|
for `sparse = sparse @ sparse` matmul_)
|
|
33
33
|
- Softmax
|
|
34
|
-
-
|
|
34
|
+
- Transpose
|
|
35
35
|
- Gather
|
|
36
36
|
- Scatter (_supports either no reduction or summation, gradients are only available for summation_)
|
|
37
|
+
- Splitting and merging of matrices along the last dimension
|
|
37
38
|
- Conversion to and from sparse form
|
|
38
39
|
- Conversion to different sparsity layouts and different sparsity block sizes
|
|
39
40
|
|
|
@@ -12,9 +12,10 @@ Currently supported operations (includes gradient calculation):
|
|
|
12
12
|
- Sparse matrix multiplication (_supports any combination of sparse and dense matrices due to support
|
|
13
13
|
for `sparse = sparse @ sparse` matmul_)
|
|
14
14
|
- Softmax
|
|
15
|
-
-
|
|
15
|
+
- Transpose
|
|
16
16
|
- Gather
|
|
17
17
|
- Scatter (_supports either no reduction or summation, gradients are only available for summation_)
|
|
18
|
+
- Splitting and merging of matrices along the last dimension
|
|
18
19
|
- Conversion to and from sparse form
|
|
19
20
|
- Conversion to different sparsity layouts and different sparsity block sizes
|
|
20
21
|
|
|
@@ -4,10 +4,11 @@ from blksprs.ops.exp import exp
|
|
|
4
4
|
from blksprs.ops.matmul import matmul
|
|
5
5
|
from blksprs.ops.softmax import softmax
|
|
6
6
|
from blksprs.ops.transpose import transpose
|
|
7
|
+
from blksprs.ops.partitioning import split, merge
|
|
7
8
|
|
|
8
9
|
class layout:
|
|
9
10
|
from blksprs.layouting.distribution_layout import build_distribution_layout
|
|
10
|
-
from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption
|
|
11
|
+
from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption, build_sparsity_layout_matmul
|
|
11
12
|
|
|
12
13
|
class misc:
|
|
13
14
|
from blksprs.misc.broadcast_ops import broadcast_add, broadcast_sub
|
|
@@ -18,4 +19,4 @@ class util:
|
|
|
18
19
|
from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse, disable_validation
|
|
19
20
|
|
|
20
21
|
class experimental:
|
|
21
|
-
from blksprs.experimental.distribution_mdi import gather_mdi
|
|
22
|
+
from blksprs.experimental.distribution_mdi import gather_mdi, scatter_reduce_mdi
|
|
@@ -188,3 +188,39 @@ def kernel_sparsity_layout_adaption(x,
|
|
|
188
188
|
// sparsity_block_size_to) * o_c_s))
|
|
189
189
|
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
190
190
|
tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def build_sparsity_layout_matmul(sparsity_layout_x: Tensor, sparsity_layout_y: Tensor):
|
|
194
|
+
"""Builds the precise sparsity layout of the result of a matrix multiplication between the two input tensors.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
sparsity_layout_x (Tensor): The sparsity layout of the first block-sparse tensor.
|
|
198
|
+
sparsity_layout_y (Tensor): The sparsity layout of the second block-sparse tensor.
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
Tensor: The precise sparsity layout of the result of a matrix multiplication between the two input tensors.
|
|
202
|
+
|
|
203
|
+
"""
|
|
204
|
+
return torch.matmul(sparsity_layout_x.to(torch.float), sparsity_layout_y.to(torch.float)).to(torch.bool)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def build_sparsity_layout_matmul_fast(sparsity_layout_x: Tensor, sparsity_layout_y: Tensor):
|
|
208
|
+
"""Builds the approximate sparsity layout of the result of a matrix multiplication between the two input tensors.
|
|
209
|
+
|
|
210
|
+
Note:
|
|
211
|
+
This function is faster than the ``build_sparsity_layout_matmul`` function due to the fact that it only checks
|
|
212
|
+
whether at least one of the blocks in either of the vectors participating in the matmul is non-sparse. The
|
|
213
|
+
resulting sparsity layout may thus overestimate the actual sparsity of the result.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
sparsity_layout_x (Tensor): The sparsity layout of the first block-sparse tensor.
|
|
217
|
+
sparsity_layout_y (Tensor): The sparsity layout of the second block-sparse tensor.
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
Tensor: The approximate sparsity layout of the result of a matrix multiplication between the two input tensors.
|
|
221
|
+
|
|
222
|
+
"""
|
|
223
|
+
sparsity_layout_x_slice = torch.max(sparsity_layout_x, dim=-1).values.unsqueeze(-1)
|
|
224
|
+
sparsity_layout_y_slice = torch.max(sparsity_layout_y, dim=-2).values.unsqueeze(1)
|
|
225
|
+
|
|
226
|
+
return torch.logical_or(sparsity_layout_x_slice, sparsity_layout_y_slice)
|
|
@@ -35,7 +35,7 @@ def repeat_interleave(x: Tensor, sparsity_layout: Tensor, repeats: int,
|
|
|
35
35
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
36
36
|
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
37
37
|
|
|
38
|
-
sparsity_layout_output = torch.repeat_interleave(sparsity_layout,
|
|
38
|
+
sparsity_layout_output = torch.repeat_interleave(sparsity_layout, repeats, dim=0).contiguous()
|
|
39
39
|
|
|
40
40
|
sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
|
|
41
41
|
|
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
from sympy.utilities.iterables import partitions
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
from triton import language as tl
|
|
6
|
+
|
|
7
|
+
from blksprs.utils.tools import get_triton_block_size
|
|
8
|
+
|
|
9
|
+
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
|
|
10
|
+
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def split(x: Tensor, sparsity_layout: Tensor, partitions: int,
|
|
14
|
+
sparsity_block_size: int, triton_block_size: int = None) -> (Tensor, Tensor):
|
|
15
|
+
"""Splits a block-sparse tensor in compressed form along the last dimension into partitions.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
x (Tensor): A block-sparse tensor in compressed form.
|
|
19
|
+
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
20
|
+
partitions (int): The number of partitions to split the block-sparse tensor into.
|
|
21
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
22
|
+
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
Tensor: The block-sparse tensor split into partitions in compressed form.
|
|
26
|
+
Tensor: The sparsity layout of the output tensor.
|
|
27
|
+
|
|
28
|
+
"""
|
|
29
|
+
x = x.contiguous()
|
|
30
|
+
|
|
31
|
+
validate_dimensions(x)
|
|
32
|
+
validate_contiguous(x)
|
|
33
|
+
validate_device(x)
|
|
34
|
+
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
35
|
+
validate_sparsity_block_size(sparsity_block_size, x)
|
|
36
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
37
|
+
|
|
38
|
+
sparsity_layout_output = (sparsity_layout
|
|
39
|
+
.reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
|
|
40
|
+
sparsity_layout.size(2) // partitions)
|
|
41
|
+
.permute(0, 2, 1, 3)
|
|
42
|
+
.reshape(sparsity_layout.size(0) * partitions, sparsity_layout.size(1),
|
|
43
|
+
sparsity_layout.size(2) // partitions).contiguous())
|
|
44
|
+
|
|
45
|
+
sparsity_lut = torch.nonzero(sparsity_layout_output).contiguous()
|
|
46
|
+
|
|
47
|
+
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
|
48
|
+
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
49
|
+
(sparsity_layout_flat == 1) -
|
|
50
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
51
|
+
.reshape(sparsity_layout.size())
|
|
52
|
+
.reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
|
|
53
|
+
sparsity_layout.size(2) // partitions)
|
|
54
|
+
.permute(0, 2, 1, 3).reshape(-1).contiguous())
|
|
55
|
+
|
|
56
|
+
n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
|
|
57
|
+
|
|
58
|
+
validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
|
|
59
|
+
|
|
60
|
+
return _BlocksparseSplit.apply(x, sparsity_layout_output, sparsity_lut, sparsity_reverse_lut, partitions,
|
|
61
|
+
sparsity_block_size, n_sparse_blocks, triton_block_size), sparsity_layout_output
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class _BlocksparseSplit(torch.autograd.Function):
|
|
65
|
+
|
|
66
|
+
@staticmethod
|
|
67
|
+
def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
68
|
+
num_partitions: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
69
|
+
ctx.num_partitions = num_partitions
|
|
70
|
+
|
|
71
|
+
return forward_reorder(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
72
|
+
n_sparse_blocks, triton_block_size)
|
|
73
|
+
|
|
74
|
+
@staticmethod
|
|
75
|
+
def backward(ctx, grad_output):
|
|
76
|
+
sparsity_layout = ctx.saved_tensors[0]
|
|
77
|
+
num_partitions = ctx.num_partitions
|
|
78
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
79
|
+
triton_block_size = ctx.triton_block_size
|
|
80
|
+
|
|
81
|
+
return merge(grad_output, sparsity_layout, num_partitions,
|
|
82
|
+
sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def merge(x: Tensor, sparsity_layout: Tensor, partitions: int,
|
|
86
|
+
sparsity_block_size: int, triton_block_size: int = None) -> (Tensor, Tensor):
|
|
87
|
+
"""Merges the specified partitions of a block-sparse tensor in compressed form along the last dimension.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
x (Tensor): A block-sparse tensor in compressed form.
|
|
91
|
+
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
92
|
+
partitions (int): The number of partitions to be merged.
|
|
93
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
94
|
+
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
Tensor: The merged block-sparse tensor in compressed form.
|
|
98
|
+
Tensor: The sparsity layout of the output tensor.
|
|
99
|
+
|
|
100
|
+
"""
|
|
101
|
+
x = x.contiguous()
|
|
102
|
+
|
|
103
|
+
validate_dimensions(x)
|
|
104
|
+
validate_contiguous(x)
|
|
105
|
+
validate_device(x)
|
|
106
|
+
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
107
|
+
validate_sparsity_block_size(sparsity_block_size, x)
|
|
108
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
109
|
+
|
|
110
|
+
sparsity_layout_output = (sparsity_layout.reshape(sparsity_layout.size(0) // partitions, partitions,
|
|
111
|
+
sparsity_layout.size(1), sparsity_layout.size(2))
|
|
112
|
+
.permute(0, 2, 1, 3)
|
|
113
|
+
.reshape(sparsity_layout.size(0) // partitions,
|
|
114
|
+
sparsity_layout.size(1), sparsity_layout.size(2) * partitions).contiguous())
|
|
115
|
+
|
|
116
|
+
sparsity_lut = torch.nonzero(sparsity_layout_output).contiguous()
|
|
117
|
+
|
|
118
|
+
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
|
119
|
+
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
120
|
+
(sparsity_layout_flat == 1) -
|
|
121
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
122
|
+
.reshape(sparsity_layout.size(0) // partitions, partitions,
|
|
123
|
+
sparsity_layout.size(1), sparsity_layout.size(2))
|
|
124
|
+
.permute(0, 2, 1, 3)
|
|
125
|
+
.reshape(sparsity_layout.size(0) // partitions,
|
|
126
|
+
sparsity_layout.size(1), sparsity_layout.size(2) * partitions)
|
|
127
|
+
.reshape(-1).contiguous())
|
|
128
|
+
|
|
129
|
+
n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
|
|
130
|
+
|
|
131
|
+
validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
|
|
132
|
+
|
|
133
|
+
return _BlocksparseMerge.apply(x, sparsity_layout_output, sparsity_lut, sparsity_reverse_lut, partitions,
|
|
134
|
+
sparsity_block_size, n_sparse_blocks, triton_block_size), sparsity_layout_output
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class _BlocksparseMerge(torch.autograd.Function):
|
|
138
|
+
|
|
139
|
+
@staticmethod
|
|
140
|
+
def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
141
|
+
num_partitions: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
142
|
+
ctx.num_partitions = num_partitions
|
|
143
|
+
|
|
144
|
+
return forward_reorder(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
145
|
+
n_sparse_blocks, triton_block_size)
|
|
146
|
+
|
|
147
|
+
@staticmethod
|
|
148
|
+
def backward(ctx, grad_output):
|
|
149
|
+
sparsity_layout = ctx.saved_tensors[0]
|
|
150
|
+
num_partitions = ctx.num_partitions
|
|
151
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
152
|
+
triton_block_size = ctx.triton_block_size
|
|
153
|
+
|
|
154
|
+
return split(grad_output, sparsity_layout, num_partitions,
|
|
155
|
+
sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def forward_reorder(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
159
|
+
sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
160
|
+
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
161
|
+
dtype=x.dtype, device=x.device)
|
|
162
|
+
|
|
163
|
+
x_b, x_r, x_c = x.size()
|
|
164
|
+
x_b_s, x_r_s, x_c_s = x.stride()
|
|
165
|
+
s_l_b, s_l_r, s_l_c = sparsity_layout_o.size()
|
|
166
|
+
s_l_b_s, s_l_r_s, s_l_c_s = sparsity_layout_o.stride()
|
|
167
|
+
s_lut_r, s_lut_c = sparsity_lut.shape
|
|
168
|
+
s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
|
|
169
|
+
o_b, o_r, o_c = output.size()
|
|
170
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
171
|
+
|
|
172
|
+
if triton_block_size is None:
|
|
173
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
174
|
+
|
|
175
|
+
triton_grid = lambda meta: [o_b,
|
|
176
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
177
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
178
|
+
|
|
179
|
+
(kernel_blocksparse_reorder[triton_grid]
|
|
180
|
+
(x,
|
|
181
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
182
|
+
s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
|
|
183
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
184
|
+
sparsity_reverse_lut,
|
|
185
|
+
output,
|
|
186
|
+
o_b, o_b_s,
|
|
187
|
+
triton_block_size))
|
|
188
|
+
|
|
189
|
+
# Save for backward pass
|
|
190
|
+
ctx.save_for_backward(sparsity_layout_o)
|
|
191
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
192
|
+
ctx.triton_block_size = triton_block_size
|
|
193
|
+
|
|
194
|
+
return output
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
@triton.jit
|
|
198
|
+
def kernel_blocksparse_reorder(x,
|
|
199
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
200
|
+
s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
|
|
201
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
202
|
+
r_lut,
|
|
203
|
+
o,
|
|
204
|
+
o_b, o_b_s,
|
|
205
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
206
|
+
# Get triton block indices
|
|
207
|
+
pid_blk = tl.program_id(axis=0)
|
|
208
|
+
pid_row = tl.program_id(axis=1)
|
|
209
|
+
pid_col = tl.program_id(axis=2)
|
|
210
|
+
|
|
211
|
+
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
212
|
+
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
213
|
+
spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
214
|
+
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
215
|
+
|
|
216
|
+
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
217
|
+
spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
|
|
218
|
+
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
219
|
+
|
|
220
|
+
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
221
|
+
spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
|
|
222
|
+
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
223
|
+
|
|
224
|
+
# Get reverse sparsity index
|
|
225
|
+
rev_idx_spa_idx = (spa_bat * s_l_b_s +
|
|
226
|
+
spa_row * s_l_r_s +
|
|
227
|
+
spa_col * s_l_c_s)
|
|
228
|
+
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_b * s_l_b_s)
|
|
229
|
+
rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
230
|
+
|
|
231
|
+
if rev_idx_spa == -1:
|
|
232
|
+
assert False, "Invalid sparsity block"
|
|
233
|
+
|
|
234
|
+
blk_x_idx = (rev_idx_spa * x_b_s +
|
|
235
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
236
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
237
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
238
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
239
|
+
|
|
240
|
+
blk_o_idx = (pid_blk * o_b_s +
|
|
241
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
242
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
243
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
244
|
+
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
@@ -56,16 +56,16 @@ def transpose(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, trit
|
|
|
56
56
|
class _BlocksparseTranspose(torch.autograd.Function):
|
|
57
57
|
|
|
58
58
|
@staticmethod
|
|
59
|
-
def forward(ctx, x: Tensor,
|
|
60
|
-
|
|
61
|
-
n_sparse_blocks: int, triton_block_size: int) ->
|
|
59
|
+
def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
60
|
+
sparsity_block_size: int,
|
|
61
|
+
n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
62
62
|
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
63
63
|
dtype=x.dtype, device=x.device)
|
|
64
64
|
|
|
65
65
|
x_b, x_r, x_c = x.size()
|
|
66
66
|
x_b_s, x_r_s, x_c_s = x.stride()
|
|
67
|
-
s_l_b, s_l_r, s_l_c =
|
|
68
|
-
s_l_b_s, s_l_r_s, s_l_c_s =
|
|
67
|
+
s_l_b, s_l_r, s_l_c = sparsity_layout_o.size()
|
|
68
|
+
s_l_b_s, s_l_r_s, s_l_c_s = sparsity_layout_o.stride()
|
|
69
69
|
s_lut_r, s_lut_c = sparsity_lut.shape
|
|
70
70
|
s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
|
|
71
71
|
o_b, o_r, o_c = output.size()
|
|
@@ -89,8 +89,7 @@ class _BlocksparseTranspose(torch.autograd.Function):
|
|
|
89
89
|
triton_block_size))
|
|
90
90
|
|
|
91
91
|
# Save for backward pass
|
|
92
|
-
ctx.save_for_backward(
|
|
93
|
-
ctx.sparsity_layout = sparsity_layout
|
|
92
|
+
ctx.save_for_backward(sparsity_layout_o)
|
|
94
93
|
ctx.sparsity_block_size = sparsity_block_size
|
|
95
94
|
ctx.triton_block_size = triton_block_size
|
|
96
95
|
|
|
@@ -63,6 +63,8 @@ def validate_sparsity(sparsity_block_size: int, *tensor_sparsity_layout_tuples:
|
|
|
63
63
|
for (tensor, sparsity_layout) in tensor_sparsity_layout_tuples:
|
|
64
64
|
_validate_sparsity_layout_values(sparsity_layout)
|
|
65
65
|
|
|
66
|
+
if not sparsity_layout.dim() == 3:
|
|
67
|
+
raise ValueError("Sparsity layout must have exactly 3 dimensions")
|
|
66
68
|
if not (tensor.size(-1) == tensor.size(-2) == sparsity_block_size):
|
|
67
69
|
raise ValueError("Blocks not conforming to sparsity block size")
|
|
68
70
|
if not tensor.size(0) == torch.sum(sparsity_layout.reshape(-1)):
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.6.1
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -31,9 +31,10 @@ Currently supported operations (includes gradient calculation):
|
|
|
31
31
|
- Sparse matrix multiplication (_supports any combination of sparse and dense matrices due to support
|
|
32
32
|
for `sparse = sparse @ sparse` matmul_)
|
|
33
33
|
- Softmax
|
|
34
|
-
-
|
|
34
|
+
- Transpose
|
|
35
35
|
- Gather
|
|
36
36
|
- Scatter (_supports either no reduction or summation, gradients are only available for summation_)
|
|
37
|
+
- Splitting and merging of matrices along the last dimension
|
|
37
38
|
- Conversion to and from sparse form
|
|
38
39
|
- Conversion to different sparsity layouts and different sparsity block sizes
|
|
39
40
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|