blksprs 1.4.1__py3-none-any.whl → 1.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/misc/broadcast_ops.py +1 -1
- blksprs/misc/row_wise.py +1 -0
- blksprs/ops/conversion.py +2 -2
- blksprs/ops/matmul.py +2 -1
- blksprs/ops/softmax.py +1 -1
- blksprs/ops/transpose.py +4 -2
- blksprs/utils/tools.py +1 -2
- blksprs/utils/validation.py +6 -3
- {blksprs-1.4.1.dist-info → blksprs-1.4.2.dist-info}/METADATA +1 -1
- blksprs-1.4.2.dist-info/RECORD +19 -0
- {blksprs-1.4.1.dist-info → blksprs-1.4.2.dist-info}/WHEEL +1 -1
- blksprs-1.4.1.dist-info/RECORD +0 -19
- {blksprs-1.4.1.dist-info → blksprs-1.4.2.dist-info}/top_level.txt +0 -0
blksprs/misc/broadcast_ops.py
CHANGED
|
@@ -41,7 +41,7 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
|
41
41
|
|
|
42
42
|
validate_contiguous(sparsity_layout_output, sparsity_lut_o)
|
|
43
43
|
|
|
44
|
-
output = torch.zeros(n_sparse_blocks, sparsity_block_size, sparsity_block_size, device=x.device)
|
|
44
|
+
output = torch.zeros(n_sparse_blocks, sparsity_block_size, sparsity_block_size, dtype=x.dtype, device=x.device)
|
|
45
45
|
|
|
46
46
|
x_b, x_c = x.size()
|
|
47
47
|
x_b_s, x_c_s = x.stride()
|
blksprs/misc/row_wise.py
CHANGED
|
@@ -56,6 +56,7 @@ def row_wise_sum(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
|
56
56
|
output = torch.zeros(size=(n_sparse_blocks_output,
|
|
57
57
|
sparsity_block_size,
|
|
58
58
|
1 if flag_slice_only else sparsity_block_size),
|
|
59
|
+
dtype=x.dtype,
|
|
59
60
|
device=x.device)
|
|
60
61
|
|
|
61
62
|
x_b, x_r, x_c = x.size()
|
blksprs/ops/conversion.py
CHANGED
|
@@ -186,8 +186,8 @@ class _BlocksparseToSparse(torch.autograd.Function):
|
|
|
186
186
|
def forward(ctx, x: Tensor,
|
|
187
187
|
sparsity_layout: Tensor, sparsity_lut: Tensor,
|
|
188
188
|
sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
189
|
-
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
190
|
-
device=x.device)
|
|
189
|
+
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
190
|
+
dtype=x.dtype, device=x.device)
|
|
191
191
|
|
|
192
192
|
x_b, x_r, x_c = x.size()
|
|
193
193
|
x_b_s, x_r_s, x_c_s = x.stride()
|
blksprs/ops/matmul.py
CHANGED
|
@@ -78,7 +78,8 @@ class _BlocksparseMatmulSSS(torch.autograd.Function):
|
|
|
78
78
|
sparsity_layout_y: Tensor, sparsity_reverse_lut_y: Tensor,
|
|
79
79
|
sparsity_layout_o: Tensor, sparsity_lut_o: Tensor,
|
|
80
80
|
sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
81
|
-
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
81
|
+
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
82
|
+
dtype=x.dtype, device=x.device)
|
|
82
83
|
|
|
83
84
|
x_b, x_r, x_c = x.size()
|
|
84
85
|
x_b_s, x_r_s, x_c_s = x.stride()
|
blksprs/ops/softmax.py
CHANGED
|
@@ -127,7 +127,7 @@ class _BlocksparseSoftmax(torch.autograd.Function):
|
|
|
127
127
|
s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
|
|
128
128
|
s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = sparsity_layout_s.stride()
|
|
129
129
|
|
|
130
|
-
grad_x = torch.empty_like(o)
|
|
130
|
+
grad_x = torch.empty_like(o, dtype=torch.float)
|
|
131
131
|
|
|
132
132
|
triton_grid = lambda meta: [o_b,
|
|
133
133
|
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
blksprs/ops/transpose.py
CHANGED
|
@@ -59,7 +59,8 @@ class _BlocksparseTranspose(torch.autograd.Function):
|
|
|
59
59
|
def forward(ctx, x: Tensor,
|
|
60
60
|
sparsity_layout: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor, sparsity_block_size: int,
|
|
61
61
|
n_sparse_blocks: int, triton_block_size: int) -> (Tensor, Tensor):
|
|
62
|
-
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
62
|
+
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
63
|
+
dtype=x.dtype, device=x.device)
|
|
63
64
|
|
|
64
65
|
x_b, x_r, x_c = x.size()
|
|
65
66
|
x_b_s, x_r_s, x_c_s = x.stride()
|
|
@@ -101,7 +102,8 @@ class _BlocksparseTranspose(torch.autograd.Function):
|
|
|
101
102
|
sparsity_block_size = ctx.sparsity_block_size
|
|
102
103
|
triton_block_size = ctx.triton_block_size
|
|
103
104
|
|
|
104
|
-
return transpose(grad_output, sparsity_layout, sparsity_block_size, triton_block_size)[
|
|
105
|
+
return transpose(grad_output, sparsity_layout, sparsity_block_size, triton_block_size)[
|
|
106
|
+
0], None, None, None, None, None, None
|
|
105
107
|
|
|
106
108
|
@staticmethod
|
|
107
109
|
@triton.jit
|
blksprs/utils/tools.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import torch
|
|
2
1
|
from torch import Tensor, Size
|
|
3
2
|
|
|
4
3
|
from blksprs.utils.validation import _set_skip_validation
|
|
@@ -8,7 +7,7 @@ def do_shape_blocksparse(x: Tensor):
|
|
|
8
7
|
if x.dim() == 3:
|
|
9
8
|
return x.contiguous(), x.size()
|
|
10
9
|
|
|
11
|
-
return x.reshape(-1, x.size(-2), x.size(-1)), x.size()
|
|
10
|
+
return x.reshape(-1, x.size(-2), x.size(-1)).contiguous(), x.size()
|
|
12
11
|
|
|
13
12
|
|
|
14
13
|
def undo_shape_blocksparse(x: Tensor, shape: Size):
|
blksprs/utils/validation.py
CHANGED
|
@@ -3,13 +3,13 @@ from torch import Tensor
|
|
|
3
3
|
|
|
4
4
|
VALIDATION = True
|
|
5
5
|
|
|
6
|
-
def validate_dimensions(*tensors: Tensor) -> None:
|
|
6
|
+
def validate_dimensions(*tensors: Tensor, dims=3) -> None:
|
|
7
7
|
if _check_skip_validation():
|
|
8
8
|
return
|
|
9
9
|
|
|
10
10
|
for tensor in tensors:
|
|
11
|
-
if tensor.dim() !=
|
|
12
|
-
raise ValueError("Tensor must have
|
|
11
|
+
if tensor.dim() != dims:
|
|
12
|
+
raise ValueError(f"Tensor must have {dims} dimensions")
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
def validate_contiguous(*tensors: Tensor) -> None:
|
|
@@ -91,6 +91,9 @@ def validate_triton_block_size(triton_block_size: int, sparsity_block_size: int)
|
|
|
91
91
|
if triton_block_size is None:
|
|
92
92
|
return
|
|
93
93
|
|
|
94
|
+
if not (triton_block_size & (triton_block_size - 1)) == 0:
|
|
95
|
+
raise ValueError("Triton block size must be a power of 2")
|
|
96
|
+
|
|
94
97
|
if triton_block_size > sparsity_block_size:
|
|
95
98
|
raise ValueError("Triton block size cannot be larger than sparsity block size")
|
|
96
99
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 1.4.
|
|
3
|
+
Version: 1.4.2
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
blksprs/__init__.py,sha256=ORAVhGR91G1wyIOs9Wecv-xfmjju3bJ4Jynq_SGOVY4,833
|
|
2
|
+
blksprs/layouting/distribution_layout.py,sha256=Xd8KjZwI87L9EL1Bw5SGUW9YztFD5q0Ygr99sffvdak,4939
|
|
3
|
+
blksprs/layouting/sparsity_layout.py,sha256=vZL8r5LkMwILYYqTYPZcN_NYFJuVFIB6mmBkdtRyXmI,7893
|
|
4
|
+
blksprs/misc/broadcast_ops.py,sha256=ahm7_lI12bJ6VTKRuSkwEeaEYWRY-BeMIOhtei35zpQ,5323
|
|
5
|
+
blksprs/misc/repeat_interleave.py,sha256=KJeapmxbpA7zGFfa5hUhCGrk4aFmhOhlMw-hbTh9PLI,5668
|
|
6
|
+
blksprs/misc/row_wise.py,sha256=1UtjLplrGx1FkxhzQ2hjSBBY11ToLQs0JiLaXKRAkL4,16893
|
|
7
|
+
blksprs/ops/conversion.py,sha256=vuiNwrwyuGI6H4PKrS_UHI7OKWJwNZd2i3LSjf6RetU,21332
|
|
8
|
+
blksprs/ops/distribution.py,sha256=KhtHRVcv4_woyNlldAjIWF-7021-KX-xyIcN6rE-UgE,16879
|
|
9
|
+
blksprs/ops/exp.py,sha256=CVWVq_emO2CnS_xk6Unx67P7EI7IL26dwtsmBJZOLzQ,3698
|
|
10
|
+
blksprs/ops/matmul.py,sha256=743XeD5M4iUv28sYf7q6mVXDd4jZpV04JAx8bF7hWkw,11254
|
|
11
|
+
blksprs/ops/softmax.py,sha256=cs1utM6UCzHhdJpf-ZysBr6CwbjI-5aQG0ahYY37Zy0,11991
|
|
12
|
+
blksprs/ops/transpose.py,sha256=Ru4YKyg796WT6OnDSTCYG45tMmdgvju3hMFzkwsJnO8,6801
|
|
13
|
+
blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
|
|
14
|
+
blksprs/utils/tools.py,sha256=JAuwsLISr_hcvxIgUVvKz5ZPf9M5ycquplsBU5dVfDc,596
|
|
15
|
+
blksprs/utils/validation.py,sha256=rP6yr-C2ghXfJEERry_pfvVJ0g0VyqV4sL4HkBRlJg8,3345
|
|
16
|
+
blksprs-1.4.2.dist-info/METADATA,sha256=wpv1H29xlts3Muvlg_dtA1KW3TUeBtlD4rr4MHRZm5c,7609
|
|
17
|
+
blksprs-1.4.2.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
|
|
18
|
+
blksprs-1.4.2.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
|
|
19
|
+
blksprs-1.4.2.dist-info/RECORD,,
|
blksprs-1.4.1.dist-info/RECORD
DELETED
|
@@ -1,19 +0,0 @@
|
|
|
1
|
-
blksprs/__init__.py,sha256=ORAVhGR91G1wyIOs9Wecv-xfmjju3bJ4Jynq_SGOVY4,833
|
|
2
|
-
blksprs/layouting/distribution_layout.py,sha256=Xd8KjZwI87L9EL1Bw5SGUW9YztFD5q0Ygr99sffvdak,4939
|
|
3
|
-
blksprs/layouting/sparsity_layout.py,sha256=vZL8r5LkMwILYYqTYPZcN_NYFJuVFIB6mmBkdtRyXmI,7893
|
|
4
|
-
blksprs/misc/broadcast_ops.py,sha256=RTcqvx6X_THRBb55jipeEe63YSLIAh27jdpuze0aSek,5308
|
|
5
|
-
blksprs/misc/repeat_interleave.py,sha256=KJeapmxbpA7zGFfa5hUhCGrk4aFmhOhlMw-hbTh9PLI,5668
|
|
6
|
-
blksprs/misc/row_wise.py,sha256=KCDO5ry5TkjI88LLD_QINZwBkzfmjoQpOOvYLfpUn5I,16853
|
|
7
|
-
blksprs/ops/conversion.py,sha256=h1c5T74rQjqYgY9dwWXfPTXRpgzy0dtAhCmtUp8-6uo,21332
|
|
8
|
-
blksprs/ops/distribution.py,sha256=KhtHRVcv4_woyNlldAjIWF-7021-KX-xyIcN6rE-UgE,16879
|
|
9
|
-
blksprs/ops/exp.py,sha256=CVWVq_emO2CnS_xk6Unx67P7EI7IL26dwtsmBJZOLzQ,3698
|
|
10
|
-
blksprs/ops/matmul.py,sha256=6DaYxecJgwiW8L-UISkgyNyzQ31AAkmDL-Oq1EjHt98,11210
|
|
11
|
-
blksprs/ops/softmax.py,sha256=cSTxDnNmMRlJGOlCSpdg1U5KUIFpVtHulz8fteJFeh0,11972
|
|
12
|
-
blksprs/ops/transpose.py,sha256=et8R124L29TUqihci18ms_hBoYXTtPu5LXgEA8sxk_w,6744
|
|
13
|
-
blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
|
|
14
|
-
blksprs/utils/tools.py,sha256=RKGWCGd5h1qFOIoShsdJObx4-QsS0RxCyzFie0geNxo,596
|
|
15
|
-
blksprs/utils/validation.py,sha256=Gsx3aah6355bWXRPpbFuZ1p0fOrYduIqaM3ON9d5NiI,3197
|
|
16
|
-
blksprs-1.4.1.dist-info/METADATA,sha256=3xRmBFHv2U2KnrW3_QX3003SHLkQ1JCaSqh4AUBsJD4,7609
|
|
17
|
-
blksprs-1.4.1.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
|
18
|
-
blksprs-1.4.1.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
|
|
19
|
-
blksprs-1.4.1.dist-info/RECORD,,
|
|
File without changes
|