blksprs 2.1.7__py3-none-any.whl → 2.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +10 -2
- blksprs/layouting/distribution_layout.py +1 -1
- blksprs/layouting/sparsity_layout.py +1 -1
- blksprs/ops/conversion.py +13 -13
- blksprs/ops/distribution.py +9 -9
- blksprs/ops/flow.py +1 -1
- blksprs/ops/matmul.py +6 -6
- blksprs/ops/misc/broadcast_ops.py +2 -2
- blksprs/ops/misc/row_wise.py +4 -6
- blksprs/ops/partitioning.py +2 -2
- blksprs/ops/repeat.py +2 -2
- blksprs/ops/softmax.py +9 -11
- blksprs/ops/transpose.py +3 -3
- blksprs/utils/autotuning.py +1 -1
- blksprs/utils/blksprs_tensor.py +10 -1
- blksprs/utils/processing.py +2 -1
- blksprs/utils/tools.py +2 -5
- {blksprs-2.1.7.dist-info → blksprs-2.1.8.dist-info}/METADATA +2 -2
- blksprs-2.1.8.dist-info/RECORD +23 -0
- blksprs-2.1.7.dist-info/RECORD +0 -23
- {blksprs-2.1.7.dist-info → blksprs-2.1.8.dist-info}/WHEEL +0 -0
- {blksprs-2.1.7.dist-info → blksprs-2.1.8.dist-info}/top_level.txt +0 -0
blksprs/__init__.py
CHANGED
|
@@ -1,6 +1,14 @@
|
|
|
1
|
-
|
|
1
|
+
# Settings
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
# Capture scalar outputs for JIT compilation
|
|
5
|
+
torch._dynamo.config.capture_scalar_outputs = True
|
|
6
|
+
# Set version
|
|
7
|
+
__version__ = "2.1.8"
|
|
2
8
|
|
|
3
|
-
|
|
9
|
+
# Imports
|
|
10
|
+
|
|
11
|
+
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
4
12
|
|
|
5
13
|
|
|
6
14
|
class ops:
|
|
@@ -7,9 +7,9 @@ from torch._library import triton_op
|
|
|
7
7
|
from torch._library.triton import wrap_triton
|
|
8
8
|
from triton import language as tl
|
|
9
9
|
|
|
10
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
10
11
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
11
12
|
from blksprs.utils.tools import stride
|
|
12
|
-
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
13
13
|
from blksprs.utils.validation import validate_dimensions, validate_device, \
|
|
14
14
|
validate_contiguous
|
|
15
15
|
|
|
@@ -6,9 +6,9 @@ from torch import Tensor
|
|
|
6
6
|
from torch._library.triton import wrap_triton, triton_op
|
|
7
7
|
from triton import language as tl
|
|
8
8
|
|
|
9
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs, prune_autotune_configs_conversion
|
|
9
10
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
10
11
|
from blksprs.utils.tools import stride
|
|
11
|
-
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs, prune_autotune_configs_conversion
|
|
12
12
|
from blksprs.utils.validation import validate_dimensions, validate_device, \
|
|
13
13
|
validate_contiguous, validate_sparsity, validate_sparsity_block_size
|
|
14
14
|
|
blksprs/ops/conversion.py
CHANGED
|
@@ -5,9 +5,9 @@ from torch._library.triton import wrap_triton, triton_op
|
|
|
5
5
|
from triton import language as tl
|
|
6
6
|
|
|
7
7
|
from blksprs.layouting.sparsity_layout import build_sparsity_layout_adaption
|
|
8
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs, prune_autotune_configs_conversion
|
|
8
9
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
9
10
|
from blksprs.utils.tools import stride
|
|
10
|
-
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs, prune_autotune_configs_conversion
|
|
11
11
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
12
12
|
validate_sparsity, validate_sparsity_block_size, validate_sparsity_dense, ensure_contiguous
|
|
13
13
|
|
|
@@ -46,10 +46,10 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor,
|
|
|
46
46
|
lut = to_sparse_build_lut(lut, sparsity_layout)
|
|
47
47
|
|
|
48
48
|
if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
|
|
49
|
-
return BlksprsTensor(x)
|
|
49
|
+
return BlksprsTensor.wrap(x)
|
|
50
50
|
|
|
51
|
-
return BlksprsTensor(to_sparse_forward(x, sparsity_layout,
|
|
52
|
-
|
|
51
|
+
return BlksprsTensor.wrap(to_sparse_forward(x, sparsity_layout,
|
|
52
|
+
lut["sparsity_lut"], sparsity_block_size, lut["n_sparse_blocks"]))
|
|
53
53
|
|
|
54
54
|
|
|
55
55
|
@triton_op("blksprs::to_sparse_forward", mutates_args={})
|
|
@@ -201,7 +201,7 @@ def to_dense(x: BlksprsTensor, sparsity_layout: Tensor,
|
|
|
201
201
|
return x
|
|
202
202
|
|
|
203
203
|
return Tensor(to_dense_forward(x, sparsity_layout,
|
|
204
|
-
|
|
204
|
+
lut["sparsity_reverse_lut"], sparsity_block_size, fill_value))
|
|
205
205
|
|
|
206
206
|
|
|
207
207
|
@triton_op("blksprs::to_dense_forward", mutates_args={})
|
|
@@ -360,14 +360,14 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
|
|
|
360
360
|
validate_contiguous(sparsity_reverse_lut_from, sparsity_layout_to, sparsity_lut_to)
|
|
361
361
|
|
|
362
362
|
if (sparsity_block_size_from == sparsity_block_size_to) and torch.equal(sparsity_layout_from, sparsity_layout_to):
|
|
363
|
-
return BlksprsTensor(x), sparsity_layout_to
|
|
364
|
-
|
|
365
|
-
return BlksprsTensor(adapt_layout_forward(x,
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
363
|
+
return BlksprsTensor.wrap(x), sparsity_layout_to
|
|
364
|
+
|
|
365
|
+
return BlksprsTensor.wrap(adapt_layout_forward(x,
|
|
366
|
+
sparsity_layout_from, sparsity_reverse_lut_from,
|
|
367
|
+
sparsity_block_size_from,
|
|
368
|
+
sparsity_layout_to, sparsity_lut_to,
|
|
369
|
+
sparsity_block_size_to,
|
|
370
|
+
n_sparse_blocks_to)), sparsity_layout_to
|
|
371
371
|
|
|
372
372
|
|
|
373
373
|
@triton_op("blksprs::adapt_layout_forward", mutates_args={})
|
blksprs/ops/distribution.py
CHANGED
|
@@ -5,9 +5,9 @@ from torch._library import triton_op
|
|
|
5
5
|
from torch._library.triton import wrap_triton
|
|
6
6
|
from triton import language as tl
|
|
7
7
|
|
|
8
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
8
9
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
9
10
|
from blksprs.utils.tools import stride
|
|
10
|
-
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
11
11
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
12
12
|
validate_sparsity, validate_dtype_int, validate_sparsity_block_size, ensure_contiguous
|
|
13
13
|
|
|
@@ -45,9 +45,9 @@ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
45
45
|
|
|
46
46
|
lut = gather_build_lut(lut, sparsity_layout_src, sparsity_layout_idx)
|
|
47
47
|
|
|
48
|
-
return BlksprsTensor(gather_forward(src, sparsity_layout_src, lut["sparsity_reverse_lut_x"],
|
|
49
|
-
|
|
50
|
-
|
|
48
|
+
return BlksprsTensor.wrap(gather_forward(src, sparsity_layout_src, lut["sparsity_reverse_lut_x"],
|
|
49
|
+
adjusted_dim, idx, sparsity_layout_idx, lut["sparsity_lut_i"],
|
|
50
|
+
sparsity_block_size))
|
|
51
51
|
|
|
52
52
|
|
|
53
53
|
@triton_op("blksprs::gather_forward", mutates_args={})
|
|
@@ -276,11 +276,11 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
276
276
|
|
|
277
277
|
lut = scatter_reduce_build_lut(lut, sparsity_layout_src, sparsity_layout_tgt)
|
|
278
278
|
|
|
279
|
-
return BlksprsTensor(scatter_reduce_forward(src, sparsity_layout_src, lut["sparsity_lut_x"],
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
279
|
+
return BlksprsTensor.wrap(scatter_reduce_forward(src, sparsity_layout_src, lut["sparsity_lut_x"],
|
|
280
|
+
adjusted_dim, idx,
|
|
281
|
+
sparsity_layout_tgt, lut["sparsity_reverse_lut_o"],
|
|
282
|
+
sparsity_block_size, lut["n_sparse_blocks"],
|
|
283
|
+
reduce_op))
|
|
284
284
|
|
|
285
285
|
|
|
286
286
|
@triton_op("blksprs::scatter_reduce_forward", mutates_args={})
|
blksprs/ops/flow.py
CHANGED
|
@@ -5,8 +5,8 @@ from torch._library import triton_op
|
|
|
5
5
|
from torch._library.triton import wrap_triton
|
|
6
6
|
from triton import language as tl
|
|
7
7
|
|
|
8
|
-
from blksprs.utils.tools import stride
|
|
9
8
|
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
9
|
+
from blksprs.utils.tools import stride
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
@triton_op("blksprs::flow_pull_forward", mutates_args={})
|
blksprs/ops/matmul.py
CHANGED
|
@@ -5,9 +5,9 @@ from torch.library import triton_op, wrap_triton
|
|
|
5
5
|
from triton import language as tl
|
|
6
6
|
|
|
7
7
|
from blksprs.ops.transpose import transpose
|
|
8
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
8
9
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
9
10
|
from blksprs.utils.tools import stride
|
|
10
|
-
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
11
11
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
12
12
|
validate_sparsity, validate_sparsity_block_size, validate_dtype_float, ensure_contiguous
|
|
13
13
|
|
|
@@ -47,11 +47,11 @@ def matmul(x: BlksprsTensor, sparsity_layout_x: Tensor,
|
|
|
47
47
|
|
|
48
48
|
lut = matmul_build_lut(lut, sparsity_layout_x, sparsity_layout_y, sparsity_layout_output)
|
|
49
49
|
|
|
50
|
-
return BlksprsTensor(matmul_forward(x, y,
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
50
|
+
return BlksprsTensor.wrap(matmul_forward(x, y,
|
|
51
|
+
sparsity_layout_x, lut["sparsity_reverse_lut_x"],
|
|
52
|
+
sparsity_layout_y, lut["sparsity_reverse_lut_y"],
|
|
53
|
+
sparsity_layout_output, lut["sparsity_lut_o"],
|
|
54
|
+
sparsity_block_size, lut["n_sparse_blocks"]))
|
|
55
55
|
|
|
56
56
|
|
|
57
57
|
@triton_op("blksprs::matmul_forward", mutates_args={})
|
|
@@ -5,9 +5,9 @@ from torch._library import triton_op
|
|
|
5
5
|
from torch._library.triton import wrap_triton
|
|
6
6
|
from triton import language as tl
|
|
7
7
|
|
|
8
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
8
9
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
9
10
|
from blksprs.utils.tools import stride
|
|
10
|
-
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
11
11
|
from blksprs.utils.validation import validate_contiguous, validate_device, \
|
|
12
12
|
validate_sparsity_block_size, ensure_contiguous
|
|
13
13
|
|
|
@@ -43,7 +43,7 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
|
43
43
|
|
|
44
44
|
validate_contiguous(sparsity_layout_output, sparsity_lut_o)
|
|
45
45
|
|
|
46
|
-
return BlksprsTensor(broadcast_add_forward(x, y, sparsity_lut_o, sparsity_block_size, n_sparse_blocks))
|
|
46
|
+
return BlksprsTensor.wrap(broadcast_add_forward(x, y, sparsity_lut_o, sparsity_block_size, n_sparse_blocks))
|
|
47
47
|
|
|
48
48
|
|
|
49
49
|
def broadcast_sub(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
blksprs/ops/misc/row_wise.py
CHANGED
|
@@ -55,7 +55,7 @@ def row_wise_sum(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
|
|
|
55
55
|
validate_contiguous(sparsity_layout, sparsity_lut,
|
|
56
56
|
sparsity_layout_output, sparsity_reverse_lut_output)
|
|
57
57
|
|
|
58
|
-
return BlksprsTensor(row_wise_sum_forward(
|
|
58
|
+
return BlksprsTensor.wrap(row_wise_sum_forward(
|
|
59
59
|
x, sparsity_lut, sparsity_layout_output, sparsity_reverse_lut_output,
|
|
60
60
|
sparsity_block_size, n_sparse_blocks_output, flag_slice_only)), sparsity_layout_output
|
|
61
61
|
|
|
@@ -174,8 +174,6 @@ def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
|
|
|
174
174
|
of the input and the sparsity layout of the output tensor.
|
|
175
175
|
|
|
176
176
|
"""
|
|
177
|
-
# TODO Fix for triton bug, see https://github.com/triton-lang/triton/issues/6376, should be fixed with the upcoming 3.4.0 release
|
|
178
|
-
x = torch.where(x == -0.0, torch.tensor(0.0), x)
|
|
179
177
|
x = ensure_contiguous(x)
|
|
180
178
|
|
|
181
179
|
validate_dimensions(x)
|
|
@@ -197,7 +195,7 @@ def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
|
|
|
197
195
|
validate_contiguous(sparsity_layout, sparsity_lut,
|
|
198
196
|
sparsity_layout_output, sparsity_reverse_lut_output)
|
|
199
197
|
|
|
200
|
-
return BlksprsTensor(
|
|
198
|
+
return BlksprsTensor.wrap(
|
|
201
199
|
row_wise_max_forward(x, sparsity_lut, sparsity_layout_output, sparsity_reverse_lut_output, sparsity_block_size,
|
|
202
200
|
n_sparse_blocks_output, flag_slice_only)), sparsity_layout_output
|
|
203
201
|
|
|
@@ -329,8 +327,8 @@ def row_wise_add(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
|
|
|
329
327
|
|
|
330
328
|
validate_contiguous(sparsity_layout_x, sparsity_lut_x, sparsity_reverse_lut_rwm)
|
|
331
329
|
|
|
332
|
-
return BlksprsTensor(row_wise_add_forward(x, sparsity_lut_x, sparsity_layout_rwm,
|
|
333
|
-
|
|
330
|
+
return BlksprsTensor.wrap(row_wise_add_forward(x, sparsity_lut_x, sparsity_layout_rwm,
|
|
331
|
+
sparsity_reverse_lut_rwm, y, sparsity_block_size))
|
|
334
332
|
|
|
335
333
|
|
|
336
334
|
def row_wise_sub(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
|
blksprs/ops/partitioning.py
CHANGED
|
@@ -41,7 +41,7 @@ def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
|
41
41
|
|
|
42
42
|
lut = split_build_lut(lut, sparsity_layout, partitions)
|
|
43
43
|
|
|
44
|
-
return BlksprsTensor(split_forward(
|
|
44
|
+
return BlksprsTensor.wrap(split_forward(
|
|
45
45
|
x, lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"],
|
|
46
46
|
partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_output"]
|
|
47
47
|
|
|
@@ -146,7 +146,7 @@ def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
|
146
146
|
|
|
147
147
|
lut = merge_build_lut(lut, sparsity_layout, partitions)
|
|
148
148
|
|
|
149
|
-
return BlksprsTensor(merge_forward(
|
|
149
|
+
return BlksprsTensor.wrap(merge_forward(
|
|
150
150
|
x, lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"],
|
|
151
151
|
partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_output"]
|
|
152
152
|
|
blksprs/ops/repeat.py
CHANGED
|
@@ -46,7 +46,7 @@ def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int,
|
|
|
46
46
|
|
|
47
47
|
lut = repeat_build_lut(lut, sparsity_layout_x, repeats, sparsity_layout_output)
|
|
48
48
|
|
|
49
|
-
return BlksprsTensor(repeat_forward(
|
|
49
|
+
return BlksprsTensor.wrap(repeat_forward(
|
|
50
50
|
x, sparsity_layout_x, lut["sparsity_layout_o"], lut["sparsity_lut"],
|
|
51
51
|
lut["sparsity_reverse_lut"], sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_o"]
|
|
52
52
|
|
|
@@ -87,7 +87,7 @@ def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
|
|
|
87
87
|
|
|
88
88
|
lut = repeat_interleave_build_lut(lut, sparsity_layout_x, repeats, sparsity_layout_output)
|
|
89
89
|
|
|
90
|
-
return BlksprsTensor(repeat_forward(
|
|
90
|
+
return BlksprsTensor.wrap(repeat_forward(
|
|
91
91
|
x, sparsity_layout_x, lut["sparsity_layout_o"], lut["sparsity_lut"],
|
|
92
92
|
lut["sparsity_reverse_lut"], sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_o"]
|
|
93
93
|
|
blksprs/ops/softmax.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
import pdb
|
|
2
|
-
|
|
3
1
|
import torch
|
|
4
2
|
import triton
|
|
5
3
|
from torch import Tensor
|
|
@@ -8,9 +6,9 @@ from torch._library.triton import wrap_triton
|
|
|
8
6
|
from triton import language as tl
|
|
9
7
|
|
|
10
8
|
from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
|
|
9
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
11
10
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
12
11
|
from blksprs.utils.tools import stride, ceil_pow2
|
|
13
|
-
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
14
12
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
15
13
|
validate_sparsity, validate_sparsity_block_size, validate_dtype_float_32, ensure_contiguous
|
|
16
14
|
|
|
@@ -55,10 +53,10 @@ def softmax_regular(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_si
|
|
|
55
53
|
|
|
56
54
|
lut = softmax_build_lut(lut, sparsity_layout)
|
|
57
55
|
|
|
58
|
-
return BlksprsTensor(softmax_forward(x, sparsity_layout,
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
56
|
+
return BlksprsTensor.wrap(softmax_forward(x, sparsity_layout,
|
|
57
|
+
lut["sparsity_lut"],
|
|
58
|
+
lut["sparsity_reverse_lut_rws"],
|
|
59
|
+
sparsity_block_size))
|
|
62
60
|
|
|
63
61
|
|
|
64
62
|
@triton_op("blksprs::softmax_forward", mutates_args={})
|
|
@@ -346,10 +344,10 @@ def softmax_fused(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size
|
|
|
346
344
|
|
|
347
345
|
lut = softmax_fused_build_lut(lut, sparsity_layout)
|
|
348
346
|
|
|
349
|
-
return BlksprsTensor(softmax_fused_forward(x, sparsity_layout,
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
347
|
+
return BlksprsTensor.wrap(softmax_fused_forward(x, sparsity_layout,
|
|
348
|
+
lut["sparsity_reverse_lut_sorted"],
|
|
349
|
+
lut["max_blocks_line"],
|
|
350
|
+
sparsity_block_size))
|
|
353
351
|
|
|
354
352
|
|
|
355
353
|
@triton_op("blksprs::softmax_fused_forward", mutates_args={})
|
blksprs/ops/transpose.py
CHANGED
|
@@ -37,9 +37,9 @@ def transpose(x: BlksprsTensor, sparsity_layout: Tensor,
|
|
|
37
37
|
|
|
38
38
|
lut = transpose_build_lut(lut, sparsity_layout)
|
|
39
39
|
|
|
40
|
-
return BlksprsTensor(transpose_forward(x, lut["sparsity_layout_t"],
|
|
41
|
-
|
|
42
|
-
|
|
40
|
+
return BlksprsTensor.wrap(transpose_forward(x, lut["sparsity_layout_t"],
|
|
41
|
+
lut["sparsity_lut"], lut["sparsity_reverse_lut"],
|
|
42
|
+
sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_t"]
|
|
43
43
|
|
|
44
44
|
|
|
45
45
|
@triton_op("blksprs::transpose_forward", mutates_args={})
|
blksprs/utils/autotuning.py
CHANGED
blksprs/utils/blksprs_tensor.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import Union
|
|
2
|
+
|
|
1
3
|
import torch
|
|
2
4
|
from torch import Tensor
|
|
3
5
|
|
|
@@ -7,4 +9,11 @@ class BlksprsTensor(Tensor):
|
|
|
7
9
|
"""
|
|
8
10
|
|
|
9
11
|
def __repr__(self):
|
|
10
|
-
return f"BlksprsTensor({torch.Tensor(self).__repr__()})"
|
|
12
|
+
return f"BlksprsTensor({torch.Tensor(self).__repr__()})"
|
|
13
|
+
|
|
14
|
+
@staticmethod
|
|
15
|
+
def wrap(tensor: Tensor) -> Union[Tensor, "BlksprsTensor"]:
|
|
16
|
+
if torch._dynamo.is_compiling():
|
|
17
|
+
return tensor
|
|
18
|
+
else:
|
|
19
|
+
return BlksprsTensor(tensor)
|
blksprs/utils/processing.py
CHANGED
|
@@ -26,7 +26,8 @@ def apply_torch_linear(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block
|
|
|
26
26
|
|
|
27
27
|
# Apply weights
|
|
28
28
|
sparsity_layout_xw = build_sparsity_layout_matmul_fast(sparsity_layout, sparsity_layout_w_t)
|
|
29
|
-
xw = matmul(x, sparsity_layout, BlksprsTensor(w_t_bs.to(x.dtype)), sparsity_layout_w_t, sparsity_layout_xw,
|
|
29
|
+
xw = matmul(x, sparsity_layout, BlksprsTensor.wrap(w_t_bs.to(x.dtype)), sparsity_layout_w_t, sparsity_layout_xw,
|
|
30
|
+
sparsity_block_size)
|
|
30
31
|
interim = xw
|
|
31
32
|
|
|
32
33
|
# Apply bias
|
blksprs/utils/tools.py
CHANGED
|
@@ -1,9 +1,5 @@
|
|
|
1
|
-
import torch
|
|
2
1
|
from torch import Tensor, Size
|
|
3
2
|
|
|
4
|
-
# Capture scalar outputs for JIT compilation
|
|
5
|
-
torch._dynamo.config.capture_scalar_outputs = True
|
|
6
|
-
|
|
7
3
|
|
|
8
4
|
def do_shape_blocksparse(x: Tensor) -> tuple[Tensor, Size]:
|
|
9
5
|
if x.dim() == 3:
|
|
@@ -27,7 +23,8 @@ def stride(x: Tensor):
|
|
|
27
23
|
else:
|
|
28
24
|
raise NotImplementedError
|
|
29
25
|
|
|
26
|
+
|
|
30
27
|
def ceil_pow2(x: int) -> int:
|
|
31
28
|
if x <= 0:
|
|
32
29
|
raise ValueError("Input must be a positive integer.")
|
|
33
|
-
return 1 << (x - 1).bit_length()
|
|
30
|
+
return 1 << (x - 1).bit_length()
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 2.1.
|
|
3
|
+
Version: 2.1.8
|
|
4
4
|
Summary: A lightweight library for operations on block-sparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
7
7
|
Project-URL: Bugtracker, https://github.com/FelixSchoen/blksprs/issues
|
|
8
8
|
Requires-Python: >=3.11
|
|
9
9
|
Description-Content-Type: text/markdown
|
|
10
|
-
Requires-Dist: torch
|
|
10
|
+
Requires-Dist: torch>=2.8.0
|
|
11
11
|
Requires-Dist: numpy
|
|
12
12
|
Provides-Extra: test
|
|
13
13
|
Requires-Dist: pytest; extra == "test"
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
blksprs/__init__.py,sha256=JoMjkaouJY54Z0v85diwErpJRnT9fFIH1wQvg0djOME,1777
|
|
2
|
+
blksprs/layouting/distribution_layout.py,sha256=GtiJH0IcIQodWZnkXqo7ZQ0TLl1HJe30eSuMs86CKPg,5861
|
|
3
|
+
blksprs/layouting/sparsity_layout.py,sha256=hsgXUbPYfLVTjHUwXyMEx2XrbF0OqUWbUg07KKsQG_g,11207
|
|
4
|
+
blksprs/ops/conversion.py,sha256=6gXHMnkh1-XzjnjCorDBNTANcEWM6hbTRsI9e-VRWjw,21462
|
|
5
|
+
blksprs/ops/distribution.py,sha256=9s2V78Og_p7BvTqWm3jW-7ep_0hfcmZtkDtF7HBEUU4,20220
|
|
6
|
+
blksprs/ops/flow.py,sha256=ziR3q1Da25k6xcFhlks1BWzzuOvJxvO_Bm9XZNuyJJQ,7760
|
|
7
|
+
blksprs/ops/matmul.py,sha256=_bF357SM_38JTIaZih9yXSbFjfmHU2cMCyXj4aseKFE,11609
|
|
8
|
+
blksprs/ops/partitioning.py,sha256=88TU77uDbvZTcYdTah9oChJrbgqZdkj4tNPylf9IS1c,9995
|
|
9
|
+
blksprs/ops/repeat.py,sha256=2Ilr0qf0Glow-lNwxV0mW5iuKPT8Dt10D0tpOOe5EGs,9090
|
|
10
|
+
blksprs/ops/softmax.py,sha256=LMrkef7OxMhd6eclj6975KAAklFT2hm7h_0ewjYgHx4,23498
|
|
11
|
+
blksprs/ops/transpose.py,sha256=IaNdqWDZ2rNSaO8kwpQyoSUpVpsoxMREgEXzhVBTsaY,4112
|
|
12
|
+
blksprs/ops/misc/broadcast_ops.py,sha256=iosw2dsbvEVmq3THP3vZ98_NqtAf9P_u5nTXFmf3sLA,5697
|
|
13
|
+
blksprs/ops/misc/row_wise.py,sha256=ldfdy9h_ERhCtJJvu4tXr9tqkNB5QoiSa7NUNo2MQgA,19253
|
|
14
|
+
blksprs/utils/autotuning.py,sha256=SOvVesXmVDbAprKtVGlfqKQ-JyRHfynvxtsbH7Qjem0,2053
|
|
15
|
+
blksprs/utils/benchmarking.py,sha256=dLabDscTFn5NkmOI1g7DnKeTneUYW3RIVv9MDF-8BKc,1271
|
|
16
|
+
blksprs/utils/blksprs_tensor.py,sha256=Y8YnsFPifvdCf5Khsm8bDVv-589U0N8IsCFlnDETfzE,476
|
|
17
|
+
blksprs/utils/processing.py,sha256=GcsUl54DDrEoZ0iuWZV5Q0BR2ZML3jWOhypOMxDCsrs,3759
|
|
18
|
+
blksprs/utils/tools.py,sha256=vlIH89TzMxotKeqts0Pipr09uf0HDQN9oQYGSGfAdk4,751
|
|
19
|
+
blksprs/utils/validation.py,sha256=P98sCk6PZCQB0wO3scGTJIXfkv5EpHFM_uNHBXr42n4,4844
|
|
20
|
+
blksprs-2.1.8.dist-info/METADATA,sha256=Yb0YfVMqDLx9PfPBaPfTsrCAe20X-sZWgz2AuniK-Hg,9597
|
|
21
|
+
blksprs-2.1.8.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
22
|
+
blksprs-2.1.8.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
|
|
23
|
+
blksprs-2.1.8.dist-info/RECORD,,
|
blksprs-2.1.7.dist-info/RECORD
DELETED
|
@@ -1,23 +0,0 @@
|
|
|
1
|
-
blksprs/__init__.py,sha256=Kq4YytK5Im-70RGrH1hj39qgc77tL_DH0Gc3KpZgijQ,1631
|
|
2
|
-
blksprs/layouting/distribution_layout.py,sha256=ur1ty_2U-Hfj78hMWsLZvu7ZuGhzW3qGLKMc72DfTZM,5861
|
|
3
|
-
blksprs/layouting/sparsity_layout.py,sha256=eXHmu2h7K5Q-YUpfOxocJoeP_5ZoQFZf_eHLxRZQbYU,11207
|
|
4
|
-
blksprs/ops/conversion.py,sha256=NcBxWWWzMkjQx_fEfh14RWt688X6J82FzDqByAd3Pj4,21405
|
|
5
|
-
blksprs/ops/distribution.py,sha256=pabgyw0m3A4A0osfnOoKffk-b2BKXCn-lC6BU26ocKY,20180
|
|
6
|
-
blksprs/ops/flow.py,sha256=JEGES5ZbMqxR02rwi2Ym4j3VDxkcRxhFO1f-5nNUlM8,7760
|
|
7
|
-
blksprs/ops/matmul.py,sha256=9XPsKbYBw0cdmZY6i4T3Phbx00LXIuA6KI0EIcyGo9U,11584
|
|
8
|
-
blksprs/ops/partitioning.py,sha256=67_a9a5ZpsRmB4BVTOks0stFWp34cb0nk28zQFkXEZc,9985
|
|
9
|
-
blksprs/ops/repeat.py,sha256=Eo7L-TcrrXb_I6xKXLVklp1EuCuA0sfhPaOzw_8y1eU,9080
|
|
10
|
-
blksprs/ops/softmax.py,sha256=YcoZpdC1BdL4zKRQOSjIRtfGgDoQvUZabgNmjbeY8-4,23470
|
|
11
|
-
blksprs/ops/transpose.py,sha256=AyIPuiMAtUAPJPs9eK-Apz6vjZdmnJO9RF6_yH6u6Fk,4097
|
|
12
|
-
blksprs/ops/misc/broadcast_ops.py,sha256=ro7K2ZMOsscxNEp2HY_6efqJ4Wrf-QCFL4NLeDqvah8,5692
|
|
13
|
-
blksprs/ops/misc/row_wise.py,sha256=dfhuXexyFBaNvfZjOt9w3s29ih19JhWIy04_FhUnHgk,19420
|
|
14
|
-
blksprs/utils/autotuning.py,sha256=xalNP3sWdRn8XiVG4jE1-_iy2QhUmIJvTGM83YwgKA0,2052
|
|
15
|
-
blksprs/utils/benchmarking.py,sha256=dLabDscTFn5NkmOI1g7DnKeTneUYW3RIVv9MDF-8BKc,1271
|
|
16
|
-
blksprs/utils/blksprs_tensor.py,sha256=pfoz59aJixj_fIoFx76ySiygwRQUemmgjMKepZ2c4j0,244
|
|
17
|
-
blksprs/utils/processing.py,sha256=RNkEDc0g-sNHRuMPkRzNWU13d3_lIkXMJdoqES4yQTM,3738
|
|
18
|
-
blksprs/utils/tools.py,sha256=TKygEKge4wJtJnXXDg8BTL8vzBpqIJsQ_A3_5FmLpcE,859
|
|
19
|
-
blksprs/utils/validation.py,sha256=P98sCk6PZCQB0wO3scGTJIXfkv5EpHFM_uNHBXr42n4,4844
|
|
20
|
-
blksprs-2.1.7.dist-info/METADATA,sha256=A7ZUYLyq7D8A243kuLSUVu2GKrPEp8-Bi5EYVFgbMdU,9590
|
|
21
|
-
blksprs-2.1.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
22
|
-
blksprs-2.1.7.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
|
|
23
|
-
blksprs-2.1.7.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|