blksprs 1.2.1__py3-none-any.whl → 1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/misc/repeat_interleave.py +130 -0
- blksprs/ops/transpose.py +1 -1
- blksprs/utils/tools.py +1 -1
- {blksprs-1.2.1.dist-info → blksprs-1.3.dist-info}/METADATA +1 -1
- {blksprs-1.2.1.dist-info → blksprs-1.3.dist-info}/RECORD +7 -6
- {blksprs-1.2.1.dist-info → blksprs-1.3.dist-info}/WHEEL +0 -0
- {blksprs-1.2.1.dist-info → blksprs-1.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
from triton import language as tl
|
|
5
|
+
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size
|
|
7
|
+
from blksprs.utils.validation import validate_contiguous, validate_device, \
|
|
8
|
+
validate_sparsity_block_size, validate_triton_block_size, validate_dimensions
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def repeat_interleave(x: Tensor, sparsity_layout: Tensor, repeats: int,
|
|
12
|
+
sparsity_block_size: int, triton_block_size: int = None) -> tuple[Tensor, Tensor]:
|
|
13
|
+
"""Repeats and interleaves the block-sparse tensor in compressed form.
|
|
14
|
+
|
|
15
|
+
Repeats each matrix contained in the tensors by ``repeats`` amount and places them consecutively in the output
|
|
16
|
+
tensor.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
x (Tensor): A block-sparse tensor in compressed form.
|
|
20
|
+
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
21
|
+
repeats (int): The number of times to repeat the matrices.
|
|
22
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
23
|
+
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
Tensor: A block-sparse tensor in compressed form containing the repeated and interleaved matrices.
|
|
27
|
+
Tensor: The sparsity layout of the resulting output tensor.
|
|
28
|
+
|
|
29
|
+
"""
|
|
30
|
+
validate_dimensions(x)
|
|
31
|
+
validate_contiguous(x)
|
|
32
|
+
validate_device(x)
|
|
33
|
+
validate_sparsity_block_size(sparsity_block_size, x)
|
|
34
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
35
|
+
|
|
36
|
+
sparsity_layout_output = torch.repeat_interleave(sparsity_layout, 3, dim=0).contiguous()
|
|
37
|
+
|
|
38
|
+
sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
|
|
39
|
+
|
|
40
|
+
sparsity_layout_output_flat = sparsity_layout_output.reshape(-1)
|
|
41
|
+
sparsity_output_reverse_lut = ((torch.cumsum(sparsity_layout_output_flat, dim=-1) - 1) *
|
|
42
|
+
(sparsity_layout_output_flat == 1) -
|
|
43
|
+
(1 * (sparsity_layout_output_flat == 0)))
|
|
44
|
+
|
|
45
|
+
n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
|
|
46
|
+
|
|
47
|
+
validate_contiguous(sparsity_layout, sparsity_lut, sparsity_layout_output, sparsity_output_reverse_lut)
|
|
48
|
+
|
|
49
|
+
output = torch.empty(n_sparse_blocks * repeats, sparsity_block_size, sparsity_block_size,
|
|
50
|
+
dtype=x.dtype, device=x.device)
|
|
51
|
+
|
|
52
|
+
x_b, x_r, x_c = x.size()
|
|
53
|
+
x_b_s, x_r_s, x_c_s = x.stride()
|
|
54
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
55
|
+
s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
|
|
56
|
+
o_b, o_r, o_c = output.size()
|
|
57
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
58
|
+
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
|
|
59
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_output.stride()
|
|
60
|
+
|
|
61
|
+
if triton_block_size is None:
|
|
62
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
63
|
+
|
|
64
|
+
triton_grid = lambda meta: [x_b,
|
|
65
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
66
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
67
|
+
|
|
68
|
+
(kernel_repeat_interleave[triton_grid]
|
|
69
|
+
(x,
|
|
70
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
71
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
72
|
+
output,
|
|
73
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
74
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
75
|
+
sparsity_output_reverse_lut,
|
|
76
|
+
repeats,
|
|
77
|
+
triton_block_size))
|
|
78
|
+
|
|
79
|
+
return output, sparsity_layout_output
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@triton.jit
|
|
83
|
+
def kernel_repeat_interleave(x,
|
|
84
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
85
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
86
|
+
o,
|
|
87
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
88
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
89
|
+
r_lut_o,
|
|
90
|
+
repeats,
|
|
91
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
92
|
+
# Get triton block indices
|
|
93
|
+
pid_blk = tl.program_id(axis=0)
|
|
94
|
+
pid_row = tl.program_id(axis=1)
|
|
95
|
+
pid_col = tl.program_id(axis=2)
|
|
96
|
+
|
|
97
|
+
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
98
|
+
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
99
|
+
spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
100
|
+
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
101
|
+
|
|
102
|
+
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
103
|
+
spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
|
|
104
|
+
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
105
|
+
|
|
106
|
+
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
107
|
+
spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
|
|
108
|
+
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
109
|
+
|
|
110
|
+
# Load block
|
|
111
|
+
blk_x_idx = ((pid_blk * x_b_s) +
|
|
112
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
113
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
114
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
115
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
116
|
+
|
|
117
|
+
for repeat in range(repeats):
|
|
118
|
+
# Get reverse sparsity index
|
|
119
|
+
rev_idx_spa_idx = ((spa_bat * repeats + repeat) * s_l_o_b_s +
|
|
120
|
+
spa_row * s_l_o_r_s +
|
|
121
|
+
spa_col * s_l_o_c_s)
|
|
122
|
+
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
|
|
123
|
+
rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
124
|
+
|
|
125
|
+
# Store block
|
|
126
|
+
blk_o_idx = ((rev_idx_spa * o_b_s) +
|
|
127
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
128
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
129
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
130
|
+
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
blksprs/ops/transpose.py
CHANGED
|
@@ -129,7 +129,7 @@ class _BlocksparseTranspose(torch.autograd.Function):
|
|
|
129
129
|
spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
|
|
130
130
|
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
131
131
|
|
|
132
|
-
# Get reverse sparsity
|
|
132
|
+
# Get reverse sparsity index
|
|
133
133
|
rev_idx_spa_idx = (spa_bat * s_l_b_s +
|
|
134
134
|
spa_row * s_l_r_s +
|
|
135
135
|
spa_col * s_l_c_s)
|
blksprs/utils/tools.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.3
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -1,17 +1,18 @@
|
|
|
1
1
|
blksprs/layouting/distribution_layout.py,sha256=GQ-ZRXbeImiLcbaqnL2FuUZ6DoFwmB0naT_YrOpD84Q,4940
|
|
2
2
|
blksprs/layouting/sparsity_layout.py,sha256=TtADT_WWcZpW3zyGy6KAgkAo44gDryXZqdJLZGEX2V8,7895
|
|
3
3
|
blksprs/misc/broadcast_addition.py,sha256=vf1Hdqz9Uyqykto3DCjmdyepMzpMXL238SpANQqRAwI,5297
|
|
4
|
+
blksprs/misc/repeat_interleave.py,sha256=WrIp7uJsnvjIhFeLYPfkL2j5vXyKmDQGrJ69b3Y0lQ8,5644
|
|
4
5
|
blksprs/ops/conversion.py,sha256=-AOzj_j3WrBLGIgd2oVPvYS8XKfzlvGtSIWzW_qP1lk,21260
|
|
5
6
|
blksprs/ops/distribution.py,sha256=_fQb6fWpLxocAh86D74ATahChi0EK0eBb4eUOUEBVps,16769
|
|
6
7
|
blksprs/ops/exp.py,sha256=qs8fVtCzxl4CKT4GepaqurjEL62jyi8VjMY12JFrFAU,3674
|
|
7
8
|
blksprs/ops/matmul.py,sha256=x3lrYg4g8fIf5PeMtZY_SEpi11kP9RFcRoemCIxcSDE,11086
|
|
8
9
|
blksprs/ops/row_wise_sum.py,sha256=ojuSejV37cLtRNS3lBfknA5KY3TEg8EHxOqVT6JZzoM,11387
|
|
9
10
|
blksprs/ops/softmax.py,sha256=ZyeAVqmG_VzJ72FArGrpUSFfoSM4GPxyubrmNKERVIA,11654
|
|
10
|
-
blksprs/ops/transpose.py,sha256=
|
|
11
|
+
blksprs/ops/transpose.py,sha256=cX_E3b-QMhsUDNn9D8HVkYesc2JBc-EcVBUZfCWExM8,6720
|
|
11
12
|
blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
|
|
12
|
-
blksprs/utils/tools.py,sha256=
|
|
13
|
+
blksprs/utils/tools.py,sha256=DwophH01AeNTZAo0B1uWbKFSGBQjI5z0WmFnYKh-BBk,465
|
|
13
14
|
blksprs/utils/validation.py,sha256=gJYZO5C48YUrXV3Fy_Z_lCaOpiFj951FT-Od7sKfprg,3007
|
|
14
|
-
blksprs-1.
|
|
15
|
-
blksprs-1.
|
|
16
|
-
blksprs-1.
|
|
17
|
-
blksprs-1.
|
|
15
|
+
blksprs-1.3.dist-info/METADATA,sha256=bs4_e4DjSYyAQ354tLVNIKcGLkww_-C2AfHnJIMdjA8,7515
|
|
16
|
+
blksprs-1.3.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
|
17
|
+
blksprs-1.3.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
|
|
18
|
+
blksprs-1.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|