blksprs 2.0rc6__py3-none-any.whl → 2.0rc7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/ops/distribution.py +1 -1
- blksprs/ops/misc/broadcast_ops.py +1 -0
- blksprs/ops/misc/row_wise.py +7 -3
- blksprs/utils/tools.py +0 -9
- {blksprs-2.0rc6.dist-info → blksprs-2.0rc7.dist-info}/METADATA +7 -3
- {blksprs-2.0rc6.dist-info → blksprs-2.0rc7.dist-info}/RECORD +8 -8
- {blksprs-2.0rc6.dist-info → blksprs-2.0rc7.dist-info}/WHEEL +0 -0
- {blksprs-2.0rc6.dist-info → blksprs-2.0rc7.dist-info}/top_level.txt +0 -0
blksprs/ops/distribution.py
CHANGED
|
@@ -240,7 +240,7 @@ def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
240
240
|
reduce_op="none", lut=lut)
|
|
241
241
|
|
|
242
242
|
|
|
243
|
-
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.
|
|
243
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
|
|
244
244
|
def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
245
245
|
dim: int,
|
|
246
246
|
idx: BlksprsTensor,
|
|
@@ -12,6 +12,7 @@ from blksprs.utils.validation import validate_contiguous, validate_device, \
|
|
|
12
12
|
validate_sparsity_block_size
|
|
13
13
|
|
|
14
14
|
|
|
15
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
15
16
|
def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
16
17
|
sparsity_block_size: int) -> BlksprsTensor:
|
|
17
18
|
"""Performs a broadcast and subsequent addition of two dense tensors x and y. Returns a block-sparse tensor in
|
blksprs/ops/misc/row_wise.py
CHANGED
|
@@ -4,9 +4,9 @@ from torch import Tensor
|
|
|
4
4
|
from torch._library.triton import wrap_triton, triton_op
|
|
5
5
|
from triton import language as tl
|
|
6
6
|
|
|
7
|
-
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
8
|
-
from blksprs.utils.tools import stride, get_autocast_min_val
|
|
9
7
|
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
8
|
+
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
9
|
+
from blksprs.utils.tools import stride
|
|
10
10
|
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, validate_sparsity, \
|
|
11
11
|
validate_sparsity_block_size
|
|
12
12
|
|
|
@@ -95,6 +95,7 @@ def row_wise_sum_forward(x: Tensor, sparsity_lut: Tensor,
|
|
|
95
95
|
return output
|
|
96
96
|
|
|
97
97
|
|
|
98
|
+
# noinspection PyUnusedLocal
|
|
98
99
|
@triton.autotune(
|
|
99
100
|
configs=get_autotune_configs(),
|
|
100
101
|
key=["sparsity_block_size"],
|
|
@@ -175,6 +176,8 @@ def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
|
|
|
175
176
|
of the input and the sparsity layout of the output tensor.
|
|
176
177
|
|
|
177
178
|
"""
|
|
179
|
+
# TODO Fix for triton bug, see https://github.com/triton-lang/triton/issues/6376
|
|
180
|
+
x = torch.where(x == -0.0, torch.tensor(0.0), x)
|
|
178
181
|
x = x.contiguous()
|
|
179
182
|
|
|
180
183
|
validate_dimensions(x)
|
|
@@ -209,7 +212,7 @@ def row_wise_max_forward(x: Tensor, sparsity_lut: Tensor,
|
|
|
209
212
|
output = torch.full(size=(n_sparse_blocks_output,
|
|
210
213
|
sparsity_block_size,
|
|
211
214
|
1 if flag_slice_only else sparsity_block_size),
|
|
212
|
-
fill_value=
|
|
215
|
+
fill_value=torch.finfo(x.dtype).min,
|
|
213
216
|
device=x.device)
|
|
214
217
|
|
|
215
218
|
x_b, x_r, x_c = x.size()
|
|
@@ -238,6 +241,7 @@ def row_wise_max_forward(x: Tensor, sparsity_lut: Tensor,
|
|
|
238
241
|
return output
|
|
239
242
|
|
|
240
243
|
|
|
244
|
+
# noinspection PyUnusedLocal
|
|
241
245
|
@triton.autotune(
|
|
242
246
|
configs=get_autotune_configs(),
|
|
243
247
|
key=["sparsity_block_size"],
|
blksprs/utils/tools.py
CHANGED
|
@@ -26,12 +26,3 @@ def stride(x: Tensor):
|
|
|
26
26
|
return x.size(1) * x.size(2), x.size(2), 1
|
|
27
27
|
else:
|
|
28
28
|
raise NotImplementedError
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
def get_autocast_min_val():
|
|
32
|
-
if torch.is_autocast_enabled():
|
|
33
|
-
dtype = torch.get_autocast_dtype("cuda")
|
|
34
|
-
else:
|
|
35
|
-
dtype = torch.float
|
|
36
|
-
|
|
37
|
-
return torch.finfo(dtype).min
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.0rc7
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -108,12 +108,16 @@ library.
|
|
|
108
108
|
|
|
109
109
|
## Known Limitations and Issues
|
|
110
110
|
|
|
111
|
+
- Triton has a bug with `tl.atomix_max()` used for the row-wise max operation.
|
|
112
|
+
In order to work around this bug a manual conversion of some values is needed, (slightly) negatively impacting
|
|
113
|
+
performance.
|
|
114
|
+
Watch the [issue](https://github.com/triton-lang/triton/issues/6376) on Triton's issue tracker for more information.
|
|
111
115
|
- PyTorch's `wrap_triton()` currently does not support config pruning. It thus cannot be used for some of the kernels,
|
|
112
116
|
which could impact graph compilation.
|
|
113
117
|
- There seem to be some issues with autocasting, forcing some operations to manually cast.
|
|
114
118
|
- There will be some slight numerical differences between vanilla and blksprs operations.
|
|
115
|
-
These instabilities are due to Triton and thus cannot be fixed by this library alone.
|
|
116
|
-
However, for all intents and purposes, these very minor differences should not matter and can safely be ignored.
|
|
119
|
+
These instabilities are due to Triton and thus cannot be fixed by this library alone.
|
|
120
|
+
However, for all intents and purposes, these very minor differences should not matter and can safely be ignored.
|
|
117
121
|
|
|
118
122
|
## Usage
|
|
119
123
|
|
|
@@ -2,22 +2,22 @@ blksprs/__init__.py,sha256=OHfpwJCZWGUfpT-DVfC1YSaeZl4aCMNt9CrzMPymywU,1577
|
|
|
2
2
|
blksprs/layouting/distribution_layout.py,sha256=TkMh_DYKX56Cb8Vq7EHyupMRvzm0XbUNP8QP7afv9wM,5122
|
|
3
3
|
blksprs/layouting/sparsity_layout.py,sha256=6GOjwllDUK9L8jEQNu2i17Pp1BIIQm8fv3xVuiR0zIw,10228
|
|
4
4
|
blksprs/ops/conversion.py,sha256=2zAdbaZ1iP2lisLVeG-k-f571G4HJapADhSwpY0Zd3o,21503
|
|
5
|
-
blksprs/ops/distribution.py,sha256=
|
|
5
|
+
blksprs/ops/distribution.py,sha256=6joac_zl3ZnRkPqLPQ0d88r7IbcrWAg0HiV93LOZw-w,20453
|
|
6
6
|
blksprs/ops/flow.py,sha256=UO5ba5TFgVpEyT7r0hnWYw3vhRDpBOxyPHUBeNOAYPs,7935
|
|
7
7
|
blksprs/ops/matmul.py,sha256=02hujXMtFgF7ohepM3v6h9okrfcU-J3mQZV17B-qvh0,12235
|
|
8
8
|
blksprs/ops/partitioning.py,sha256=nAV28f3NtvT4OFvDtnE0A-VxpDQmMXS0pZw4CJwzqGA,9838
|
|
9
9
|
blksprs/ops/repeat.py,sha256=bQpJuwtt8aRdSzxT78lJ8f8fLDhPkYK5UvMfJ-PQrkc,8977
|
|
10
10
|
blksprs/ops/softmax.py,sha256=-NoTf1Cpuku9C99N0LuMydT_ObozWTnZJGDZxseXEXI,12209
|
|
11
11
|
blksprs/ops/transpose.py,sha256=PQKteFnzNAOEC7voO7wh_dq9c54UjCboJz889aBCwKc,4010
|
|
12
|
-
blksprs/ops/misc/broadcast_ops.py,sha256=
|
|
13
|
-
blksprs/ops/misc/row_wise.py,sha256=
|
|
12
|
+
blksprs/ops/misc/broadcast_ops.py,sha256=DhUbliT9TBT6zlEjutBmY1EAEUPmYOt2mKQ5i46vN1c,5880
|
|
13
|
+
blksprs/ops/misc/row_wise.py,sha256=5u_J8WOTepvf6XtZ8r0lLPofYrI5fGB7mxSmGC81IR0,19167
|
|
14
14
|
blksprs/utils/autotuning.py,sha256=tDfMWklm2rvbo0-ahH81C3Gg0U6LHjPn3d_3pEOzmJs,2053
|
|
15
15
|
blksprs/utils/benchmarking.py,sha256=dLabDscTFn5NkmOI1g7DnKeTneUYW3RIVv9MDF-8BKc,1271
|
|
16
16
|
blksprs/utils/blksprs_tensor.py,sha256=pfoz59aJixj_fIoFx76ySiygwRQUemmgjMKepZ2c4j0,244
|
|
17
17
|
blksprs/utils/processing.py,sha256=xuu9iDpwTvsqI_WKMSD8QCNuvPnfcKMRcuF2L4Zs6Ts,3808
|
|
18
|
-
blksprs/utils/tools.py,sha256=
|
|
18
|
+
blksprs/utils/tools.py,sha256=3_2IBbd54vVU4-6m2KtAN7qjU6jeF4UfPkbjeFqMpYo,664
|
|
19
19
|
blksprs/utils/validation.py,sha256=G8eQlvJVMKfEX3k2AwBD0A6Ck-gFoRLpLNY6HXsB3fA,4348
|
|
20
|
-
blksprs-2.
|
|
21
|
-
blksprs-2.
|
|
22
|
-
blksprs-2.
|
|
23
|
-
blksprs-2.
|
|
20
|
+
blksprs-2.0rc7.dist-info/METADATA,sha256=ER9DHdVeYUZUsjE-2bEB9fePw0FVI1vknwPNrj7mDPE,9509
|
|
21
|
+
blksprs-2.0rc7.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
|
22
|
+
blksprs-2.0rc7.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
|
|
23
|
+
blksprs-2.0rc7.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|