blksprs 1.11__py3-none-any.whl → 2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +4 -5
- blksprs/layouting/distribution_layout.py +64 -48
- blksprs/layouting/sparsity_layout.py +96 -72
- blksprs/ops/conversion.py +349 -338
- blksprs/ops/distribution.py +318 -294
- blksprs/ops/flow.py +113 -100
- blksprs/ops/matmul.py +187 -172
- blksprs/ops/misc/broadcast_ops.py +68 -53
- blksprs/ops/misc/row_wise.py +223 -176
- blksprs/ops/partitioning.py +140 -132
- blksprs/ops/repeat.py +118 -120
- blksprs/ops/softmax.py +240 -214
- blksprs/ops/transpose.py +55 -52
- blksprs/utils/autotuning.py +78 -0
- blksprs/utils/benchmarking.py +3 -3
- blksprs/utils/processing.py +2 -1
- blksprs/utils/tools.py +5 -6
- blksprs/utils/validation.py +22 -16
- {blksprs-1.11.dist-info → blksprs-2.0.dist-info}/METADATA +55 -36
- blksprs-2.0.dist-info/RECORD +23 -0
- {blksprs-1.11.dist-info → blksprs-2.0.dist-info}/WHEEL +1 -1
- blksprs/utils/layout_utils.py +0 -17
- blksprs-1.11.dist-info/RECORD +0 -23
- {blksprs-1.11.dist-info → blksprs-2.0.dist-info}/top_level.txt +0 -0
blksprs/ops/transpose.py
CHANGED
|
@@ -1,17 +1,16 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
import triton
|
|
3
2
|
from torch import Tensor
|
|
4
|
-
from
|
|
3
|
+
from torch._library import triton_op
|
|
5
4
|
|
|
6
|
-
from blksprs.ops.flow import
|
|
5
|
+
from blksprs.ops.flow import flow_pull_forward
|
|
7
6
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
8
|
-
from blksprs.utils.tools import get_triton_block_size, stride
|
|
9
7
|
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
|
|
10
|
-
validate_sparsity, validate_sparsity_block_size
|
|
8
|
+
validate_sparsity, validate_sparsity_block_size
|
|
11
9
|
|
|
12
10
|
|
|
13
|
-
|
|
14
|
-
|
|
11
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
12
|
+
def transpose(x: BlksprsTensor, sparsity_layout: Tensor,
|
|
13
|
+
sparsity_block_size: int, lut: dict = None) -> (BlksprsTensor, Tensor):
|
|
15
14
|
"""Transposes a block-sparse tensor in compressed form.
|
|
16
15
|
|
|
17
16
|
Note:
|
|
@@ -21,7 +20,6 @@ def transpose(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: in
|
|
|
21
20
|
x (BlksprsTensor): A block-sparse tensor in compressed form.
|
|
22
21
|
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
23
22
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
24
|
-
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
25
23
|
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
26
24
|
|
|
27
25
|
Returns:
|
|
@@ -30,68 +28,73 @@ def transpose(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: in
|
|
|
30
28
|
|
|
31
29
|
"""
|
|
32
30
|
x = x.contiguous()
|
|
33
|
-
x_t = x.transpose(-1, -2).contiguous()
|
|
34
31
|
|
|
35
32
|
validate_dimensions(x)
|
|
36
33
|
validate_contiguous(x)
|
|
37
34
|
validate_device(x)
|
|
38
35
|
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
39
36
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
40
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
41
37
|
|
|
42
|
-
lut =
|
|
38
|
+
lut = transpose_build_lut(lut, sparsity_layout)
|
|
43
39
|
|
|
44
|
-
return BlksprsTensor(
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
lut["n_sparse_blocks"], triton_block_size)), lut["sparsity_layout_t"]
|
|
40
|
+
return BlksprsTensor(transpose_forward(x, lut["sparsity_layout_t"],
|
|
41
|
+
lut["sparsity_lut"], lut["sparsity_reverse_lut"],
|
|
42
|
+
sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_t"]
|
|
48
43
|
|
|
49
44
|
|
|
50
|
-
|
|
45
|
+
@triton_op("blksprs::transpose_forward", mutates_args={})
|
|
46
|
+
def transpose_forward(x: Tensor, sparsity_layout_o: Tensor,
|
|
47
|
+
sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
48
|
+
sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
49
|
+
with torch.no_grad():
|
|
50
|
+
x_t = x.transpose(-1, -2).contiguous()
|
|
51
|
+
return flow_pull_forward(x_t, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
|
|
52
|
+
sparsity_block_size, n_sparse_blocks)
|
|
51
53
|
|
|
52
|
-
@staticmethod
|
|
53
|
-
def build_lut(lut: dict, sparsity_layout: Tensor):
|
|
54
|
-
if lut is None:
|
|
55
|
-
lut = dict()
|
|
56
54
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
55
|
+
def transpose_wrapper_backward(ctx, grad_output):
|
|
56
|
+
sparsity_layout = ctx.saved_tensors[0]
|
|
57
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
60
58
|
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
lut["sparsity_lut"] = sparsity_lut
|
|
59
|
+
return transpose(grad_output, sparsity_layout, sparsity_block_size)[
|
|
60
|
+
0], None, None, None, None, None
|
|
64
61
|
|
|
65
|
-
if "sparsity_reverse_lut" not in lut:
|
|
66
|
-
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
|
67
|
-
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
68
|
-
(sparsity_layout_flat == 1) -
|
|
69
|
-
(1 * (sparsity_layout_flat == 0)))
|
|
70
|
-
.reshape(sparsity_layout.size()).transpose(-1, -2).contiguous().reshape(-1))
|
|
71
|
-
lut["sparsity_reverse_lut"] = sparsity_reverse_lut
|
|
72
62
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
63
|
+
def transpose_build_lut(lut: dict, sparsity_layout: Tensor):
|
|
64
|
+
if lut is None:
|
|
65
|
+
lut = dict()
|
|
76
66
|
|
|
77
|
-
|
|
67
|
+
if "sparsity_layout_t" not in lut:
|
|
68
|
+
sparsity_layout_t = sparsity_layout.transpose(-1, -2).contiguous()
|
|
69
|
+
lut["sparsity_layout_t"] = sparsity_layout_t
|
|
78
70
|
|
|
79
|
-
|
|
71
|
+
if "sparsity_lut" not in lut:
|
|
72
|
+
sparsity_lut = torch.nonzero(lut["sparsity_layout_t"]).contiguous()
|
|
73
|
+
lut["sparsity_lut"] = sparsity_lut
|
|
80
74
|
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
75
|
+
if "sparsity_reverse_lut" not in lut:
|
|
76
|
+
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
|
77
|
+
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
78
|
+
(sparsity_layout_flat == 1) -
|
|
79
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
80
|
+
.reshape(sparsity_layout.size()).transpose(-1, -2).contiguous().reshape(-1))
|
|
81
|
+
lut["sparsity_reverse_lut"] = sparsity_reverse_lut
|
|
86
82
|
|
|
87
|
-
|
|
88
|
-
|
|
83
|
+
if "n_sparse_blocks" not in lut:
|
|
84
|
+
n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
|
|
85
|
+
lut["n_sparse_blocks"] = n_sparse_blocks
|
|
89
86
|
|
|
90
|
-
|
|
91
|
-
def backward(ctx, grad_output):
|
|
92
|
-
sparsity_layout = ctx.saved_tensors[0]
|
|
93
|
-
sparsity_block_size = ctx.sparsity_block_size
|
|
94
|
-
triton_block_size = ctx.triton_block_size
|
|
87
|
+
validate_contiguous(lut["sparsity_layout_t"], lut["sparsity_lut"], lut["sparsity_reverse_lut"])
|
|
95
88
|
|
|
96
|
-
|
|
97
|
-
|
|
89
|
+
return lut
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
# noinspection PyUnusedLocal
|
|
93
|
+
def transpose_setup_context(ctx, inputs, output):
|
|
94
|
+
(_, sparsity_layout_o, _, _, sparsity_block_size, _) = inputs
|
|
95
|
+
|
|
96
|
+
ctx.save_for_backward(sparsity_layout_o)
|
|
97
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
transpose_forward.register_autograd(transpose_wrapper_backward, setup_context=transpose_setup_context)
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
blksprs_autotune_mode = os.getenv("BLKSPRS_AUTOTUNE", "DEFAULT")
|
|
4
|
+
|
|
5
|
+
if blksprs_autotune_mode == "DEFAULT":
|
|
6
|
+
autotune_parameters = [
|
|
7
|
+
(16, 3, 8),
|
|
8
|
+
(16, 4, 4),
|
|
9
|
+
(16, 5, 2),
|
|
10
|
+
|
|
11
|
+
(32, 3, 8),
|
|
12
|
+
(32, 4, 4),
|
|
13
|
+
(32, 5, 2),
|
|
14
|
+
|
|
15
|
+
(64, 3, 8),
|
|
16
|
+
(64, 4, 4),
|
|
17
|
+
(64, 5, 2),
|
|
18
|
+
|
|
19
|
+
(128, 3, 8),
|
|
20
|
+
(128, 4, 4),
|
|
21
|
+
(128, 5, 2),
|
|
22
|
+
]
|
|
23
|
+
elif blksprs_autotune_mode == "TEST":
|
|
24
|
+
autotune_parameters = [
|
|
25
|
+
(16, 3, 8),
|
|
26
|
+
|
|
27
|
+
(32, 3, 8),
|
|
28
|
+
|
|
29
|
+
(64, 3, 8),
|
|
30
|
+
]
|
|
31
|
+
else:
|
|
32
|
+
raise NotImplementedError(f"Unknown autotune mode: {blksprs_autotune_mode}")
|
|
33
|
+
|
|
34
|
+
import torch
|
|
35
|
+
import triton
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def prune_autotune_configs(autotune_configs, kernel_args, **kwargs):
|
|
39
|
+
sparsity_block_size = kernel_args["sparsity_block_size"]
|
|
40
|
+
|
|
41
|
+
pruned_configs = []
|
|
42
|
+
|
|
43
|
+
for config in autotune_configs:
|
|
44
|
+
if config.kwargs["TRITON_BLOCK_SIZE"] <= sparsity_block_size:
|
|
45
|
+
pruned_configs.append(config)
|
|
46
|
+
|
|
47
|
+
assert len(pruned_configs) > 0, f"No valid autotune configs found for sparsity block size {sparsity_block_size}"
|
|
48
|
+
|
|
49
|
+
return pruned_configs
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def prune_autotune_configs_conversion(autotune_configs, kernel_args, **kwargs):
|
|
53
|
+
sparsity_block_size_from = kernel_args["sparsity_block_size_from"]
|
|
54
|
+
sparsity_block_size_to = kernel_args["sparsity_block_size_to"]
|
|
55
|
+
sparsity_block_size = min(sparsity_block_size_from, sparsity_block_size_to)
|
|
56
|
+
|
|
57
|
+
pruned_configs = []
|
|
58
|
+
|
|
59
|
+
for config in autotune_configs:
|
|
60
|
+
if config.kwargs["TRITON_BLOCK_SIZE"] <= sparsity_block_size:
|
|
61
|
+
pruned_configs.append(config)
|
|
62
|
+
|
|
63
|
+
assert len(pruned_configs) > 0, f"No valid autotune configs found for sparsity block size {sparsity_block_size}"
|
|
64
|
+
|
|
65
|
+
return pruned_configs
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@torch.compile
|
|
69
|
+
def get_autotune_configs():
|
|
70
|
+
global autotune_parameters
|
|
71
|
+
|
|
72
|
+
autotune_configs = []
|
|
73
|
+
|
|
74
|
+
for block_size, num_stages, num_warps in autotune_parameters:
|
|
75
|
+
autotune_configs.append(
|
|
76
|
+
triton.Config({"TRITON_BLOCK_SIZE": block_size}, num_stages=num_stages, num_warps=num_warps))
|
|
77
|
+
|
|
78
|
+
return autotune_configs
|
blksprs/utils/benchmarking.py
CHANGED
|
@@ -5,13 +5,13 @@ from matplotlib import pyplot as plt
|
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
def benchmark(method_labels: list[str], func_input_generator: Callable,
|
|
8
|
-
matrix_sizes: list[int], sparsity_block_sizes: list[int],
|
|
8
|
+
matrix_sizes: list[int], sparsity_block_sizes: list[int],
|
|
9
9
|
*funcs_test_subject: Callable, y_lim_top: int = None):
|
|
10
10
|
quantiles = [0.5, 0.2, 0.8]
|
|
11
11
|
results = {}
|
|
12
12
|
|
|
13
|
-
for matrix_size, sparsity_block_size
|
|
14
|
-
arguments = func_input_generator(matrix_size, sparsity_block_size
|
|
13
|
+
for matrix_size, sparsity_block_size in zip(matrix_sizes, sparsity_block_sizes):
|
|
14
|
+
arguments = func_input_generator(matrix_size, sparsity_block_size)
|
|
15
15
|
|
|
16
16
|
for i, func_test_subject in enumerate(funcs_test_subject):
|
|
17
17
|
func_ms_avg, func_ms_min, func_ms_max = triton.testing.do_bench(
|
blksprs/utils/processing.py
CHANGED
|
@@ -11,6 +11,7 @@ from blksprs.ops.repeat import repeat
|
|
|
11
11
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
12
12
|
|
|
13
13
|
|
|
14
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
14
15
|
def apply_torch_linear(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
15
16
|
linear: nn.Linear, bias: nn.Parameter = None) -> (BlksprsTensor, Tensor):
|
|
16
17
|
# Extract weight and bias
|
|
@@ -25,7 +26,7 @@ def apply_torch_linear(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block
|
|
|
25
26
|
|
|
26
27
|
# Apply weights
|
|
27
28
|
sparsity_layout_xw = build_sparsity_layout_matmul_fast(sparsity_layout, sparsity_layout_w_t)
|
|
28
|
-
xw = matmul(x, sparsity_layout, w_t_bs, sparsity_layout_w_t, sparsity_layout_xw, sparsity_block_size)
|
|
29
|
+
xw = matmul(x, sparsity_layout, BlksprsTensor(w_t_bs.to(x.dtype)), sparsity_layout_w_t, sparsity_layout_xw, sparsity_block_size)
|
|
29
30
|
interim = xw
|
|
30
31
|
|
|
31
32
|
# Apply bias
|
blksprs/utils/tools.py
CHANGED
|
@@ -1,25 +1,24 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
from torch import Tensor, Size
|
|
3
3
|
|
|
4
|
+
# Capture scalar outputs for JIT compilation
|
|
5
|
+
torch._dynamo.config.capture_scalar_outputs = True
|
|
4
6
|
|
|
5
|
-
|
|
7
|
+
|
|
8
|
+
def do_shape_blocksparse(x: Tensor) -> tuple[Tensor, Size]:
|
|
6
9
|
if x.dim() == 3:
|
|
7
10
|
return x.contiguous(), x.size()
|
|
8
11
|
|
|
9
12
|
return x.reshape(-1, x.size(-2), x.size(-1)).contiguous(), x.size()
|
|
10
13
|
|
|
11
14
|
|
|
12
|
-
def undo_shape_blocksparse(x: Tensor, shape: Size):
|
|
15
|
+
def undo_shape_blocksparse(x: Tensor, shape: Size | tuple[int, ...]) -> Tensor:
|
|
13
16
|
if x.shape[:-2] == shape[:-2]:
|
|
14
17
|
return x
|
|
15
18
|
|
|
16
19
|
return x.reshape((*shape[:-2], *x.shape[-2:]))
|
|
17
20
|
|
|
18
21
|
|
|
19
|
-
def get_triton_block_size(sparsity_block_size: int, limit: int = 128):
|
|
20
|
-
return min(sparsity_block_size, limit)
|
|
21
|
-
|
|
22
|
-
|
|
23
22
|
def stride(x: Tensor):
|
|
24
23
|
if x.dim() == 2:
|
|
25
24
|
return x.size(1), 1
|
blksprs/utils/validation.py
CHANGED
|
@@ -26,6 +26,23 @@ def validate_dtype_float(*tensors: Tensor) -> None:
|
|
|
26
26
|
if _check_skip_validation():
|
|
27
27
|
return
|
|
28
28
|
|
|
29
|
+
dtype = None
|
|
30
|
+
|
|
31
|
+
for i, tensor in enumerate(tensors):
|
|
32
|
+
if i == 0:
|
|
33
|
+
dtype = tensor.dtype
|
|
34
|
+
|
|
35
|
+
if tensor.dtype != torch.float16 and tensor.dtype != torch.float32:
|
|
36
|
+
raise ValueError("Tensor must have either float16 or float32 dtype")
|
|
37
|
+
|
|
38
|
+
if tensor.dtype != dtype:
|
|
39
|
+
raise ValueError("Tensors must have same dtype")
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def validate_dtype_float_32(*tensors: Tensor) -> None:
|
|
43
|
+
if _check_skip_validation():
|
|
44
|
+
return
|
|
45
|
+
|
|
29
46
|
for tensor in tensors:
|
|
30
47
|
if tensor.dtype != torch.float32:
|
|
31
48
|
raise ValueError("Tensor must have float32 dtype")
|
|
@@ -38,7 +55,7 @@ def validate_dtype_int(*tensors: Tensor) -> None:
|
|
|
38
55
|
for tensor in tensors:
|
|
39
56
|
if (tensor.dtype !=
|
|
40
57
|
torch.int32 and tensor.dtype != torch.int64):
|
|
41
|
-
raise ValueError("Tensor must have int32 or int64 dtype")
|
|
58
|
+
raise ValueError("Tensor must have either int32 or int64 dtype")
|
|
42
59
|
|
|
43
60
|
|
|
44
61
|
def validate_device(*tensors: Tensor) -> None:
|
|
@@ -51,7 +68,7 @@ def validate_device(*tensors: Tensor) -> None:
|
|
|
51
68
|
if i == 0:
|
|
52
69
|
device = tensor.device
|
|
53
70
|
|
|
54
|
-
if not device.type ==
|
|
71
|
+
if not device.type == "cuda":
|
|
55
72
|
raise ValueError("Tensors must be on GPU")
|
|
56
73
|
|
|
57
74
|
if tensor.device != device:
|
|
@@ -96,6 +113,9 @@ def validate_sparsity_block_size(sparsity_block_size: int, *tensors):
|
|
|
96
113
|
if _check_skip_validation():
|
|
97
114
|
return
|
|
98
115
|
|
|
116
|
+
if not sparsity_block_size >= 16:
|
|
117
|
+
raise ValueError("Sparsity block size must be at least 16")
|
|
118
|
+
|
|
99
119
|
if not (sparsity_block_size & (sparsity_block_size - 1)) == 0:
|
|
100
120
|
raise ValueError("Sparsity block size must be a power of 2")
|
|
101
121
|
|
|
@@ -104,20 +124,6 @@ def validate_sparsity_block_size(sparsity_block_size: int, *tensors):
|
|
|
104
124
|
raise ValueError("Tensor sizes must be divisible by sparsity block size")
|
|
105
125
|
|
|
106
126
|
|
|
107
|
-
def validate_triton_block_size(triton_block_size: int, sparsity_block_size: int):
|
|
108
|
-
if _check_skip_validation():
|
|
109
|
-
return
|
|
110
|
-
|
|
111
|
-
if triton_block_size is None:
|
|
112
|
-
return
|
|
113
|
-
|
|
114
|
-
if not (triton_block_size & (triton_block_size - 1)) == 0:
|
|
115
|
-
raise ValueError("Triton block size must be a power of 2")
|
|
116
|
-
|
|
117
|
-
if triton_block_size > sparsity_block_size:
|
|
118
|
-
raise ValueError("Triton block size cannot be larger than sparsity block size")
|
|
119
|
-
|
|
120
|
-
|
|
121
127
|
def _check_skip_validation():
|
|
122
128
|
return not VALIDATION
|
|
123
129
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version:
|
|
3
|
+
Version: 2.0
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -14,9 +14,8 @@ Requires-Dist: pytest; extra == "test"
|
|
|
14
14
|
Requires-Dist: pytest-xdist; extra == "test"
|
|
15
15
|
Requires-Dist: pytest-cov; extra == "test"
|
|
16
16
|
Requires-Dist: coverage; extra == "test"
|
|
17
|
+
Requires-Dist: build; extra == "test"
|
|
17
18
|
Requires-Dist: matplotlib; extra == "test"
|
|
18
|
-
Provides-Extra: build
|
|
19
|
-
Requires-Dist: build; extra == "build"
|
|
20
19
|
|
|
21
20
|
# blksprs
|
|
22
21
|
|
|
@@ -25,6 +24,13 @@ Requires-Dist: build; extra == "build"
|
|
|
25
24
|
|
|
26
25
|
## Overview
|
|
27
26
|
|
|
27
|
+
### News
|
|
28
|
+
|
|
29
|
+
🎉 ***Version 2.0 released***. blksprs now supports kernel auto-tuning, JIT compilation, specification of pre-calculated
|
|
30
|
+
LUTs, autocasting, and makes use of `torch.library.triton_op()`!
|
|
31
|
+
|
|
32
|
+
---
|
|
33
|
+
|
|
28
34
|
A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
|
|
29
35
|
|
|
30
36
|
Currently supported operations (includes gradient calculation):
|
|
@@ -52,23 +58,25 @@ These include, e.g.,
|
|
|
52
58
|
Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to
|
|
53
59
|
match.
|
|
54
60
|
|
|
55
|
-
Further helpful operations (included in the ``bs.ops.misc`` module) that do **not** support gradient calculation
|
|
61
|
+
Further helpful operations (included in the ``bs.ops.misc`` module) that do **not** support gradient calculation
|
|
62
|
+
include:
|
|
56
63
|
|
|
57
64
|
- Row-wise sum, max, addition, and subtraction
|
|
58
65
|
- Broadcast addition and subtraction between slices
|
|
59
66
|
|
|
60
|
-
Furthermore, the library provides a set of utility functions
|
|
67
|
+
Furthermore, the library provides a set of utility functions
|
|
61
68
|
|
|
62
69
|
- for the creation of sparsity layouts based on existing
|
|
63
|
-
dense tensors and for the scatter operation (module ``bs.layouting``),
|
|
70
|
+
dense tensors and for the scatter operation (module ``bs.layouting``),
|
|
64
71
|
- for the application of ``nn.Linear``, ``nn.Dropout``, and ``nn.LayerNorm`` layers to block-sparse tensors,
|
|
65
72
|
- as well as utility functions to ensure correct input dimensionality, and validate input (module ``bs.utils``).
|
|
66
73
|
|
|
67
|
-
_* see the [Roadmap](#roadmap) section for more information_
|
|
74
|
+
_* see the [Roadmap](#roadmap) section for more information_
|
|
68
75
|
|
|
69
76
|
## Installation
|
|
70
77
|
|
|
71
|
-
Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible
|
|
78
|
+
Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible
|
|
79
|
+
with
|
|
72
80
|
the Linux platform**.
|
|
73
81
|
Keep track of this [issue](https://github.com/triton-lang/triton/issues/1640) for updates.
|
|
74
82
|
|
|
@@ -78,8 +86,8 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
|
|
|
78
86
|
|
|
79
87
|
### Dependencies
|
|
80
88
|
|
|
81
|
-
- [PyTorch](https://pytorch.org/) (built with v2.
|
|
82
|
-
- _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.2.
|
|
89
|
+
- [PyTorch](https://pytorch.org/) (built with v2.6)
|
|
90
|
+
- _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.2.4)_
|
|
83
91
|
- _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
|
|
84
92
|
|
|
85
93
|
## Changelog
|
|
@@ -89,12 +97,27 @@ See [`CHANGELOG.md`](https://github.com/FelixSchoen/blksprs/blob/main/CHANGELOG.
|
|
|
89
97
|
## Roadmap
|
|
90
98
|
|
|
91
99
|
Note that since this library covers all our current needs it is in a **bugfix-only** state.
|
|
92
|
-
This means that there are no plans to add new features, e.g., support for dimension specification of the ``split`` and
|
|
100
|
+
This means that there are no plans to add new features, e.g., support for dimension specification of the ``split`` and
|
|
101
|
+
``merge`` operations.
|
|
93
102
|
We will continue to maintain the library and fix any issues that arise.
|
|
94
103
|
Should you find any bugs please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
|
|
95
104
|
We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
|
|
96
105
|
|
|
97
|
-
It might be that this changes with future projects, but as of
|
|
106
|
+
It might be that this changes with future projects, but as of March 2025, we are content with the current state of the
|
|
107
|
+
library.
|
|
108
|
+
|
|
109
|
+
## Known Limitations and Issues
|
|
110
|
+
|
|
111
|
+
- Triton has a bug with `tl.atomix_max()` used for the row-wise max operation.
|
|
112
|
+
In order to work around this bug a manual conversion of some values is needed, (slightly) negatively impacting
|
|
113
|
+
performance.
|
|
114
|
+
Watch the [issue](https://github.com/triton-lang/triton/issues/6376) on Triton's issue tracker for more information.
|
|
115
|
+
- PyTorch's `wrap_triton()` currently does not support config pruning. It thus cannot be used for some of the kernels,
|
|
116
|
+
which could impact graph compilation.
|
|
117
|
+
- There seem to be some issues with autocasting, forcing some operations to manually cast.
|
|
118
|
+
- There will be some slight numerical differences between vanilla and blksprs operations.
|
|
119
|
+
These instabilities are due to Triton and thus cannot be fixed by this library alone.
|
|
120
|
+
However, for all intents and purposes, these very minor differences should not matter and can safely be ignored.
|
|
98
121
|
|
|
99
122
|
## Usage
|
|
100
123
|
|
|
@@ -120,10 +143,6 @@ def test_readme():
|
|
|
120
143
|
# Must be a power of two, greater than or equal to 16 for matmul, and divide m, n, and k
|
|
121
144
|
sparsity_block_size = 16
|
|
122
145
|
|
|
123
|
-
# Must be a power of two and smaller than or equal to sparsity_block_size
|
|
124
|
-
# If it is set to ``none`` a value will be chosen automatically
|
|
125
|
-
triton_block_size = None
|
|
126
|
-
|
|
127
146
|
# Initialise random (dense) tensors
|
|
128
147
|
x = torch.randn(size=(b, h, m, k), device="cuda")
|
|
129
148
|
y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()
|
|
@@ -133,53 +152,53 @@ def test_readme():
|
|
|
133
152
|
y_dense, y_shape_original = bs.utils.do_shape_blocksparse(y)
|
|
134
153
|
|
|
135
154
|
# Create sparsity layouts from existing tensors
|
|
136
|
-
sparsity_layout_x = bs.layouting.build_sparsity_layout(x_dense, sparsity_block_size
|
|
137
|
-
|
|
138
|
-
sparsity_layout_y = bs.layouting.build_sparsity_layout(y_dense, sparsity_block_size,
|
|
139
|
-
triton_block_size=triton_block_size)
|
|
155
|
+
sparsity_layout_x = bs.layouting.build_sparsity_layout(x_dense, sparsity_block_size)
|
|
156
|
+
sparsity_layout_y = bs.layouting.build_sparsity_layout(y_dense, sparsity_block_size)
|
|
140
157
|
|
|
141
158
|
# Create random sparsity layout for output tensor
|
|
142
159
|
sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)
|
|
143
160
|
|
|
144
161
|
# Convert tensors to sparse tensors for matrix multiplication
|
|
145
|
-
x_sparse = bs.to_sparse(x_dense, sparsity_layout_x, sparsity_block_size
|
|
146
|
-
y_sparse = bs.to_sparse(y_dense, sparsity_layout_y, sparsity_block_size
|
|
162
|
+
x_sparse = bs.ops.to_sparse(x_dense, sparsity_layout_x, sparsity_block_size)
|
|
163
|
+
y_sparse = bs.ops.to_sparse(y_dense, sparsity_layout_y, sparsity_block_size)
|
|
164
|
+
|
|
165
|
+
# As of version 2.0, blksprs supports JIT compilation
|
|
166
|
+
matmul_compiled = torch.compile(bs.ops.matmul)
|
|
147
167
|
|
|
148
168
|
# Perform matrix multiplication
|
|
149
|
-
o_sparse =
|
|
150
|
-
|
|
151
|
-
|
|
169
|
+
o_sparse = matmul_compiled(x_sparse, sparsity_layout_x,
|
|
170
|
+
y_sparse, sparsity_layout_y,
|
|
171
|
+
sparsity_layout_o, sparsity_block_size)
|
|
152
172
|
|
|
153
173
|
# Apply element-wise operation
|
|
154
174
|
o_sparse = torch.add(o_sparse, 1)
|
|
155
175
|
|
|
156
|
-
o_dense = bs.to_dense(o_sparse, sparsity_layout_o, sparsity_block_size
|
|
176
|
+
o_dense = bs.ops.to_dense(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
157
177
|
|
|
158
178
|
# Sanity check
|
|
159
179
|
o_torch = torch.matmul(x_dense, y_dense)
|
|
160
180
|
o_torch = torch.add(o_torch, 1)
|
|
161
181
|
|
|
162
182
|
# Perform round trip to set sparse blocks to 0
|
|
163
|
-
o_torch_round_trip = bs.to_dense(
|
|
164
|
-
bs.to_sparse(o_torch, sparsity_layout_o, sparsity_block_size
|
|
165
|
-
sparsity_layout_o, sparsity_block_size, fill_value=0
|
|
183
|
+
o_torch_round_trip = bs.ops.to_dense(
|
|
184
|
+
bs.ops.to_sparse(o_torch, sparsity_layout_o, sparsity_block_size),
|
|
185
|
+
sparsity_layout_o, sparsity_block_size, fill_value=0)
|
|
166
186
|
|
|
167
187
|
# Assert that the output is correct
|
|
168
188
|
assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2) # Note that small numerical differences are expected
|
|
169
189
|
|
|
170
190
|
# Assert that the output has the correct sparsity layout
|
|
171
|
-
actual_sparsity_layout_o = bs.layouting.build_sparsity_layout(o_dense, sparsity_block_size
|
|
172
|
-
triton_block_size=triton_block_size)
|
|
191
|
+
actual_sparsity_layout_o = bs.layouting.build_sparsity_layout(o_dense, sparsity_block_size)
|
|
173
192
|
assert torch.allclose(actual_sparsity_layout_o.to(torch.int), sparsity_layout_o)
|
|
174
193
|
|
|
175
194
|
# Convert output tensor back to original shape
|
|
176
195
|
o = bs.utils.undo_shape_blocksparse(o_dense, x_shape_original)
|
|
177
196
|
|
|
178
197
|
# Other available functions
|
|
179
|
-
bs.transpose(o_sparse, sparsity_layout_o, sparsity_block_size
|
|
180
|
-
bs.softmax(o_sparse, sparsity_layout_o, sparsity_block_size
|
|
181
|
-
bs.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size
|
|
182
|
-
bs.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size
|
|
198
|
+
bs.ops.transpose(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
199
|
+
bs.ops.softmax(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
200
|
+
bs.ops.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
201
|
+
bs.ops.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
183
202
|
|
|
184
203
|
|
|
185
204
|
def _get_random_sparsity_layout(b, m, n, sparsity_block_size, sparsity_percentage):
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
blksprs/__init__.py,sha256=QXYiDairqn5LzvpoNyBXHdQfyac-aX3IDjTQUMkaH-A,1598
|
|
2
|
+
blksprs/layouting/distribution_layout.py,sha256=ur1ty_2U-Hfj78hMWsLZvu7ZuGhzW3qGLKMc72DfTZM,5861
|
|
3
|
+
blksprs/layouting/sparsity_layout.py,sha256=eXHmu2h7K5Q-YUpfOxocJoeP_5ZoQFZf_eHLxRZQbYU,11207
|
|
4
|
+
blksprs/ops/conversion.py,sha256=kf5HKofZ4nVeHCIqQoYKiIlgsAhq33Tnmnr1c17Fkqs,21906
|
|
5
|
+
blksprs/ops/distribution.py,sha256=0tPldv0ARzmCV1CU2jvfqpHBgOuHPrDFiCtqsLs7CZc,20789
|
|
6
|
+
blksprs/ops/flow.py,sha256=PDZAD8u4y9qW1IXERki6ItKbEKnm_ChG8SKWM3_P9Oc,8245
|
|
7
|
+
blksprs/ops/matmul.py,sha256=5tVBKU_lglUjaLDi6J_dscdqlmzRz38OGxqAxZxZXDs,11879
|
|
8
|
+
blksprs/ops/partitioning.py,sha256=cfQmY9BZqGTvvJorIhtb-EyuGRJGPraWR-wTKdb47aI,9954
|
|
9
|
+
blksprs/ops/repeat.py,sha256=TLYNxwPuT9y5K9xyM41WK5gnggAJF3lI61Q2K7zWjns,9035
|
|
10
|
+
blksprs/ops/softmax.py,sha256=BwrRQdtRdkiSvl2mf5bpsTmyIxWiJOpa1HFg0st5yGU,12778
|
|
11
|
+
blksprs/ops/transpose.py,sha256=U-VAyLRT6_NDv9qYSFzBqfVlDeIpTqAMEXkqto0VF6w,4072
|
|
12
|
+
blksprs/ops/misc/broadcast_ops.py,sha256=-PrHiSJikZh8nXUmXxSCtFEP27TTxFr4wcrNxBjnimk,5987
|
|
13
|
+
blksprs/ops/misc/row_wise.py,sha256=n5FJjAuOd8BHBJQx4bsQwr-HmXkR9PYVAqfk77wjOFU,19653
|
|
14
|
+
blksprs/utils/autotuning.py,sha256=a-kmWRjJ3eED2XbjkQeOJSyW8bdIs27HgKMPvAKqWeU,2052
|
|
15
|
+
blksprs/utils/benchmarking.py,sha256=dLabDscTFn5NkmOI1g7DnKeTneUYW3RIVv9MDF-8BKc,1271
|
|
16
|
+
blksprs/utils/blksprs_tensor.py,sha256=pfoz59aJixj_fIoFx76ySiygwRQUemmgjMKepZ2c4j0,244
|
|
17
|
+
blksprs/utils/processing.py,sha256=RNkEDc0g-sNHRuMPkRzNWU13d3_lIkXMJdoqES4yQTM,3738
|
|
18
|
+
blksprs/utils/tools.py,sha256=CPf7viQ2OTcZFrB1aSL8_us4VE9M6YEfDz2dE30jr9I,715
|
|
19
|
+
blksprs/utils/validation.py,sha256=G8eQlvJVMKfEX3k2AwBD0A6Ck-gFoRLpLNY6HXsB3fA,4348
|
|
20
|
+
blksprs-2.0.dist-info/METADATA,sha256=7YVN_akf-ewrAW5thDZzhT3hogn0dVxV73lvlPMk59c,9506
|
|
21
|
+
blksprs-2.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
22
|
+
blksprs-2.0.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
|
|
23
|
+
blksprs-2.0.dist-info/RECORD,,
|
blksprs/utils/layout_utils.py
DELETED
|
@@ -1,17 +0,0 @@
|
|
|
1
|
-
import math
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
import triton
|
|
5
|
-
from torch import Tensor
|
|
6
|
-
from torch.xpu import device
|
|
7
|
-
from triton import language as tl
|
|
8
|
-
|
|
9
|
-
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
10
|
-
from blksprs.utils.tools import get_triton_block_size, stride
|
|
11
|
-
from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
|
|
12
|
-
validate_contiguous, validate_sparsity, validate_sparsity_block_size
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
def build_full_sparsity_layout(x: Tensor, sparsity_block_size: int) -> Tensor:
|
|
16
|
-
return torch.ones(size=(x.size(0), x.size(1) // sparsity_block_size, x.size(2) // sparsity_block_size),
|
|
17
|
-
dtype=torch.bool, device=x.device)
|
blksprs-1.11.dist-info/RECORD
DELETED
|
@@ -1,23 +0,0 @@
|
|
|
1
|
-
blksprs/__init__.py,sha256=AJYVfR40nOfE5F3waHPVSuajwYDcoGkiEQc8HhQbUBU,1721
|
|
2
|
-
blksprs/layouting/distribution_layout.py,sha256=xDGY5-J7uSD8oenlf8bEJ2amMiQG3NBf2klTTydbTJE,5140
|
|
3
|
-
blksprs/layouting/sparsity_layout.py,sha256=IVtHc_nN3ZM2y4GFcys70PqDWmWc7tkHlVGlToErANk,9894
|
|
4
|
-
blksprs/ops/conversion.py,sha256=QFtZ-nmY2JAWutheiO07vatXqz3eSZBP5Ym_U2Q1oWk,23299
|
|
5
|
-
blksprs/ops/distribution.py,sha256=nHTuE7Tq0Q404VN8bWNC2sEwmmdAtgZI6I7auRICdps,21749
|
|
6
|
-
blksprs/ops/flow.py,sha256=7tOXfTBKOAixYmDa_VXg7TwviLV5ZQMHQjtbyOjqA00,7879
|
|
7
|
-
blksprs/ops/matmul.py,sha256=eVj_BGj78bJkXYuvw4KctMfcfveQBt5OdYmeXzdpO88,12631
|
|
8
|
-
blksprs/ops/partitioning.py,sha256=qMv9w3yFWXwXIhIppdcJ_JMsoZ25HCH38vb6GRneoLM,10416
|
|
9
|
-
blksprs/ops/repeat.py,sha256=i824ijprfYpCaEjiSG5FTUZz7wMS5ksVy_-vY7ZX8Fg,9729
|
|
10
|
-
blksprs/ops/softmax.py,sha256=_mGkA2jHN8cXwtWXYswobEPyM7UC0JyzRszoE4ZYs7w,13063
|
|
11
|
-
blksprs/ops/transpose.py,sha256=O1XhGIGiVkhOSKcBD0HrYaeK6HmpvEEzLb7zJl7xsIM,4246
|
|
12
|
-
blksprs/ops/misc/broadcast_ops.py,sha256=pv0nssSDOdDbQFttpqUIs2ZXShqfm2RYCfJH-C5x3H0,5544
|
|
13
|
-
blksprs/ops/misc/row_wise.py,sha256=DnV5-xEJUbqZlK2fETwHiPQDUMwT-lkc0VUhBlnJ5Y0,17458
|
|
14
|
-
blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
|
|
15
|
-
blksprs/utils/blksprs_tensor.py,sha256=pfoz59aJixj_fIoFx76ySiygwRQUemmgjMKepZ2c4j0,244
|
|
16
|
-
blksprs/utils/layout_utils.py,sha256=49ZdPS_gMn_IrWty3FARbi2rda5a8g5DmAEL8LOrC30,670
|
|
17
|
-
blksprs/utils/processing.py,sha256=WLuMJQ8v-YovXwcDjhlDn3N31WMZXrtyeeyKSgq_zn4,3642
|
|
18
|
-
blksprs/utils/tools.py,sha256=k2OfEplbQiAwVjP84zZf7SNB8FzvMtOFBL9sC98OCbI,683
|
|
19
|
-
blksprs/utils/validation.py,sha256=CbxBbeQWJo8wox5eMoVzaTlP9FVBwt3-gxUOmi3EUgw,4213
|
|
20
|
-
blksprs-1.11.dist-info/METADATA,sha256=NUEiHexWiFNbMxQI2TUEzMw9iGBhxqflhWr2xCgOw28,9105
|
|
21
|
-
blksprs-1.11.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
|
|
22
|
-
blksprs-1.11.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
|
|
23
|
-
blksprs-1.11.dist-info/RECORD,,
|
|
File without changes
|