blksprs 1.4__tar.gz → 1.4.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {blksprs-1.4 → blksprs-1.4.2}/PKG-INFO +4 -6
- {blksprs-1.4 → blksprs-1.4.2}/README.md +1 -1
- {blksprs-1.4 → blksprs-1.4.2}/blksprs/__init__.py +1 -1
- {blksprs-1.4 → blksprs-1.4.2}/blksprs/layouting/distribution_layout.py +1 -1
- {blksprs-1.4 → blksprs-1.4.2}/blksprs/layouting/sparsity_layout.py +2 -2
- {blksprs-1.4 → blksprs-1.4.2}/blksprs/misc/broadcast_ops.py +4 -1
- {blksprs-1.4 → blksprs-1.4.2}/blksprs/misc/repeat_interleave.py +2 -0
- {blksprs-1.4 → blksprs-1.4.2}/blksprs/misc/row_wise.py +5 -0
- {blksprs-1.4 → blksprs-1.4.2}/blksprs/ops/conversion.py +8 -2
- {blksprs-1.4 → blksprs-1.4.2}/blksprs/ops/distribution.py +6 -0
- {blksprs-1.4 → blksprs-1.4.2}/blksprs/ops/exp.py +2 -0
- {blksprs-1.4 → blksprs-1.4.2}/blksprs/ops/matmul.py +8 -3
- {blksprs-1.4 → blksprs-1.4.2}/blksprs/ops/softmax.py +3 -1
- {blksprs-1.4 → blksprs-1.4.2}/blksprs/ops/transpose.py +6 -2
- {blksprs-1.4 → blksprs-1.4.2}/blksprs/utils/tools.py +8 -3
- {blksprs-1.4 → blksprs-1.4.2}/blksprs/utils/validation.py +21 -13
- {blksprs-1.4 → blksprs-1.4.2}/blksprs.egg-info/PKG-INFO +4 -6
- {blksprs-1.4 → blksprs-1.4.2}/blksprs.egg-info/requires.txt +1 -3
- {blksprs-1.4 → blksprs-1.4.2}/pyproject.toml +3 -5
- {blksprs-1.4 → blksprs-1.4.2}/blksprs/utils/benchmarking.py +0 -0
- {blksprs-1.4 → blksprs-1.4.2}/blksprs.egg-info/SOURCES.txt +0 -0
- {blksprs-1.4 → blksprs-1.4.2}/blksprs.egg-info/dependency_links.txt +0 -0
- {blksprs-1.4 → blksprs-1.4.2}/blksprs.egg-info/top_level.txt +0 -0
- {blksprs-1.4 → blksprs-1.4.2}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 1.4
|
|
3
|
+
Version: 1.4.2
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -14,10 +14,8 @@ Requires-Dist: pytest-xdist; extra == "test"
|
|
|
14
14
|
Requires-Dist: pytest-cov; extra == "test"
|
|
15
15
|
Requires-Dist: coverage; extra == "test"
|
|
16
16
|
Requires-Dist: matplotlib; extra == "test"
|
|
17
|
-
Provides-Extra:
|
|
18
|
-
Requires-Dist: build; extra == "
|
|
19
|
-
Requires-Dist: twine; extra == "deploy"
|
|
20
|
-
Requires-Dist: pdoc3; extra == "deploy"
|
|
17
|
+
Provides-Extra: build
|
|
18
|
+
Requires-Dist: build; extra == "build"
|
|
21
19
|
|
|
22
20
|
# blksprs
|
|
23
21
|
|
|
@@ -146,7 +144,7 @@ def test_readme():
|
|
|
146
144
|
# Assert that the output has the correct sparsity layout
|
|
147
145
|
actual_sparsity_layout_o = bs.layout.build_sparsity_layout(o_dense, sparsity_block_size,
|
|
148
146
|
triton_block_size=triton_block_size)
|
|
149
|
-
assert torch.allclose(actual_sparsity_layout_o, sparsity_layout_o)
|
|
147
|
+
assert torch.allclose(actual_sparsity_layout_o.to(torch.int), sparsity_layout_o)
|
|
150
148
|
|
|
151
149
|
# Convert output tensor back to original shape
|
|
152
150
|
o = bs.util.undo_shape_blocksparse(o_dense, x_shape_original)
|
|
@@ -125,7 +125,7 @@ def test_readme():
|
|
|
125
125
|
# Assert that the output has the correct sparsity layout
|
|
126
126
|
actual_sparsity_layout_o = bs.layout.build_sparsity_layout(o_dense, sparsity_block_size,
|
|
127
127
|
triton_block_size=triton_block_size)
|
|
128
|
-
assert torch.allclose(actual_sparsity_layout_o, sparsity_layout_o)
|
|
128
|
+
assert torch.allclose(actual_sparsity_layout_o.to(torch.int), sparsity_layout_o)
|
|
129
129
|
|
|
130
130
|
# Convert output tensor back to original shape
|
|
131
131
|
o = bs.util.undo_shape_blocksparse(o_dense, x_shape_original)
|
|
@@ -15,4 +15,4 @@ class misc:
|
|
|
15
15
|
from blksprs.misc.row_wise import row_wise_sum, row_wise_max, row_wise_add, row_wise_sub
|
|
16
16
|
|
|
17
17
|
class util:
|
|
18
|
-
from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse
|
|
18
|
+
from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse, disable_validation
|
|
@@ -31,7 +31,7 @@ def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
|
|
|
31
31
|
sparsity_lut_i = torch.nonzero(sparsity_layout_indices).contiguous()
|
|
32
32
|
|
|
33
33
|
output = torch.zeros(size_target[0], size_target[1] // sparsity_block_size, size_target[2] // sparsity_block_size,
|
|
34
|
-
|
|
34
|
+
dtype=torch.bool, device=indices.device)
|
|
35
35
|
|
|
36
36
|
i_b, i_r, i_c = indices.size()
|
|
37
37
|
i_b_s, i_r_s, i_c_s = indices.stride()
|
|
@@ -27,7 +27,7 @@ def build_sparsity_layout(x: Tensor, sparsity_block_size: int, triton_block_size
|
|
|
27
27
|
validate_device(x)
|
|
28
28
|
|
|
29
29
|
output = torch.zeros(x.size(0), x.size(1) // sparsity_block_size, x.size(2) // sparsity_block_size,
|
|
30
|
-
|
|
30
|
+
dtype=torch.bool, device=x.device)
|
|
31
31
|
|
|
32
32
|
x_b, x_r, x_c = x.size()
|
|
33
33
|
x_b_s, x_r_s, x_c_s = x.stride()
|
|
@@ -117,7 +117,7 @@ def build_sparsity_layout_adaption(x: Tensor, sparsity_layout_from: Tensor,
|
|
|
117
117
|
o_r = math.ceil(sparsity_layout_from.size(1) * sparsity_block_size_from // sparsity_block_size_to)
|
|
118
118
|
o_c = math.ceil(sparsity_layout_from.size(2) * sparsity_block_size_from // sparsity_block_size_to)
|
|
119
119
|
|
|
120
|
-
output = torch.zeros(o_b, o_r, o_c,
|
|
120
|
+
output = torch.zeros(o_b, o_r, o_c, dtype=torch.bool, device=x.device)
|
|
121
121
|
|
|
122
122
|
x_b, x_r, x_c = x.size()
|
|
123
123
|
x_b_s, x_r_s, x_c_s = x.stride()
|
|
@@ -25,6 +25,9 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
|
25
25
|
output tensor corresponds to x(i) + y(j).
|
|
26
26
|
|
|
27
27
|
"""
|
|
28
|
+
x = x.contiguous()
|
|
29
|
+
y = y.contiguous()
|
|
30
|
+
|
|
28
31
|
validate_device(x, y)
|
|
29
32
|
validate_contiguous(x, y)
|
|
30
33
|
if x.size(-1) != y.size(-1):
|
|
@@ -38,7 +41,7 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
|
38
41
|
|
|
39
42
|
validate_contiguous(sparsity_layout_output, sparsity_lut_o)
|
|
40
43
|
|
|
41
|
-
output = torch.zeros(n_sparse_blocks, sparsity_block_size, sparsity_block_size, device=x.device)
|
|
44
|
+
output = torch.zeros(n_sparse_blocks, sparsity_block_size, sparsity_block_size, dtype=x.dtype, device=x.device)
|
|
42
45
|
|
|
43
46
|
x_b, x_c = x.size()
|
|
44
47
|
x_b_s, x_c_s = x.stride()
|
|
@@ -31,6 +31,8 @@ def row_wise_sum(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
|
31
31
|
of the input and the sparsity layout of the output tensor.
|
|
32
32
|
|
|
33
33
|
"""
|
|
34
|
+
x = x.contiguous()
|
|
35
|
+
|
|
34
36
|
validate_dimensions(x)
|
|
35
37
|
validate_contiguous(x)
|
|
36
38
|
validate_device(x)
|
|
@@ -54,6 +56,7 @@ def row_wise_sum(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
|
54
56
|
output = torch.zeros(size=(n_sparse_blocks_output,
|
|
55
57
|
sparsity_block_size,
|
|
56
58
|
1 if flag_slice_only else sparsity_block_size),
|
|
59
|
+
dtype=x.dtype,
|
|
57
60
|
device=x.device)
|
|
58
61
|
|
|
59
62
|
x_b, x_r, x_c = x.size()
|
|
@@ -151,6 +154,8 @@ def row_wise_max(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
|
151
154
|
of the input and the sparsity layout of the output tensor.
|
|
152
155
|
|
|
153
156
|
"""
|
|
157
|
+
x = x.contiguous()
|
|
158
|
+
|
|
154
159
|
validate_dimensions(x)
|
|
155
160
|
validate_contiguous(x)
|
|
156
161
|
validate_device(x)
|
|
@@ -28,6 +28,8 @@ def to_dense(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_
|
|
|
28
28
|
Tensor: The block-sparse tensor converted to regular form.
|
|
29
29
|
|
|
30
30
|
"""
|
|
31
|
+
x = x.contiguous()
|
|
32
|
+
|
|
31
33
|
validate_dimensions(x)
|
|
32
34
|
validate_contiguous(x, sparsity_layout)
|
|
33
35
|
validate_device(x)
|
|
@@ -156,6 +158,8 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, trit
|
|
|
156
158
|
Tensor: The block-sparse tensor converted to compressed form.
|
|
157
159
|
|
|
158
160
|
"""
|
|
161
|
+
x = x.contiguous()
|
|
162
|
+
|
|
159
163
|
validate_dimensions(x)
|
|
160
164
|
validate_contiguous(x)
|
|
161
165
|
validate_device(x)
|
|
@@ -182,8 +186,8 @@ class _BlocksparseToSparse(torch.autograd.Function):
|
|
|
182
186
|
def forward(ctx, x: Tensor,
|
|
183
187
|
sparsity_layout: Tensor, sparsity_lut: Tensor,
|
|
184
188
|
sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
185
|
-
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
186
|
-
device=x.device)
|
|
189
|
+
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
190
|
+
dtype=x.dtype, device=x.device)
|
|
187
191
|
|
|
188
192
|
x_b, x_r, x_c = x.size()
|
|
189
193
|
x_b_s, x_r_s, x_c_s = x.stride()
|
|
@@ -282,6 +286,8 @@ def adapt_layout(x: Tensor, sparsity_layout_from: Tensor, sparsity_block_size_fr
|
|
|
282
286
|
Tensor: The block-sparse tensor in compressed form with the adapted sparsity layout and sparsity block size.
|
|
283
287
|
|
|
284
288
|
"""
|
|
289
|
+
x = x.contiguous()
|
|
290
|
+
|
|
285
291
|
validate_dimensions(x)
|
|
286
292
|
validate_contiguous(x, sparsity_layout_from)
|
|
287
293
|
validate_device(x)
|
|
@@ -24,6 +24,9 @@ def gather(src: Tensor, sparsity_layout_src: Tensor, idx: Tensor, sparsity_layou
|
|
|
24
24
|
Tensor: The result of the gather operation as a block-sparse tensor in compressed form.
|
|
25
25
|
|
|
26
26
|
"""
|
|
27
|
+
src = src.contiguous()
|
|
28
|
+
idx = idx.contiguous()
|
|
29
|
+
|
|
27
30
|
validate_dimensions(src, idx)
|
|
28
31
|
validate_contiguous(src, idx)
|
|
29
32
|
validate_dtype_int(idx)
|
|
@@ -200,6 +203,9 @@ def scatter_reduce(src: Tensor, sparsity_layout_src: Tensor,
|
|
|
200
203
|
Tensor: The result of the scatter operation as a block-sparse tensor in compressed form.
|
|
201
204
|
|
|
202
205
|
"""
|
|
206
|
+
src = src.contiguous()
|
|
207
|
+
idx = idx.contiguous()
|
|
208
|
+
|
|
203
209
|
validate_dimensions(src, idx)
|
|
204
210
|
validate_contiguous(src, idx)
|
|
205
211
|
validate_dtype_int(idx)
|
|
@@ -6,7 +6,7 @@ from triton import language as tl
|
|
|
6
6
|
from blksprs.ops.transpose import transpose
|
|
7
7
|
from blksprs.utils.tools import get_triton_block_size
|
|
8
8
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
9
|
-
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
9
|
+
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size, validate_dtype_float
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
def matmul(x: Tensor, sparsity_layout_x: Tensor,
|
|
@@ -30,8 +30,12 @@ def matmul(x: Tensor, sparsity_layout_x: Tensor,
|
|
|
30
30
|
Tensor: The result of the matrix multiplication as a block-sparse tensor in compressed form.
|
|
31
31
|
|
|
32
32
|
"""
|
|
33
|
+
x = x.contiguous()
|
|
34
|
+
y = y.contiguous()
|
|
35
|
+
|
|
33
36
|
validate_dimensions(x, y)
|
|
34
37
|
validate_contiguous(x, y)
|
|
38
|
+
validate_dtype_float(x, y)
|
|
35
39
|
validate_device(x, y)
|
|
36
40
|
validate_sparsity(sparsity_block_size, (x, sparsity_layout_x), (y, sparsity_layout_y))
|
|
37
41
|
if sparsity_layout_x.size(-1) != sparsity_layout_y.size(-2):
|
|
@@ -74,7 +78,8 @@ class _BlocksparseMatmulSSS(torch.autograd.Function):
|
|
|
74
78
|
sparsity_layout_y: Tensor, sparsity_reverse_lut_y: Tensor,
|
|
75
79
|
sparsity_layout_o: Tensor, sparsity_lut_o: Tensor,
|
|
76
80
|
sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
77
|
-
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
81
|
+
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
82
|
+
dtype=x.dtype, device=x.device)
|
|
78
83
|
|
|
79
84
|
x_b, x_r, x_c = x.size()
|
|
80
85
|
x_b_s, x_r_s, x_c_s = x.stride()
|
|
@@ -211,7 +216,7 @@ class _BlocksparseMatmulSSS(torch.autograd.Function):
|
|
|
211
216
|
blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
|
|
212
217
|
|
|
213
218
|
# Perform matrix multiplication
|
|
214
|
-
buf += tl.dot(blk_x, blk_y)
|
|
219
|
+
buf += tl.dot(blk_x, blk_y, input_precision="tf32")
|
|
215
220
|
|
|
216
221
|
# Store output
|
|
217
222
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
@@ -26,6 +26,8 @@ def softmax(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton
|
|
|
26
26
|
Tensor: The result of the softmax operation as a block-sparse tensor in compressed form.
|
|
27
27
|
|
|
28
28
|
"""
|
|
29
|
+
x = x.contiguous()
|
|
30
|
+
|
|
29
31
|
validate_dimensions(x)
|
|
30
32
|
validate_contiguous(x)
|
|
31
33
|
validate_device(x)
|
|
@@ -125,7 +127,7 @@ class _BlocksparseSoftmax(torch.autograd.Function):
|
|
|
125
127
|
s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
|
|
126
128
|
s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = sparsity_layout_s.stride()
|
|
127
129
|
|
|
128
|
-
grad_x = torch.empty_like(o)
|
|
130
|
+
grad_x = torch.empty_like(o, dtype=torch.float)
|
|
129
131
|
|
|
130
132
|
triton_grid = lambda meta: [o_b,
|
|
131
133
|
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
@@ -26,6 +26,8 @@ def transpose(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, trit
|
|
|
26
26
|
Tensor: The sparsity layout of the transposed tensor.
|
|
27
27
|
|
|
28
28
|
"""
|
|
29
|
+
x = x.contiguous()
|
|
30
|
+
|
|
29
31
|
validate_dimensions(x)
|
|
30
32
|
validate_contiguous(x)
|
|
31
33
|
validate_device(x)
|
|
@@ -57,7 +59,8 @@ class _BlocksparseTranspose(torch.autograd.Function):
|
|
|
57
59
|
def forward(ctx, x: Tensor,
|
|
58
60
|
sparsity_layout: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor, sparsity_block_size: int,
|
|
59
61
|
n_sparse_blocks: int, triton_block_size: int) -> (Tensor, Tensor):
|
|
60
|
-
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
62
|
+
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
63
|
+
dtype=x.dtype, device=x.device)
|
|
61
64
|
|
|
62
65
|
x_b, x_r, x_c = x.size()
|
|
63
66
|
x_b_s, x_r_s, x_c_s = x.stride()
|
|
@@ -99,7 +102,8 @@ class _BlocksparseTranspose(torch.autograd.Function):
|
|
|
99
102
|
sparsity_block_size = ctx.sparsity_block_size
|
|
100
103
|
triton_block_size = ctx.triton_block_size
|
|
101
104
|
|
|
102
|
-
return transpose(grad_output, sparsity_layout, sparsity_block_size, triton_block_size)[
|
|
105
|
+
return transpose(grad_output, sparsity_layout, sparsity_block_size, triton_block_size)[
|
|
106
|
+
0], None, None, None, None, None, None
|
|
103
107
|
|
|
104
108
|
@staticmethod
|
|
105
109
|
@triton.jit
|
|
@@ -1,12 +1,13 @@
|
|
|
1
|
-
import torch
|
|
2
1
|
from torch import Tensor, Size
|
|
3
2
|
|
|
3
|
+
from blksprs.utils.validation import _set_skip_validation
|
|
4
|
+
|
|
4
5
|
|
|
5
6
|
def do_shape_blocksparse(x: Tensor):
|
|
6
7
|
if x.dim() == 3:
|
|
7
|
-
return x, x.size()
|
|
8
|
+
return x.contiguous(), x.size()
|
|
8
9
|
|
|
9
|
-
return x.reshape(-1, x.size(-2), x.size(-1)), x.size()
|
|
10
|
+
return x.reshape(-1, x.size(-2), x.size(-1)).contiguous(), x.size()
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
def undo_shape_blocksparse(x: Tensor, shape: Size):
|
|
@@ -18,3 +19,7 @@ def undo_shape_blocksparse(x: Tensor, shape: Size):
|
|
|
18
19
|
|
|
19
20
|
def get_triton_block_size(sparsity_block_size: int, limit: int = 128):
|
|
20
21
|
return min(sparsity_block_size, limit)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def disable_validation():
|
|
25
|
+
_set_skip_validation(True)
|
|
@@ -1,18 +1,19 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
from torch import Tensor
|
|
3
3
|
|
|
4
|
+
VALIDATION = True
|
|
4
5
|
|
|
5
|
-
def validate_dimensions(*tensors: Tensor) -> None:
|
|
6
|
-
if
|
|
6
|
+
def validate_dimensions(*tensors: Tensor, dims=3) -> None:
|
|
7
|
+
if _check_skip_validation():
|
|
7
8
|
return
|
|
8
9
|
|
|
9
10
|
for tensor in tensors:
|
|
10
|
-
if tensor.dim() !=
|
|
11
|
-
raise ValueError("Tensor must have
|
|
11
|
+
if tensor.dim() != dims:
|
|
12
|
+
raise ValueError(f"Tensor must have {dims} dimensions")
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
def validate_contiguous(*tensors: Tensor) -> None:
|
|
15
|
-
if
|
|
16
|
+
if _check_skip_validation():
|
|
16
17
|
return
|
|
17
18
|
|
|
18
19
|
for tensor in tensors:
|
|
@@ -21,7 +22,7 @@ def validate_contiguous(*tensors: Tensor) -> None:
|
|
|
21
22
|
|
|
22
23
|
|
|
23
24
|
def validate_dtype_float(*tensors: Tensor) -> None:
|
|
24
|
-
if
|
|
25
|
+
if _check_skip_validation():
|
|
25
26
|
return
|
|
26
27
|
|
|
27
28
|
for tensor in tensors:
|
|
@@ -30,7 +31,7 @@ def validate_dtype_float(*tensors: Tensor) -> None:
|
|
|
30
31
|
|
|
31
32
|
|
|
32
33
|
def validate_dtype_int(*tensors: Tensor) -> None:
|
|
33
|
-
if
|
|
34
|
+
if _check_skip_validation():
|
|
34
35
|
return
|
|
35
36
|
|
|
36
37
|
for tensor in tensors:
|
|
@@ -39,7 +40,7 @@ def validate_dtype_int(*tensors: Tensor) -> None:
|
|
|
39
40
|
|
|
40
41
|
|
|
41
42
|
def validate_device(*tensors: Tensor) -> None:
|
|
42
|
-
if
|
|
43
|
+
if _check_skip_validation():
|
|
43
44
|
return
|
|
44
45
|
|
|
45
46
|
device = None
|
|
@@ -56,7 +57,7 @@ def validate_device(*tensors: Tensor) -> None:
|
|
|
56
57
|
|
|
57
58
|
|
|
58
59
|
def validate_sparsity(sparsity_block_size: int, *tensor_sparsity_layout_tuples: tuple[Tensor, Tensor]) -> None:
|
|
59
|
-
if
|
|
60
|
+
if _check_skip_validation():
|
|
60
61
|
return
|
|
61
62
|
|
|
62
63
|
for (tensor, sparsity_layout) in tensor_sparsity_layout_tuples:
|
|
@@ -73,7 +74,7 @@ def _validate_sparsity_layout_values(sparsity_layout: Tensor):
|
|
|
73
74
|
raise ValueError("Sparsity layout values must be either 0 or 1")
|
|
74
75
|
|
|
75
76
|
def validate_sparsity_block_size(sparsity_block_size: int, *tensors):
|
|
76
|
-
if
|
|
77
|
+
if _check_skip_validation():
|
|
77
78
|
return
|
|
78
79
|
|
|
79
80
|
if not (sparsity_block_size & (sparsity_block_size - 1)) == 0:
|
|
@@ -84,14 +85,21 @@ def validate_sparsity_block_size(sparsity_block_size: int, *tensors):
|
|
|
84
85
|
raise ValueError("Tensor sizes must be divisible by sparsity block size")
|
|
85
86
|
|
|
86
87
|
def validate_triton_block_size(triton_block_size: int, sparsity_block_size: int):
|
|
87
|
-
if
|
|
88
|
+
if _check_skip_validation():
|
|
88
89
|
return
|
|
89
90
|
|
|
90
91
|
if triton_block_size is None:
|
|
91
92
|
return
|
|
92
93
|
|
|
94
|
+
if not (triton_block_size & (triton_block_size - 1)) == 0:
|
|
95
|
+
raise ValueError("Triton block size must be a power of 2")
|
|
96
|
+
|
|
93
97
|
if triton_block_size > sparsity_block_size:
|
|
94
98
|
raise ValueError("Triton block size cannot be larger than sparsity block size")
|
|
95
99
|
|
|
96
|
-
def
|
|
97
|
-
return
|
|
100
|
+
def _check_skip_validation():
|
|
101
|
+
return not VALIDATION
|
|
102
|
+
|
|
103
|
+
def _set_skip_validation(skip_validation: bool):
|
|
104
|
+
global VALIDATION
|
|
105
|
+
VALIDATION = not skip_validation
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 1.4
|
|
3
|
+
Version: 1.4.2
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -14,10 +14,8 @@ Requires-Dist: pytest-xdist; extra == "test"
|
|
|
14
14
|
Requires-Dist: pytest-cov; extra == "test"
|
|
15
15
|
Requires-Dist: coverage; extra == "test"
|
|
16
16
|
Requires-Dist: matplotlib; extra == "test"
|
|
17
|
-
Provides-Extra:
|
|
18
|
-
Requires-Dist: build; extra == "
|
|
19
|
-
Requires-Dist: twine; extra == "deploy"
|
|
20
|
-
Requires-Dist: pdoc3; extra == "deploy"
|
|
17
|
+
Provides-Extra: build
|
|
18
|
+
Requires-Dist: build; extra == "build"
|
|
21
19
|
|
|
22
20
|
# blksprs
|
|
23
21
|
|
|
@@ -146,7 +144,7 @@ def test_readme():
|
|
|
146
144
|
# Assert that the output has the correct sparsity layout
|
|
147
145
|
actual_sparsity_layout_o = bs.layout.build_sparsity_layout(o_dense, sparsity_block_size,
|
|
148
146
|
triton_block_size=triton_block_size)
|
|
149
|
-
assert torch.allclose(actual_sparsity_layout_o, sparsity_layout_o)
|
|
147
|
+
assert torch.allclose(actual_sparsity_layout_o.to(torch.int), sparsity_layout_o)
|
|
150
148
|
|
|
151
149
|
# Convert output tensor back to original shape
|
|
152
150
|
o = bs.util.undo_shape_blocksparse(o_dense, x_shape_original)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "blksprs"
|
|
3
|
-
version = "1.4"
|
|
3
|
+
version = "1.4.2"
|
|
4
4
|
authors = [{ name = "Felix Schön", email = "schoen@kr.tuwien.ac.at" }]
|
|
5
5
|
description = "A lightweight library for operations on blocksparse matrices in PyTorch."
|
|
6
6
|
readme = "README.md"
|
|
@@ -22,10 +22,8 @@ test = [
|
|
|
22
22
|
"coverage",
|
|
23
23
|
"matplotlib"
|
|
24
24
|
]
|
|
25
|
-
|
|
26
|
-
"build"
|
|
27
|
-
"twine",
|
|
28
|
-
"pdoc3"
|
|
25
|
+
build = [
|
|
26
|
+
"build"
|
|
29
27
|
]
|
|
30
28
|
|
|
31
29
|
[build-system]
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|