blksprs 0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs-0.1/PKG-INFO ADDED
@@ -0,0 +1,26 @@
1
+ Metadata-Version: 2.1
2
+ Name: blksprs
3
+ Version: 0.1
4
+ Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
+ Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
+ Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
7
+ Project-URL: Bugtracker, https://github.com/FelixSchoen/blksprs/issues
8
+ Requires-Python: >=3.12
9
+ Description-Content-Type: text/markdown
10
+ Requires-Dist: torch
11
+ Provides-Extra: test
12
+ Requires-Dist: pytest; extra == "test"
13
+ Requires-Dist: pytest-xdist; extra == "test"
14
+ Requires-Dist: pytest-cov; extra == "test"
15
+ Requires-Dist: coverage; extra == "test"
16
+ Requires-Dist: matplotlib; extra == "test"
17
+ Provides-Extra: deploy
18
+ Requires-Dist: build; extra == "deploy"
19
+ Requires-Dist: twine; extra == "deploy"
20
+ Requires-Dist: pdoc3; extra == "deploy"
21
+
22
+ # blksprs
23
+
24
+ ## Overview
25
+
26
+ A lightweight library for operations on blocksparse matrices in PyTorch.
blksprs-0.1/README.md ADDED
@@ -0,0 +1,5 @@
1
+ # blksprs
2
+
3
+ ## Overview
4
+
5
+ A lightweight library for operations on blocksparse matrices in PyTorch.
@@ -0,0 +1,479 @@
1
+ from abc import ABC
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+ from torch import Tensor, Size
7
+ from torch.nn import Module
8
+
9
+
10
+ class BaseBlocksparse(Module, ABC):
11
+
12
+ def __init__(self, sparsity_block_size: int, device: torch.device) -> None:
13
+ super().__init__()
14
+
15
+ self.sparsity_block_size = sparsity_block_size
16
+ self.device = device
17
+
18
+ def validate(self, *tensors: Tensor) -> None:
19
+ for tensor in tensors:
20
+ assert tensor.dim() == 3, "Input tensors must have 3 dimensions"
21
+ assert tensor.is_contiguous(), "Input tensors must be contiguous"
22
+ assert tensor.dtype == torch.float32, "Input tensors must be of type float32"
23
+ assert tensor.device == self.device, "Input tensors must be on the same device"
24
+
25
+ def validate_sparsity(self, *tensors: Tensor) -> None:
26
+ for tensor in tensors:
27
+ assert tensor.size(-1) == tensor.size(
28
+ -2) == self.sparsity_block_size, "Tensor not conforming to sparsity specification"
29
+
30
+ @staticmethod
31
+ def get_triton_block_size(sparsity_block_size):
32
+ return min(sparsity_block_size, 128)
33
+
34
+
35
+ # --- Matmul SSS ---
36
+
37
+ class BlocksparseMatmulSSS(BaseBlocksparse):
38
+
39
+ def __init__(self, sparsity_block_size: int, device: torch.device) -> None:
40
+ super().__init__(sparsity_block_size, device)
41
+
42
+ def forward(self, x: Tensor, y: Tensor,
43
+ sparsity_layout_x: Tensor, sparsity_layout_y: Tensor, sparsity_layout_output: Tensor) -> Tensor:
44
+ self.validate(x, y)
45
+ self.validate_sparsity(x, y)
46
+ assert x.size(2) == y.size(1), "Inner dimensions must match"
47
+
48
+ output_n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
49
+ sparsity_layout_x_flat = sparsity_layout_x.reshape(-1)
50
+ sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
51
+ (sparsity_layout_x_flat == 1) -
52
+ (1 * (sparsity_layout_x_flat == 0)))
53
+ sparsity_layout_y_flat = sparsity_layout_y.reshape(-1)
54
+ sparsity_reverse_lut_y = ((torch.cumsum(sparsity_layout_y_flat, dim=-1) - 1) *
55
+ (sparsity_layout_y_flat == 1) -
56
+ (1 * (sparsity_layout_y_flat == 0)))
57
+ sparsity_lut_output = torch.nonzero(sparsity_layout_output)
58
+
59
+ return _BlocksparseMatmulSSS.apply(x, y,
60
+ sparsity_layout_x, sparsity_reverse_lut_x,
61
+ sparsity_layout_y, sparsity_reverse_lut_y,
62
+ sparsity_layout_output, sparsity_lut_output,
63
+ self.sparsity_block_size,
64
+ output_n_sparse_blocks,
65
+ self.device)
66
+
67
+
68
+ class _BlocksparseMatmulSSS(torch.autograd.Function):
69
+
70
+ @staticmethod
71
+ def forward(ctx, x: Tensor, y: Tensor,
72
+ sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
73
+ sparsity_layout_y: Tensor, sparsity_reverse_lut_y: Tensor,
74
+ sparsity_layout_output: Tensor, sparsity_lut_output: Tensor,
75
+ sparsity_block_size: int, output_n_sparse_blocks: int, device: torch.device) -> Tensor:
76
+ output = torch.zeros(size=(output_n_sparse_blocks, sparsity_block_size, sparsity_block_size), device=device)
77
+
78
+ x_b, x_r, x_c = x.size()
79
+ x_b_s, x_r_s, x_c_s = x.stride()
80
+ s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
81
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout_x.stride()
82
+ y_b, y_r, y_c = y.size()
83
+ y_b_s, y_r_s, y_c_s = y.stride()
84
+ s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_y.size()
85
+ s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = sparsity_layout_y.stride()
86
+ o_b, o_r, o_c = output.size()
87
+ o_b_s, o_r_s, o_c_s = output.stride()
88
+ s_lut_o_r, s_lut_o_c = sparsity_lut_output.size()
89
+ s_lut_o_r_s, s_lut_o_c_s = sparsity_lut_output.stride()
90
+
91
+ triton_block_size = BaseBlocksparse.get_triton_block_size(sparsity_block_size)
92
+
93
+ triton_grid = lambda meta: [o_b,
94
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
95
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
96
+
97
+ _BlocksparseMatmulSSS.kernel_blocksparse_matmul_sss[triton_grid](x,
98
+ x_b, x_b_s, x_r, x_r_s, x_c, x_c_s,
99
+ s_l_x_b, s_l_x_b_s,
100
+ s_l_x_r, s_l_x_r_s, s_l_x_c, s_l_x_c_s,
101
+ sparsity_reverse_lut_x,
102
+ y,
103
+ y_b, y_b_s, y_r, y_r_s, y_c, y_c_s,
104
+ s_l_y_b, s_l_y_b_s,
105
+ s_l_y_r, s_l_y_r_s, s_l_y_c, s_l_y_c_s,
106
+ sparsity_reverse_lut_y,
107
+ output,
108
+ o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,
109
+ sparsity_lut_output,
110
+ s_lut_o_r, s_lut_o_r_s,
111
+ s_lut_o_c, s_lut_o_c_s,
112
+ sparsity_block_size,
113
+ triton_block_size)
114
+
115
+ return output
116
+
117
+ @staticmethod
118
+ @triton.jit
119
+ def kernel_blocksparse_matmul_sss(x,
120
+ x_b, x_b_s, x_r, x_r_s, x_c, x_c_s,
121
+ s_l_x_b, s_l_x_b_s, s_l_x_r, s_l_x_r_s, s_l_x_c, s_l_x_c_s,
122
+ r_lut_x,
123
+ y,
124
+ y_b, y_b_s, y_r, y_r_s, y_c, y_c_s,
125
+ s_l_y_b, s_l_y_b_s, s_l_y_r, s_l_y_r_s, s_l_y_c, s_l_y_c_s,
126
+ r_lut_y,
127
+ o,
128
+ o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,
129
+ s_lut_o,
130
+ s_lut_o_r, s_lut_o_r_s,
131
+ s_lut_o_c, s_lut_o_c_s,
132
+ sparsity_block_size,
133
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
134
+ # Get triton block indices
135
+ pid_blk = tl.program_id(axis=0)
136
+ pid_row = tl.program_id(axis=1)
137
+ pid_col = tl.program_id(axis=2)
138
+
139
+ # Get sparsity index of current output block consisting of its batch, row, and column index
140
+ spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
141
+ spa_bat_o_msk = (spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s + s_lut_o_c * s_lut_o_c_s)
142
+ spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
143
+
144
+ spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
145
+ spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s + s_lut_o_c * s_lut_o_c_s)
146
+ spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
147
+
148
+ spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
149
+ spa_col_o_msk = (spa_col_o_idx < s_lut_o_r * s_lut_o_r_s + s_lut_o_c * s_lut_o_c_s)
150
+ spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
151
+
152
+ # Setup buffer
153
+ buf = tl.zeros(shape=(TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), dtype=tl.float32)
154
+
155
+ # Slide over triton block sized segments of input tensors
156
+ for i_seg_tri in range(0, tl.cdiv(s_l_x_c * sparsity_block_size, TRITON_BLOCK_SIZE)):
157
+ # Convert to segment index of sparsity layout
158
+ i_seg_spa = (i_seg_tri * TRITON_BLOCK_SIZE) // sparsity_block_size
159
+ # Calculate the triton segment index within a block
160
+ i_seg_tri_mod = i_seg_tri % (sparsity_block_size // TRITON_BLOCK_SIZE)
161
+
162
+ # Get reverse sparsity indices for input tensors.
163
+ # These are either -1 if the block is empty or equal to the index of the block in the sparse tensor.
164
+ rev_idx_spa_x_idx = (spa_bat_o * s_l_x_b_s + spa_row_o * s_l_x_r_s + i_seg_spa * s_l_x_c_s)
165
+ rev_idx_spa_x_msk = (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s + s_l_x_r * s_l_x_r_s + s_l_x_c * s_l_x_c_s)
166
+ rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
167
+
168
+ rev_idx_spa_y_idx = (spa_bat_o * s_l_y_b_s + i_seg_spa * s_l_y_r_s + spa_col_o * s_l_y_c_s)
169
+ rev_idx_spa_y_msk = (rev_idx_spa_y_idx < s_l_y_b * s_l_y_b_s + s_l_y_r * s_l_y_r_s + s_l_y_c * s_l_y_c_s)
170
+ rev_idx_spa_y = tl.load(r_lut_y + rev_idx_spa_y_idx, mask=rev_idx_spa_y_msk).to(tl.int32)
171
+
172
+ # If both blocks are present commence calculation
173
+ if rev_idx_spa_x >= 0 and rev_idx_spa_y >= 0:
174
+ blk_x_idx = ((rev_idx_spa_x * x_b_s) +
175
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
176
+ ((i_seg_tri_mod * TRITON_BLOCK_SIZE +
177
+ tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
178
+ blk_x_msk = (blk_x_idx < x_b * x_b_s + x_r * x_r_s + x_c * x_c_s)
179
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
180
+
181
+ blk_y_idx = ((rev_idx_spa_y * y_b_s) +
182
+ ((i_seg_tri_mod * TRITON_BLOCK_SIZE +
183
+ tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
184
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
185
+ blk_y_msk = (blk_y_idx < y_b * y_b_s + y_r * y_r_s + y_c * y_c_s)
186
+ blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
187
+
188
+ # Perform matrix multiplication
189
+ buf += tl.dot(blk_x, blk_y)
190
+
191
+ # Store output
192
+ blk_o_idx = ((pid_blk * o_b_s) +
193
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
194
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
195
+ blk_o_msk = (blk_o_idx < o_b * o_b_s + o_r * o_r_s + o_c * o_c_s)
196
+ tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
197
+
198
+
199
+ # --- Softmax ---
200
+
201
+ class BlocksparseSoftmax(BaseBlocksparse):
202
+ # TODO At the moment uses standard softmax instead of blocksparse improvements
203
+
204
+ def __init__(self, sparsity_block_size: int, device: torch.device) -> None:
205
+ super().__init__(sparsity_block_size, device)
206
+
207
+ self.blksprs_to_dense = BlocksparseToDense(sparsity_block_size, device)
208
+ self.blksprs_to_sparse = BlocksparseToSparse(sparsity_block_size, device)
209
+
210
+ def forward(self, x: Tensor, sparsity_layout: Tensor) -> Tensor:
211
+ self.validate(x)
212
+
213
+ x_dense = self.blksprs_to_dense(x, sparsity_layout, fill_value=float('-inf'))
214
+ x_softmax = torch.softmax(x_dense, dim=-1)
215
+ x_sparse = self.blksprs_to_sparse(x_softmax, sparsity_layout)
216
+
217
+ return x_sparse
218
+
219
+
220
+ # --- Transpose ---
221
+
222
+ class BlocksparseTranspose(BaseBlocksparse):
223
+
224
+ def __init__(self, sparsity_block_size: int, device: torch.device) -> None:
225
+ super().__init__(sparsity_block_size, device)
226
+
227
+ def forward(self, x: Tensor, sparsity_layout: Tensor, shuffle_blocks: bool = True) -> (Tensor, Tensor):
228
+ self.validate(x)
229
+
230
+ x_t = x.transpose(1, 2).contiguous()
231
+ sparsity_layout_t = sparsity_layout.transpose(-1, -2).contiguous()
232
+
233
+ shuffle_layout = (torch.cumsum(sparsity_layout.reshape(-1), dim=-1)
234
+ .reshape(sparsity_layout.size()).transpose(-1, -2)
235
+ .reshape(-1).to(torch.int) - 1)
236
+
237
+ x_t = x_t[shuffle_layout, :, :]
238
+
239
+ return x_t, sparsity_layout_t
240
+
241
+
242
+ # --- To Dense ---
243
+
244
+ class BlocksparseToDense(BaseBlocksparse):
245
+
246
+ def __init__(self, sparsity_block_size: int, device: torch.device) -> None:
247
+ super().__init__(sparsity_block_size, device)
248
+
249
+ def forward(self, x: Tensor, sparsity_layout: Tensor, fill_value: int = 0) -> Tensor:
250
+ self.validate(x)
251
+
252
+ sparsity_layout_flat = sparsity_layout.reshape(-1)
253
+ sparsity_reverse_lut = ((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
254
+ (sparsity_layout_flat == 1) -
255
+ (1 * (sparsity_layout_flat == 0)))
256
+
257
+ return _BlocksparseToDense.apply(x,
258
+ sparsity_layout, sparsity_reverse_lut,
259
+ self.sparsity_block_size, fill_value, self.device)
260
+
261
+
262
+ class _BlocksparseToDense(torch.autograd.Function):
263
+
264
+ @staticmethod
265
+ def forward(ctx, x: Tensor,
266
+ sparsity_layout: Tensor, sparsity_reverse_lut: Tensor,
267
+ sparsity_block_size: int, fill_value: int, device: torch.device) -> Tensor:
268
+ output = torch.full(size=(sparsity_layout.size(0), sparsity_layout.size(1) * sparsity_block_size,
269
+ sparsity_layout.size(2) * sparsity_block_size), fill_value=fill_value,
270
+ dtype=x.dtype, device=device)
271
+
272
+ x_b, x_r, x_c = x.shape
273
+ x_b_s, x_r_s, x_c_s = x.stride()
274
+ s_l_b, s_l_r, s_l_c = sparsity_layout.size()
275
+ s_l_b_s, s_l_r_s, s_l_c_s = sparsity_layout.stride()
276
+ o_b, o_r, o_c = output.size()
277
+ o_b_s, o_r_s, o_c_s = output.stride()
278
+
279
+ triton_block_size = BaseBlocksparse.get_triton_block_size(sparsity_block_size)
280
+
281
+ triton_grid = lambda meta: [o_b,
282
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
283
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
284
+
285
+ _BlocksparseToDense.kernel_blocksparse_to_dense[triton_grid](x,
286
+ x_b, x_b_s, x_r, x_r_s, x_c, x_c_s,
287
+ s_l_b, s_l_b_s, s_l_r, s_l_r_s, s_l_c, s_l_c_s,
288
+ sparsity_reverse_lut,
289
+ output,
290
+ o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,
291
+ sparsity_block_size,
292
+ triton_block_size)
293
+
294
+ return output
295
+
296
+ @staticmethod
297
+ def backward(ctx, grad_output):
298
+ raise NotImplementedError
299
+
300
+ @staticmethod
301
+ @triton.jit
302
+ def kernel_blocksparse_to_dense(x,
303
+ x_b, x_b_s, x_r, x_r_s, x_c, x_c_s,
304
+ s_l_b, s_l_b_s, s_l_r, s_l_r_s, s_l_c, s_l_c_s,
305
+ sparsity_reverse_lut,
306
+ o,
307
+ o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,
308
+ sparsity_block_size,
309
+ TRITON_BLOCK_SIZE: tl.constexpr):
310
+ # Get triton block indices
311
+ pid_bat = tl.program_id(axis=0)
312
+ pid_row = tl.program_id(axis=1)
313
+ pid_col = tl.program_id(axis=2)
314
+
315
+ # Get sparsity index of current block
316
+ spa_row = (pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size
317
+ spa_col = (pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size
318
+
319
+ # Get reverse sparsity index for current block
320
+ rev_idx_spa_idx = (pid_bat * s_l_b_s + spa_row * s_l_r_s + spa_col * s_l_c_s)
321
+ rev_idx_spa_msk = (rev_idx_spa_idx < s_l_b * s_l_b_s + s_l_r * s_l_r_s + s_l_c * s_l_c_s)
322
+ rev_idx_spa = tl.load(sparsity_reverse_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
323
+
324
+ # If block is present commence operations
325
+ if rev_idx_spa >= 0:
326
+ blk_idx = (rev_idx_spa * x_b_s +
327
+ (((pid_row % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
328
+ tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
329
+ (((pid_col % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
330
+ tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
331
+ blk_msk = (blk_idx < x_b * x_b_s + x_r * x_r_s + x_c * x_c_s)
332
+ blk = tl.load(x + blk_idx, mask=blk_msk)
333
+
334
+ o_idx = (pid_bat * o_b_s +
335
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
336
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
337
+ o_msk = (o_idx < o_b * o_b_s + o_r * o_r_s + o_c * o_c_s)
338
+ tl.store(o + o_idx, blk, o_msk)
339
+
340
+
341
+ # --- To Sparse ---
342
+
343
+ class BlocksparseToSparse(BaseBlocksparse):
344
+
345
+ def __init__(self, sparsity_block_size: int, device: torch.device) -> None:
346
+ super().__init__(sparsity_block_size, device)
347
+
348
+ def forward(self, x: Tensor, sparsity_layout: Tensor) -> Tensor:
349
+ self.validate(x)
350
+
351
+ sparsity_lut = torch.nonzero(sparsity_layout)
352
+ n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
353
+
354
+ return _BlocksparseToSparse.apply(x,
355
+ sparsity_layout, sparsity_lut,
356
+ self.sparsity_block_size, n_sparse_blocks, self.device)
357
+
358
+
359
+ class _BlocksparseToSparse(torch.autograd.Function):
360
+
361
+ @staticmethod
362
+ def forward(ctx, x: Tensor,
363
+ sparsity_layout: Tensor, sparsity_lut: Tensor,
364
+ sparsity_block_size: int, n_sparse_blocks: int, device: torch.device) -> Tensor:
365
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size), device=device)
366
+
367
+ x_b, x_r, x_c = x.size()
368
+ x_b_s, x_r_s, x_c_s = x.stride()
369
+ o_b, o_r, o_c = output.size()
370
+ o_b_s, o_r_s, o_c_s = output.stride()
371
+ s_lut_r, s_lut_c = sparsity_lut.size()
372
+ s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
373
+
374
+ triton_block_size = BaseBlocksparse.get_triton_block_size(sparsity_block_size)
375
+
376
+ triton_grid = lambda meta: [o_b,
377
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
378
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
379
+
380
+ _BlocksparseToSparse.kernel_blocksparse_to_sparse[triton_grid](x, x_b, x_b_s, x_r, x_r_s, x_c, x_c_s,
381
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c,
382
+ s_lut_c_s,
383
+ output, o_b_s, o_r_s, o_c_s,
384
+ sparsity_block_size,
385
+ triton_block_size)
386
+
387
+ return output
388
+
389
+ @staticmethod
390
+ def backward(ctx, grad_output):
391
+ raise NotImplementedError
392
+
393
+ @staticmethod
394
+ @triton.jit
395
+ def kernel_blocksparse_to_sparse(x,
396
+ x_b, x_b_s, x_r, x_r_s, x_c: tl.constexpr, x_c_s,
397
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c, s_lut_c_s,
398
+ o,
399
+ o_b_s, o_r_s, o_c_s,
400
+ sparsity_block_size,
401
+ TRITON_BLOCK_SIZE: tl.constexpr):
402
+ # Get triton block indices
403
+ pid_blk = tl.program_id(axis=0)
404
+ pid_row = tl.program_id(axis=1)
405
+ pid_col = tl.program_id(axis=2)
406
+
407
+ # Get sparsity index of current output block consisting of its batch, row, and column index
408
+ spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
409
+ spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s + s_lut_c * s_lut_c_s)
410
+ spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
411
+
412
+ spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
413
+ spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s + s_lut_c * s_lut_c_s)
414
+ spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
415
+
416
+ spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
417
+ spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s + s_lut_c * s_lut_c_s)
418
+ spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
419
+
420
+ # Load block from dense tensor
421
+ blk_d_idx = (spa_bat * x_b_s +
422
+ ((spa_row * sparsity_block_size + pid_row * TRITON_BLOCK_SIZE +
423
+ tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
424
+ ((spa_col * sparsity_block_size + pid_col * TRITON_BLOCK_SIZE +
425
+ tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
426
+ blk_d_msk = (blk_d_idx < x_b * x_b_s + x_r * x_r_s + x_c * x_c_s)
427
+ blk_d = tl.load(x + blk_d_idx, mask=blk_d_msk)
428
+
429
+ # Store block in sparse tensor
430
+ blk_o_idx = ((pid_blk * o_b_s) +
431
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
432
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
433
+ blk_o_msk = (blk_o_idx < (pid_blk + 1) * o_b_s)
434
+ tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
435
+
436
+
437
+ class BlocksparseTools:
438
+
439
+ @staticmethod
440
+ def do_shape_blocksparse(x: Tensor):
441
+ if x.dim() == 3:
442
+ return x
443
+
444
+ return x.reshape(-1, x.size(-2), x.size(-1))
445
+
446
+ @staticmethod
447
+ def undo_shape_blocksparse(x: Tensor, shape: Size):
448
+ if x.dim() == 3:
449
+ return x
450
+
451
+ return x.reshape(shape)
452
+
453
+ @staticmethod
454
+ def to_dense(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int):
455
+ output = torch.zeros(size=(sparsity_layout.size(0), sparsity_layout.size(1) * sparsity_block_size,
456
+ sparsity_layout.size(2) * sparsity_block_size), device=x.device)
457
+ indices_sparse_blocks = sparsity_layout.nonzero(as_tuple=True)
458
+
459
+ for idx, (b, r, c) in enumerate(zip(*indices_sparse_blocks)):
460
+ t_r = r * sparsity_block_size
461
+ t_c = c * sparsity_block_size
462
+ to_insert = x[idx]
463
+ output[b, t_r:t_r + sparsity_block_size, t_c:t_c + sparsity_block_size] = to_insert
464
+
465
+ return output
466
+
467
+ @staticmethod
468
+ def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int):
469
+ indices_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
470
+ output = torch.zeros(size=(indices_sparse_blocks, sparsity_block_size, sparsity_block_size), device=x.device)
471
+ indices_sparse_blocks = sparsity_layout.nonzero(as_tuple=True)
472
+
473
+ for idx, (b, r, c) in enumerate(zip(*indices_sparse_blocks)):
474
+ t_r = r * sparsity_block_size
475
+ t_c = c * sparsity_block_size
476
+ to_insert = x[b, t_r:t_r + sparsity_block_size, t_c:t_c + sparsity_block_size]
477
+ output[idx] = to_insert
478
+
479
+ return output
@@ -0,0 +1,26 @@
1
+ Metadata-Version: 2.1
2
+ Name: blksprs
3
+ Version: 0.1
4
+ Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
+ Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
+ Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
7
+ Project-URL: Bugtracker, https://github.com/FelixSchoen/blksprs/issues
8
+ Requires-Python: >=3.12
9
+ Description-Content-Type: text/markdown
10
+ Requires-Dist: torch
11
+ Provides-Extra: test
12
+ Requires-Dist: pytest; extra == "test"
13
+ Requires-Dist: pytest-xdist; extra == "test"
14
+ Requires-Dist: pytest-cov; extra == "test"
15
+ Requires-Dist: coverage; extra == "test"
16
+ Requires-Dist: matplotlib; extra == "test"
17
+ Provides-Extra: deploy
18
+ Requires-Dist: build; extra == "deploy"
19
+ Requires-Dist: twine; extra == "deploy"
20
+ Requires-Dist: pdoc3; extra == "deploy"
21
+
22
+ # blksprs
23
+
24
+ ## Overview
25
+
26
+ A lightweight library for operations on blocksparse matrices in PyTorch.
@@ -0,0 +1,8 @@
1
+ README.md
2
+ pyproject.toml
3
+ blksprs/blocksparse.py
4
+ blksprs.egg-info/PKG-INFO
5
+ blksprs.egg-info/SOURCES.txt
6
+ blksprs.egg-info/dependency_links.txt
7
+ blksprs.egg-info/requires.txt
8
+ blksprs.egg-info/top_level.txt
@@ -0,0 +1,13 @@
1
+ torch
2
+
3
+ [deploy]
4
+ build
5
+ twine
6
+ pdoc3
7
+
8
+ [test]
9
+ pytest
10
+ pytest-xdist
11
+ pytest-cov
12
+ coverage
13
+ matplotlib
@@ -0,0 +1 @@
1
+ blksprs
@@ -0,0 +1,36 @@
1
+ [project]
2
+ name = "blksprs"
3
+ version = "0.1"
4
+ authors = [{ name = "Felix Schön", email = "schoen@kr.tuwien.ac.at" }]
5
+ description = "A lightweight library for operations on blocksparse matrices in PyTorch."
6
+ readme = "README.md"
7
+ requires-python = ">=3.12"
8
+ license = { file = "LICENSE.md" }
9
+ dependencies = [
10
+ "torch"
11
+ ]
12
+
13
+ [project.urls]
14
+ "Homepage" = "https://github.com/FelixSchoen/blksprs"
15
+ "Bugtracker" = "https://github.com/FelixSchoen/blksprs/issues"
16
+
17
+ [project.optional-dependencies]
18
+ test = [
19
+ "pytest",
20
+ "pytest-xdist",
21
+ "pytest-cov",
22
+ "coverage",
23
+ "matplotlib"
24
+ ]
25
+ deploy = [
26
+ "build",
27
+ "twine",
28
+ "pdoc3"
29
+ ]
30
+
31
+ [build-system]
32
+ requires = ["setuptools", "wheel"]
33
+ build-backend = "setuptools.build_meta"
34
+
35
+ [tool.setuptools.package-data]
36
+ "*" = ["*.json", "*.conf"]
blksprs-0.1/setup.cfg ADDED
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+