blksprs 2.1.2__py3-none-any.whl → 2.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/ops/conversion.py +1 -1
- blksprs/ops/flow.py +1 -1
- blksprs/ops/matmul.py +1 -1
- blksprs/ops/softmax.py +15 -13
- {blksprs-2.1.2.dist-info → blksprs-2.1.3.dist-info}/METADATA +3 -3
- {blksprs-2.1.2.dist-info → blksprs-2.1.3.dist-info}/RECORD +8 -8
- {blksprs-2.1.2.dist-info → blksprs-2.1.3.dist-info}/WHEEL +0 -0
- {blksprs-2.1.2.dist-info → blksprs-2.1.3.dist-info}/top_level.txt +0 -0
blksprs/ops/conversion.py
CHANGED
|
@@ -56,7 +56,7 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor,
|
|
|
56
56
|
def to_sparse_forward(x: Tensor, _: Tensor,
|
|
57
57
|
sparsity_lut: Tensor, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
58
58
|
with torch.no_grad():
|
|
59
|
-
output = torch.
|
|
59
|
+
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
60
60
|
dtype=x.dtype, device=x.device)
|
|
61
61
|
|
|
62
62
|
x_b, x_r, x_c = x.size()
|
blksprs/ops/flow.py
CHANGED
|
@@ -14,7 +14,7 @@ def flow_pull_forward(x: Tensor, sparsity_layout_o: Tensor,
|
|
|
14
14
|
sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
15
15
|
sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
16
16
|
with torch.no_grad():
|
|
17
|
-
output = torch.
|
|
17
|
+
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
18
18
|
dtype=x.dtype, device=x.device)
|
|
19
19
|
|
|
20
20
|
x_b, x_r, x_c = x.size()
|
blksprs/ops/matmul.py
CHANGED
|
@@ -62,7 +62,7 @@ def matmul_forward(x: Tensor, y: Tensor,
|
|
|
62
62
|
_: Tensor, sparsity_lut_o: Tensor,
|
|
63
63
|
sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
64
64
|
with torch.no_grad():
|
|
65
|
-
output = torch.
|
|
65
|
+
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
66
66
|
dtype=x.dtype, device=x.device)
|
|
67
67
|
|
|
68
68
|
x_b, x_r, x_c = x.size()
|
blksprs/ops/softmax.py
CHANGED
|
@@ -66,7 +66,7 @@ def softmax_forward(x: Tensor, sparsity_layout: Tensor,
|
|
|
66
66
|
sparsity_lut: Tensor,
|
|
67
67
|
sparsity_reverse_lut_rws: Tensor,
|
|
68
68
|
sparsity_block_size: int) -> Tensor:
|
|
69
|
-
output = torch.
|
|
69
|
+
output = torch.empty_like(x)
|
|
70
70
|
|
|
71
71
|
x_row_wise_max, sparsity_layout_rwm = row_wise_max(x, sparsity_layout, sparsity_block_size,
|
|
72
72
|
flag_slice_only=True)
|
|
@@ -114,7 +114,7 @@ def softmax_backward_wrapper(ctx, grad_output):
|
|
|
114
114
|
def softmax_backward(grad_output: Tensor, o: Tensor, sparsity_lut: Tensor, sparsity_layout: Tensor,
|
|
115
115
|
sparsity_block_size: int) -> Tensor:
|
|
116
116
|
with torch.no_grad():
|
|
117
|
-
grad_x = torch.
|
|
117
|
+
grad_x = torch.empty_like(o, dtype=torch.float)
|
|
118
118
|
|
|
119
119
|
s, sparsity_layout_s = row_wise_sum(grad_output * o, sparsity_layout, sparsity_block_size, flag_slice_only=True)
|
|
120
120
|
|
|
@@ -359,7 +359,7 @@ def softmax_fused_forward(x: Tensor, sparsity_layout: Tensor,
|
|
|
359
359
|
sparsity_reverse_lut_sorted: Tensor,
|
|
360
360
|
max_blocks_line: int,
|
|
361
361
|
sparsity_block_size: int) -> Tensor:
|
|
362
|
-
output = torch.
|
|
362
|
+
output = torch.empty_like(x)
|
|
363
363
|
|
|
364
364
|
x_b, x_r, x_c = x.size()
|
|
365
365
|
x_b_s, x_r_s, x_c_s = stride(x)
|
|
@@ -374,7 +374,7 @@ def softmax_fused_forward(x: Tensor, sparsity_layout: Tensor,
|
|
|
374
374
|
(x,
|
|
375
375
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
376
376
|
output,
|
|
377
|
-
s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
|
|
377
|
+
s_l_b, s_l_b_s, s_l_r_s, s_l_c, s_l_c_s,
|
|
378
378
|
sparsity_reverse_lut_sorted,
|
|
379
379
|
max_blocks_line,
|
|
380
380
|
sparsity_block_size))
|
|
@@ -399,7 +399,7 @@ def softmax_fused_backward(grad_output: Tensor,
|
|
|
399
399
|
max_blocks_line: int,
|
|
400
400
|
sparsity_block_size: int) -> Tensor:
|
|
401
401
|
with torch.no_grad():
|
|
402
|
-
grad_x = torch.
|
|
402
|
+
grad_x = torch.empty_like(o)
|
|
403
403
|
|
|
404
404
|
g_b, g_r, g_c = grad_output.size()
|
|
405
405
|
g_b_s, g_r_s, g_c_s = stride(grad_output)
|
|
@@ -417,7 +417,7 @@ def softmax_fused_backward(grad_output: Tensor,
|
|
|
417
417
|
g_b, g_b_s, g_r_s, g_c_s,
|
|
418
418
|
o,
|
|
419
419
|
o_b, o_b_s, o_r_s, o_c_s,
|
|
420
|
-
s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
|
|
420
|
+
s_l_b, s_l_b_s, s_l_r_s, s_l_c, s_l_c_s,
|
|
421
421
|
sparsity_reverse_lut_sorted,
|
|
422
422
|
grad_x,
|
|
423
423
|
max_blocks_line,
|
|
@@ -437,7 +437,7 @@ def softmax_fused_backward(grad_output: Tensor,
|
|
|
437
437
|
def softmax_fused_kernel(x,
|
|
438
438
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
439
439
|
o,
|
|
440
|
-
s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
|
|
440
|
+
s_l_b, s_l_b_s, s_l_r_s, s_l_c, s_l_c_s,
|
|
441
441
|
r_lut_s,
|
|
442
442
|
mbs: tl.constexpr,
|
|
443
443
|
sparsity_block_size: tl.constexpr,
|
|
@@ -451,8 +451,9 @@ def softmax_fused_kernel(x,
|
|
|
451
451
|
blk_rev_idx = (pid_bat * s_l_b_s +
|
|
452
452
|
pid_row * s_l_r_s +
|
|
453
453
|
(tl.arange(0, mbs) * s_l_c_s))
|
|
454
|
-
blk_rev_msk = (blk_rev_idx >= 0 and blk_rev_idx < s_l_b * s_l_b_s)
|
|
455
|
-
|
|
454
|
+
blk_rev_msk = ((blk_rev_idx >= 0 and blk_rev_idx < s_l_b * s_l_b_s) and
|
|
455
|
+
(tl.arange(0, mbs) < s_l_c))
|
|
456
|
+
blk_rev = tl.load(r_lut_s + blk_rev_idx, mask=blk_rev_msk, other=-1).to(tl.int32)
|
|
456
457
|
|
|
457
458
|
if (not (tl.min(blk_rev) == -1 and
|
|
458
459
|
tl.max(blk_rev) == -1)):
|
|
@@ -488,7 +489,7 @@ def softmax_fused_kernel_grad(g,
|
|
|
488
489
|
g_b, g_b_s, g_r_s, g_c_s,
|
|
489
490
|
x,
|
|
490
491
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
491
|
-
s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
|
|
492
|
+
s_l_b, s_l_b_s, s_l_r_s, s_l_c, s_l_c_s,
|
|
492
493
|
r_lut_s,
|
|
493
494
|
o,
|
|
494
495
|
mbs: tl.constexpr,
|
|
@@ -503,8 +504,9 @@ def softmax_fused_kernel_grad(g,
|
|
|
503
504
|
blk_rev_idx = (pid_bat * s_l_b_s +
|
|
504
505
|
pid_row * s_l_r_s +
|
|
505
506
|
(tl.arange(0, mbs) * s_l_c_s))
|
|
506
|
-
blk_rev_msk = (blk_rev_idx >= 0 and blk_rev_idx < s_l_b * s_l_b_s)
|
|
507
|
-
|
|
507
|
+
blk_rev_msk = ((blk_rev_idx >= 0 and blk_rev_idx < s_l_b * s_l_b_s) and
|
|
508
|
+
(tl.arange(0, mbs) < s_l_c))
|
|
509
|
+
blk_rev = tl.load(r_lut_s + blk_rev_idx, mask=blk_rev_msk, other=-1).to(tl.int32)
|
|
508
510
|
|
|
509
511
|
if (not (tl.min(blk_rev) == -1 and
|
|
510
512
|
tl.max(blk_rev) == -1)):
|
|
@@ -557,7 +559,7 @@ def softmax_fused_build_lut(lut: dict, sparsity_layout: Tensor):
|
|
|
557
559
|
.sum(dim=-1)
|
|
558
560
|
.max()
|
|
559
561
|
.item())
|
|
560
|
-
lut["max_blocks_line"] =
|
|
562
|
+
lut["max_blocks_line"] = ceil_pow2(max(max_blocks_line, 2))
|
|
561
563
|
|
|
562
564
|
validate_contiguous(sparsity_layout, lut["sparsity_reverse_lut_sorted"])
|
|
563
565
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 2.1.
|
|
3
|
+
Version: 2.1.3
|
|
4
4
|
Summary: A lightweight library for operations on block-sparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -196,8 +196,8 @@ def test_readme():
|
|
|
196
196
|
|
|
197
197
|
# Other available functions
|
|
198
198
|
bs.ops.transpose(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
199
|
-
bs.ops.softmax(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
200
|
-
bs.ops.softmax_fused(o_sparse, sparsity_layout_o, sparsity_block_size) # Significantly faster version that requires that rows of matrix fit into memory
|
|
199
|
+
bs.ops.softmax(o_sparse, sparsity_layout_o, sparsity_block_size, flag_fused=False)
|
|
200
|
+
bs.ops.softmax_fused(o_sparse, sparsity_layout_o, sparsity_block_size) # Significantly faster version that requires that rows of matrix fit into memory (default if flag is not set)
|
|
201
201
|
bs.ops.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
202
202
|
bs.ops.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
203
203
|
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
blksprs/__init__.py,sha256=NRxydw4i9jg7WeDuojfEePdtdbughV9AZsEcT9yywK4,1615
|
|
2
2
|
blksprs/layouting/distribution_layout.py,sha256=ur1ty_2U-Hfj78hMWsLZvu7ZuGhzW3qGLKMc72DfTZM,5861
|
|
3
3
|
blksprs/layouting/sparsity_layout.py,sha256=eXHmu2h7K5Q-YUpfOxocJoeP_5ZoQFZf_eHLxRZQbYU,11207
|
|
4
|
-
blksprs/ops/conversion.py,sha256=
|
|
4
|
+
blksprs/ops/conversion.py,sha256=nv5gXiyZkUtk1kCIlPr0Vpaj4G8G6dJdW7StlbV3nDw,21914
|
|
5
5
|
blksprs/ops/distribution.py,sha256=0tPldv0ARzmCV1CU2jvfqpHBgOuHPrDFiCtqsLs7CZc,20789
|
|
6
|
-
blksprs/ops/flow.py,sha256=
|
|
7
|
-
blksprs/ops/matmul.py,sha256=
|
|
6
|
+
blksprs/ops/flow.py,sha256=oUn_xDT74220-EmnBnB8bRNtbS1mjbxWpm76PFsK22o,8246
|
|
7
|
+
blksprs/ops/matmul.py,sha256=ES9bpiCIRBxaynNIL5ftDP0c9LSArbj8YJqkPEzBaIU,11879
|
|
8
8
|
blksprs/ops/partitioning.py,sha256=cfQmY9BZqGTvvJorIhtb-EyuGRJGPraWR-wTKdb47aI,9954
|
|
9
9
|
blksprs/ops/repeat.py,sha256=TLYNxwPuT9y5K9xyM41WK5gnggAJF3lI61Q2K7zWjns,9035
|
|
10
|
-
blksprs/ops/softmax.py,sha256=
|
|
10
|
+
blksprs/ops/softmax.py,sha256=tfC_jaAKrA956rxGeb57klMuYRKTiyMCd5Zg5DIH3fc,23649
|
|
11
11
|
blksprs/ops/transpose.py,sha256=U-VAyLRT6_NDv9qYSFzBqfVlDeIpTqAMEXkqto0VF6w,4072
|
|
12
12
|
blksprs/ops/misc/broadcast_ops.py,sha256=-PrHiSJikZh8nXUmXxSCtFEP27TTxFr4wcrNxBjnimk,5987
|
|
13
13
|
blksprs/ops/misc/row_wise.py,sha256=n5FJjAuOd8BHBJQx4bsQwr-HmXkR9PYVAqfk77wjOFU,19653
|
|
@@ -17,7 +17,7 @@ blksprs/utils/blksprs_tensor.py,sha256=pfoz59aJixj_fIoFx76ySiygwRQUemmgjMKepZ2c4
|
|
|
17
17
|
blksprs/utils/processing.py,sha256=RNkEDc0g-sNHRuMPkRzNWU13d3_lIkXMJdoqES4yQTM,3738
|
|
18
18
|
blksprs/utils/tools.py,sha256=TKygEKge4wJtJnXXDg8BTL8vzBpqIJsQ_A3_5FmLpcE,859
|
|
19
19
|
blksprs/utils/validation.py,sha256=G8eQlvJVMKfEX3k2AwBD0A6Ck-gFoRLpLNY6HXsB3fA,4348
|
|
20
|
-
blksprs-2.1.
|
|
21
|
-
blksprs-2.1.
|
|
22
|
-
blksprs-2.1.
|
|
23
|
-
blksprs-2.1.
|
|
20
|
+
blksprs-2.1.3.dist-info/METADATA,sha256=6ZrxPPpkLwXgmq1d-4VQBNPNjlRm76dEMI-LJyiqlfI,9712
|
|
21
|
+
blksprs-2.1.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
22
|
+
blksprs-2.1.3.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
|
|
23
|
+
blksprs-2.1.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|