blksprs 1.10.2__py3-none-any.whl → 2.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +2 -6
- blksprs/layouting/distribution_layout.py +32 -25
- blksprs/layouting/sparsity_layout.py +65 -52
- blksprs/ops/conversion.py +423 -374
- blksprs/ops/distribution.py +403 -335
- blksprs/ops/flow.py +135 -83
- blksprs/ops/matmul.py +221 -187
- blksprs/ops/misc/broadcast_ops.py +53 -35
- blksprs/ops/misc/row_wise.py +151 -91
- blksprs/ops/partitioning.py +136 -89
- blksprs/ops/repeat.py +115 -108
- blksprs/ops/softmax.py +244 -208
- blksprs/ops/transpose.py +69 -131
- blksprs/utils/benchmarking.py +3 -3
- blksprs/utils/tools.py +31 -4
- blksprs/utils/validation.py +0 -14
- {blksprs-1.10.2.dist-info → blksprs-2.0rc1.dist-info}/METADATA +42 -36
- blksprs-2.0rc1.dist-info/RECORD +22 -0
- {blksprs-1.10.2.dist-info → blksprs-2.0rc1.dist-info}/WHEEL +1 -1
- blksprs/ops/misc/exp.py +0 -104
- blksprs/utils/layout_utils.py +0 -17
- blksprs-1.10.2.dist-info/RECORD +0 -24
- {blksprs-1.10.2.dist-info → blksprs-2.0rc1.dist-info}/top_level.txt +0 -0
blksprs/ops/transpose.py
CHANGED
|
@@ -1,16 +1,15 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
import triton
|
|
3
2
|
from torch import Tensor
|
|
4
|
-
from
|
|
3
|
+
from torch._library import triton_op
|
|
5
4
|
|
|
5
|
+
from blksprs.ops.flow import flow_pull_forward
|
|
6
6
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
7
|
-
from blksprs.utils.tools import get_triton_block_size, stride
|
|
8
7
|
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
|
|
9
|
-
validate_sparsity, validate_sparsity_block_size
|
|
8
|
+
validate_sparsity, validate_sparsity_block_size
|
|
10
9
|
|
|
11
10
|
|
|
12
|
-
def transpose(x: BlksprsTensor, sparsity_layout: Tensor,
|
|
13
|
-
|
|
11
|
+
def transpose(x: BlksprsTensor, sparsity_layout: Tensor,
|
|
12
|
+
sparsity_block_size: int, lut: dict = None) -> (BlksprsTensor, Tensor):
|
|
14
13
|
"""Transposes a block-sparse tensor in compressed form.
|
|
15
14
|
|
|
16
15
|
Note:
|
|
@@ -20,7 +19,7 @@ def transpose(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: in
|
|
|
20
19
|
x (BlksprsTensor): A block-sparse tensor in compressed form.
|
|
21
20
|
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
22
21
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
23
|
-
|
|
22
|
+
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
24
23
|
|
|
25
24
|
Returns:
|
|
26
25
|
BlksprsTensor: The transposed block-sparse tensor in compressed form.
|
|
@@ -28,133 +27,72 @@ def transpose(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: in
|
|
|
28
27
|
|
|
29
28
|
"""
|
|
30
29
|
x = x.contiguous()
|
|
30
|
+
x_t = x.transpose(-1, -2).contiguous()
|
|
31
31
|
|
|
32
32
|
validate_dimensions(x)
|
|
33
33
|
validate_contiguous(x)
|
|
34
34
|
validate_device(x)
|
|
35
35
|
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
36
36
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
@staticmethod
|
|
101
|
-
def backward(ctx, grad_output):
|
|
102
|
-
sparsity_layout = ctx.saved_tensors[0]
|
|
103
|
-
sparsity_block_size = ctx.sparsity_block_size
|
|
104
|
-
triton_block_size = ctx.triton_block_size
|
|
105
|
-
|
|
106
|
-
return transpose(grad_output, sparsity_layout, sparsity_block_size, triton_block_size)[
|
|
107
|
-
0], None, None, None, None, None, None
|
|
108
|
-
|
|
109
|
-
@staticmethod
|
|
110
|
-
@triton.jit
|
|
111
|
-
def kernel_blocksparse_transpose(x,
|
|
112
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
113
|
-
s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
|
|
114
|
-
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
115
|
-
r_lut,
|
|
116
|
-
o,
|
|
117
|
-
o_b, o_b_s,
|
|
118
|
-
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
119
|
-
# Get triton block indices
|
|
120
|
-
pid_blk = tl.program_id(axis=0)
|
|
121
|
-
pid_row = tl.program_id(axis=1)
|
|
122
|
-
pid_col = tl.program_id(axis=2)
|
|
123
|
-
|
|
124
|
-
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
125
|
-
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
126
|
-
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
127
|
-
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
128
|
-
|
|
129
|
-
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
130
|
-
spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
|
|
131
|
-
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
132
|
-
|
|
133
|
-
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
134
|
-
spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
|
|
135
|
-
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
136
|
-
|
|
137
|
-
# Get reverse sparsity index
|
|
138
|
-
rev_idx_spa_idx = (spa_bat * s_l_b_s +
|
|
139
|
-
spa_row * s_l_r_s +
|
|
140
|
-
spa_col * s_l_c_s)
|
|
141
|
-
rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_b * s_l_b_s)
|
|
142
|
-
rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
143
|
-
|
|
144
|
-
if rev_idx_spa == -1:
|
|
145
|
-
tl.device_assert(False)
|
|
146
|
-
return
|
|
147
|
-
|
|
148
|
-
blk_x_idx = (rev_idx_spa * x_b_s +
|
|
149
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
150
|
-
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
151
|
-
blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
|
|
152
|
-
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
153
|
-
|
|
154
|
-
blk_x_t = tl.trans(blk_x)
|
|
155
|
-
|
|
156
|
-
blk_o_idx = (pid_blk * o_b_s +
|
|
157
|
-
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
158
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
159
|
-
blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
|
|
160
|
-
tl.store(o + blk_o_idx, blk_x_t, mask=blk_o_msk)
|
|
37
|
+
|
|
38
|
+
lut = transpose_build_lut(lut, sparsity_layout)
|
|
39
|
+
|
|
40
|
+
return BlksprsTensor(transpose_forward(x_t, lut["sparsity_layout_t"],
|
|
41
|
+
lut["sparsity_lut"], lut["sparsity_reverse_lut"],
|
|
42
|
+
sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_t"]
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@triton_op("blksprs::transpose", mutates_args={})
|
|
46
|
+
def transpose_forward(x: Tensor, sparsity_layout_o: Tensor,
|
|
47
|
+
sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
48
|
+
sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
49
|
+
return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
|
|
50
|
+
sparsity_block_size, n_sparse_blocks)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def transpose_backward(ctx, grad_output):
|
|
54
|
+
sparsity_layout = ctx.saved_tensors[0]
|
|
55
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
56
|
+
|
|
57
|
+
return transpose(grad_output, sparsity_layout, sparsity_block_size)[
|
|
58
|
+
0], None, None, None, None, None
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def transpose_build_lut(lut: dict, sparsity_layout: Tensor):
|
|
62
|
+
if lut is None:
|
|
63
|
+
lut = dict()
|
|
64
|
+
|
|
65
|
+
if "sparsity_layout_t" not in lut:
|
|
66
|
+
sparsity_layout_t = sparsity_layout.transpose(-1, -2).contiguous()
|
|
67
|
+
lut["sparsity_layout_t"] = sparsity_layout_t
|
|
68
|
+
|
|
69
|
+
if "sparsity_lut" not in lut:
|
|
70
|
+
sparsity_lut = torch.nonzero(lut["sparsity_layout_t"]).contiguous()
|
|
71
|
+
lut["sparsity_lut"] = sparsity_lut
|
|
72
|
+
|
|
73
|
+
if "sparsity_reverse_lut" not in lut:
|
|
74
|
+
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
|
75
|
+
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
76
|
+
(sparsity_layout_flat == 1) -
|
|
77
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
78
|
+
.reshape(sparsity_layout.size()).transpose(-1, -2).contiguous().reshape(-1))
|
|
79
|
+
lut["sparsity_reverse_lut"] = sparsity_reverse_lut
|
|
80
|
+
|
|
81
|
+
if "n_sparse_blocks" not in lut:
|
|
82
|
+
n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
|
|
83
|
+
lut["n_sparse_blocks"] = n_sparse_blocks
|
|
84
|
+
|
|
85
|
+
validate_contiguous(lut["sparsity_layout_t"], lut["sparsity_lut"], lut["sparsity_reverse_lut"])
|
|
86
|
+
|
|
87
|
+
return lut
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
# noinspection PyUnusedLocal
|
|
91
|
+
def transpose_setup_context(ctx, inputs, output):
|
|
92
|
+
(_, sparsity_layout_o, _, _, sparsity_block_size, _) = inputs
|
|
93
|
+
|
|
94
|
+
ctx.save_for_backward(sparsity_layout_o)
|
|
95
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
transpose_forward.register_autograd(transpose_backward, setup_context=transpose_setup_context)
|
blksprs/utils/benchmarking.py
CHANGED
|
@@ -5,13 +5,13 @@ from matplotlib import pyplot as plt
|
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
def benchmark(method_labels: list[str], func_input_generator: Callable,
|
|
8
|
-
matrix_sizes: list[int], sparsity_block_sizes: list[int],
|
|
8
|
+
matrix_sizes: list[int], sparsity_block_sizes: list[int],
|
|
9
9
|
*funcs_test_subject: Callable, y_lim_top: int = None):
|
|
10
10
|
quantiles = [0.5, 0.2, 0.8]
|
|
11
11
|
results = {}
|
|
12
12
|
|
|
13
|
-
for matrix_size, sparsity_block_size
|
|
14
|
-
arguments = func_input_generator(matrix_size, sparsity_block_size
|
|
13
|
+
for matrix_size, sparsity_block_size in zip(matrix_sizes, sparsity_block_sizes):
|
|
14
|
+
arguments = func_input_generator(matrix_size, sparsity_block_size)
|
|
15
15
|
|
|
16
16
|
for i, func_test_subject in enumerate(funcs_test_subject):
|
|
17
17
|
func_ms_avg, func_ms_min, func_ms_max = triton.testing.do_bench(
|
blksprs/utils/tools.py
CHANGED
|
@@ -1,6 +1,10 @@
|
|
|
1
1
|
import torch
|
|
2
|
+
import triton
|
|
2
3
|
from torch import Tensor, Size
|
|
3
4
|
|
|
5
|
+
# Capture scalar outputs for JIT compilation
|
|
6
|
+
torch._dynamo.config.capture_scalar_outputs = True
|
|
7
|
+
|
|
4
8
|
|
|
5
9
|
def do_shape_blocksparse(x: Tensor):
|
|
6
10
|
if x.dim() == 3:
|
|
@@ -16,10 +20,6 @@ def undo_shape_blocksparse(x: Tensor, shape: Size):
|
|
|
16
20
|
return x.reshape((*shape[:-2], *x.shape[-2:]))
|
|
17
21
|
|
|
18
22
|
|
|
19
|
-
def get_triton_block_size(sparsity_block_size: int, limit: int = 128):
|
|
20
|
-
return min(sparsity_block_size, limit)
|
|
21
|
-
|
|
22
|
-
|
|
23
23
|
def stride(x: Tensor):
|
|
24
24
|
if x.dim() == 2:
|
|
25
25
|
return x.size(1), 1
|
|
@@ -27,3 +27,30 @@ def stride(x: Tensor):
|
|
|
27
27
|
return x.size(1) * x.size(2), x.size(2), 1
|
|
28
28
|
else:
|
|
29
29
|
raise NotImplementedError
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@torch.compile
|
|
33
|
+
def get_autotune_configs():
|
|
34
|
+
configs = []
|
|
35
|
+
config_parameters = [
|
|
36
|
+
(16, 3, 8),
|
|
37
|
+
(16, 4, 4),
|
|
38
|
+
(16, 5, 2),
|
|
39
|
+
|
|
40
|
+
(32, 3, 8),
|
|
41
|
+
(32, 4, 4),
|
|
42
|
+
(32, 5, 2),
|
|
43
|
+
|
|
44
|
+
(64, 3, 8),
|
|
45
|
+
(64, 4, 4),
|
|
46
|
+
(64, 5, 2),
|
|
47
|
+
|
|
48
|
+
(128, 3, 8),
|
|
49
|
+
(128, 4, 4),
|
|
50
|
+
(128, 5, 2),
|
|
51
|
+
]
|
|
52
|
+
|
|
53
|
+
for block_size, num_stages, num_warps in config_parameters:
|
|
54
|
+
configs.append(triton.Config({"TRITON_BLOCK_SIZE": block_size}, num_stages=num_stages, num_warps=num_warps))
|
|
55
|
+
|
|
56
|
+
return configs
|
blksprs/utils/validation.py
CHANGED
|
@@ -104,20 +104,6 @@ def validate_sparsity_block_size(sparsity_block_size: int, *tensors):
|
|
|
104
104
|
raise ValueError("Tensor sizes must be divisible by sparsity block size")
|
|
105
105
|
|
|
106
106
|
|
|
107
|
-
def validate_triton_block_size(triton_block_size: int, sparsity_block_size: int):
|
|
108
|
-
if _check_skip_validation():
|
|
109
|
-
return
|
|
110
|
-
|
|
111
|
-
if triton_block_size is None:
|
|
112
|
-
return
|
|
113
|
-
|
|
114
|
-
if not (triton_block_size & (triton_block_size - 1)) == 0:
|
|
115
|
-
raise ValueError("Triton block size must be a power of 2")
|
|
116
|
-
|
|
117
|
-
if triton_block_size > sparsity_block_size:
|
|
118
|
-
raise ValueError("Triton block size cannot be larger than sparsity block size")
|
|
119
|
-
|
|
120
|
-
|
|
121
107
|
def _check_skip_validation():
|
|
122
108
|
return not VALIDATION
|
|
123
109
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version:
|
|
3
|
+
Version: 2.0rc1
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -14,9 +14,8 @@ Requires-Dist: pytest; extra == "test"
|
|
|
14
14
|
Requires-Dist: pytest-xdist; extra == "test"
|
|
15
15
|
Requires-Dist: pytest-cov; extra == "test"
|
|
16
16
|
Requires-Dist: coverage; extra == "test"
|
|
17
|
+
Requires-Dist: build; extra == "test"
|
|
17
18
|
Requires-Dist: matplotlib; extra == "test"
|
|
18
|
-
Provides-Extra: build
|
|
19
|
-
Requires-Dist: build; extra == "build"
|
|
20
19
|
|
|
21
20
|
# blksprs
|
|
22
21
|
|
|
@@ -25,6 +24,13 @@ Requires-Dist: build; extra == "build"
|
|
|
25
24
|
|
|
26
25
|
## Overview
|
|
27
26
|
|
|
27
|
+
### News
|
|
28
|
+
|
|
29
|
+
🎉 ***Version 2.0 released***. blksprs now supports kernel auto-tuning, JIT compilation, specification of pre-calculated
|
|
30
|
+
LUTs, and makes use of `torch.library.triton_op()`!
|
|
31
|
+
|
|
32
|
+
---
|
|
33
|
+
|
|
28
34
|
A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
|
|
29
35
|
|
|
30
36
|
Currently supported operations (includes gradient calculation):
|
|
@@ -52,23 +58,25 @@ These include, e.g.,
|
|
|
52
58
|
Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to
|
|
53
59
|
match.
|
|
54
60
|
|
|
55
|
-
Further helpful operations (included in the ``bs.ops.misc`` module) that do **not** support gradient calculation
|
|
61
|
+
Further helpful operations (included in the ``bs.ops.misc`` module) that do **not** support gradient calculation
|
|
62
|
+
include:
|
|
56
63
|
|
|
57
64
|
- Row-wise sum, max, addition, and subtraction
|
|
58
65
|
- Broadcast addition and subtraction between slices
|
|
59
66
|
|
|
60
|
-
Furthermore, the library provides a set of utility functions
|
|
67
|
+
Furthermore, the library provides a set of utility functions
|
|
61
68
|
|
|
62
69
|
- for the creation of sparsity layouts based on existing
|
|
63
|
-
dense tensors and for the scatter operation (module ``bs.layouting``),
|
|
70
|
+
dense tensors and for the scatter operation (module ``bs.layouting``),
|
|
64
71
|
- for the application of ``nn.Linear``, ``nn.Dropout``, and ``nn.LayerNorm`` layers to block-sparse tensors,
|
|
65
72
|
- as well as utility functions to ensure correct input dimensionality, and validate input (module ``bs.utils``).
|
|
66
73
|
|
|
67
|
-
_* see the [Roadmap](#roadmap) section for more information_
|
|
74
|
+
_* see the [Roadmap](#roadmap) section for more information_
|
|
68
75
|
|
|
69
76
|
## Installation
|
|
70
77
|
|
|
71
|
-
Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible
|
|
78
|
+
Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible
|
|
79
|
+
with
|
|
72
80
|
the Linux platform**.
|
|
73
81
|
Keep track of this [issue](https://github.com/triton-lang/triton/issues/1640) for updates.
|
|
74
82
|
|
|
@@ -78,8 +86,8 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
|
|
|
78
86
|
|
|
79
87
|
### Dependencies
|
|
80
88
|
|
|
81
|
-
- [PyTorch](https://pytorch.org/) (built with v2.
|
|
82
|
-
- _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.2.
|
|
89
|
+
- [PyTorch](https://pytorch.org/) (built with v2.6)
|
|
90
|
+
- _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.2.4)_
|
|
83
91
|
- _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
|
|
84
92
|
|
|
85
93
|
## Changelog
|
|
@@ -89,12 +97,14 @@ See [`CHANGELOG.md`](https://github.com/FelixSchoen/blksprs/blob/main/CHANGELOG.
|
|
|
89
97
|
## Roadmap
|
|
90
98
|
|
|
91
99
|
Note that since this library covers all our current needs it is in a **bugfix-only** state.
|
|
92
|
-
This means that there are no plans to add new features, e.g., support for dimension specification of the ``split`` and
|
|
100
|
+
This means that there are no plans to add new features, e.g., support for dimension specification of the ``split`` and
|
|
101
|
+
``merge`` operations.
|
|
93
102
|
We will continue to maintain the library and fix any issues that arise.
|
|
94
103
|
Should you find any bugs please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
|
|
95
104
|
We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
|
|
96
105
|
|
|
97
|
-
It might be that this changes with future projects, but as of
|
|
106
|
+
It might be that this changes with future projects, but as of March 2025, we are content with the current state of the
|
|
107
|
+
library.
|
|
98
108
|
|
|
99
109
|
## Usage
|
|
100
110
|
|
|
@@ -120,10 +130,6 @@ def test_readme():
|
|
|
120
130
|
# Must be a power of two, greater than or equal to 16 for matmul, and divide m, n, and k
|
|
121
131
|
sparsity_block_size = 16
|
|
122
132
|
|
|
123
|
-
# Must be a power of two and smaller than or equal to sparsity_block_size
|
|
124
|
-
# If it is set to ``none`` a value will be chosen automatically
|
|
125
|
-
triton_block_size = None
|
|
126
|
-
|
|
127
133
|
# Initialise random (dense) tensors
|
|
128
134
|
x = torch.randn(size=(b, h, m, k), device="cuda")
|
|
129
135
|
y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()
|
|
@@ -133,53 +139,53 @@ def test_readme():
|
|
|
133
139
|
y_dense, y_shape_original = bs.utils.do_shape_blocksparse(y)
|
|
134
140
|
|
|
135
141
|
# Create sparsity layouts from existing tensors
|
|
136
|
-
sparsity_layout_x = bs.layouting.build_sparsity_layout(x_dense, sparsity_block_size
|
|
137
|
-
|
|
138
|
-
sparsity_layout_y = bs.layouting.build_sparsity_layout(y_dense, sparsity_block_size,
|
|
139
|
-
triton_block_size=triton_block_size)
|
|
142
|
+
sparsity_layout_x = bs.layouting.build_sparsity_layout(x_dense, sparsity_block_size)
|
|
143
|
+
sparsity_layout_y = bs.layouting.build_sparsity_layout(y_dense, sparsity_block_size)
|
|
140
144
|
|
|
141
145
|
# Create random sparsity layout for output tensor
|
|
142
146
|
sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)
|
|
143
147
|
|
|
144
148
|
# Convert tensors to sparse tensors for matrix multiplication
|
|
145
|
-
x_sparse = bs.to_sparse(x_dense, sparsity_layout_x, sparsity_block_size
|
|
146
|
-
y_sparse = bs.to_sparse(y_dense, sparsity_layout_y, sparsity_block_size
|
|
149
|
+
x_sparse = bs.ops.to_sparse(x_dense, sparsity_layout_x, sparsity_block_size)
|
|
150
|
+
y_sparse = bs.ops.to_sparse(y_dense, sparsity_layout_y, sparsity_block_size)
|
|
151
|
+
|
|
152
|
+
# As of version 2.0, blksprs supports JIT compilation
|
|
153
|
+
matmul_compiled = torch.compile(bs.ops.matmul)
|
|
147
154
|
|
|
148
155
|
# Perform matrix multiplication
|
|
149
|
-
o_sparse =
|
|
150
|
-
|
|
151
|
-
|
|
156
|
+
o_sparse = matmul_compiled(x_sparse, sparsity_layout_x,
|
|
157
|
+
y_sparse, sparsity_layout_y,
|
|
158
|
+
sparsity_layout_o, sparsity_block_size)
|
|
152
159
|
|
|
153
160
|
# Apply element-wise operation
|
|
154
161
|
o_sparse = torch.add(o_sparse, 1)
|
|
155
162
|
|
|
156
|
-
o_dense = bs.to_dense(o_sparse, sparsity_layout_o, sparsity_block_size
|
|
163
|
+
o_dense = bs.ops.to_dense(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
157
164
|
|
|
158
165
|
# Sanity check
|
|
159
166
|
o_torch = torch.matmul(x_dense, y_dense)
|
|
160
167
|
o_torch = torch.add(o_torch, 1)
|
|
161
168
|
|
|
162
169
|
# Perform round trip to set sparse blocks to 0
|
|
163
|
-
o_torch_round_trip = bs.to_dense(
|
|
164
|
-
bs.to_sparse(o_torch, sparsity_layout_o, sparsity_block_size
|
|
165
|
-
sparsity_layout_o, sparsity_block_size, fill_value=0
|
|
170
|
+
o_torch_round_trip = bs.ops.to_dense(
|
|
171
|
+
bs.ops.to_sparse(o_torch, sparsity_layout_o, sparsity_block_size),
|
|
172
|
+
sparsity_layout_o, sparsity_block_size, fill_value=0)
|
|
166
173
|
|
|
167
174
|
# Assert that the output is correct
|
|
168
175
|
assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2) # Note that small numerical differences are expected
|
|
169
176
|
|
|
170
177
|
# Assert that the output has the correct sparsity layout
|
|
171
|
-
actual_sparsity_layout_o = bs.layouting.build_sparsity_layout(o_dense, sparsity_block_size
|
|
172
|
-
triton_block_size=triton_block_size)
|
|
178
|
+
actual_sparsity_layout_o = bs.layouting.build_sparsity_layout(o_dense, sparsity_block_size)
|
|
173
179
|
assert torch.allclose(actual_sparsity_layout_o.to(torch.int), sparsity_layout_o)
|
|
174
180
|
|
|
175
181
|
# Convert output tensor back to original shape
|
|
176
182
|
o = bs.utils.undo_shape_blocksparse(o_dense, x_shape_original)
|
|
177
183
|
|
|
178
184
|
# Other available functions
|
|
179
|
-
bs.transpose(o_sparse, sparsity_layout_o, sparsity_block_size
|
|
180
|
-
bs.softmax(o_sparse, sparsity_layout_o, sparsity_block_size
|
|
181
|
-
bs.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size
|
|
182
|
-
bs.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size
|
|
185
|
+
bs.ops.transpose(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
186
|
+
bs.ops.softmax(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
187
|
+
bs.ops.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
188
|
+
bs.ops.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
183
189
|
|
|
184
190
|
|
|
185
191
|
def _get_random_sparsity_layout(b, m, n, sparsity_block_size, sparsity_percentage):
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
blksprs/__init__.py,sha256=OHfpwJCZWGUfpT-DVfC1YSaeZl4aCMNt9CrzMPymywU,1577
|
|
2
|
+
blksprs/layouting/distribution_layout.py,sha256=0glIteoY5oDkiEu5rjLIC-BB_oC4sa3rFWVkohsAG00,5329
|
|
3
|
+
blksprs/layouting/sparsity_layout.py,sha256=UzMcdW7l4zoiLB_LMEbBR1JBdqVSgINDGYvoCYIOulk,10283
|
|
4
|
+
blksprs/ops/conversion.py,sha256=_JKOovDZOmYJLcurJGhgNt5iQB9kOKp3fufFxD8QCZs,22204
|
|
5
|
+
blksprs/ops/distribution.py,sha256=5gE19kPQGQljVbRpDZeqNaOe8ehRhxdQS7PiJp6mMug,21352
|
|
6
|
+
blksprs/ops/flow.py,sha256=G8L_sMAWIM77gv-YLJtyutEzXqyaaofnSX2QKvmDr44,8409
|
|
7
|
+
blksprs/ops/matmul.py,sha256=YAurJcXa_39gRdh2nWUOmbhm8h99arLoO-SN-l134II,11879
|
|
8
|
+
blksprs/ops/partitioning.py,sha256=AooYZOw0oZgA9zXSu09O60hkJcnpWT1OTosr2T2wdQo,9700
|
|
9
|
+
blksprs/ops/repeat.py,sha256=qty0qIFcfiWzROV2A2FB2KiPCC2Pe4q5TwJyGuDBAQE,8839
|
|
10
|
+
blksprs/ops/softmax.py,sha256=eaZ8pfCpNZCX6Gk5Tk-lhNIrBQDhvfHqNNPltqxp91k,12793
|
|
11
|
+
blksprs/ops/transpose.py,sha256=30pGCSjZs42Sg6TEXUdJNCDgmlN1n8aN88uNbV5wOtA,3941
|
|
12
|
+
blksprs/ops/misc/broadcast_ops.py,sha256=lZ5bBIftUKffzeYz77SWB1xmtZTRGMvjF-tG9rqkOXA,6018
|
|
13
|
+
blksprs/ops/misc/row_wise.py,sha256=iwOrHU8HiJGxq2hEmgJGZ60asRm72WLi10-PrpNrdeQ,19532
|
|
14
|
+
blksprs/utils/benchmarking.py,sha256=dLabDscTFn5NkmOI1g7DnKeTneUYW3RIVv9MDF-8BKc,1271
|
|
15
|
+
blksprs/utils/blksprs_tensor.py,sha256=pfoz59aJixj_fIoFx76ySiygwRQUemmgjMKepZ2c4j0,244
|
|
16
|
+
blksprs/utils/processing.py,sha256=WLuMJQ8v-YovXwcDjhlDn3N31WMZXrtyeeyKSgq_zn4,3642
|
|
17
|
+
blksprs/utils/tools.py,sha256=RL18P4NAj7d8gXTTKbMZt4SHCynsw1wPu9yvlrnBQlo,1220
|
|
18
|
+
blksprs/utils/validation.py,sha256=_Ee6bqu7CxdYLFSy4WZOFoXJgd0p_RBMumCwGCk2_Hw,3763
|
|
19
|
+
blksprs-2.0rc1.dist-info/METADATA,sha256=zXzVOvuwgYSyx-lCBycdFvRUmHUD_qYbK8sFkKWZnp8,8601
|
|
20
|
+
blksprs-2.0rc1.dist-info/WHEEL,sha256=1tXe9gY0PYatrMPMDd6jXqjfpz_B-Wqm32CPfRC58XU,91
|
|
21
|
+
blksprs-2.0rc1.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
|
|
22
|
+
blksprs-2.0rc1.dist-info/RECORD,,
|
blksprs/ops/misc/exp.py
DELETED
|
@@ -1,104 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
import triton
|
|
3
|
-
from torch import Tensor
|
|
4
|
-
from triton import language as tl
|
|
5
|
-
|
|
6
|
-
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
7
|
-
from blksprs.utils.tools import get_triton_block_size, stride
|
|
8
|
-
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
9
|
-
validate_sparsity_block_size, validate_triton_block_size
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
def exp(x: BlksprsTensor, sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
|
|
13
|
-
"""Applies the element-wise exponential function to a block-sparse tensor.
|
|
14
|
-
|
|
15
|
-
Note:
|
|
16
|
-
This operation does not consider sparse blocks, i.e., these will not be set to ``e^0``.
|
|
17
|
-
Consider this when converting back to tensors in regular form.
|
|
18
|
-
|
|
19
|
-
Args:
|
|
20
|
-
x (BlksprsTensor): A block-sparse tensor in compressed form.
|
|
21
|
-
sparsity_block_size (int): The size of the sparsity blocks.
|
|
22
|
-
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
23
|
-
|
|
24
|
-
Returns:
|
|
25
|
-
BlksprsTensor: The exponential function applied to all elements of the input tensor as a block-sparse tensor in
|
|
26
|
-
compressed form.
|
|
27
|
-
|
|
28
|
-
"""
|
|
29
|
-
x = x.contiguous()
|
|
30
|
-
|
|
31
|
-
validate_dimensions(x)
|
|
32
|
-
validate_contiguous(x)
|
|
33
|
-
validate_device(x)
|
|
34
|
-
validate_sparsity_block_size(sparsity_block_size, x)
|
|
35
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
36
|
-
|
|
37
|
-
return BlksprsTensor(_BlocksparseExp.apply(x, sparsity_block_size, triton_block_size))
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
class _BlocksparseExp(torch.autograd.Function):
|
|
41
|
-
|
|
42
|
-
@staticmethod
|
|
43
|
-
def forward(ctx, x: Tensor, sparsity_block_size: int, triton_block_size: int) -> Tensor:
|
|
44
|
-
output = torch.empty_like(x)
|
|
45
|
-
|
|
46
|
-
x_b, x_r, x_c = x.shape
|
|
47
|
-
x_b_s, x_r_s, x_c_s = stride(x)
|
|
48
|
-
o_b, o_r, o_c = output.shape
|
|
49
|
-
o_b_s, o_r_s, o_c_s = stride(output)
|
|
50
|
-
|
|
51
|
-
if triton_block_size is None:
|
|
52
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
53
|
-
|
|
54
|
-
triton_grid = lambda meta: [o_b,
|
|
55
|
-
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
56
|
-
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
57
|
-
|
|
58
|
-
(_BlocksparseExp.kernel_blocksparse_exp[triton_grid]
|
|
59
|
-
(x,
|
|
60
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
61
|
-
output,
|
|
62
|
-
o_b, o_b_s, o_r_s, o_c_s,
|
|
63
|
-
triton_block_size))
|
|
64
|
-
|
|
65
|
-
ctx.save_for_backward(output)
|
|
66
|
-
|
|
67
|
-
return output
|
|
68
|
-
|
|
69
|
-
@staticmethod
|
|
70
|
-
def backward(ctx, grad_output):
|
|
71
|
-
o = ctx.saved_tensors[0]
|
|
72
|
-
|
|
73
|
-
grad_x = torch.mul(grad_output, o)
|
|
74
|
-
|
|
75
|
-
return grad_x, None, None
|
|
76
|
-
|
|
77
|
-
@staticmethod
|
|
78
|
-
@triton.jit
|
|
79
|
-
def kernel_blocksparse_exp(x,
|
|
80
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
81
|
-
o,
|
|
82
|
-
o_b, o_b_s, o_r_s, o_c_s,
|
|
83
|
-
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
84
|
-
# Get triton block indices
|
|
85
|
-
pid_blk = tl.program_id(axis=0)
|
|
86
|
-
pid_row = tl.program_id(axis=1)
|
|
87
|
-
pid_col = tl.program_id(axis=2)
|
|
88
|
-
|
|
89
|
-
# Load block
|
|
90
|
-
blk_x_idx = ((pid_blk * x_b_s) +
|
|
91
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
92
|
-
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
93
|
-
blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
|
|
94
|
-
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
95
|
-
|
|
96
|
-
# Compute exp
|
|
97
|
-
buf = tl.exp(blk_x)
|
|
98
|
-
|
|
99
|
-
# Store block
|
|
100
|
-
blk_o_idx = ((pid_blk * o_b_s) +
|
|
101
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
102
|
-
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
103
|
-
blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
|
|
104
|
-
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
blksprs/utils/layout_utils.py
DELETED
|
@@ -1,17 +0,0 @@
|
|
|
1
|
-
import math
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
import triton
|
|
5
|
-
from torch import Tensor
|
|
6
|
-
from torch.xpu import device
|
|
7
|
-
from triton import language as tl
|
|
8
|
-
|
|
9
|
-
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
10
|
-
from blksprs.utils.tools import get_triton_block_size, stride
|
|
11
|
-
from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
|
|
12
|
-
validate_contiguous, validate_sparsity, validate_sparsity_block_size
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
def build_full_sparsity_layout(x: Tensor, sparsity_block_size: int) -> Tensor:
|
|
16
|
-
return torch.ones(size=(x.size(0), x.size(1) // sparsity_block_size, x.size(2) // sparsity_block_size),
|
|
17
|
-
dtype=torch.bool, device=x.device)
|