blksprs 2.0rc4__py3-none-any.whl → 2.0rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,78 @@
1
+ import os
2
+
3
+ blksprs_autotune_mode = os.getenv("BLKSPRS_AUTOTUNE", "DEFAULT")
4
+
5
+ if blksprs_autotune_mode == "TEST":
6
+ autotune_parameters = [
7
+ (16, 3, 8),
8
+
9
+ (32, 3, 8),
10
+
11
+ (64, 3, 8),
12
+ ]
13
+ elif blksprs_autotune_mode == "DEFAULT":
14
+ autotune_parameters = [
15
+ (16, 3, 8),
16
+ (16, 4, 4),
17
+ (16, 5, 2),
18
+
19
+ (32, 3, 8),
20
+ (32, 4, 4),
21
+ (32, 5, 2),
22
+
23
+ (64, 3, 8),
24
+ (64, 4, 4),
25
+ (64, 5, 2),
26
+
27
+ (128, 3, 8),
28
+ (128, 4, 4),
29
+ (128, 5, 2),
30
+ ]
31
+ else:
32
+ raise NotImplementedError(f"Unknown autotune mode: {blksprs_autotune_mode}")
33
+
34
+ import torch
35
+ import triton
36
+
37
+
38
+ def prune_autotune_configs(autotune_configs, kernel_args, **kwargs):
39
+ sparsity_block_size = kernel_args["sparsity_block_size"]
40
+
41
+ pruned_configs = []
42
+
43
+ for config in autotune_configs:
44
+ if config.kwargs["TRITON_BLOCK_SIZE"] <= sparsity_block_size:
45
+ pruned_configs.append(config)
46
+
47
+ assert len(pruned_configs) > 0, f"No valid autotune configs found for sparsity block size {sparsity_block_size}"
48
+
49
+ return pruned_configs
50
+
51
+
52
+ def prune_autotune_configs_conversion(autotune_configs, kernel_args, **kwargs):
53
+ sparsity_block_size_from = kernel_args["sparsity_block_size_from"]
54
+ sparsity_block_size_to = kernel_args["sparsity_block_size_to"]
55
+ sparsity_block_size = min(sparsity_block_size_from, sparsity_block_size_to)
56
+
57
+ pruned_configs = []
58
+
59
+ for config in autotune_configs:
60
+ if config.kwargs["TRITON_BLOCK_SIZE"] <= sparsity_block_size:
61
+ pruned_configs.append(config)
62
+
63
+ assert len(pruned_configs) > 0, f"No valid autotune configs found for sparsity block size {sparsity_block_size}"
64
+
65
+ return pruned_configs
66
+
67
+
68
+ @torch.compile
69
+ def get_autotune_configs():
70
+ global autotune_parameters
71
+
72
+ autotune_configs = []
73
+
74
+ for block_size, num_stages, num_warps in autotune_parameters:
75
+ autotune_configs.append(
76
+ triton.Config({"TRITON_BLOCK_SIZE": block_size}, num_stages=num_stages, num_warps=num_warps))
77
+
78
+ return autotune_configs
blksprs/utils/tools.py CHANGED
@@ -1,5 +1,4 @@
1
1
  import torch
2
- import triton
3
2
  from torch import Tensor, Size
4
3
 
5
4
  # Capture scalar outputs for JIT compilation
@@ -27,30 +26,3 @@ def stride(x: Tensor):
27
26
  return x.size(1) * x.size(2), x.size(2), 1
28
27
  else:
29
28
  raise NotImplementedError
30
-
31
-
32
- @torch.compile
33
- def get_autotune_configs():
34
- configs = []
35
- config_parameters = [
36
- (16, 3, 8),
37
- (16, 4, 4),
38
- (16, 5, 2),
39
-
40
- (32, 3, 8),
41
- (32, 4, 4),
42
- (32, 5, 2),
43
-
44
- (64, 3, 8),
45
- (64, 4, 4),
46
- (64, 5, 2),
47
-
48
- (128, 3, 8),
49
- (128, 4, 4),
50
- (128, 5, 2),
51
- ]
52
-
53
- for block_size, num_stages, num_warps in config_parameters:
54
- configs.append(triton.Config({"TRITON_BLOCK_SIZE": block_size}, num_stages=num_stages, num_warps=num_warps))
55
-
56
- return configs
@@ -113,6 +113,9 @@ def validate_sparsity_block_size(sparsity_block_size: int, *tensors):
113
113
  if _check_skip_validation():
114
114
  return
115
115
 
116
+ if not sparsity_block_size >= 16:
117
+ raise ValueError("Sparsity block size must be at least 16")
118
+
116
119
  if not (sparsity_block_size & (sparsity_block_size - 1)) == 0:
117
120
  raise ValueError("Sparsity block size must be a power of 2")
118
121
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: blksprs
3
- Version: 2.0rc4
3
+ Version: 2.0rc7
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -24,10 +24,10 @@ Requires-Dist: matplotlib; extra == "test"
24
24
 
25
25
  ## Overview
26
26
 
27
- ### News
28
-
29
- 🎉 ***Version 2.0 released***. blksprs now supports kernel auto-tuning, JIT compilation, specification of pre-calculated
30
- LUTs, autocasting, and makes use of `torch.library.triton_op()`!
27
+ ### News
28
+
29
+ 🎉 ***Version 2.0 released***. blksprs now supports kernel auto-tuning, JIT compilation, specification of pre-calculated
30
+ LUTs, autocasting, and makes use of `torch.library.triton_op()`!
31
31
 
32
32
  ---
33
33
 
@@ -106,6 +106,19 @@ We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
106
106
  It might be that this changes with future projects, but as of March 2025, we are content with the current state of the
107
107
  library.
108
108
 
109
+ ## Known Limitations and Issues
110
+
111
+ - Triton has a bug with `tl.atomix_max()` used for the row-wise max operation.
112
+ In order to work around this bug a manual conversion of some values is needed, (slightly) negatively impacting
113
+ performance.
114
+ Watch the [issue](https://github.com/triton-lang/triton/issues/6376) on Triton's issue tracker for more information.
115
+ - PyTorch's `wrap_triton()` currently does not support config pruning. It thus cannot be used for some of the kernels,
116
+ which could impact graph compilation.
117
+ - There seem to be some issues with autocasting, forcing some operations to manually cast.
118
+ - There will be some slight numerical differences between vanilla and blksprs operations.
119
+ These instabilities are due to Triton and thus cannot be fixed by this library alone.
120
+ However, for all intents and purposes, these very minor differences should not matter and can safely be ignored.
121
+
109
122
  ## Usage
110
123
 
111
124
  We provide an example below to demonstrate the usage of the library.
@@ -0,0 +1,23 @@
1
+ blksprs/__init__.py,sha256=OHfpwJCZWGUfpT-DVfC1YSaeZl4aCMNt9CrzMPymywU,1577
2
+ blksprs/layouting/distribution_layout.py,sha256=TkMh_DYKX56Cb8Vq7EHyupMRvzm0XbUNP8QP7afv9wM,5122
3
+ blksprs/layouting/sparsity_layout.py,sha256=6GOjwllDUK9L8jEQNu2i17Pp1BIIQm8fv3xVuiR0zIw,10228
4
+ blksprs/ops/conversion.py,sha256=2zAdbaZ1iP2lisLVeG-k-f571G4HJapADhSwpY0Zd3o,21503
5
+ blksprs/ops/distribution.py,sha256=6joac_zl3ZnRkPqLPQ0d88r7IbcrWAg0HiV93LOZw-w,20453
6
+ blksprs/ops/flow.py,sha256=UO5ba5TFgVpEyT7r0hnWYw3vhRDpBOxyPHUBeNOAYPs,7935
7
+ blksprs/ops/matmul.py,sha256=02hujXMtFgF7ohepM3v6h9okrfcU-J3mQZV17B-qvh0,12235
8
+ blksprs/ops/partitioning.py,sha256=nAV28f3NtvT4OFvDtnE0A-VxpDQmMXS0pZw4CJwzqGA,9838
9
+ blksprs/ops/repeat.py,sha256=bQpJuwtt8aRdSzxT78lJ8f8fLDhPkYK5UvMfJ-PQrkc,8977
10
+ blksprs/ops/softmax.py,sha256=-NoTf1Cpuku9C99N0LuMydT_ObozWTnZJGDZxseXEXI,12209
11
+ blksprs/ops/transpose.py,sha256=PQKteFnzNAOEC7voO7wh_dq9c54UjCboJz889aBCwKc,4010
12
+ blksprs/ops/misc/broadcast_ops.py,sha256=DhUbliT9TBT6zlEjutBmY1EAEUPmYOt2mKQ5i46vN1c,5880
13
+ blksprs/ops/misc/row_wise.py,sha256=5u_J8WOTepvf6XtZ8r0lLPofYrI5fGB7mxSmGC81IR0,19167
14
+ blksprs/utils/autotuning.py,sha256=tDfMWklm2rvbo0-ahH81C3Gg0U6LHjPn3d_3pEOzmJs,2053
15
+ blksprs/utils/benchmarking.py,sha256=dLabDscTFn5NkmOI1g7DnKeTneUYW3RIVv9MDF-8BKc,1271
16
+ blksprs/utils/blksprs_tensor.py,sha256=pfoz59aJixj_fIoFx76ySiygwRQUemmgjMKepZ2c4j0,244
17
+ blksprs/utils/processing.py,sha256=xuu9iDpwTvsqI_WKMSD8QCNuvPnfcKMRcuF2L4Zs6Ts,3808
18
+ blksprs/utils/tools.py,sha256=3_2IBbd54vVU4-6m2KtAN7qjU6jeF4UfPkbjeFqMpYo,664
19
+ blksprs/utils/validation.py,sha256=G8eQlvJVMKfEX3k2AwBD0A6Ck-gFoRLpLNY6HXsB3fA,4348
20
+ blksprs-2.0rc7.dist-info/METADATA,sha256=ER9DHdVeYUZUsjE-2bEB9fePw0FVI1vknwPNrj7mDPE,9509
21
+ blksprs-2.0rc7.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
22
+ blksprs-2.0rc7.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
23
+ blksprs-2.0rc7.dist-info/RECORD,,
@@ -1,22 +0,0 @@
1
- blksprs/__init__.py,sha256=OHfpwJCZWGUfpT-DVfC1YSaeZl4aCMNt9CrzMPymywU,1577
2
- blksprs/layouting/distribution_layout.py,sha256=0glIteoY5oDkiEu5rjLIC-BB_oC4sa3rFWVkohsAG00,5329
3
- blksprs/layouting/sparsity_layout.py,sha256=ZUhJm1jJn-npiJWFjsVyzjXDQOp8z-Wjjv0MPQOXRvg,10490
4
- blksprs/ops/conversion.py,sha256=FsujfUH3R8ijSti_ifsTQihB0djK8Snny2fbGRruzRw,22459
5
- blksprs/ops/distribution.py,sha256=CTcDcUx8vwe-9F9Y25B7ea7tcvy5gR2Pyk0Ko48MWFo,21514
6
- blksprs/ops/flow.py,sha256=MY1ypGLIAlkZty5iQINip5mDIQxu9pP1D1dIae4sKJg,8433
7
- blksprs/ops/matmul.py,sha256=xFxWSCy9NwPDTxfSUOyQU_X4sHp3HrJtohlUCc1WO8g,12028
8
- blksprs/ops/partitioning.py,sha256=nAV28f3NtvT4OFvDtnE0A-VxpDQmMXS0pZw4CJwzqGA,9838
9
- blksprs/ops/repeat.py,sha256=bQpJuwtt8aRdSzxT78lJ8f8fLDhPkYK5UvMfJ-PQrkc,8977
10
- blksprs/ops/softmax.py,sha256=PdRPAkCJahtGBO5W-aqF_Dxi9X8RJ621XmYfVo2I0OM,12968
11
- blksprs/ops/transpose.py,sha256=PQKteFnzNAOEC7voO7wh_dq9c54UjCboJz889aBCwKc,4010
12
- blksprs/ops/misc/broadcast_ops.py,sha256=lZ5bBIftUKffzeYz77SWB1xmtZTRGMvjF-tG9rqkOXA,6018
13
- blksprs/ops/misc/row_wise.py,sha256=FOy73-I5_OuCugiq0xQxtre9-ytfBQPDaXQv8tssuXg,19764
14
- blksprs/utils/benchmarking.py,sha256=dLabDscTFn5NkmOI1g7DnKeTneUYW3RIVv9MDF-8BKc,1271
15
- blksprs/utils/blksprs_tensor.py,sha256=pfoz59aJixj_fIoFx76ySiygwRQUemmgjMKepZ2c4j0,244
16
- blksprs/utils/processing.py,sha256=xuu9iDpwTvsqI_WKMSD8QCNuvPnfcKMRcuF2L4Zs6Ts,3808
17
- blksprs/utils/tools.py,sha256=RL18P4NAj7d8gXTTKbMZt4SHCynsw1wPu9yvlrnBQlo,1220
18
- blksprs/utils/validation.py,sha256=7ks9hdNKbov1JE9y1bpnIfjWCVhqINTZOIZPi6d7k8E,4241
19
- blksprs-2.0rc4.dist-info/METADATA,sha256=uM3Ssh-i170VnuaaPf-kjM4EwztirvAXlU7xINY6YhM,8614
20
- blksprs-2.0rc4.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
21
- blksprs-2.0rc4.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
22
- blksprs-2.0rc4.dist-info/RECORD,,