blksprs 2.0rc4__py3-none-any.whl → 2.0rc7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/layouting/distribution_layout.py +11 -15
- blksprs/layouting/sparsity_layout.py +26 -31
- blksprs/ops/conversion.py +45 -63
- blksprs/ops/distribution.py +38 -57
- blksprs/ops/flow.py +22 -33
- blksprs/ops/matmul.py +19 -20
- blksprs/ops/misc/broadcast_ops.py +15 -19
- blksprs/ops/misc/row_wise.py +39 -54
- blksprs/ops/softmax.py +30 -44
- blksprs/utils/autotuning.py +78 -0
- blksprs/utils/tools.py +0 -28
- blksprs/utils/validation.py +3 -0
- {blksprs-2.0rc4.dist-info → blksprs-2.0rc7.dist-info}/METADATA +18 -5
- blksprs-2.0rc7.dist-info/RECORD +23 -0
- blksprs-2.0rc4.dist-info/RECORD +0 -22
- {blksprs-2.0rc4.dist-info → blksprs-2.0rc7.dist-info}/WHEEL +0 -0
- {blksprs-2.0rc4.dist-info → blksprs-2.0rc7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
blksprs_autotune_mode = os.getenv("BLKSPRS_AUTOTUNE", "DEFAULT")
|
|
4
|
+
|
|
5
|
+
if blksprs_autotune_mode == "TEST":
|
|
6
|
+
autotune_parameters = [
|
|
7
|
+
(16, 3, 8),
|
|
8
|
+
|
|
9
|
+
(32, 3, 8),
|
|
10
|
+
|
|
11
|
+
(64, 3, 8),
|
|
12
|
+
]
|
|
13
|
+
elif blksprs_autotune_mode == "DEFAULT":
|
|
14
|
+
autotune_parameters = [
|
|
15
|
+
(16, 3, 8),
|
|
16
|
+
(16, 4, 4),
|
|
17
|
+
(16, 5, 2),
|
|
18
|
+
|
|
19
|
+
(32, 3, 8),
|
|
20
|
+
(32, 4, 4),
|
|
21
|
+
(32, 5, 2),
|
|
22
|
+
|
|
23
|
+
(64, 3, 8),
|
|
24
|
+
(64, 4, 4),
|
|
25
|
+
(64, 5, 2),
|
|
26
|
+
|
|
27
|
+
(128, 3, 8),
|
|
28
|
+
(128, 4, 4),
|
|
29
|
+
(128, 5, 2),
|
|
30
|
+
]
|
|
31
|
+
else:
|
|
32
|
+
raise NotImplementedError(f"Unknown autotune mode: {blksprs_autotune_mode}")
|
|
33
|
+
|
|
34
|
+
import torch
|
|
35
|
+
import triton
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def prune_autotune_configs(autotune_configs, kernel_args, **kwargs):
|
|
39
|
+
sparsity_block_size = kernel_args["sparsity_block_size"]
|
|
40
|
+
|
|
41
|
+
pruned_configs = []
|
|
42
|
+
|
|
43
|
+
for config in autotune_configs:
|
|
44
|
+
if config.kwargs["TRITON_BLOCK_SIZE"] <= sparsity_block_size:
|
|
45
|
+
pruned_configs.append(config)
|
|
46
|
+
|
|
47
|
+
assert len(pruned_configs) > 0, f"No valid autotune configs found for sparsity block size {sparsity_block_size}"
|
|
48
|
+
|
|
49
|
+
return pruned_configs
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def prune_autotune_configs_conversion(autotune_configs, kernel_args, **kwargs):
|
|
53
|
+
sparsity_block_size_from = kernel_args["sparsity_block_size_from"]
|
|
54
|
+
sparsity_block_size_to = kernel_args["sparsity_block_size_to"]
|
|
55
|
+
sparsity_block_size = min(sparsity_block_size_from, sparsity_block_size_to)
|
|
56
|
+
|
|
57
|
+
pruned_configs = []
|
|
58
|
+
|
|
59
|
+
for config in autotune_configs:
|
|
60
|
+
if config.kwargs["TRITON_BLOCK_SIZE"] <= sparsity_block_size:
|
|
61
|
+
pruned_configs.append(config)
|
|
62
|
+
|
|
63
|
+
assert len(pruned_configs) > 0, f"No valid autotune configs found for sparsity block size {sparsity_block_size}"
|
|
64
|
+
|
|
65
|
+
return pruned_configs
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@torch.compile
|
|
69
|
+
def get_autotune_configs():
|
|
70
|
+
global autotune_parameters
|
|
71
|
+
|
|
72
|
+
autotune_configs = []
|
|
73
|
+
|
|
74
|
+
for block_size, num_stages, num_warps in autotune_parameters:
|
|
75
|
+
autotune_configs.append(
|
|
76
|
+
triton.Config({"TRITON_BLOCK_SIZE": block_size}, num_stages=num_stages, num_warps=num_warps))
|
|
77
|
+
|
|
78
|
+
return autotune_configs
|
blksprs/utils/tools.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
import triton
|
|
3
2
|
from torch import Tensor, Size
|
|
4
3
|
|
|
5
4
|
# Capture scalar outputs for JIT compilation
|
|
@@ -27,30 +26,3 @@ def stride(x: Tensor):
|
|
|
27
26
|
return x.size(1) * x.size(2), x.size(2), 1
|
|
28
27
|
else:
|
|
29
28
|
raise NotImplementedError
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
@torch.compile
|
|
33
|
-
def get_autotune_configs():
|
|
34
|
-
configs = []
|
|
35
|
-
config_parameters = [
|
|
36
|
-
(16, 3, 8),
|
|
37
|
-
(16, 4, 4),
|
|
38
|
-
(16, 5, 2),
|
|
39
|
-
|
|
40
|
-
(32, 3, 8),
|
|
41
|
-
(32, 4, 4),
|
|
42
|
-
(32, 5, 2),
|
|
43
|
-
|
|
44
|
-
(64, 3, 8),
|
|
45
|
-
(64, 4, 4),
|
|
46
|
-
(64, 5, 2),
|
|
47
|
-
|
|
48
|
-
(128, 3, 8),
|
|
49
|
-
(128, 4, 4),
|
|
50
|
-
(128, 5, 2),
|
|
51
|
-
]
|
|
52
|
-
|
|
53
|
-
for block_size, num_stages, num_warps in config_parameters:
|
|
54
|
-
configs.append(triton.Config({"TRITON_BLOCK_SIZE": block_size}, num_stages=num_stages, num_warps=num_warps))
|
|
55
|
-
|
|
56
|
-
return configs
|
blksprs/utils/validation.py
CHANGED
|
@@ -113,6 +113,9 @@ def validate_sparsity_block_size(sparsity_block_size: int, *tensors):
|
|
|
113
113
|
if _check_skip_validation():
|
|
114
114
|
return
|
|
115
115
|
|
|
116
|
+
if not sparsity_block_size >= 16:
|
|
117
|
+
raise ValueError("Sparsity block size must be at least 16")
|
|
118
|
+
|
|
116
119
|
if not (sparsity_block_size & (sparsity_block_size - 1)) == 0:
|
|
117
120
|
raise ValueError("Sparsity block size must be a power of 2")
|
|
118
121
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.0rc7
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -24,10 +24,10 @@ Requires-Dist: matplotlib; extra == "test"
|
|
|
24
24
|
|
|
25
25
|
## Overview
|
|
26
26
|
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
27
|
+
### News
|
|
28
|
+
|
|
29
|
+
🎉 ***Version 2.0 released***. blksprs now supports kernel auto-tuning, JIT compilation, specification of pre-calculated
|
|
30
|
+
LUTs, autocasting, and makes use of `torch.library.triton_op()`!
|
|
31
31
|
|
|
32
32
|
---
|
|
33
33
|
|
|
@@ -106,6 +106,19 @@ We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
|
|
|
106
106
|
It might be that this changes with future projects, but as of March 2025, we are content with the current state of the
|
|
107
107
|
library.
|
|
108
108
|
|
|
109
|
+
## Known Limitations and Issues
|
|
110
|
+
|
|
111
|
+
- Triton has a bug with `tl.atomix_max()` used for the row-wise max operation.
|
|
112
|
+
In order to work around this bug a manual conversion of some values is needed, (slightly) negatively impacting
|
|
113
|
+
performance.
|
|
114
|
+
Watch the [issue](https://github.com/triton-lang/triton/issues/6376) on Triton's issue tracker for more information.
|
|
115
|
+
- PyTorch's `wrap_triton()` currently does not support config pruning. It thus cannot be used for some of the kernels,
|
|
116
|
+
which could impact graph compilation.
|
|
117
|
+
- There seem to be some issues with autocasting, forcing some operations to manually cast.
|
|
118
|
+
- There will be some slight numerical differences between vanilla and blksprs operations.
|
|
119
|
+
These instabilities are due to Triton and thus cannot be fixed by this library alone.
|
|
120
|
+
However, for all intents and purposes, these very minor differences should not matter and can safely be ignored.
|
|
121
|
+
|
|
109
122
|
## Usage
|
|
110
123
|
|
|
111
124
|
We provide an example below to demonstrate the usage of the library.
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
blksprs/__init__.py,sha256=OHfpwJCZWGUfpT-DVfC1YSaeZl4aCMNt9CrzMPymywU,1577
|
|
2
|
+
blksprs/layouting/distribution_layout.py,sha256=TkMh_DYKX56Cb8Vq7EHyupMRvzm0XbUNP8QP7afv9wM,5122
|
|
3
|
+
blksprs/layouting/sparsity_layout.py,sha256=6GOjwllDUK9L8jEQNu2i17Pp1BIIQm8fv3xVuiR0zIw,10228
|
|
4
|
+
blksprs/ops/conversion.py,sha256=2zAdbaZ1iP2lisLVeG-k-f571G4HJapADhSwpY0Zd3o,21503
|
|
5
|
+
blksprs/ops/distribution.py,sha256=6joac_zl3ZnRkPqLPQ0d88r7IbcrWAg0HiV93LOZw-w,20453
|
|
6
|
+
blksprs/ops/flow.py,sha256=UO5ba5TFgVpEyT7r0hnWYw3vhRDpBOxyPHUBeNOAYPs,7935
|
|
7
|
+
blksprs/ops/matmul.py,sha256=02hujXMtFgF7ohepM3v6h9okrfcU-J3mQZV17B-qvh0,12235
|
|
8
|
+
blksprs/ops/partitioning.py,sha256=nAV28f3NtvT4OFvDtnE0A-VxpDQmMXS0pZw4CJwzqGA,9838
|
|
9
|
+
blksprs/ops/repeat.py,sha256=bQpJuwtt8aRdSzxT78lJ8f8fLDhPkYK5UvMfJ-PQrkc,8977
|
|
10
|
+
blksprs/ops/softmax.py,sha256=-NoTf1Cpuku9C99N0LuMydT_ObozWTnZJGDZxseXEXI,12209
|
|
11
|
+
blksprs/ops/transpose.py,sha256=PQKteFnzNAOEC7voO7wh_dq9c54UjCboJz889aBCwKc,4010
|
|
12
|
+
blksprs/ops/misc/broadcast_ops.py,sha256=DhUbliT9TBT6zlEjutBmY1EAEUPmYOt2mKQ5i46vN1c,5880
|
|
13
|
+
blksprs/ops/misc/row_wise.py,sha256=5u_J8WOTepvf6XtZ8r0lLPofYrI5fGB7mxSmGC81IR0,19167
|
|
14
|
+
blksprs/utils/autotuning.py,sha256=tDfMWklm2rvbo0-ahH81C3Gg0U6LHjPn3d_3pEOzmJs,2053
|
|
15
|
+
blksprs/utils/benchmarking.py,sha256=dLabDscTFn5NkmOI1g7DnKeTneUYW3RIVv9MDF-8BKc,1271
|
|
16
|
+
blksprs/utils/blksprs_tensor.py,sha256=pfoz59aJixj_fIoFx76ySiygwRQUemmgjMKepZ2c4j0,244
|
|
17
|
+
blksprs/utils/processing.py,sha256=xuu9iDpwTvsqI_WKMSD8QCNuvPnfcKMRcuF2L4Zs6Ts,3808
|
|
18
|
+
blksprs/utils/tools.py,sha256=3_2IBbd54vVU4-6m2KtAN7qjU6jeF4UfPkbjeFqMpYo,664
|
|
19
|
+
blksprs/utils/validation.py,sha256=G8eQlvJVMKfEX3k2AwBD0A6Ck-gFoRLpLNY6HXsB3fA,4348
|
|
20
|
+
blksprs-2.0rc7.dist-info/METADATA,sha256=ER9DHdVeYUZUsjE-2bEB9fePw0FVI1vknwPNrj7mDPE,9509
|
|
21
|
+
blksprs-2.0rc7.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
|
22
|
+
blksprs-2.0rc7.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
|
|
23
|
+
blksprs-2.0rc7.dist-info/RECORD,,
|
blksprs-2.0rc4.dist-info/RECORD
DELETED
|
@@ -1,22 +0,0 @@
|
|
|
1
|
-
blksprs/__init__.py,sha256=OHfpwJCZWGUfpT-DVfC1YSaeZl4aCMNt9CrzMPymywU,1577
|
|
2
|
-
blksprs/layouting/distribution_layout.py,sha256=0glIteoY5oDkiEu5rjLIC-BB_oC4sa3rFWVkohsAG00,5329
|
|
3
|
-
blksprs/layouting/sparsity_layout.py,sha256=ZUhJm1jJn-npiJWFjsVyzjXDQOp8z-Wjjv0MPQOXRvg,10490
|
|
4
|
-
blksprs/ops/conversion.py,sha256=FsujfUH3R8ijSti_ifsTQihB0djK8Snny2fbGRruzRw,22459
|
|
5
|
-
blksprs/ops/distribution.py,sha256=CTcDcUx8vwe-9F9Y25B7ea7tcvy5gR2Pyk0Ko48MWFo,21514
|
|
6
|
-
blksprs/ops/flow.py,sha256=MY1ypGLIAlkZty5iQINip5mDIQxu9pP1D1dIae4sKJg,8433
|
|
7
|
-
blksprs/ops/matmul.py,sha256=xFxWSCy9NwPDTxfSUOyQU_X4sHp3HrJtohlUCc1WO8g,12028
|
|
8
|
-
blksprs/ops/partitioning.py,sha256=nAV28f3NtvT4OFvDtnE0A-VxpDQmMXS0pZw4CJwzqGA,9838
|
|
9
|
-
blksprs/ops/repeat.py,sha256=bQpJuwtt8aRdSzxT78lJ8f8fLDhPkYK5UvMfJ-PQrkc,8977
|
|
10
|
-
blksprs/ops/softmax.py,sha256=PdRPAkCJahtGBO5W-aqF_Dxi9X8RJ621XmYfVo2I0OM,12968
|
|
11
|
-
blksprs/ops/transpose.py,sha256=PQKteFnzNAOEC7voO7wh_dq9c54UjCboJz889aBCwKc,4010
|
|
12
|
-
blksprs/ops/misc/broadcast_ops.py,sha256=lZ5bBIftUKffzeYz77SWB1xmtZTRGMvjF-tG9rqkOXA,6018
|
|
13
|
-
blksprs/ops/misc/row_wise.py,sha256=FOy73-I5_OuCugiq0xQxtre9-ytfBQPDaXQv8tssuXg,19764
|
|
14
|
-
blksprs/utils/benchmarking.py,sha256=dLabDscTFn5NkmOI1g7DnKeTneUYW3RIVv9MDF-8BKc,1271
|
|
15
|
-
blksprs/utils/blksprs_tensor.py,sha256=pfoz59aJixj_fIoFx76ySiygwRQUemmgjMKepZ2c4j0,244
|
|
16
|
-
blksprs/utils/processing.py,sha256=xuu9iDpwTvsqI_WKMSD8QCNuvPnfcKMRcuF2L4Zs6Ts,3808
|
|
17
|
-
blksprs/utils/tools.py,sha256=RL18P4NAj7d8gXTTKbMZt4SHCynsw1wPu9yvlrnBQlo,1220
|
|
18
|
-
blksprs/utils/validation.py,sha256=7ks9hdNKbov1JE9y1bpnIfjWCVhqINTZOIZPi6d7k8E,4241
|
|
19
|
-
blksprs-2.0rc4.dist-info/METADATA,sha256=uM3Ssh-i170VnuaaPf-kjM4EwztirvAXlU7xINY6YhM,8614
|
|
20
|
-
blksprs-2.0rc4.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
|
21
|
-
blksprs-2.0rc4.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
|
|
22
|
-
blksprs-2.0rc4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|