blksprs 1.8.1__py3-none-any.whl → 1.8.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/experimental/distribution_mdi.py +22 -21
- blksprs/layouting/distribution_layout.py +3 -2
- blksprs/layouting/sparsity_layout.py +3 -2
- blksprs/misc/broadcast_ops.py +5 -4
- blksprs/misc/exp.py +5 -4
- blksprs/misc/partitioning.py +13 -12
- blksprs/misc/row_wise.py +19 -18
- blksprs/ops/conversion.py +35 -25
- blksprs/ops/distribution.py +19 -18
- blksprs/ops/matmul.py +14 -13
- blksprs/ops/repeat.py +13 -12
- blksprs/ops/softmax.py +6 -5
- blksprs/ops/transpose.py +7 -6
- blksprs/utils/blksprs_tensor.py +8 -0
- {blksprs-1.8.1.dist-info → blksprs-1.8.2.dist-info}/METADATA +1 -1
- blksprs-1.8.2.dist-info/RECORD +22 -0
- {blksprs-1.8.1.dist-info → blksprs-1.8.2.dist-info}/WHEEL +1 -1
- blksprs-1.8.1.dist-info/RECORD +0 -21
- {blksprs-1.8.1.dist-info → blksprs-1.8.2.dist-info}/top_level.txt +0 -0
|
@@ -3,17 +3,18 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
+
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
6
7
|
from blksprs.utils.tools import get_triton_block_size, stride
|
|
7
8
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
8
9
|
validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
|
|
9
10
|
|
|
10
11
|
|
|
11
|
-
def gather_mdi(src:
|
|
12
|
-
idx_bat:
|
|
13
|
-
idx_row:
|
|
14
|
-
idx_col:
|
|
12
|
+
def gather_mdi(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
13
|
+
idx_bat: BlksprsTensor,
|
|
14
|
+
idx_row: BlksprsTensor,
|
|
15
|
+
idx_col: BlksprsTensor,
|
|
15
16
|
sparsity_layout_idx: Tensor,
|
|
16
|
-
sparsity_block_size: int, triton_block_size: int = None) ->
|
|
17
|
+
sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
|
|
17
18
|
src = src.contiguous()
|
|
18
19
|
idx_bat = idx_bat.contiguous()
|
|
19
20
|
idx_col = idx_col.contiguous()
|
|
@@ -37,9 +38,9 @@ def gather_mdi(src: Tensor, sparsity_layout_src: Tensor,
|
|
|
37
38
|
validate_contiguous(sparsity_layout_src, sparsity_reverse_lut_x,
|
|
38
39
|
sparsity_layout_idx, sparsity_lut_i)
|
|
39
40
|
|
|
40
|
-
return _BlocksparseGatherMDI.apply(src, sparsity_layout_src, sparsity_reverse_lut_x,
|
|
41
|
-
|
|
42
|
-
|
|
41
|
+
return BlksprsTensor(_BlocksparseGatherMDI.apply(src, sparsity_layout_src, sparsity_reverse_lut_x,
|
|
42
|
+
idx_bat, idx_col, sparsity_layout_idx, sparsity_lut_i,
|
|
43
|
+
sparsity_block_size, triton_block_size))
|
|
43
44
|
|
|
44
45
|
|
|
45
46
|
class _BlocksparseGatherMDI(torch.autograd.Function):
|
|
@@ -167,13 +168,13 @@ class _BlocksparseGatherMDI(torch.autograd.Function):
|
|
|
167
168
|
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
168
169
|
|
|
169
170
|
|
|
170
|
-
def scatter_reduce_mdi(src:
|
|
171
|
-
idx_bat:
|
|
172
|
-
idx_row:
|
|
173
|
-
idx_col:
|
|
171
|
+
def scatter_reduce_mdi(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
172
|
+
idx_bat: BlksprsTensor,
|
|
173
|
+
idx_row: BlksprsTensor,
|
|
174
|
+
idx_col: BlksprsTensor,
|
|
174
175
|
sparsity_layout_tgt: Tensor,
|
|
175
176
|
sparsity_block_size: int,
|
|
176
|
-
reduce_op: str = "sum", triton_block_size: int = None) ->
|
|
177
|
+
reduce_op: str = "sum", triton_block_size: int = None) -> BlksprsTensor:
|
|
177
178
|
src = src.contiguous()
|
|
178
179
|
idx_bat = idx_bat.contiguous()
|
|
179
180
|
idx_col = idx_col.contiguous()
|
|
@@ -203,12 +204,12 @@ def scatter_reduce_mdi(src: Tensor, sparsity_layout_src: Tensor,
|
|
|
203
204
|
validate_contiguous(sparsity_layout_src, sparsity_lut_x,
|
|
204
205
|
sparsity_layout_tgt, sparsity_reverse_lut_o)
|
|
205
206
|
|
|
206
|
-
return _BlocksparseScatterReduceMDI.apply(src, sparsity_layout_src, sparsity_lut_x,
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
207
|
+
return BlksprsTensor(_BlocksparseScatterReduceMDI.apply(src, sparsity_layout_src, sparsity_lut_x,
|
|
208
|
+
idx_bat,
|
|
209
|
+
idx_col,
|
|
210
|
+
sparsity_layout_tgt, sparsity_reverse_lut_o,
|
|
211
|
+
sparsity_block_size, n_sparse_blocks,
|
|
212
|
+
reduce_op, triton_block_size))
|
|
212
213
|
|
|
213
214
|
|
|
214
215
|
class _BlocksparseScatterReduceMDI(torch.autograd.Function):
|
|
@@ -353,8 +354,8 @@ class _BlocksparseScatterReduceMDI(torch.autograd.Function):
|
|
|
353
354
|
tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
354
355
|
|
|
355
356
|
|
|
356
|
-
def build_distribution_layout_mdi(idx_bat:
|
|
357
|
-
size_target: torch.Size,
|
|
357
|
+
def build_distribution_layout_mdi(idx_bat: BlksprsTensor, idx_row: BlksprsTensor, idx_col: BlksprsTensor,
|
|
358
|
+
sparsity_layout_idx: Tensor, size_target: torch.Size,
|
|
358
359
|
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
359
360
|
validate_dimensions(idx_bat, idx_col)
|
|
360
361
|
validate_contiguous(idx_bat, idx_col)
|
|
@@ -3,18 +3,19 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
+
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
6
7
|
from blksprs.utils.tools import get_triton_block_size, stride
|
|
7
8
|
from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
|
|
8
9
|
validate_contiguous
|
|
9
10
|
|
|
10
11
|
|
|
11
|
-
def build_distribution_layout(indices:
|
|
12
|
+
def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: Tensor,
|
|
12
13
|
size_target: torch.Size,
|
|
13
14
|
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
14
15
|
"""Builds the sparsity layout of either the source of a gather or the target of a scatter operation.
|
|
15
16
|
|
|
16
17
|
Args:
|
|
17
|
-
indices (
|
|
18
|
+
indices (BlksprsTensor): The block-sparse indices tensor in compressed form used for the gather or scatter operation.
|
|
18
19
|
sparsity_layout_indices (Tensor): The sparsity layout of the indices block-sparse tensor.
|
|
19
20
|
size_target (torch.Size): The size of the block-sparse target tensor in regular form.
|
|
20
21
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
@@ -5,6 +5,7 @@ import triton
|
|
|
5
5
|
from torch import Tensor
|
|
6
6
|
from triton import language as tl
|
|
7
7
|
|
|
8
|
+
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
8
9
|
from blksprs.utils.tools import get_triton_block_size, stride
|
|
9
10
|
from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
|
|
10
11
|
validate_contiguous, validate_sparsity, validate_sparsity_block_size
|
|
@@ -82,14 +83,14 @@ def kernel_sparsity_layout(x,
|
|
|
82
83
|
tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
|
|
83
84
|
|
|
84
85
|
|
|
85
|
-
def build_sparsity_layout_adaption(x:
|
|
86
|
+
def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tensor,
|
|
86
87
|
sparsity_block_size_from: int, sparsity_block_size_to: int,
|
|
87
88
|
triton_block_size: int = None) -> Tensor:
|
|
88
89
|
"""Builds the sparsity layout of a block-sparse tensor in compressed form if a different sparsity block size were
|
|
89
90
|
used.
|
|
90
91
|
|
|
91
92
|
Args:
|
|
92
|
-
x (
|
|
93
|
+
x (BlksprsTensor): A block-sparse tensor in compressed form.
|
|
93
94
|
sparsity_layout_from (Tensor): The sparsity layout of the input block-sparse tensor.
|
|
94
95
|
sparsity_block_size_from (int): The size of the sparsity blocks of the input tensor.
|
|
95
96
|
sparsity_block_size_to (int): The desired size of the sparsity blocks for the resulting layout.
|
blksprs/misc/broadcast_ops.py
CHANGED
|
@@ -3,13 +3,14 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
+
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
6
7
|
from blksprs.utils.tools import get_triton_block_size, stride
|
|
7
8
|
from blksprs.utils.validation import validate_contiguous, validate_device, \
|
|
8
9
|
validate_sparsity_block_size, validate_triton_block_size
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
12
|
-
sparsity_block_size: int, triton_block_size: int = None) ->
|
|
13
|
+
sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
|
|
13
14
|
"""Performs a broadcast and subsequent addition of two dense tensors x and y. Returns a block-sparse tensor in
|
|
14
15
|
compressed form.
|
|
15
16
|
|
|
@@ -21,7 +22,7 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
|
21
22
|
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
22
23
|
|
|
23
24
|
Returns:
|
|
24
|
-
|
|
25
|
+
BlksprsTensor: The result of the operation as a block-sparse tensor in compressed form. Each element o(i, j) of the
|
|
25
26
|
output tensor corresponds to x(i) + y(j).
|
|
26
27
|
|
|
27
28
|
"""
|
|
@@ -70,11 +71,11 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
|
70
71
|
sparsity_block_size,
|
|
71
72
|
triton_block_size))
|
|
72
73
|
|
|
73
|
-
return output
|
|
74
|
+
return BlksprsTensor(output)
|
|
74
75
|
|
|
75
76
|
|
|
76
77
|
def broadcast_sub(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
77
|
-
sparsity_block_size: int, triton_block_size: int = None) ->
|
|
78
|
+
sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
|
|
78
79
|
"""Wrapper for ``broadcast_add`` with negated y.
|
|
79
80
|
|
|
80
81
|
"""
|
blksprs/misc/exp.py
CHANGED
|
@@ -3,12 +3,13 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
+
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
6
7
|
from blksprs.utils.tools import get_triton_block_size, stride
|
|
7
8
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
8
9
|
validate_sparsity_block_size, validate_triton_block_size
|
|
9
10
|
|
|
10
11
|
|
|
11
|
-
def exp(x:
|
|
12
|
+
def exp(x: BlksprsTensor, sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
|
|
12
13
|
"""Applies the element-wise exponential function to a block-sparse tensor.
|
|
13
14
|
|
|
14
15
|
Note:
|
|
@@ -16,12 +17,12 @@ def exp(x: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> T
|
|
|
16
17
|
Consider this when converting back to tensors in regular form.
|
|
17
18
|
|
|
18
19
|
Args:
|
|
19
|
-
x (
|
|
20
|
+
x (BlksprsTensor): A block-sparse tensor in compressed form.
|
|
20
21
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
21
22
|
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
22
23
|
|
|
23
24
|
Returns:
|
|
24
|
-
|
|
25
|
+
BlksprsTensor: The exponential function applied to all elements of the input tensor as a block-sparse tensor in
|
|
25
26
|
compressed form.
|
|
26
27
|
|
|
27
28
|
"""
|
|
@@ -33,7 +34,7 @@ def exp(x: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> T
|
|
|
33
34
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
34
35
|
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
35
36
|
|
|
36
|
-
return _BlocksparseExp.apply(x, sparsity_block_size, triton_block_size)
|
|
37
|
+
return BlksprsTensor(_BlocksparseExp.apply(x, sparsity_block_size, triton_block_size))
|
|
37
38
|
|
|
38
39
|
|
|
39
40
|
class _BlocksparseExp(torch.autograd.Function):
|
blksprs/misc/partitioning.py
CHANGED
|
@@ -2,24 +2,25 @@ import torch
|
|
|
2
2
|
from torch import Tensor
|
|
3
3
|
|
|
4
4
|
from blksprs.ops.repeat import forward_flow
|
|
5
|
+
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
5
6
|
|
|
6
7
|
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
|
|
7
8
|
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
8
9
|
|
|
9
10
|
|
|
10
|
-
def split(x:
|
|
11
|
-
sparsity_block_size: int, triton_block_size: int = None) -> (
|
|
11
|
+
def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
12
|
+
sparsity_block_size: int, triton_block_size: int = None) -> (BlksprsTensor, Tensor):
|
|
12
13
|
"""Splits a block-sparse tensor in compressed form along the last dimension into partitions.
|
|
13
14
|
|
|
14
15
|
Args:
|
|
15
|
-
x (
|
|
16
|
+
x (BlksprsTensor): A block-sparse tensor in compressed form.
|
|
16
17
|
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
17
18
|
partitions (int): The number of partitions to split the block-sparse tensor into.
|
|
18
19
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
19
20
|
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
20
21
|
|
|
21
22
|
Returns:
|
|
22
|
-
|
|
23
|
+
BlksprsTensor: The block-sparse tensor split into partitions in compressed form.
|
|
23
24
|
Tensor: The sparsity layout of the output tensor.
|
|
24
25
|
|
|
25
26
|
"""
|
|
@@ -53,8 +54,8 @@ def split(x: Tensor, sparsity_layout: Tensor, partitions: int,
|
|
|
53
54
|
|
|
54
55
|
validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
|
|
55
56
|
|
|
56
|
-
return _BlocksparseSplit.apply(x, sparsity_layout_output, sparsity_lut, sparsity_reverse_lut, partitions,
|
|
57
|
-
sparsity_block_size, n_sparse_blocks, triton_block_size), sparsity_layout_output
|
|
57
|
+
return BlksprsTensor(_BlocksparseSplit.apply(x, sparsity_layout_output, sparsity_lut, sparsity_reverse_lut, partitions,
|
|
58
|
+
sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_output
|
|
58
59
|
|
|
59
60
|
|
|
60
61
|
class _BlocksparseSplit(torch.autograd.Function):
|
|
@@ -79,19 +80,19 @@ class _BlocksparseSplit(torch.autograd.Function):
|
|
|
79
80
|
sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None
|
|
80
81
|
|
|
81
82
|
|
|
82
|
-
def merge(x:
|
|
83
|
-
sparsity_block_size: int, triton_block_size: int = None) -> (
|
|
83
|
+
def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
84
|
+
sparsity_block_size: int, triton_block_size: int = None) -> (BlksprsTensor, Tensor):
|
|
84
85
|
"""Merges the specified partitions of a block-sparse tensor in compressed form along the last dimension.
|
|
85
86
|
|
|
86
87
|
Args:
|
|
87
|
-
x (
|
|
88
|
+
x (BlksprsTensor): A block-sparse tensor in compressed form.
|
|
88
89
|
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
89
90
|
partitions (int): The number of partitions to be merged.
|
|
90
91
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
91
92
|
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
92
93
|
|
|
93
94
|
Returns:
|
|
94
|
-
|
|
95
|
+
BlksprsTensor: The merged block-sparse tensor in compressed form.
|
|
95
96
|
Tensor: The sparsity layout of the output tensor.
|
|
96
97
|
|
|
97
98
|
"""
|
|
@@ -127,8 +128,8 @@ def merge(x: Tensor, sparsity_layout: Tensor, partitions: int,
|
|
|
127
128
|
|
|
128
129
|
validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
|
|
129
130
|
|
|
130
|
-
return _BlocksparseMerge.apply(x, sparsity_layout_output, sparsity_lut, sparsity_reverse_lut, partitions,
|
|
131
|
-
sparsity_block_size, n_sparse_blocks, triton_block_size), sparsity_layout_output
|
|
131
|
+
return BlksprsTensor(_BlocksparseMerge.apply(x, sparsity_layout_output, sparsity_lut, sparsity_reverse_lut, partitions,
|
|
132
|
+
sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_output
|
|
132
133
|
|
|
133
134
|
|
|
134
135
|
class _BlocksparseMerge(torch.autograd.Function):
|
blksprs/misc/row_wise.py
CHANGED
|
@@ -3,13 +3,14 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
+
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
6
7
|
from blksprs.utils.tools import get_triton_block_size, stride
|
|
7
8
|
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, validate_sparsity, \
|
|
8
9
|
validate_sparsity_block_size, validate_triton_block_size
|
|
9
10
|
|
|
10
11
|
|
|
11
|
-
def row_wise_sum(x:
|
|
12
|
-
flag_slice_only: bool = False, triton_block_size: int = None) ->
|
|
12
|
+
def row_wise_sum(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
13
|
+
flag_slice_only: bool = False, triton_block_size: int = None) -> (BlksprsTensor, Tensor):
|
|
13
14
|
"""Computes the row-wise sum of a block-sparse tensor.
|
|
14
15
|
|
|
15
16
|
Returns a block-sparse tensor in compressed form with only one block per row, where the first entry contains the sum
|
|
@@ -19,7 +20,7 @@ def row_wise_sum(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
|
19
20
|
If ``flag_slice_only`` is set the output will be of shape ``[x.size(0), x.size(1), 1]``.
|
|
20
21
|
|
|
21
22
|
Args:
|
|
22
|
-
x (
|
|
23
|
+
x (BlksprsTensor): A block-sparse tensor in compressed form.
|
|
23
24
|
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
24
25
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
25
26
|
flag_slice_only (bool, optional): If set the output will be of shape ``[x.size(0), x.size(1), 1]``
|
|
@@ -27,7 +28,7 @@ def row_wise_sum(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
|
27
28
|
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
28
29
|
|
|
29
30
|
Returns:
|
|
30
|
-
tuple[
|
|
31
|
+
tuple[BlksprsTensor, Tensor]: A tuple containing a block-sparse tensor in compressed form containing the row-wise sum
|
|
31
32
|
of the input and the sparsity layout of the output tensor.
|
|
32
33
|
|
|
33
34
|
"""
|
|
@@ -85,7 +86,7 @@ def row_wise_sum(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
|
85
86
|
sparsity_reverse_lut_output,
|
|
86
87
|
triton_block_size))
|
|
87
88
|
|
|
88
|
-
return (output, sparsity_layout_output
|
|
89
|
+
return BlksprsTensor(output), sparsity_layout_output
|
|
89
90
|
|
|
90
91
|
|
|
91
92
|
@triton.jit
|
|
@@ -131,8 +132,8 @@ def kernel_blocksparse_row_wise_sum(x,
|
|
|
131
132
|
tl.atomic_add(o + o_idx, buf, o_msk)
|
|
132
133
|
|
|
133
134
|
|
|
134
|
-
def row_wise_max(x:
|
|
135
|
-
flag_slice_only: bool = False, triton_block_size: int = None) ->
|
|
135
|
+
def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
136
|
+
flag_slice_only: bool = False, triton_block_size: int = None) -> (BlksprsTensor, Tensor):
|
|
136
137
|
"""Computes the row-wise max of a block-sparse tensor.
|
|
137
138
|
|
|
138
139
|
Returns a block-sparse tensor in compressed form with only one block per row, where the first entry contains the
|
|
@@ -142,7 +143,7 @@ def row_wise_max(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
|
142
143
|
If ``flag_slice_only`` is set the output will be of shape ``[x.size(0), x.size(1), 1]``.
|
|
143
144
|
|
|
144
145
|
Args:
|
|
145
|
-
x (
|
|
146
|
+
x (BlksprsTensor): A block-sparse tensor in compressed form.
|
|
146
147
|
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
147
148
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
148
149
|
flag_slice_only (bool, optional): If set the output will be of shape ``[x.size(0), x.size(1), 1]``
|
|
@@ -150,7 +151,7 @@ def row_wise_max(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
|
150
151
|
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
151
152
|
|
|
152
153
|
Returns:
|
|
153
|
-
tuple[
|
|
154
|
+
tuple[BlksprsTensor, Tensor]: A tuple containing a block-sparse tensor in compressed form containing the row-wise max
|
|
154
155
|
of the input and the sparsity layout of the output tensor.
|
|
155
156
|
|
|
156
157
|
"""
|
|
@@ -208,7 +209,7 @@ def row_wise_max(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
|
208
209
|
sparsity_reverse_lut_output,
|
|
209
210
|
triton_block_size))
|
|
210
211
|
|
|
211
|
-
return output, sparsity_layout_output
|
|
212
|
+
return BlksprsTensor(output), sparsity_layout_output
|
|
212
213
|
|
|
213
214
|
|
|
214
215
|
@triton.jit
|
|
@@ -254,19 +255,19 @@ def kernel_blocksparse_row_wise_max(x,
|
|
|
254
255
|
tl.atomic_max(o + o_idx, buf, o_msk)
|
|
255
256
|
|
|
256
257
|
|
|
257
|
-
def row_wise_add(x:
|
|
258
|
-
sparsity_block_size: int, triton_block_size: int = None) ->
|
|
258
|
+
def row_wise_add(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
|
|
259
|
+
sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
|
|
259
260
|
"""For each row in ``y`` adds the value to each value in the corresponding row of the block-sparse tensor ``x``.
|
|
260
261
|
|
|
261
262
|
Args:
|
|
262
|
-
x (
|
|
263
|
+
x (BlksprsTensor): A block-sparse tensor in compressed form.
|
|
263
264
|
sparsity_layout_x (Tensor): The sparsity layout of the block-sparse tensor.
|
|
264
|
-
y (
|
|
265
|
+
y (BlksprsTensor): A block-sparse tensor in compressed form with only one value per row and a single column of sparse blocks.
|
|
265
266
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
266
267
|
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
267
268
|
|
|
268
269
|
Returns:
|
|
269
|
-
|
|
270
|
+
BlksprsTensor: The values of ``x`` with the first value of ``y`` in each row added to them as a block-sparse tensor in
|
|
270
271
|
compressed form.
|
|
271
272
|
|
|
272
273
|
"""
|
|
@@ -319,11 +320,11 @@ def row_wise_add(x: Tensor, sparsity_layout_x: Tensor, y: Tensor,
|
|
|
319
320
|
triton_block_size
|
|
320
321
|
))
|
|
321
322
|
|
|
322
|
-
return output
|
|
323
|
+
return BlksprsTensor(output)
|
|
323
324
|
|
|
324
325
|
|
|
325
|
-
def row_wise_sub(x:
|
|
326
|
-
sparsity_block_size: int, triton_block_size: int = None) ->
|
|
326
|
+
def row_wise_sub(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
|
|
327
|
+
sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
|
|
327
328
|
"""Wrapper for ``row_wise_add`` with negated y.
|
|
328
329
|
|
|
329
330
|
"""
|
blksprs/ops/conversion.py
CHANGED
|
@@ -6,23 +6,27 @@ from torch import Tensor
|
|
|
6
6
|
from triton import language as tl
|
|
7
7
|
|
|
8
8
|
from blksprs.layouting.sparsity_layout import build_sparsity_layout_adaption
|
|
9
|
+
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
9
10
|
from blksprs.utils.tools import get_triton_block_size, stride
|
|
10
11
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
11
12
|
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size, validate_sparsity_dense
|
|
12
13
|
|
|
13
14
|
|
|
14
|
-
def from_blksprs(x:
|
|
15
|
+
def from_blksprs(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
|
|
15
16
|
triton_block_size: int = None) -> Tensor:
|
|
17
|
+
"""Wrapper for ``to_dense``.
|
|
18
|
+
|
|
19
|
+
"""
|
|
16
20
|
return to_dense(x, sparsity_layout, sparsity_block_size, fill_value, triton_block_size)
|
|
17
21
|
|
|
18
22
|
|
|
19
|
-
def to_dense(x:
|
|
23
|
+
def to_dense(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
|
|
20
24
|
triton_block_size: int = None) -> Tensor:
|
|
21
25
|
"""Converts a block-sparse tensor in compressed form to a block-sparse tensor in regular form based on the given
|
|
22
26
|
sparsity layout.
|
|
23
27
|
|
|
24
28
|
Args:
|
|
25
|
-
x (
|
|
29
|
+
x (BlksprsTensor): A block-sparse tensor in compressed form.
|
|
26
30
|
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
27
31
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
28
32
|
fill_value (float): The value to fill the resulting dense tensor with where the block-sparse tensor is not
|
|
@@ -50,12 +54,12 @@ def to_dense(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_
|
|
|
50
54
|
validate_contiguous(sparsity_reverse_lut)
|
|
51
55
|
|
|
52
56
|
if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
|
|
53
|
-
return x
|
|
57
|
+
return BlksprsTensor(x)
|
|
54
58
|
|
|
55
|
-
return _BlocksparseToDense.apply(x,
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
+
return BlksprsTensor(_BlocksparseToDense.apply(x,
|
|
60
|
+
sparsity_layout, sparsity_reverse_lut,
|
|
61
|
+
sparsity_block_size, fill_value,
|
|
62
|
+
triton_block_size))
|
|
59
63
|
|
|
60
64
|
|
|
61
65
|
class _BlocksparseToDense(torch.autograd.Function):
|
|
@@ -150,11 +154,15 @@ class _BlocksparseToDense(torch.autograd.Function):
|
|
|
150
154
|
|
|
151
155
|
|
|
152
156
|
def to_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
153
|
-
triton_block_size: int = None) ->
|
|
157
|
+
triton_block_size: int = None) -> BlksprsTensor:
|
|
158
|
+
"""Wrapper for ``to_sparse``.
|
|
159
|
+
|
|
160
|
+
"""
|
|
154
161
|
return to_sparse(x, sparsity_layout, sparsity_block_size, triton_block_size)
|
|
155
162
|
|
|
156
163
|
|
|
157
|
-
def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
164
|
+
def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
165
|
+
triton_block_size: int = None) -> BlksprsTensor:
|
|
158
166
|
"""Converts a block-sparse tensor in regular form to a block-sparse tensor in compressed form based on the given
|
|
159
167
|
sparsity layout.
|
|
160
168
|
|
|
@@ -165,7 +173,7 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, trit
|
|
|
165
173
|
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
166
174
|
|
|
167
175
|
Returns:
|
|
168
|
-
|
|
176
|
+
BlksprsTensor: The block-sparse tensor converted to compressed form.
|
|
169
177
|
|
|
170
178
|
"""
|
|
171
179
|
x = x.contiguous()
|
|
@@ -183,12 +191,12 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, trit
|
|
|
183
191
|
validate_contiguous(sparsity_layout, sparsity_lut)
|
|
184
192
|
|
|
185
193
|
if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
|
|
186
|
-
return x
|
|
194
|
+
return BlksprsTensor(x)
|
|
187
195
|
|
|
188
|
-
return _BlocksparseToSparse.apply(x,
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
196
|
+
return BlksprsTensor(_BlocksparseToSparse.apply(x,
|
|
197
|
+
sparsity_layout, sparsity_lut,
|
|
198
|
+
sparsity_block_size, n_sparse_blocks,
|
|
199
|
+
triton_block_size))
|
|
192
200
|
|
|
193
201
|
|
|
194
202
|
class _BlocksparseToSparse(torch.autograd.Function):
|
|
@@ -280,13 +288,14 @@ class _BlocksparseToSparse(torch.autograd.Function):
|
|
|
280
288
|
tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
|
|
281
289
|
|
|
282
290
|
|
|
283
|
-
def adapt_layout(x:
|
|
284
|
-
|
|
291
|
+
def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_size_from: int,
|
|
292
|
+
sparsity_block_size_to: int,
|
|
293
|
+
preprocess_data: dict = None, triton_block_size: int = None) -> BlksprsTensor:
|
|
285
294
|
"""Adapts the sparsity layout of a block-sparse tensor, resulting in a new block-sparse tensor in compressed form
|
|
286
295
|
conforming to the new sparsity layout (and sparsity block size) definition.
|
|
287
296
|
|
|
288
297
|
Args:
|
|
289
|
-
x (
|
|
298
|
+
x (BlksprsTensor): A block-sparse tensor in compressed form.
|
|
290
299
|
sparsity_layout_from (Tensor): The sparsity layout of the input block-sparse tensor.
|
|
291
300
|
sparsity_block_size_from (int): The size of the sparsity blocks of the input sparsity layout.
|
|
292
301
|
sparsity_block_size_to (int): The size of the sparsity blocks of the output sparsity layout.
|
|
@@ -294,7 +303,7 @@ def adapt_layout(x: Tensor, sparsity_layout_from: Tensor, sparsity_block_size_fr
|
|
|
294
303
|
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
295
304
|
|
|
296
305
|
Returns:
|
|
297
|
-
|
|
306
|
+
BlksprsTensor: The block-sparse tensor in compressed form with the adapted sparsity layout and sparsity block size.
|
|
298
307
|
|
|
299
308
|
"""
|
|
300
309
|
x = x.contiguous()
|
|
@@ -339,12 +348,13 @@ def adapt_layout(x: Tensor, sparsity_layout_from: Tensor, sparsity_block_size_fr
|
|
|
339
348
|
validate_contiguous(sparsity_layout_to, sparsity_reverse_lut_from, sparsity_lut_to)
|
|
340
349
|
|
|
341
350
|
if (sparsity_block_size_from == sparsity_block_size_to) and torch.equal(sparsity_layout_from, sparsity_layout_to):
|
|
342
|
-
return x
|
|
351
|
+
return BlksprsTensor(x)
|
|
343
352
|
|
|
344
|
-
return _BlocksparseAdaptLayout.apply(x,
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
353
|
+
return BlksprsTensor(_BlocksparseAdaptLayout.apply(x,
|
|
354
|
+
sparsity_layout_from, sparsity_reverse_lut_from,
|
|
355
|
+
sparsity_block_size_from,
|
|
356
|
+
sparsity_layout_to, sparsity_lut_to, sparsity_block_size_to,
|
|
357
|
+
n_sparse_blocks_to, min_sparsity_block_size, triton_block_size))
|
|
348
358
|
|
|
349
359
|
|
|
350
360
|
class _BlocksparseAdaptLayout(torch.autograd.Function):
|
blksprs/ops/distribution.py
CHANGED
|
@@ -3,25 +3,26 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
+
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
6
7
|
from blksprs.utils.tools import get_triton_block_size, stride
|
|
7
8
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
8
9
|
validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
|
|
9
10
|
|
|
10
11
|
|
|
11
|
-
def gather(src:
|
|
12
|
-
sparsity_block_size: int, triton_block_size: int = None) ->
|
|
12
|
+
def gather(src: BlksprsTensor, sparsity_layout_src: Tensor, idx: BlksprsTensor, sparsity_layout_idx: Tensor,
|
|
13
|
+
sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
|
|
13
14
|
"""Applies a gather operation on a block-sparse tensor in compressed form.
|
|
14
15
|
|
|
15
16
|
Args:
|
|
16
|
-
src (
|
|
17
|
+
src (BlksprsTensor): The source block-sparse tensor in compressed form to gather from.
|
|
17
18
|
sparsity_layout_src (Tensor): The sparsity layout of the source block-sparse tensor.
|
|
18
|
-
idx (
|
|
19
|
+
idx (BlksprsTensor): The block-sparse indices tensor in compressed form specifying how to gather from the source tensor.
|
|
19
20
|
sparsity_layout_idx (Tensor): The sparsity layout of the indices block-sparse tensor.
|
|
20
21
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
21
22
|
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
22
23
|
|
|
23
24
|
Returns:
|
|
24
|
-
|
|
25
|
+
BlksprsTensor: The result of the gather operation as a block-sparse tensor in compressed form.
|
|
25
26
|
|
|
26
27
|
"""
|
|
27
28
|
src = src.contiguous()
|
|
@@ -45,9 +46,9 @@ def gather(src: Tensor, sparsity_layout_src: Tensor, idx: Tensor, sparsity_layou
|
|
|
45
46
|
validate_contiguous(sparsity_layout_src, sparsity_reverse_lut_x,
|
|
46
47
|
sparsity_layout_idx, sparsity_lut_i)
|
|
47
48
|
|
|
48
|
-
return _BlocksparseGather.apply(src, sparsity_layout_src, sparsity_reverse_lut_x,
|
|
49
|
+
return BlksprsTensor(_BlocksparseGather.apply(src, sparsity_layout_src, sparsity_reverse_lut_x,
|
|
49
50
|
idx, sparsity_layout_idx, sparsity_lut_i,
|
|
50
|
-
sparsity_block_size, triton_block_size)
|
|
51
|
+
sparsity_block_size, triton_block_size))
|
|
51
52
|
|
|
52
53
|
|
|
53
54
|
class _BlocksparseGather(torch.autograd.Function):
|
|
@@ -168,10 +169,10 @@ class _BlocksparseGather(torch.autograd.Function):
|
|
|
168
169
|
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
169
170
|
|
|
170
171
|
|
|
171
|
-
def scatter(src:
|
|
172
|
-
idx:
|
|
172
|
+
def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
173
|
+
idx: BlksprsTensor,
|
|
173
174
|
sparsity_layout_tgt: Tensor,
|
|
174
|
-
sparsity_block_size: int, triton_block_size: int = None) ->
|
|
175
|
+
sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
|
|
175
176
|
"""Wrapper for ``scatter_reduce`` with ``reduce_op="none"``.
|
|
176
177
|
|
|
177
178
|
"""
|
|
@@ -182,17 +183,17 @@ def scatter(src: Tensor, sparsity_layout_src: Tensor,
|
|
|
182
183
|
reduce_op="none", triton_block_size=triton_block_size)
|
|
183
184
|
|
|
184
185
|
|
|
185
|
-
def scatter_reduce(src:
|
|
186
|
-
idx:
|
|
186
|
+
def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
187
|
+
idx: BlksprsTensor,
|
|
187
188
|
sparsity_layout_tgt: Tensor,
|
|
188
189
|
sparsity_block_size: int,
|
|
189
|
-
reduce_op: str = "sum", triton_block_size: int = None) ->
|
|
190
|
+
reduce_op: str = "sum", triton_block_size: int = None) -> BlksprsTensor:
|
|
190
191
|
"""Applies a scatter operation on a block-sparse tensor in compressed form.
|
|
191
192
|
|
|
192
193
|
Args:
|
|
193
|
-
src (
|
|
194
|
+
src (BlksprsTensor): The source block-sparse tensor in compressed form to scatter from.
|
|
194
195
|
sparsity_layout_src (Tensor): The sparsity layout of the source block-sparse tensor.
|
|
195
|
-
idx (
|
|
196
|
+
idx (BlksprsTensor): The block-sparse indices tensor in compressed form specifying how to scatter to the target tensor.
|
|
196
197
|
sparsity_layout_tgt (Tensor): The sparsity layout of the target block-sparse tensor.
|
|
197
198
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
198
199
|
reduce_op (str, optional): The reduction operation to apply during the scatter operation (default ``"sum"``).
|
|
@@ -200,7 +201,7 @@ def scatter_reduce(src: Tensor, sparsity_layout_src: Tensor,
|
|
|
200
201
|
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
201
202
|
|
|
202
203
|
Returns:
|
|
203
|
-
|
|
204
|
+
BlksprsTensor: The result of the scatter operation as a block-sparse tensor in compressed form.
|
|
204
205
|
|
|
205
206
|
"""
|
|
206
207
|
src = src.contiguous()
|
|
@@ -229,11 +230,11 @@ def scatter_reduce(src: Tensor, sparsity_layout_src: Tensor,
|
|
|
229
230
|
validate_contiguous(sparsity_layout_src, sparsity_lut_x,
|
|
230
231
|
sparsity_layout_tgt, sparsity_reverse_lut_o)
|
|
231
232
|
|
|
232
|
-
return _BlocksparseScatterReduce.apply(src, sparsity_layout_src, sparsity_lut_x,
|
|
233
|
+
return BlksprsTensor(_BlocksparseScatterReduce.apply(src, sparsity_layout_src, sparsity_lut_x,
|
|
233
234
|
idx,
|
|
234
235
|
sparsity_layout_tgt, sparsity_reverse_lut_o,
|
|
235
236
|
sparsity_block_size, n_sparse_blocks,
|
|
236
|
-
reduce_op, triton_block_size)
|
|
237
|
+
reduce_op, triton_block_size))
|
|
237
238
|
|
|
238
239
|
|
|
239
240
|
class _BlocksparseScatterReduce(torch.autograd.Function):
|
blksprs/ops/matmul.py
CHANGED
|
@@ -4,22 +4,23 @@ from torch import Tensor
|
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
6
|
from blksprs.ops.transpose import transpose
|
|
7
|
+
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
7
8
|
from blksprs.utils.tools import get_triton_block_size, stride
|
|
8
9
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
9
10
|
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size, validate_dtype_float
|
|
10
11
|
|
|
11
12
|
|
|
12
|
-
def matmul(x:
|
|
13
|
-
y:
|
|
13
|
+
def matmul(x: BlksprsTensor, sparsity_layout_x: Tensor,
|
|
14
|
+
y: BlksprsTensor, sparsity_layout_y: Tensor,
|
|
14
15
|
sparsity_layout_output: Tensor,
|
|
15
|
-
sparsity_block_size: int, triton_block_size: int = None) ->
|
|
16
|
+
sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
|
|
16
17
|
"""Performs matrix multiplication between two block-sparse tensors.
|
|
17
18
|
|
|
18
19
|
The sparsity layout of the output tensor is used to only calculate blocks that will be present in the output.
|
|
19
20
|
|
|
20
21
|
Args:
|
|
21
|
-
x (
|
|
22
|
-
y (
|
|
22
|
+
x (BlksprsTensor): A block-sparse tensor in compressed form.
|
|
23
|
+
y (BlksprsTensor): A block-sparse tensor in compressed form.
|
|
23
24
|
sparsity_layout_x (Tensor): The sparsity layout of the first block-sparse tensor.
|
|
24
25
|
sparsity_layout_y (Tensor): The sparsity layout of the second block-sparse tensor.
|
|
25
26
|
sparsity_layout_output (Tensor): The sparsity layout of the output tensor.
|
|
@@ -27,7 +28,7 @@ def matmul(x: Tensor, sparsity_layout_x: Tensor,
|
|
|
27
28
|
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
28
29
|
|
|
29
30
|
Returns:
|
|
30
|
-
|
|
31
|
+
BlksprsTensor: The result of the matrix multiplication as a block-sparse tensor in compressed form.
|
|
31
32
|
|
|
32
33
|
"""
|
|
33
34
|
x = x.contiguous()
|
|
@@ -61,13 +62,13 @@ def matmul(x: Tensor, sparsity_layout_x: Tensor,
|
|
|
61
62
|
sparsity_layout_y, sparsity_reverse_lut_y,
|
|
62
63
|
sparsity_layout_output, sparsity_lut_o)
|
|
63
64
|
|
|
64
|
-
return _BlocksparseMatmulSSS.apply(x, y,
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
65
|
+
return BlksprsTensor(_BlocksparseMatmulSSS.apply(x, y,
|
|
66
|
+
sparsity_layout_x, sparsity_reverse_lut_x,
|
|
67
|
+
sparsity_layout_y, sparsity_reverse_lut_y,
|
|
68
|
+
sparsity_layout_output, sparsity_lut_o,
|
|
69
|
+
sparsity_block_size,
|
|
70
|
+
n_sparse_blocks,
|
|
71
|
+
triton_block_size))
|
|
71
72
|
|
|
72
73
|
|
|
73
74
|
class _BlocksparseMatmulSSS(torch.autograd.Function):
|
blksprs/ops/repeat.py
CHANGED
|
@@ -3,14 +3,15 @@ import triton
|
|
|
3
3
|
from triton import language as tl
|
|
4
4
|
from torch import Tensor
|
|
5
5
|
|
|
6
|
+
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
6
7
|
from blksprs.utils.tools import get_triton_block_size, stride
|
|
7
8
|
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
|
|
8
9
|
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
9
10
|
|
|
10
11
|
|
|
11
|
-
def repeat(x:
|
|
12
|
+
def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
|
|
12
13
|
sparsity_block_size: int, sparsity_layout_output: Tensor = None, triton_block_size: int = None) -> (
|
|
13
|
-
|
|
14
|
+
BlksprsTensor, Tensor):
|
|
14
15
|
"""Repeats a block-spare tensor in compressed form according to the given repeats.
|
|
15
16
|
|
|
16
17
|
Repeats is a 3-tuple of integers, where each integer represents the number of times the tensor should be repeated in
|
|
@@ -22,7 +23,7 @@ def repeat(x: Tensor, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
|
|
|
22
23
|
them to be sparse.
|
|
23
24
|
|
|
24
25
|
Args:
|
|
25
|
-
x (
|
|
26
|
+
x (BlksprsTensor): A block-sparse tensor in compressed form.
|
|
26
27
|
sparsity_layout_x (Tensor): The sparsity layout of the block-sparse tensor.
|
|
27
28
|
repeats (tuple[int, int, int]): The number of times the tensor should be repeated in the first, second and
|
|
28
29
|
third dimension respectively.
|
|
@@ -31,7 +32,7 @@ def repeat(x: Tensor, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
|
|
|
31
32
|
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
32
33
|
|
|
33
34
|
Returns:
|
|
34
|
-
|
|
35
|
+
BlksprsTensor: A block-sparse tensor in compressed form containing the repeated values.
|
|
35
36
|
Tensor: The sparsity layout of the resulting output tensor.
|
|
36
37
|
|
|
37
38
|
"""
|
|
@@ -63,14 +64,14 @@ def repeat(x: Tensor, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
|
|
|
63
64
|
|
|
64
65
|
validate_contiguous(sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
|
|
65
66
|
|
|
66
|
-
return _BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
|
|
67
|
-
sparsity_block_size, n_sparse_blocks, triton_block_size), sparsity_layout_o
|
|
67
|
+
return BlksprsTensor(_BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
|
|
68
|
+
sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_o
|
|
68
69
|
|
|
69
70
|
|
|
70
|
-
def repeat_interleave(x:
|
|
71
|
+
def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
|
|
71
72
|
sparsity_block_size: int, sparsity_layout_output: Tensor = None,
|
|
72
73
|
triton_block_size: int = None) -> (
|
|
73
|
-
|
|
74
|
+
BlksprsTensor, Tensor):
|
|
74
75
|
"""Repeats and interleaves the block-sparse tensor in compressed form.
|
|
75
76
|
|
|
76
77
|
Repeats each matrix contained in the tensors by ``repeats`` amount and places them consecutively in the output
|
|
@@ -81,7 +82,7 @@ def repeat_interleave(x: Tensor, sparsity_layout_x: Tensor, repeats: int,
|
|
|
81
82
|
non-sparse blocks will be filled.
|
|
82
83
|
|
|
83
84
|
Args:
|
|
84
|
-
x (
|
|
85
|
+
x (BlksprsTensor): A block-sparse tensor in compressed form.
|
|
85
86
|
sparsity_layout_x (Tensor): The sparsity layout of the block-sparse tensor.
|
|
86
87
|
repeats (int): The number of times to repeat the matrices.
|
|
87
88
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
@@ -89,7 +90,7 @@ def repeat_interleave(x: Tensor, sparsity_layout_x: Tensor, repeats: int,
|
|
|
89
90
|
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
90
91
|
|
|
91
92
|
Returns:
|
|
92
|
-
|
|
93
|
+
BlksprsTensor: A block-sparse tensor in compressed form containing the repeated and interleaved matrices.
|
|
93
94
|
Tensor: The sparsity layout of the resulting output tensor.
|
|
94
95
|
|
|
95
96
|
"""
|
|
@@ -121,8 +122,8 @@ def repeat_interleave(x: Tensor, sparsity_layout_x: Tensor, repeats: int,
|
|
|
121
122
|
|
|
122
123
|
validate_contiguous(sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
|
|
123
124
|
|
|
124
|
-
return _BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
|
|
125
|
-
sparsity_block_size, n_sparse_blocks, triton_block_size), sparsity_layout_o
|
|
125
|
+
return BlksprsTensor(_BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
|
|
126
|
+
sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_o
|
|
126
127
|
|
|
127
128
|
|
|
128
129
|
class _BlocksparseRepeat(torch.autograd.Function):
|
blksprs/ops/softmax.py
CHANGED
|
@@ -5,25 +5,26 @@ from triton import language as tl
|
|
|
5
5
|
|
|
6
6
|
from blksprs.misc.exp import exp
|
|
7
7
|
from blksprs.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
|
|
8
|
+
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
8
9
|
from blksprs.utils.tools import get_triton_block_size, stride
|
|
9
10
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
10
11
|
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
11
12
|
|
|
12
13
|
|
|
13
|
-
def softmax(x:
|
|
14
|
+
def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
|
|
14
15
|
"""Computes the softmax of a block-sparse tensor in compressed form.
|
|
15
16
|
|
|
16
17
|
Note:
|
|
17
18
|
Sparse blocks are not considered for the calculation of the softmax, i.e., all values are assumed to be ``-inf``.
|
|
18
19
|
|
|
19
20
|
Args:
|
|
20
|
-
x (
|
|
21
|
+
x (BlksprsTensor): A block-sparse tensor in compressed form.
|
|
21
22
|
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
22
23
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
23
24
|
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
24
25
|
|
|
25
26
|
Returns:
|
|
26
|
-
|
|
27
|
+
BlksprsTensor: The result of the softmax operation as a block-sparse tensor in compressed form.
|
|
27
28
|
|
|
28
29
|
"""
|
|
29
30
|
x = x.contiguous()
|
|
@@ -45,10 +46,10 @@ def softmax(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton
|
|
|
45
46
|
|
|
46
47
|
validate_contiguous(sparsity_layout, sparsity_lut, sparsity_reverse_lut_rws)
|
|
47
48
|
|
|
48
|
-
return _BlocksparseSoftmax.apply(x, sparsity_layout,
|
|
49
|
+
return BlksprsTensor(_BlocksparseSoftmax.apply(x, sparsity_layout,
|
|
49
50
|
sparsity_lut,
|
|
50
51
|
sparsity_reverse_lut_rws,
|
|
51
|
-
sparsity_block_size, triton_block_size)
|
|
52
|
+
sparsity_block_size, triton_block_size))
|
|
52
53
|
|
|
53
54
|
|
|
54
55
|
class _BlocksparseSoftmax(torch.autograd.Function):
|
blksprs/ops/transpose.py
CHANGED
|
@@ -3,26 +3,27 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
+
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
6
7
|
from blksprs.utils.tools import get_triton_block_size, stride
|
|
7
8
|
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
|
|
8
9
|
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
9
10
|
|
|
10
11
|
|
|
11
|
-
def transpose(x:
|
|
12
|
-
|
|
12
|
+
def transpose(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> (
|
|
13
|
+
BlksprsTensor, Tensor):
|
|
13
14
|
"""Transposes a block-sparse tensor in compressed form.
|
|
14
15
|
|
|
15
16
|
Note:
|
|
16
17
|
Returns the transposed tensor and the sparsity layout of the transposed tensor.
|
|
17
18
|
|
|
18
19
|
Args:
|
|
19
|
-
x (
|
|
20
|
+
x (BlksprsTensor): A block-sparse tensor in compressed form.
|
|
20
21
|
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
21
22
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
22
23
|
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
23
24
|
|
|
24
25
|
Returns:
|
|
25
|
-
|
|
26
|
+
BlksprsTensor: The transposed block-sparse tensor in compressed form.
|
|
26
27
|
Tensor: The sparsity layout of the transposed tensor.
|
|
27
28
|
|
|
28
29
|
"""
|
|
@@ -49,8 +50,8 @@ def transpose(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, trit
|
|
|
49
50
|
|
|
50
51
|
validate_contiguous(sparsity_layout_t, sparsity_lut, sparsity_reverse_lut)
|
|
51
52
|
|
|
52
|
-
return _BlocksparseTranspose.apply(x, sparsity_layout_t, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
53
|
-
n_sparse_blocks, triton_block_size), sparsity_layout_t
|
|
53
|
+
return BlksprsTensor(_BlocksparseTranspose.apply(x, sparsity_layout_t, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
54
|
+
n_sparse_blocks, triton_block_size)), sparsity_layout_t
|
|
54
55
|
|
|
55
56
|
|
|
56
57
|
class _BlocksparseTranspose(torch.autograd.Function):
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 1.8.
|
|
3
|
+
Version: 1.8.2
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
blksprs/__init__.py,sha256=np0msosWMaZNVVfuFGt8rE6HZURyIald391dKAs1dSQ,1093
|
|
2
|
+
blksprs/experimental/distribution_mdi.py,sha256=HaRUu6LTWATzjuHWgddIUE-0fgY-O87STpJO4JY7k_8,20357
|
|
3
|
+
blksprs/layouting/distribution_layout.py,sha256=wmj1SwWyY_fhbvMmh6AXrR77LoSp6xLwUWCCyO9i5lk,4239
|
|
4
|
+
blksprs/layouting/sparsity_layout.py,sha256=-sScIn4hhG35j9BXytrojEzp8jnFkMargJjtivPV1fc,9755
|
|
5
|
+
blksprs/misc/broadcast_ops.py,sha256=cPtRJa3pkZfY1QG51CJ-zDn4SK-CRpX5LEXoKGGMvRU,5418
|
|
6
|
+
blksprs/misc/exp.py,sha256=FnSFosBfJHuiEbD0MD-i4axLghRn4a0f8KvHXrKBB6M,3802
|
|
7
|
+
blksprs/misc/partitioning.py,sha256=K0ExR2a3W62d_9xxCJzsdJDLgtbxTI6P8loOOBdhPzE,7674
|
|
8
|
+
blksprs/misc/row_wise.py,sha256=SvJuNww-_QoVKTyTjMvjmzHlBuUlTKamkuq_rKzwAqs,17081
|
|
9
|
+
blksprs/ops/conversion.py,sha256=ol-iV45wDzp9G1dJEkY53EdrvnmHzcl7QQmPJ-xqQTs,22410
|
|
10
|
+
blksprs/ops/distribution.py,sha256=fXZV6UegCVpIwzh-A825OSYClHWu5k0UMYdO2UGDUpM,17067
|
|
11
|
+
blksprs/ops/matmul.py,sha256=yh2ZnO0ZltT1AgadiFP0vX28YJ4n74xO-I_5vFUmOmA,11452
|
|
12
|
+
blksprs/ops/repeat.py,sha256=IvSIRbuyFn0b57LObymLgup0LqlWQ3ndIw-QuiYQcaU,14564
|
|
13
|
+
blksprs/ops/softmax.py,sha256=D9wITz3KB24QXGGjgn_RLQ0Iiq_SjX0bTbUyv9479uU,12094
|
|
14
|
+
blksprs/ops/transpose.py,sha256=jxzFFffrj4S_9tiCrwwUMdz6EA98o1dziWXjlqb64a4,6859
|
|
15
|
+
blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
|
|
16
|
+
blksprs/utils/blksprs_tensor.py,sha256=VjplBgDhnf9sxf-1R5feA0xp5FDCDdaeZmCeoIRdCnc,151
|
|
17
|
+
blksprs/utils/tools.py,sha256=S3836Zuc-BMigv-5mLTjRznCzuaF6oYW-Ir9zzUnr3o,655
|
|
18
|
+
blksprs/utils/validation.py,sha256=WzihRPibXYzss3PMkhDt5_d3Q3NHA_d1TzTz3CoGPGg,4136
|
|
19
|
+
blksprs-1.8.2.dist-info/METADATA,sha256=Zoc860mYmFss7v5ChNoi9407v1qDo_ecc6JUWCvaesg,8009
|
|
20
|
+
blksprs-1.8.2.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
|
21
|
+
blksprs-1.8.2.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
|
|
22
|
+
blksprs-1.8.2.dist-info/RECORD,,
|
blksprs-1.8.1.dist-info/RECORD
DELETED
|
@@ -1,21 +0,0 @@
|
|
|
1
|
-
blksprs/__init__.py,sha256=np0msosWMaZNVVfuFGt8rE6HZURyIald391dKAs1dSQ,1093
|
|
2
|
-
blksprs/experimental/distribution_mdi.py,sha256=jE_SbB0SMGPcCoeM0699cceWAikBkBpGM_8Oo5A7Ets,20083
|
|
3
|
-
blksprs/layouting/distribution_layout.py,sha256=nCR3fCF6eNpi97DI6MMLF2hq_9Lwyo6_aUCIZiJfiX4,4170
|
|
4
|
-
blksprs/layouting/sparsity_layout.py,sha256=JNwbJ4L-418zCiCbt-vLfZ0xU7ReP0zr_tLHs_dytRA,9686
|
|
5
|
-
blksprs/misc/broadcast_ops.py,sha256=0RLnLMYV7GAPI2YL8RotcxjIUSBZKGxdVcsGaJFeL_I,5327
|
|
6
|
-
blksprs/misc/exp.py,sha256=cdF0s93Q9iucIXuEE3howsB0N6D60xgvem7C-a-yiGI,3704
|
|
7
|
-
blksprs/misc/partitioning.py,sha256=nBRZzfi3XYAhDLEBzYflQkvGa3MIZ-qNeIlrZ16k44g,7533
|
|
8
|
-
blksprs/misc/row_wise.py,sha256=0vDJA8uCocmebSIPIbFeND5_PQIE10pUj3DBOQXlTvE,16888
|
|
9
|
-
blksprs/ops/conversion.py,sha256=9xVdCrj38m1cMh43LQs-GrXZ5pNRjhQyKx6paaw3C6A,21898
|
|
10
|
-
blksprs/ops/distribution.py,sha256=V3TK5SlNT_JdGHNaDNl-U4U5vwAYsgkAOg4eTmYxbuA,16877
|
|
11
|
-
blksprs/ops/matmul.py,sha256=uqVe6Dz2aaCbCglM1uS2eRHVKh7PQcuecaIBWFubPEw,11256
|
|
12
|
-
blksprs/ops/repeat.py,sha256=OSsa2rj6BHL3Kedfu3wr0D82mn4HmbJ1l7XEmT-6ehg,14423
|
|
13
|
-
blksprs/ops/softmax.py,sha256=5nAgeT68nucgOugjtCy1aBIMa7Kyk1KNN-j8fgmeVuk,11996
|
|
14
|
-
blksprs/ops/transpose.py,sha256=67pDdCEb7r-Xifupl82fBKAYsxKcCUDy--cPPfduRvU,6761
|
|
15
|
-
blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
|
|
16
|
-
blksprs/utils/tools.py,sha256=S3836Zuc-BMigv-5mLTjRznCzuaF6oYW-Ir9zzUnr3o,655
|
|
17
|
-
blksprs/utils/validation.py,sha256=WzihRPibXYzss3PMkhDt5_d3Q3NHA_d1TzTz3CoGPGg,4136
|
|
18
|
-
blksprs-1.8.1.dist-info/METADATA,sha256=UDXUjS8PHyD4Zm-gWF4maXzY1k2SjKHMQllu-uOwLIA,8009
|
|
19
|
-
blksprs-1.8.1.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
|
|
20
|
-
blksprs-1.8.1.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
|
|
21
|
-
blksprs-1.8.1.dist-info/RECORD,,
|
|
File without changes
|