blksprs 1.10.2__tar.gz → 1.11__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {blksprs-1.10.2 → blksprs-1.11}/PKG-INFO +2 -2
- {blksprs-1.10.2 → blksprs-1.11}/blksprs/__init__.py +0 -1
- {blksprs-1.10.2 → blksprs-1.11}/blksprs/ops/conversion.py +42 -15
- {blksprs-1.10.2 → blksprs-1.11}/blksprs/ops/distribution.py +60 -30
- {blksprs-1.10.2 → blksprs-1.11}/blksprs/ops/flow.py +63 -30
- {blksprs-1.10.2 → blksprs-1.11}/blksprs/ops/matmul.py +40 -22
- blksprs-1.11/blksprs/ops/partitioning.py +213 -0
- blksprs-1.11/blksprs/ops/repeat.py +196 -0
- {blksprs-1.10.2 → blksprs-1.11}/blksprs/ops/softmax.py +71 -63
- blksprs-1.11/blksprs/ops/transpose.py +97 -0
- {blksprs-1.10.2 → blksprs-1.11}/blksprs.egg-info/PKG-INFO +2 -2
- {blksprs-1.10.2 → blksprs-1.11}/blksprs.egg-info/SOURCES.txt +0 -1
- {blksprs-1.10.2 → blksprs-1.11}/pyproject.toml +1 -1
- blksprs-1.10.2/blksprs/ops/misc/exp.py +0 -104
- blksprs-1.10.2/blksprs/ops/partitioning.py +0 -170
- blksprs-1.10.2/blksprs/ops/repeat.py +0 -184
- blksprs-1.10.2/blksprs/ops/transpose.py +0 -160
- {blksprs-1.10.2 → blksprs-1.11}/README.md +0 -0
- {blksprs-1.10.2 → blksprs-1.11}/blksprs/layouting/distribution_layout.py +0 -0
- {blksprs-1.10.2 → blksprs-1.11}/blksprs/layouting/sparsity_layout.py +0 -0
- {blksprs-1.10.2 → blksprs-1.11}/blksprs/ops/misc/broadcast_ops.py +0 -0
- {blksprs-1.10.2 → blksprs-1.11}/blksprs/ops/misc/row_wise.py +0 -0
- {blksprs-1.10.2 → blksprs-1.11}/blksprs/utils/benchmarking.py +0 -0
- {blksprs-1.10.2 → blksprs-1.11}/blksprs/utils/blksprs_tensor.py +0 -0
- {blksprs-1.10.2 → blksprs-1.11}/blksprs/utils/layout_utils.py +0 -0
- {blksprs-1.10.2 → blksprs-1.11}/blksprs/utils/processing.py +0 -0
- {blksprs-1.10.2 → blksprs-1.11}/blksprs/utils/tools.py +0 -0
- {blksprs-1.10.2 → blksprs-1.11}/blksprs/utils/validation.py +0 -0
- {blksprs-1.10.2 → blksprs-1.11}/blksprs.egg-info/dependency_links.txt +0 -0
- {blksprs-1.10.2 → blksprs-1.11}/blksprs.egg-info/requires.txt +0 -0
- {blksprs-1.10.2 → blksprs-1.11}/blksprs.egg-info/top_level.txt +0 -0
- {blksprs-1.10.2 → blksprs-1.11}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.2
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.11
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -13,7 +13,6 @@ class ops:
|
|
|
13
13
|
class misc:
|
|
14
14
|
from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_add, row_wise_sub
|
|
15
15
|
from blksprs.ops.misc.broadcast_ops import broadcast_add, broadcast_sub
|
|
16
|
-
from blksprs.ops.misc.exp import exp
|
|
17
16
|
|
|
18
17
|
|
|
19
18
|
class layouting:
|
|
@@ -19,7 +19,7 @@ def from_blksprs(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
|
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
def to_dense(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
|
|
22
|
-
triton_block_size: int = None) -> Tensor:
|
|
22
|
+
triton_block_size: int = None, lut: dict = None) -> Tensor:
|
|
23
23
|
"""Converts a block-sparse tensor in compressed form to a block-sparse tensor in regular form based on the given
|
|
24
24
|
sparsity layout.
|
|
25
25
|
|
|
@@ -30,6 +30,7 @@ def to_dense(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int
|
|
|
30
30
|
fill_value (float): The value to fill the resulting dense tensor with where the block-sparse tensor is not
|
|
31
31
|
present (default ``0``).
|
|
32
32
|
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
33
|
+
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
33
34
|
|
|
34
35
|
Returns:
|
|
35
36
|
Tensor: The block-sparse tensor converted to regular form.
|
|
@@ -44,24 +45,35 @@ def to_dense(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int
|
|
|
44
45
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
45
46
|
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
46
47
|
|
|
47
|
-
|
|
48
|
-
sparsity_reverse_lut = ((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
49
|
-
(sparsity_layout_flat == 1) -
|
|
50
|
-
(1 * (sparsity_layout_flat == 0)))
|
|
51
|
-
|
|
52
|
-
validate_contiguous(sparsity_reverse_lut)
|
|
48
|
+
lut = _BlocksparseToDense.build_lut(lut, sparsity_layout)
|
|
53
49
|
|
|
54
50
|
if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
|
|
55
51
|
return x
|
|
56
52
|
|
|
57
53
|
return _BlocksparseToDense.apply(x,
|
|
58
|
-
sparsity_layout, sparsity_reverse_lut,
|
|
54
|
+
sparsity_layout, lut["sparsity_reverse_lut"],
|
|
59
55
|
sparsity_block_size, fill_value,
|
|
60
56
|
triton_block_size)
|
|
61
57
|
|
|
62
58
|
|
|
63
59
|
class _BlocksparseToDense(torch.autograd.Function):
|
|
64
60
|
|
|
61
|
+
@staticmethod
|
|
62
|
+
def build_lut(lut: dict, sparsity_layout: Tensor):
|
|
63
|
+
if lut is None:
|
|
64
|
+
lut = dict()
|
|
65
|
+
|
|
66
|
+
if "sparsity_reverse_lut" not in lut:
|
|
67
|
+
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
|
68
|
+
sparsity_reverse_lut = ((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
69
|
+
(sparsity_layout_flat == 1) -
|
|
70
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
71
|
+
lut["sparsity_reverse_lut"] = sparsity_reverse_lut
|
|
72
|
+
|
|
73
|
+
validate_contiguous(lut["sparsity_reverse_lut"])
|
|
74
|
+
|
|
75
|
+
return lut
|
|
76
|
+
|
|
65
77
|
@staticmethod
|
|
66
78
|
def forward(ctx, x: Tensor,
|
|
67
79
|
sparsity_layout: Tensor, sparsity_reverse_lut: Tensor,
|
|
@@ -160,7 +172,7 @@ def to_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
|
160
172
|
|
|
161
173
|
|
|
162
174
|
def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
163
|
-
triton_block_size: int = None) -> BlksprsTensor:
|
|
175
|
+
triton_block_size: int = None, lut: dict = None) -> BlksprsTensor:
|
|
164
176
|
"""Converts a block-sparse tensor in regular form to a block-sparse tensor in compressed form based on the given
|
|
165
177
|
sparsity layout.
|
|
166
178
|
|
|
@@ -169,6 +181,7 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
|
169
181
|
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
170
182
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
171
183
|
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
184
|
+
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
172
185
|
|
|
173
186
|
Returns:
|
|
174
187
|
BlksprsTensor: The block-sparse tensor converted to compressed form.
|
|
@@ -183,22 +196,36 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
|
183
196
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
184
197
|
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
185
198
|
|
|
186
|
-
|
|
187
|
-
n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
|
|
188
|
-
|
|
189
|
-
validate_contiguous(sparsity_layout, sparsity_lut)
|
|
199
|
+
lut = _BlocksparseToSparse.build_lut(lut, sparsity_layout)
|
|
190
200
|
|
|
191
201
|
if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
|
|
192
202
|
return BlksprsTensor(x)
|
|
193
203
|
|
|
194
204
|
return BlksprsTensor(_BlocksparseToSparse.apply(x,
|
|
195
|
-
sparsity_layout, sparsity_lut,
|
|
196
|
-
sparsity_block_size, n_sparse_blocks,
|
|
205
|
+
sparsity_layout, lut["sparsity_lut"],
|
|
206
|
+
sparsity_block_size, lut["n_sparse_blocks"],
|
|
197
207
|
triton_block_size))
|
|
198
208
|
|
|
199
209
|
|
|
200
210
|
class _BlocksparseToSparse(torch.autograd.Function):
|
|
201
211
|
|
|
212
|
+
@staticmethod
|
|
213
|
+
def build_lut(lut: dict, sparsity_layout: Tensor):
|
|
214
|
+
if lut is None:
|
|
215
|
+
lut = dict()
|
|
216
|
+
|
|
217
|
+
if "sparsity_lut" not in lut:
|
|
218
|
+
sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
|
|
219
|
+
lut["sparsity_lut"] = sparsity_lut
|
|
220
|
+
|
|
221
|
+
if "n_sparse_blocks" not in lut:
|
|
222
|
+
n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
|
|
223
|
+
lut["n_sparse_blocks"] = n_sparse_blocks
|
|
224
|
+
|
|
225
|
+
validate_contiguous(sparsity_layout, lut["sparsity_lut"])
|
|
226
|
+
|
|
227
|
+
return lut
|
|
228
|
+
|
|
202
229
|
@staticmethod
|
|
203
230
|
def forward(ctx, x: Tensor,
|
|
204
231
|
sparsity_layout: Tensor, sparsity_lut: Tensor,
|
|
@@ -13,7 +13,7 @@ from blksprs.utils.validation import validate_contiguous, validate_dimensions, v
|
|
|
13
13
|
def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
14
14
|
dim: int,
|
|
15
15
|
idx: BlksprsTensor, sparsity_layout_idx: Tensor,
|
|
16
|
-
sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
|
|
16
|
+
sparsity_block_size: int, triton_block_size: int = None, lut: dict = None) -> BlksprsTensor:
|
|
17
17
|
"""Applies a gather operation on a block-sparse tensor in compressed form.
|
|
18
18
|
|
|
19
19
|
Args:
|
|
@@ -24,6 +24,7 @@ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
24
24
|
sparsity_layout_idx (Tensor): The sparsity layout of the indices block-sparse tensor.
|
|
25
25
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
26
26
|
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
27
|
+
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
27
28
|
|
|
28
29
|
Returns:
|
|
29
30
|
BlksprsTensor: The result of the gather operation as a block-sparse tensor in compressed form.
|
|
@@ -40,25 +41,38 @@ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
40
41
|
validate_sparsity_block_size(sparsity_block_size, src, idx)
|
|
41
42
|
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
42
43
|
|
|
43
|
-
sparsity_layout_x_flat = sparsity_layout_src.reshape(-1)
|
|
44
|
-
sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
|
|
45
|
-
(sparsity_layout_x_flat == 1) -
|
|
46
|
-
(1 * (sparsity_layout_x_flat == 0)))
|
|
47
|
-
|
|
48
|
-
sparsity_lut_i = torch.nonzero(sparsity_layout_idx).contiguous()
|
|
49
|
-
|
|
50
|
-
validate_contiguous(sparsity_layout_src, sparsity_reverse_lut_x,
|
|
51
|
-
sparsity_layout_idx, sparsity_lut_i)
|
|
52
|
-
|
|
53
44
|
adjusted_dim = dim % 3
|
|
54
45
|
|
|
55
|
-
|
|
56
|
-
|
|
46
|
+
lut = _BlocksparseGather.build_lut(lut, sparsity_layout_src, sparsity_layout_idx)
|
|
47
|
+
|
|
48
|
+
return BlksprsTensor(_BlocksparseGather.apply(src, sparsity_layout_src, lut["sparsity_reverse_lut_x"],
|
|
49
|
+
adjusted_dim, idx, sparsity_layout_idx, lut["sparsity_lut_i"],
|
|
57
50
|
sparsity_block_size, triton_block_size))
|
|
58
51
|
|
|
59
52
|
|
|
60
53
|
class _BlocksparseGather(torch.autograd.Function):
|
|
61
54
|
|
|
55
|
+
@staticmethod
|
|
56
|
+
def build_lut(lut: dict, sparsity_layout_src: Tensor, sparsity_layout_idx: Tensor):
|
|
57
|
+
if lut is None:
|
|
58
|
+
lut = dict()
|
|
59
|
+
|
|
60
|
+
if "sparsity_reverse_lut_x" not in lut:
|
|
61
|
+
sparsity_layout_x_flat = sparsity_layout_src.reshape(-1)
|
|
62
|
+
sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
|
|
63
|
+
(sparsity_layout_x_flat == 1) -
|
|
64
|
+
(1 * (sparsity_layout_x_flat == 0)))
|
|
65
|
+
lut["sparsity_reverse_lut_x"] = sparsity_reverse_lut_x
|
|
66
|
+
|
|
67
|
+
if "sparsity_lut_i" not in lut:
|
|
68
|
+
sparsity_lut_i = torch.nonzero(sparsity_layout_idx).contiguous()
|
|
69
|
+
lut["sparsity_lut_i"] = sparsity_lut_i
|
|
70
|
+
|
|
71
|
+
validate_contiguous(sparsity_layout_src, lut["sparsity_reverse_lut_x"],
|
|
72
|
+
sparsity_layout_idx, lut["sparsity_lut_i"])
|
|
73
|
+
|
|
74
|
+
return lut
|
|
75
|
+
|
|
62
76
|
@staticmethod
|
|
63
77
|
def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
|
|
64
78
|
dim: int, i: Tensor, sparsity_layout_i: Tensor, sparsity_lut_i: Tensor,
|
|
@@ -202,7 +216,7 @@ def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
202
216
|
dim: int,
|
|
203
217
|
idx: BlksprsTensor,
|
|
204
218
|
sparsity_layout_tgt: Tensor,
|
|
205
|
-
sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
|
|
219
|
+
sparsity_block_size: int, triton_block_size: int = None, lut: dict = None) -> BlksprsTensor:
|
|
206
220
|
"""Wrapper for ``scatter_reduce`` with ``reduce_op="none"``.
|
|
207
221
|
|
|
208
222
|
"""
|
|
@@ -219,7 +233,7 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
219
233
|
idx: BlksprsTensor,
|
|
220
234
|
sparsity_layout_tgt: Tensor,
|
|
221
235
|
sparsity_block_size: int,
|
|
222
|
-
reduce_op: str = "sum", triton_block_size: int = None) -> BlksprsTensor:
|
|
236
|
+
reduce_op: str = "sum", triton_block_size: int = None, lut: dict = None) -> BlksprsTensor:
|
|
223
237
|
"""Applies a scatter operation on a block-sparse tensor in compressed form.
|
|
224
238
|
|
|
225
239
|
Args:
|
|
@@ -232,6 +246,7 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
232
246
|
reduce_op (str, optional): The reduction operation to apply during the scatter operation (default ``"sum"``).
|
|
233
247
|
Supported operations are ``"none"`` and ``"sum"``.
|
|
234
248
|
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
249
|
+
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
235
250
|
|
|
236
251
|
Returns:
|
|
237
252
|
BlksprsTensor: The result of the scatter operation as a block-sparse tensor in compressed form.
|
|
@@ -251,29 +266,44 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
251
266
|
if reduce_op not in ["none", "sum"]:
|
|
252
267
|
raise ValueError(f"Reduction operation '{reduce_op}' is not supported")
|
|
253
268
|
|
|
254
|
-
sparsity_lut_x = torch.nonzero(sparsity_layout_src).contiguous()
|
|
255
|
-
|
|
256
|
-
sparsity_layout_o_flat = sparsity_layout_tgt.reshape(-1)
|
|
257
|
-
sparsity_reverse_lut_o = ((torch.cumsum(sparsity_layout_o_flat, dim=-1) - 1) *
|
|
258
|
-
(sparsity_layout_o_flat == 1) -
|
|
259
|
-
(1 * (sparsity_layout_o_flat == 0)))
|
|
260
|
-
|
|
261
|
-
n_sparse_blocks = torch.sum(sparsity_layout_tgt.to(torch.int)).item()
|
|
262
|
-
|
|
263
|
-
validate_contiguous(sparsity_layout_src, sparsity_lut_x,
|
|
264
|
-
sparsity_layout_tgt, sparsity_reverse_lut_o)
|
|
265
|
-
|
|
266
269
|
adjusted_dim = dim % 3
|
|
267
270
|
|
|
268
|
-
|
|
271
|
+
lut = _BlocksparseScatterReduce.build_lut(lut, sparsity_layout_src, sparsity_layout_tgt)
|
|
272
|
+
|
|
273
|
+
return BlksprsTensor(_BlocksparseScatterReduce.apply(src, sparsity_layout_src, lut["sparsity_lut_x"],
|
|
269
274
|
adjusted_dim, idx,
|
|
270
|
-
sparsity_layout_tgt, sparsity_reverse_lut_o,
|
|
271
|
-
sparsity_block_size, n_sparse_blocks,
|
|
275
|
+
sparsity_layout_tgt, lut["sparsity_reverse_lut_o"],
|
|
276
|
+
sparsity_block_size, lut["n_sparse_blocks"],
|
|
272
277
|
reduce_op, triton_block_size))
|
|
273
278
|
|
|
274
279
|
|
|
275
280
|
class _BlocksparseScatterReduce(torch.autograd.Function):
|
|
276
281
|
|
|
282
|
+
@staticmethod
|
|
283
|
+
def build_lut(lut: dict, sparsity_layout_src: Tensor, sparsity_layout_tgt: Tensor):
|
|
284
|
+
if lut is None:
|
|
285
|
+
lut = dict()
|
|
286
|
+
|
|
287
|
+
if "sparsity_lut_x" not in lut:
|
|
288
|
+
sparsity_lut_x = torch.nonzero(sparsity_layout_src).contiguous()
|
|
289
|
+
lut["sparsity_lut_x"] = sparsity_lut_x
|
|
290
|
+
|
|
291
|
+
if "sparsity_reverse_lut_o" not in lut:
|
|
292
|
+
sparsity_layout_o_flat = sparsity_layout_tgt.reshape(-1)
|
|
293
|
+
sparsity_reverse_lut_o = ((torch.cumsum(sparsity_layout_o_flat, dim=-1) - 1) *
|
|
294
|
+
(sparsity_layout_o_flat == 1) -
|
|
295
|
+
(1 * (sparsity_layout_o_flat == 0)))
|
|
296
|
+
lut["sparsity_reverse_lut_o"] = sparsity_reverse_lut_o
|
|
297
|
+
|
|
298
|
+
if "n_sparse_blocks" not in lut:
|
|
299
|
+
n_sparse_blocks = torch.sum(sparsity_layout_tgt.to(torch.int)).item()
|
|
300
|
+
lut["n_sparse_blocks"] = n_sparse_blocks
|
|
301
|
+
|
|
302
|
+
validate_contiguous(sparsity_layout_src, lut["sparsity_lut_x"],
|
|
303
|
+
sparsity_layout_tgt, lut["sparsity_reverse_lut_o"])
|
|
304
|
+
|
|
305
|
+
return lut
|
|
306
|
+
|
|
277
307
|
@staticmethod
|
|
278
308
|
def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_lut_x: Tensor,
|
|
279
309
|
dim: int, i: Tensor,
|
|
@@ -40,21 +40,18 @@ def kernel_blocksparse_flow_pull(x,
|
|
|
40
40
|
rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
|
|
41
41
|
rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
42
42
|
|
|
43
|
-
if rev_idx_spa
|
|
44
|
-
|
|
45
|
-
|
|
43
|
+
if rev_idx_spa >= 0:
|
|
44
|
+
blk_x_idx = (rev_idx_spa * x_b_s +
|
|
45
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
46
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
47
|
+
blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
|
|
48
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
46
49
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
blk_o_idx = (pid_blk * o_b_s +
|
|
54
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
55
|
-
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
56
|
-
blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
|
|
57
|
-
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
50
|
+
blk_o_idx = (pid_blk * o_b_s +
|
|
51
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
52
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
53
|
+
blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
|
|
54
|
+
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
58
55
|
|
|
59
56
|
|
|
60
57
|
@triton.jit
|
|
@@ -91,25 +88,22 @@ def kernel_blocksparse_flow_push(x,
|
|
|
91
88
|
rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_x_b * s_l_x_b_s)
|
|
92
89
|
rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
93
90
|
|
|
94
|
-
if rev_idx_spa
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
101
|
-
blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
|
|
102
|
-
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
91
|
+
if rev_idx_spa >= 0:
|
|
92
|
+
blk_x_idx = (pid_blk * x_b_s +
|
|
93
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
94
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
95
|
+
blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
|
|
96
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
103
97
|
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
98
|
+
blk_o_idx = (rev_idx_spa * o_b_s +
|
|
99
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
100
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
101
|
+
blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
|
|
102
|
+
tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
109
103
|
|
|
110
104
|
|
|
111
|
-
def
|
|
112
|
-
|
|
105
|
+
def flow_forward_pull(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
106
|
+
sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
113
107
|
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
114
108
|
dtype=x.dtype, device=x.device)
|
|
115
109
|
|
|
@@ -144,3 +138,42 @@ def flow_forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor
|
|
|
144
138
|
ctx.triton_block_size = triton_block_size
|
|
145
139
|
|
|
146
140
|
return output
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def flow_forward_push(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
144
|
+
sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
145
|
+
output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
146
|
+
dtype=x.dtype, device=x.device)
|
|
147
|
+
|
|
148
|
+
x_b, x_r, x_c = x.size()
|
|
149
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
150
|
+
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
|
|
151
|
+
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
|
|
152
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
153
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
154
|
+
o_b, o_r, o_c = output.size()
|
|
155
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
156
|
+
|
|
157
|
+
if triton_block_size is None:
|
|
158
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
159
|
+
|
|
160
|
+
triton_grid = lambda meta: [x_b,
|
|
161
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
162
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
163
|
+
|
|
164
|
+
(kernel_blocksparse_flow_push[triton_grid]
|
|
165
|
+
(x,
|
|
166
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
167
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
168
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
169
|
+
sparsity_reverse_lut,
|
|
170
|
+
output,
|
|
171
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
172
|
+
triton_block_size))
|
|
173
|
+
|
|
174
|
+
# Save for backward pass
|
|
175
|
+
if ctx is not None:
|
|
176
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
177
|
+
ctx.triton_block_size = triton_block_size
|
|
178
|
+
|
|
179
|
+
return output
|
|
@@ -13,7 +13,7 @@ from blksprs.utils.validation import validate_contiguous, validate_dimensions, v
|
|
|
13
13
|
def matmul(x: BlksprsTensor, sparsity_layout_x: Tensor,
|
|
14
14
|
y: BlksprsTensor, sparsity_layout_y: Tensor,
|
|
15
15
|
sparsity_layout_output: Tensor,
|
|
16
|
-
sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
|
|
16
|
+
sparsity_block_size: int, triton_block_size: int = None, lut: dict = None) -> BlksprsTensor:
|
|
17
17
|
"""Performs matrix multiplication between two block-sparse tensors.
|
|
18
18
|
|
|
19
19
|
The sparsity layout of the output tensor is used to only calculate blocks that will be present in the output.
|
|
@@ -26,6 +26,7 @@ def matmul(x: BlksprsTensor, sparsity_layout_x: Tensor,
|
|
|
26
26
|
sparsity_layout_output (Tensor): The sparsity layout of the output tensor.
|
|
27
27
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
28
28
|
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
29
|
+
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
29
30
|
|
|
30
31
|
Returns:
|
|
31
32
|
BlksprsTensor: The result of the matrix multiplication as a block-sparse tensor in compressed form.
|
|
@@ -44,35 +45,52 @@ def matmul(x: BlksprsTensor, sparsity_layout_x: Tensor,
|
|
|
44
45
|
validate_sparsity_block_size(sparsity_block_size, x, y)
|
|
45
46
|
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
46
47
|
|
|
47
|
-
|
|
48
|
-
sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
|
|
49
|
-
(sparsity_layout_x_flat == 1) -
|
|
50
|
-
(1 * (sparsity_layout_x_flat == 0)))
|
|
51
|
-
|
|
52
|
-
sparsity_layout_y_flat = sparsity_layout_y.reshape(-1)
|
|
53
|
-
sparsity_reverse_lut_y = ((torch.cumsum(sparsity_layout_y_flat, dim=-1) - 1) *
|
|
54
|
-
(sparsity_layout_y_flat == 1) -
|
|
55
|
-
(1 * (sparsity_layout_y_flat == 0)))
|
|
56
|
-
|
|
57
|
-
sparsity_lut_o = torch.nonzero(sparsity_layout_output).contiguous()
|
|
58
|
-
|
|
59
|
-
n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
|
|
60
|
-
|
|
61
|
-
validate_contiguous(sparsity_layout_x, sparsity_reverse_lut_x,
|
|
62
|
-
sparsity_layout_y, sparsity_reverse_lut_y,
|
|
63
|
-
sparsity_layout_output, sparsity_lut_o)
|
|
48
|
+
lut = _BlocksparseMatmulSSS.build_lut(lut, sparsity_layout_x, sparsity_layout_y, sparsity_layout_output)
|
|
64
49
|
|
|
65
50
|
return BlksprsTensor(_BlocksparseMatmulSSS.apply(x, y,
|
|
66
|
-
sparsity_layout_x, sparsity_reverse_lut_x,
|
|
67
|
-
sparsity_layout_y, sparsity_reverse_lut_y,
|
|
68
|
-
sparsity_layout_output, sparsity_lut_o,
|
|
51
|
+
sparsity_layout_x, lut["sparsity_reverse_lut_x"],
|
|
52
|
+
sparsity_layout_y, lut["sparsity_reverse_lut_y"],
|
|
53
|
+
sparsity_layout_output, lut["sparsity_lut_o"],
|
|
69
54
|
sparsity_block_size,
|
|
70
|
-
n_sparse_blocks,
|
|
55
|
+
lut["n_sparse_blocks"],
|
|
71
56
|
triton_block_size))
|
|
72
57
|
|
|
73
58
|
|
|
74
59
|
class _BlocksparseMatmulSSS(torch.autograd.Function):
|
|
75
60
|
|
|
61
|
+
@staticmethod
|
|
62
|
+
def build_lut(lut: dict, sparsity_layout_x: Tensor, sparsity_layout_y: Tensor, sparsity_layout_output: Tensor):
|
|
63
|
+
if lut is None:
|
|
64
|
+
lut = dict()
|
|
65
|
+
|
|
66
|
+
if "sparsity_reverse_lut_x" not in lut:
|
|
67
|
+
sparsity_layout_x_flat = sparsity_layout_x.reshape(-1)
|
|
68
|
+
sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
|
|
69
|
+
(sparsity_layout_x_flat == 1) -
|
|
70
|
+
(1 * (sparsity_layout_x_flat == 0)))
|
|
71
|
+
lut["sparsity_reverse_lut_x"] = sparsity_reverse_lut_x
|
|
72
|
+
|
|
73
|
+
if "sparsity_reverse_lut_y" not in lut:
|
|
74
|
+
sparsity_layout_y_flat = sparsity_layout_y.reshape(-1)
|
|
75
|
+
sparsity_reverse_lut_y = ((torch.cumsum(sparsity_layout_y_flat, dim=-1) - 1) *
|
|
76
|
+
(sparsity_layout_y_flat == 1) -
|
|
77
|
+
(1 * (sparsity_layout_y_flat == 0)))
|
|
78
|
+
lut["sparsity_reverse_lut_y"] = sparsity_reverse_lut_y
|
|
79
|
+
|
|
80
|
+
if "sparsity_lut_o" not in lut:
|
|
81
|
+
sparsity_lut_o = torch.nonzero(sparsity_layout_output).contiguous()
|
|
82
|
+
lut["sparsity_lut_o"] = sparsity_lut_o
|
|
83
|
+
|
|
84
|
+
if "n_sparse_blocks" not in lut:
|
|
85
|
+
n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
|
|
86
|
+
lut["n_sparse_blocks"] = n_sparse_blocks
|
|
87
|
+
|
|
88
|
+
validate_contiguous(sparsity_layout_x, lut["sparsity_reverse_lut_x"],
|
|
89
|
+
sparsity_layout_y, lut["sparsity_reverse_lut_y"],
|
|
90
|
+
sparsity_layout_output, lut["sparsity_lut_o"])
|
|
91
|
+
|
|
92
|
+
return lut
|
|
93
|
+
|
|
76
94
|
@staticmethod
|
|
77
95
|
def forward(ctx, x: Tensor, y: Tensor,
|
|
78
96
|
sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
|