blksprs 0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs-0.1/PKG-INFO +26 -0
- blksprs-0.1/README.md +5 -0
- blksprs-0.1/blksprs/blocksparse.py +479 -0
- blksprs-0.1/blksprs.egg-info/PKG-INFO +26 -0
- blksprs-0.1/blksprs.egg-info/SOURCES.txt +8 -0
- blksprs-0.1/blksprs.egg-info/dependency_links.txt +1 -0
- blksprs-0.1/blksprs.egg-info/requires.txt +13 -0
- blksprs-0.1/blksprs.egg-info/top_level.txt +1 -0
- blksprs-0.1/pyproject.toml +36 -0
- blksprs-0.1/setup.cfg +4 -0
blksprs-0.1/PKG-INFO
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: blksprs
|
|
3
|
+
Version: 0.1
|
|
4
|
+
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
|
+
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
|
+
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
7
|
+
Project-URL: Bugtracker, https://github.com/FelixSchoen/blksprs/issues
|
|
8
|
+
Requires-Python: >=3.12
|
|
9
|
+
Description-Content-Type: text/markdown
|
|
10
|
+
Requires-Dist: torch
|
|
11
|
+
Provides-Extra: test
|
|
12
|
+
Requires-Dist: pytest; extra == "test"
|
|
13
|
+
Requires-Dist: pytest-xdist; extra == "test"
|
|
14
|
+
Requires-Dist: pytest-cov; extra == "test"
|
|
15
|
+
Requires-Dist: coverage; extra == "test"
|
|
16
|
+
Requires-Dist: matplotlib; extra == "test"
|
|
17
|
+
Provides-Extra: deploy
|
|
18
|
+
Requires-Dist: build; extra == "deploy"
|
|
19
|
+
Requires-Dist: twine; extra == "deploy"
|
|
20
|
+
Requires-Dist: pdoc3; extra == "deploy"
|
|
21
|
+
|
|
22
|
+
# blksprs
|
|
23
|
+
|
|
24
|
+
## Overview
|
|
25
|
+
|
|
26
|
+
A lightweight library for operations on blocksparse matrices in PyTorch.
|
blksprs-0.1/README.md
ADDED
|
@@ -0,0 +1,479 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import triton
|
|
5
|
+
import triton.language as tl
|
|
6
|
+
from torch import Tensor, Size
|
|
7
|
+
from torch.nn import Module
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BaseBlocksparse(Module, ABC):
|
|
11
|
+
|
|
12
|
+
def __init__(self, sparsity_block_size: int, device: torch.device) -> None:
|
|
13
|
+
super().__init__()
|
|
14
|
+
|
|
15
|
+
self.sparsity_block_size = sparsity_block_size
|
|
16
|
+
self.device = device
|
|
17
|
+
|
|
18
|
+
def validate(self, *tensors: Tensor) -> None:
|
|
19
|
+
for tensor in tensors:
|
|
20
|
+
assert tensor.dim() == 3, "Input tensors must have 3 dimensions"
|
|
21
|
+
assert tensor.is_contiguous(), "Input tensors must be contiguous"
|
|
22
|
+
assert tensor.dtype == torch.float32, "Input tensors must be of type float32"
|
|
23
|
+
assert tensor.device == self.device, "Input tensors must be on the same device"
|
|
24
|
+
|
|
25
|
+
def validate_sparsity(self, *tensors: Tensor) -> None:
|
|
26
|
+
for tensor in tensors:
|
|
27
|
+
assert tensor.size(-1) == tensor.size(
|
|
28
|
+
-2) == self.sparsity_block_size, "Tensor not conforming to sparsity specification"
|
|
29
|
+
|
|
30
|
+
@staticmethod
|
|
31
|
+
def get_triton_block_size(sparsity_block_size):
|
|
32
|
+
return min(sparsity_block_size, 128)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# --- Matmul SSS ---
|
|
36
|
+
|
|
37
|
+
class BlocksparseMatmulSSS(BaseBlocksparse):
|
|
38
|
+
|
|
39
|
+
def __init__(self, sparsity_block_size: int, device: torch.device) -> None:
|
|
40
|
+
super().__init__(sparsity_block_size, device)
|
|
41
|
+
|
|
42
|
+
def forward(self, x: Tensor, y: Tensor,
|
|
43
|
+
sparsity_layout_x: Tensor, sparsity_layout_y: Tensor, sparsity_layout_output: Tensor) -> Tensor:
|
|
44
|
+
self.validate(x, y)
|
|
45
|
+
self.validate_sparsity(x, y)
|
|
46
|
+
assert x.size(2) == y.size(1), "Inner dimensions must match"
|
|
47
|
+
|
|
48
|
+
output_n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
|
|
49
|
+
sparsity_layout_x_flat = sparsity_layout_x.reshape(-1)
|
|
50
|
+
sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
|
|
51
|
+
(sparsity_layout_x_flat == 1) -
|
|
52
|
+
(1 * (sparsity_layout_x_flat == 0)))
|
|
53
|
+
sparsity_layout_y_flat = sparsity_layout_y.reshape(-1)
|
|
54
|
+
sparsity_reverse_lut_y = ((torch.cumsum(sparsity_layout_y_flat, dim=-1) - 1) *
|
|
55
|
+
(sparsity_layout_y_flat == 1) -
|
|
56
|
+
(1 * (sparsity_layout_y_flat == 0)))
|
|
57
|
+
sparsity_lut_output = torch.nonzero(sparsity_layout_output)
|
|
58
|
+
|
|
59
|
+
return _BlocksparseMatmulSSS.apply(x, y,
|
|
60
|
+
sparsity_layout_x, sparsity_reverse_lut_x,
|
|
61
|
+
sparsity_layout_y, sparsity_reverse_lut_y,
|
|
62
|
+
sparsity_layout_output, sparsity_lut_output,
|
|
63
|
+
self.sparsity_block_size,
|
|
64
|
+
output_n_sparse_blocks,
|
|
65
|
+
self.device)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class _BlocksparseMatmulSSS(torch.autograd.Function):
|
|
69
|
+
|
|
70
|
+
@staticmethod
|
|
71
|
+
def forward(ctx, x: Tensor, y: Tensor,
|
|
72
|
+
sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
|
|
73
|
+
sparsity_layout_y: Tensor, sparsity_reverse_lut_y: Tensor,
|
|
74
|
+
sparsity_layout_output: Tensor, sparsity_lut_output: Tensor,
|
|
75
|
+
sparsity_block_size: int, output_n_sparse_blocks: int, device: torch.device) -> Tensor:
|
|
76
|
+
output = torch.zeros(size=(output_n_sparse_blocks, sparsity_block_size, sparsity_block_size), device=device)
|
|
77
|
+
|
|
78
|
+
x_b, x_r, x_c = x.size()
|
|
79
|
+
x_b_s, x_r_s, x_c_s = x.stride()
|
|
80
|
+
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
|
|
81
|
+
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout_x.stride()
|
|
82
|
+
y_b, y_r, y_c = y.size()
|
|
83
|
+
y_b_s, y_r_s, y_c_s = y.stride()
|
|
84
|
+
s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_y.size()
|
|
85
|
+
s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = sparsity_layout_y.stride()
|
|
86
|
+
o_b, o_r, o_c = output.size()
|
|
87
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
88
|
+
s_lut_o_r, s_lut_o_c = sparsity_lut_output.size()
|
|
89
|
+
s_lut_o_r_s, s_lut_o_c_s = sparsity_lut_output.stride()
|
|
90
|
+
|
|
91
|
+
triton_block_size = BaseBlocksparse.get_triton_block_size(sparsity_block_size)
|
|
92
|
+
|
|
93
|
+
triton_grid = lambda meta: [o_b,
|
|
94
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
95
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
96
|
+
|
|
97
|
+
_BlocksparseMatmulSSS.kernel_blocksparse_matmul_sss[triton_grid](x,
|
|
98
|
+
x_b, x_b_s, x_r, x_r_s, x_c, x_c_s,
|
|
99
|
+
s_l_x_b, s_l_x_b_s,
|
|
100
|
+
s_l_x_r, s_l_x_r_s, s_l_x_c, s_l_x_c_s,
|
|
101
|
+
sparsity_reverse_lut_x,
|
|
102
|
+
y,
|
|
103
|
+
y_b, y_b_s, y_r, y_r_s, y_c, y_c_s,
|
|
104
|
+
s_l_y_b, s_l_y_b_s,
|
|
105
|
+
s_l_y_r, s_l_y_r_s, s_l_y_c, s_l_y_c_s,
|
|
106
|
+
sparsity_reverse_lut_y,
|
|
107
|
+
output,
|
|
108
|
+
o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,
|
|
109
|
+
sparsity_lut_output,
|
|
110
|
+
s_lut_o_r, s_lut_o_r_s,
|
|
111
|
+
s_lut_o_c, s_lut_o_c_s,
|
|
112
|
+
sparsity_block_size,
|
|
113
|
+
triton_block_size)
|
|
114
|
+
|
|
115
|
+
return output
|
|
116
|
+
|
|
117
|
+
@staticmethod
|
|
118
|
+
@triton.jit
|
|
119
|
+
def kernel_blocksparse_matmul_sss(x,
|
|
120
|
+
x_b, x_b_s, x_r, x_r_s, x_c, x_c_s,
|
|
121
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r, s_l_x_r_s, s_l_x_c, s_l_x_c_s,
|
|
122
|
+
r_lut_x,
|
|
123
|
+
y,
|
|
124
|
+
y_b, y_b_s, y_r, y_r_s, y_c, y_c_s,
|
|
125
|
+
s_l_y_b, s_l_y_b_s, s_l_y_r, s_l_y_r_s, s_l_y_c, s_l_y_c_s,
|
|
126
|
+
r_lut_y,
|
|
127
|
+
o,
|
|
128
|
+
o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,
|
|
129
|
+
s_lut_o,
|
|
130
|
+
s_lut_o_r, s_lut_o_r_s,
|
|
131
|
+
s_lut_o_c, s_lut_o_c_s,
|
|
132
|
+
sparsity_block_size,
|
|
133
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
134
|
+
# Get triton block indices
|
|
135
|
+
pid_blk = tl.program_id(axis=0)
|
|
136
|
+
pid_row = tl.program_id(axis=1)
|
|
137
|
+
pid_col = tl.program_id(axis=2)
|
|
138
|
+
|
|
139
|
+
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
140
|
+
spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
|
|
141
|
+
spa_bat_o_msk = (spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s + s_lut_o_c * s_lut_o_c_s)
|
|
142
|
+
spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
|
|
143
|
+
|
|
144
|
+
spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
|
|
145
|
+
spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s + s_lut_o_c * s_lut_o_c_s)
|
|
146
|
+
spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
|
|
147
|
+
|
|
148
|
+
spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
|
|
149
|
+
spa_col_o_msk = (spa_col_o_idx < s_lut_o_r * s_lut_o_r_s + s_lut_o_c * s_lut_o_c_s)
|
|
150
|
+
spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
|
|
151
|
+
|
|
152
|
+
# Setup buffer
|
|
153
|
+
buf = tl.zeros(shape=(TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), dtype=tl.float32)
|
|
154
|
+
|
|
155
|
+
# Slide over triton block sized segments of input tensors
|
|
156
|
+
for i_seg_tri in range(0, tl.cdiv(s_l_x_c * sparsity_block_size, TRITON_BLOCK_SIZE)):
|
|
157
|
+
# Convert to segment index of sparsity layout
|
|
158
|
+
i_seg_spa = (i_seg_tri * TRITON_BLOCK_SIZE) // sparsity_block_size
|
|
159
|
+
# Calculate the triton segment index within a block
|
|
160
|
+
i_seg_tri_mod = i_seg_tri % (sparsity_block_size // TRITON_BLOCK_SIZE)
|
|
161
|
+
|
|
162
|
+
# Get reverse sparsity indices for input tensors.
|
|
163
|
+
# These are either -1 if the block is empty or equal to the index of the block in the sparse tensor.
|
|
164
|
+
rev_idx_spa_x_idx = (spa_bat_o * s_l_x_b_s + spa_row_o * s_l_x_r_s + i_seg_spa * s_l_x_c_s)
|
|
165
|
+
rev_idx_spa_x_msk = (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s + s_l_x_r * s_l_x_r_s + s_l_x_c * s_l_x_c_s)
|
|
166
|
+
rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
|
|
167
|
+
|
|
168
|
+
rev_idx_spa_y_idx = (spa_bat_o * s_l_y_b_s + i_seg_spa * s_l_y_r_s + spa_col_o * s_l_y_c_s)
|
|
169
|
+
rev_idx_spa_y_msk = (rev_idx_spa_y_idx < s_l_y_b * s_l_y_b_s + s_l_y_r * s_l_y_r_s + s_l_y_c * s_l_y_c_s)
|
|
170
|
+
rev_idx_spa_y = tl.load(r_lut_y + rev_idx_spa_y_idx, mask=rev_idx_spa_y_msk).to(tl.int32)
|
|
171
|
+
|
|
172
|
+
# If both blocks are present commence calculation
|
|
173
|
+
if rev_idx_spa_x >= 0 and rev_idx_spa_y >= 0:
|
|
174
|
+
blk_x_idx = ((rev_idx_spa_x * x_b_s) +
|
|
175
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
176
|
+
((i_seg_tri_mod * TRITON_BLOCK_SIZE +
|
|
177
|
+
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
178
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s + x_r * x_r_s + x_c * x_c_s)
|
|
179
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
180
|
+
|
|
181
|
+
blk_y_idx = ((rev_idx_spa_y * y_b_s) +
|
|
182
|
+
((i_seg_tri_mod * TRITON_BLOCK_SIZE +
|
|
183
|
+
tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
|
|
184
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
|
|
185
|
+
blk_y_msk = (blk_y_idx < y_b * y_b_s + y_r * y_r_s + y_c * y_c_s)
|
|
186
|
+
blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
|
|
187
|
+
|
|
188
|
+
# Perform matrix multiplication
|
|
189
|
+
buf += tl.dot(blk_x, blk_y)
|
|
190
|
+
|
|
191
|
+
# Store output
|
|
192
|
+
blk_o_idx = ((pid_blk * o_b_s) +
|
|
193
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
194
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
195
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s + o_r * o_r_s + o_c * o_c_s)
|
|
196
|
+
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
# --- Softmax ---
|
|
200
|
+
|
|
201
|
+
class BlocksparseSoftmax(BaseBlocksparse):
|
|
202
|
+
# TODO At the moment uses standard softmax instead of blocksparse improvements
|
|
203
|
+
|
|
204
|
+
def __init__(self, sparsity_block_size: int, device: torch.device) -> None:
|
|
205
|
+
super().__init__(sparsity_block_size, device)
|
|
206
|
+
|
|
207
|
+
self.blksprs_to_dense = BlocksparseToDense(sparsity_block_size, device)
|
|
208
|
+
self.blksprs_to_sparse = BlocksparseToSparse(sparsity_block_size, device)
|
|
209
|
+
|
|
210
|
+
def forward(self, x: Tensor, sparsity_layout: Tensor) -> Tensor:
|
|
211
|
+
self.validate(x)
|
|
212
|
+
|
|
213
|
+
x_dense = self.blksprs_to_dense(x, sparsity_layout, fill_value=float('-inf'))
|
|
214
|
+
x_softmax = torch.softmax(x_dense, dim=-1)
|
|
215
|
+
x_sparse = self.blksprs_to_sparse(x_softmax, sparsity_layout)
|
|
216
|
+
|
|
217
|
+
return x_sparse
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
# --- Transpose ---
|
|
221
|
+
|
|
222
|
+
class BlocksparseTranspose(BaseBlocksparse):
|
|
223
|
+
|
|
224
|
+
def __init__(self, sparsity_block_size: int, device: torch.device) -> None:
|
|
225
|
+
super().__init__(sparsity_block_size, device)
|
|
226
|
+
|
|
227
|
+
def forward(self, x: Tensor, sparsity_layout: Tensor, shuffle_blocks: bool = True) -> (Tensor, Tensor):
|
|
228
|
+
self.validate(x)
|
|
229
|
+
|
|
230
|
+
x_t = x.transpose(1, 2).contiguous()
|
|
231
|
+
sparsity_layout_t = sparsity_layout.transpose(-1, -2).contiguous()
|
|
232
|
+
|
|
233
|
+
shuffle_layout = (torch.cumsum(sparsity_layout.reshape(-1), dim=-1)
|
|
234
|
+
.reshape(sparsity_layout.size()).transpose(-1, -2)
|
|
235
|
+
.reshape(-1).to(torch.int) - 1)
|
|
236
|
+
|
|
237
|
+
x_t = x_t[shuffle_layout, :, :]
|
|
238
|
+
|
|
239
|
+
return x_t, sparsity_layout_t
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
# --- To Dense ---
|
|
243
|
+
|
|
244
|
+
class BlocksparseToDense(BaseBlocksparse):
|
|
245
|
+
|
|
246
|
+
def __init__(self, sparsity_block_size: int, device: torch.device) -> None:
|
|
247
|
+
super().__init__(sparsity_block_size, device)
|
|
248
|
+
|
|
249
|
+
def forward(self, x: Tensor, sparsity_layout: Tensor, fill_value: int = 0) -> Tensor:
|
|
250
|
+
self.validate(x)
|
|
251
|
+
|
|
252
|
+
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
|
253
|
+
sparsity_reverse_lut = ((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
254
|
+
(sparsity_layout_flat == 1) -
|
|
255
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
256
|
+
|
|
257
|
+
return _BlocksparseToDense.apply(x,
|
|
258
|
+
sparsity_layout, sparsity_reverse_lut,
|
|
259
|
+
self.sparsity_block_size, fill_value, self.device)
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
class _BlocksparseToDense(torch.autograd.Function):
|
|
263
|
+
|
|
264
|
+
@staticmethod
|
|
265
|
+
def forward(ctx, x: Tensor,
|
|
266
|
+
sparsity_layout: Tensor, sparsity_reverse_lut: Tensor,
|
|
267
|
+
sparsity_block_size: int, fill_value: int, device: torch.device) -> Tensor:
|
|
268
|
+
output = torch.full(size=(sparsity_layout.size(0), sparsity_layout.size(1) * sparsity_block_size,
|
|
269
|
+
sparsity_layout.size(2) * sparsity_block_size), fill_value=fill_value,
|
|
270
|
+
dtype=x.dtype, device=device)
|
|
271
|
+
|
|
272
|
+
x_b, x_r, x_c = x.shape
|
|
273
|
+
x_b_s, x_r_s, x_c_s = x.stride()
|
|
274
|
+
s_l_b, s_l_r, s_l_c = sparsity_layout.size()
|
|
275
|
+
s_l_b_s, s_l_r_s, s_l_c_s = sparsity_layout.stride()
|
|
276
|
+
o_b, o_r, o_c = output.size()
|
|
277
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
278
|
+
|
|
279
|
+
triton_block_size = BaseBlocksparse.get_triton_block_size(sparsity_block_size)
|
|
280
|
+
|
|
281
|
+
triton_grid = lambda meta: [o_b,
|
|
282
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
283
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
284
|
+
|
|
285
|
+
_BlocksparseToDense.kernel_blocksparse_to_dense[triton_grid](x,
|
|
286
|
+
x_b, x_b_s, x_r, x_r_s, x_c, x_c_s,
|
|
287
|
+
s_l_b, s_l_b_s, s_l_r, s_l_r_s, s_l_c, s_l_c_s,
|
|
288
|
+
sparsity_reverse_lut,
|
|
289
|
+
output,
|
|
290
|
+
o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,
|
|
291
|
+
sparsity_block_size,
|
|
292
|
+
triton_block_size)
|
|
293
|
+
|
|
294
|
+
return output
|
|
295
|
+
|
|
296
|
+
@staticmethod
|
|
297
|
+
def backward(ctx, grad_output):
|
|
298
|
+
raise NotImplementedError
|
|
299
|
+
|
|
300
|
+
@staticmethod
|
|
301
|
+
@triton.jit
|
|
302
|
+
def kernel_blocksparse_to_dense(x,
|
|
303
|
+
x_b, x_b_s, x_r, x_r_s, x_c, x_c_s,
|
|
304
|
+
s_l_b, s_l_b_s, s_l_r, s_l_r_s, s_l_c, s_l_c_s,
|
|
305
|
+
sparsity_reverse_lut,
|
|
306
|
+
o,
|
|
307
|
+
o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,
|
|
308
|
+
sparsity_block_size,
|
|
309
|
+
TRITON_BLOCK_SIZE: tl.constexpr):
|
|
310
|
+
# Get triton block indices
|
|
311
|
+
pid_bat = tl.program_id(axis=0)
|
|
312
|
+
pid_row = tl.program_id(axis=1)
|
|
313
|
+
pid_col = tl.program_id(axis=2)
|
|
314
|
+
|
|
315
|
+
# Get sparsity index of current block
|
|
316
|
+
spa_row = (pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size
|
|
317
|
+
spa_col = (pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size
|
|
318
|
+
|
|
319
|
+
# Get reverse sparsity index for current block
|
|
320
|
+
rev_idx_spa_idx = (pid_bat * s_l_b_s + spa_row * s_l_r_s + spa_col * s_l_c_s)
|
|
321
|
+
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_b * s_l_b_s + s_l_r * s_l_r_s + s_l_c * s_l_c_s)
|
|
322
|
+
rev_idx_spa = tl.load(sparsity_reverse_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
323
|
+
|
|
324
|
+
# If block is present commence operations
|
|
325
|
+
if rev_idx_spa >= 0:
|
|
326
|
+
blk_idx = (rev_idx_spa * x_b_s +
|
|
327
|
+
(((pid_row % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
|
|
328
|
+
tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
329
|
+
(((pid_col % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
|
|
330
|
+
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
331
|
+
blk_msk = (blk_idx < x_b * x_b_s + x_r * x_r_s + x_c * x_c_s)
|
|
332
|
+
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
333
|
+
|
|
334
|
+
o_idx = (pid_bat * o_b_s +
|
|
335
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
336
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
337
|
+
o_msk = (o_idx < o_b * o_b_s + o_r * o_r_s + o_c * o_c_s)
|
|
338
|
+
tl.store(o + o_idx, blk, o_msk)
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
# --- To Sparse ---
|
|
342
|
+
|
|
343
|
+
class BlocksparseToSparse(BaseBlocksparse):
|
|
344
|
+
|
|
345
|
+
def __init__(self, sparsity_block_size: int, device: torch.device) -> None:
|
|
346
|
+
super().__init__(sparsity_block_size, device)
|
|
347
|
+
|
|
348
|
+
def forward(self, x: Tensor, sparsity_layout: Tensor) -> Tensor:
|
|
349
|
+
self.validate(x)
|
|
350
|
+
|
|
351
|
+
sparsity_lut = torch.nonzero(sparsity_layout)
|
|
352
|
+
n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
|
|
353
|
+
|
|
354
|
+
return _BlocksparseToSparse.apply(x,
|
|
355
|
+
sparsity_layout, sparsity_lut,
|
|
356
|
+
self.sparsity_block_size, n_sparse_blocks, self.device)
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
class _BlocksparseToSparse(torch.autograd.Function):
|
|
360
|
+
|
|
361
|
+
@staticmethod
|
|
362
|
+
def forward(ctx, x: Tensor,
|
|
363
|
+
sparsity_layout: Tensor, sparsity_lut: Tensor,
|
|
364
|
+
sparsity_block_size: int, n_sparse_blocks: int, device: torch.device) -> Tensor:
|
|
365
|
+
output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size), device=device)
|
|
366
|
+
|
|
367
|
+
x_b, x_r, x_c = x.size()
|
|
368
|
+
x_b_s, x_r_s, x_c_s = x.stride()
|
|
369
|
+
o_b, o_r, o_c = output.size()
|
|
370
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
371
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
372
|
+
s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
|
|
373
|
+
|
|
374
|
+
triton_block_size = BaseBlocksparse.get_triton_block_size(sparsity_block_size)
|
|
375
|
+
|
|
376
|
+
triton_grid = lambda meta: [o_b,
|
|
377
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
378
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
379
|
+
|
|
380
|
+
_BlocksparseToSparse.kernel_blocksparse_to_sparse[triton_grid](x, x_b, x_b_s, x_r, x_r_s, x_c, x_c_s,
|
|
381
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c,
|
|
382
|
+
s_lut_c_s,
|
|
383
|
+
output, o_b_s, o_r_s, o_c_s,
|
|
384
|
+
sparsity_block_size,
|
|
385
|
+
triton_block_size)
|
|
386
|
+
|
|
387
|
+
return output
|
|
388
|
+
|
|
389
|
+
@staticmethod
|
|
390
|
+
def backward(ctx, grad_output):
|
|
391
|
+
raise NotImplementedError
|
|
392
|
+
|
|
393
|
+
@staticmethod
|
|
394
|
+
@triton.jit
|
|
395
|
+
def kernel_blocksparse_to_sparse(x,
|
|
396
|
+
x_b, x_b_s, x_r, x_r_s, x_c: tl.constexpr, x_c_s,
|
|
397
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c, s_lut_c_s,
|
|
398
|
+
o,
|
|
399
|
+
o_b_s, o_r_s, o_c_s,
|
|
400
|
+
sparsity_block_size,
|
|
401
|
+
TRITON_BLOCK_SIZE: tl.constexpr):
|
|
402
|
+
# Get triton block indices
|
|
403
|
+
pid_blk = tl.program_id(axis=0)
|
|
404
|
+
pid_row = tl.program_id(axis=1)
|
|
405
|
+
pid_col = tl.program_id(axis=2)
|
|
406
|
+
|
|
407
|
+
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
408
|
+
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
409
|
+
spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s + s_lut_c * s_lut_c_s)
|
|
410
|
+
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
411
|
+
|
|
412
|
+
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
413
|
+
spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s + s_lut_c * s_lut_c_s)
|
|
414
|
+
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
415
|
+
|
|
416
|
+
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
417
|
+
spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s + s_lut_c * s_lut_c_s)
|
|
418
|
+
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
419
|
+
|
|
420
|
+
# Load block from dense tensor
|
|
421
|
+
blk_d_idx = (spa_bat * x_b_s +
|
|
422
|
+
((spa_row * sparsity_block_size + pid_row * TRITON_BLOCK_SIZE +
|
|
423
|
+
tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
424
|
+
((spa_col * sparsity_block_size + pid_col * TRITON_BLOCK_SIZE +
|
|
425
|
+
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
426
|
+
blk_d_msk = (blk_d_idx < x_b * x_b_s + x_r * x_r_s + x_c * x_c_s)
|
|
427
|
+
blk_d = tl.load(x + blk_d_idx, mask=blk_d_msk)
|
|
428
|
+
|
|
429
|
+
# Store block in sparse tensor
|
|
430
|
+
blk_o_idx = ((pid_blk * o_b_s) +
|
|
431
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
432
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
|
|
433
|
+
blk_o_msk = (blk_o_idx < (pid_blk + 1) * o_b_s)
|
|
434
|
+
tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
class BlocksparseTools:
|
|
438
|
+
|
|
439
|
+
@staticmethod
|
|
440
|
+
def do_shape_blocksparse(x: Tensor):
|
|
441
|
+
if x.dim() == 3:
|
|
442
|
+
return x
|
|
443
|
+
|
|
444
|
+
return x.reshape(-1, x.size(-2), x.size(-1))
|
|
445
|
+
|
|
446
|
+
@staticmethod
|
|
447
|
+
def undo_shape_blocksparse(x: Tensor, shape: Size):
|
|
448
|
+
if x.dim() == 3:
|
|
449
|
+
return x
|
|
450
|
+
|
|
451
|
+
return x.reshape(shape)
|
|
452
|
+
|
|
453
|
+
@staticmethod
|
|
454
|
+
def to_dense(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int):
|
|
455
|
+
output = torch.zeros(size=(sparsity_layout.size(0), sparsity_layout.size(1) * sparsity_block_size,
|
|
456
|
+
sparsity_layout.size(2) * sparsity_block_size), device=x.device)
|
|
457
|
+
indices_sparse_blocks = sparsity_layout.nonzero(as_tuple=True)
|
|
458
|
+
|
|
459
|
+
for idx, (b, r, c) in enumerate(zip(*indices_sparse_blocks)):
|
|
460
|
+
t_r = r * sparsity_block_size
|
|
461
|
+
t_c = c * sparsity_block_size
|
|
462
|
+
to_insert = x[idx]
|
|
463
|
+
output[b, t_r:t_r + sparsity_block_size, t_c:t_c + sparsity_block_size] = to_insert
|
|
464
|
+
|
|
465
|
+
return output
|
|
466
|
+
|
|
467
|
+
@staticmethod
|
|
468
|
+
def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int):
|
|
469
|
+
indices_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
|
|
470
|
+
output = torch.zeros(size=(indices_sparse_blocks, sparsity_block_size, sparsity_block_size), device=x.device)
|
|
471
|
+
indices_sparse_blocks = sparsity_layout.nonzero(as_tuple=True)
|
|
472
|
+
|
|
473
|
+
for idx, (b, r, c) in enumerate(zip(*indices_sparse_blocks)):
|
|
474
|
+
t_r = r * sparsity_block_size
|
|
475
|
+
t_c = c * sparsity_block_size
|
|
476
|
+
to_insert = x[b, t_r:t_r + sparsity_block_size, t_c:t_c + sparsity_block_size]
|
|
477
|
+
output[idx] = to_insert
|
|
478
|
+
|
|
479
|
+
return output
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: blksprs
|
|
3
|
+
Version: 0.1
|
|
4
|
+
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
|
+
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
|
+
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
7
|
+
Project-URL: Bugtracker, https://github.com/FelixSchoen/blksprs/issues
|
|
8
|
+
Requires-Python: >=3.12
|
|
9
|
+
Description-Content-Type: text/markdown
|
|
10
|
+
Requires-Dist: torch
|
|
11
|
+
Provides-Extra: test
|
|
12
|
+
Requires-Dist: pytest; extra == "test"
|
|
13
|
+
Requires-Dist: pytest-xdist; extra == "test"
|
|
14
|
+
Requires-Dist: pytest-cov; extra == "test"
|
|
15
|
+
Requires-Dist: coverage; extra == "test"
|
|
16
|
+
Requires-Dist: matplotlib; extra == "test"
|
|
17
|
+
Provides-Extra: deploy
|
|
18
|
+
Requires-Dist: build; extra == "deploy"
|
|
19
|
+
Requires-Dist: twine; extra == "deploy"
|
|
20
|
+
Requires-Dist: pdoc3; extra == "deploy"
|
|
21
|
+
|
|
22
|
+
# blksprs
|
|
23
|
+
|
|
24
|
+
## Overview
|
|
25
|
+
|
|
26
|
+
A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
blksprs
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "blksprs"
|
|
3
|
+
version = "0.1"
|
|
4
|
+
authors = [{ name = "Felix Schön", email = "schoen@kr.tuwien.ac.at" }]
|
|
5
|
+
description = "A lightweight library for operations on blocksparse matrices in PyTorch."
|
|
6
|
+
readme = "README.md"
|
|
7
|
+
requires-python = ">=3.12"
|
|
8
|
+
license = { file = "LICENSE.md" }
|
|
9
|
+
dependencies = [
|
|
10
|
+
"torch"
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
[project.urls]
|
|
14
|
+
"Homepage" = "https://github.com/FelixSchoen/blksprs"
|
|
15
|
+
"Bugtracker" = "https://github.com/FelixSchoen/blksprs/issues"
|
|
16
|
+
|
|
17
|
+
[project.optional-dependencies]
|
|
18
|
+
test = [
|
|
19
|
+
"pytest",
|
|
20
|
+
"pytest-xdist",
|
|
21
|
+
"pytest-cov",
|
|
22
|
+
"coverage",
|
|
23
|
+
"matplotlib"
|
|
24
|
+
]
|
|
25
|
+
deploy = [
|
|
26
|
+
"build",
|
|
27
|
+
"twine",
|
|
28
|
+
"pdoc3"
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
[build-system]
|
|
32
|
+
requires = ["setuptools", "wheel"]
|
|
33
|
+
build-backend = "setuptools.build_meta"
|
|
34
|
+
|
|
35
|
+
[tool.setuptools.package-data]
|
|
36
|
+
"*" = ["*.json", "*.conf"]
|
blksprs-0.1/setup.cfg
ADDED