blksprs 0.2b4__py3-none-any.whl → 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,589 +0,0 @@
1
- from abc import ABC
2
-
3
- import torch
4
- import triton
5
- import triton.language as tl
6
- from torch import Tensor, Size
7
- from torch.nn import Module
8
-
9
-
10
- class BaseBlocksparse(Module, ABC):
11
- _validate = None
12
-
13
- def __init__(self, sparsity_block_size: int, device: torch.device, triton_block_size: int = None) -> None:
14
- super().__init__()
15
-
16
- self.sparsity_block_size = sparsity_block_size
17
- self.device = device
18
-
19
- self.triton_block_size = triton_block_size
20
-
21
- if BaseBlocksparse._validate is None:
22
- BaseBlocksparse._validate = True
23
- # print(
24
- # f"{'\033[93m'}Blocksparse validation is activated. Consider deactivating for production use.{'\033[0m'}")
25
-
26
- def validate_tensors(self, *tensors: Tensor, flag_dim: bool = True, flag_contiguous: bool = True,
27
- flag_dtype: bool = True,
28
- flag_device: bool = True) -> None:
29
- if not BaseBlocksparse._validate:
30
- return
31
-
32
- for tensor in tensors:
33
- if flag_dim:
34
- assert tensor.dim() == 3, "Input tensors must have 3 dimensions"
35
- if flag_contiguous:
36
- assert tensor.is_contiguous(), "Input tensors must be contiguous"
37
- if flag_dtype:
38
- assert tensor.dtype == torch.float32, "Input tensors must be of type float32"
39
- if flag_device:
40
- assert tensor.device == self.device, "Input tensors must be on the same device"
41
-
42
- def validate_sparsity(self, *tensor_sparsity_layout_tuples: tuple[Tensor, Tensor]) -> None:
43
- if not BaseBlocksparse._validate:
44
- return
45
-
46
- for tensor_sparsity_layout_tuple in tensor_sparsity_layout_tuples:
47
- tensor, sparsity_layout = tensor_sparsity_layout_tuple
48
-
49
- assert tensor.size(-1) == tensor.size(-2) == self.sparsity_block_size, \
50
- "Tensor not conforming to sparsity specification"
51
- assert tensor.size(0) == torch.sum(sparsity_layout.reshape(-1))
52
-
53
- @staticmethod
54
- def get_triton_block_size(sparsity_block_size: int, limit: int = 128):
55
- return min(sparsity_block_size, limit)
56
-
57
- @staticmethod
58
- def disable_validation():
59
- BaseBlocksparse._validate = False
60
-
61
-
62
- # --- Matmul SSS ---
63
-
64
- class BlocksparseMatmulSSS(BaseBlocksparse):
65
-
66
- def __init__(self, sparsity_block_size: int, device: torch.device, triton_block_size: int = None) -> None:
67
- super().__init__(sparsity_block_size, device, triton_block_size=triton_block_size)
68
-
69
- def forward(self, x: Tensor, y: Tensor,
70
- sparsity_layout_x: Tensor, sparsity_layout_y: Tensor, sparsity_layout_output: Tensor) -> Tensor:
71
- self.validate_tensors(x, y)
72
- self.validate_sparsity((x, sparsity_layout_x), (y, sparsity_layout_y))
73
- assert x.size(2) == y.size(1), "Inner dimensions must match"
74
-
75
- o_n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
76
-
77
- sparsity_layout_x_flat = sparsity_layout_x.reshape(-1)
78
- sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
79
- (sparsity_layout_x_flat == 1) -
80
- (1 * (sparsity_layout_x_flat == 0)))
81
-
82
- sparsity_layout_y_flat = sparsity_layout_y.reshape(-1)
83
- sparsity_reverse_lut_y = ((torch.cumsum(sparsity_layout_y_flat, dim=-1) - 1) *
84
- (sparsity_layout_y_flat == 1) -
85
- (1 * (sparsity_layout_y_flat == 0)))
86
-
87
- sparsity_lut_o = torch.nonzero(sparsity_layout_output)
88
-
89
- return _BlocksparseMatmulSSS.apply(x, y,
90
- sparsity_layout_x, sparsity_reverse_lut_x,
91
- sparsity_layout_y, sparsity_reverse_lut_y,
92
- sparsity_layout_output, sparsity_lut_o,
93
- self.sparsity_block_size,
94
- o_n_sparse_blocks,
95
- self.triton_block_size,
96
- self.device)
97
-
98
-
99
- class _BlocksparseMatmulSSS(torch.autograd.Function):
100
-
101
- @staticmethod
102
- def forward(ctx, x: Tensor, y: Tensor,
103
- sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
104
- sparsity_layout_y: Tensor, sparsity_reverse_lut_y: Tensor,
105
- sparsity_layout_o: Tensor, sparsity_lut_o: Tensor,
106
- sparsity_block_size: int, o_n_sparse_blocks: int, triton_block_size: int,
107
- device: torch.device) -> Tensor:
108
- output = torch.zeros(size=(o_n_sparse_blocks, sparsity_block_size, sparsity_block_size), device=device)
109
-
110
- x_b, x_r, x_c = x.size()
111
- x_b_s, x_r_s, x_c_s = x.stride()
112
- s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
113
- s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout_x.stride()
114
- y_b, y_r, y_c = y.size()
115
- y_b_s, y_r_s, y_c_s = y.stride()
116
- s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_y.size()
117
- s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = sparsity_layout_y.stride()
118
- o_b, o_r, o_c = output.size()
119
- o_b_s, o_r_s, o_c_s = output.stride()
120
- s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
121
- s_lut_o_r_s, s_lut_o_c_s = sparsity_lut_o.stride()
122
-
123
- if triton_block_size is None:
124
- triton_block_size = BaseBlocksparse.get_triton_block_size(sparsity_block_size)
125
-
126
- triton_grid = lambda meta: [o_b,
127
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
128
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
129
-
130
- (_BlocksparseMatmulSSS.kernel_blocksparse_matmul_sss[triton_grid]
131
- (x,
132
- x_b, x_b_s, x_r, x_r_s, x_c, x_c_s,
133
- s_l_x_b, s_l_x_b_s, s_l_x_r, s_l_x_r_s, s_l_x_c, s_l_x_c_s,
134
- sparsity_reverse_lut_x,
135
- y,
136
- y_b, y_b_s, y_r, y_r_s, y_c, y_c_s,
137
- s_l_y_b, s_l_y_b_s, s_l_y_r, s_l_y_r_s, s_l_y_c, s_l_y_c_s,
138
- sparsity_reverse_lut_y,
139
- output,
140
- o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,
141
- sparsity_lut_o,
142
- s_lut_o_r, s_lut_o_r_s,
143
- s_lut_o_c, s_lut_o_c_s,
144
- sparsity_block_size,
145
- triton_block_size))
146
-
147
- ctx.save_for_backward(x, y)
148
- ctx.sparsity_layout_x = sparsity_layout_x
149
- ctx.sparsity_layout_y = sparsity_layout_y
150
- ctx.sparsity_layout_o = sparsity_layout_o
151
- ctx.sparsity_block_size = sparsity_block_size
152
- ctx.triton_block_size = triton_block_size
153
- ctx.device = device
154
-
155
- return output
156
-
157
- @staticmethod
158
- def backward(ctx, grad_output):
159
- x, y = ctx.saved_tensors
160
- sparsity_layout_x = ctx.sparsity_layout_x
161
- sparsity_layout_y = ctx.sparsity_layout_y
162
- sparsity_layout_o = ctx.sparsity_layout_o
163
- sparsity_block_size = ctx.sparsity_block_size
164
- triton_block_size = ctx.triton_block_size
165
- device = ctx.device
166
-
167
- blksprs_transpose = BlocksparseTranspose(sparsity_block_size, device, triton_block_size)
168
-
169
- x_t, sparsity_layout_x_t = blksprs_transpose(x, sparsity_layout_x)
170
- y_t, sparsity_layout_y_t = blksprs_transpose(y, sparsity_layout_y)
171
-
172
- grad_x = BlocksparseMatmulSSS(sparsity_block_size, device, triton_block_size)(grad_output, y_t,
173
- sparsity_layout_o,
174
- sparsity_layout_y_t,
175
- sparsity_layout_x)
176
- grad_y = BlocksparseMatmulSSS(sparsity_block_size, device, triton_block_size)(x_t, grad_output,
177
- sparsity_layout_x_t,
178
- sparsity_layout_o,
179
- sparsity_layout_y)
180
-
181
- return grad_x, grad_y, None, None, None, None, None, None, None, None, None, None
182
-
183
- @staticmethod
184
- @triton.jit
185
- def kernel_blocksparse_matmul_sss(x,
186
- x_b, x_b_s, x_r, x_r_s, x_c, x_c_s,
187
- s_l_x_b, s_l_x_b_s, s_l_x_r, s_l_x_r_s, s_l_x_c, s_l_x_c_s,
188
- r_lut_x,
189
- y,
190
- y_b, y_b_s, y_r, y_r_s, y_c, y_c_s,
191
- s_l_y_b, s_l_y_b_s, s_l_y_r, s_l_y_r_s, s_l_y_c, s_l_y_c_s,
192
- r_lut_y,
193
- o,
194
- o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,
195
- s_lut_o,
196
- s_lut_o_r, s_lut_o_r_s,
197
- s_lut_o_c, s_lut_o_c_s,
198
- sparsity_block_size,
199
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
200
- # Get triton block indices
201
- pid_blk = tl.program_id(axis=0)
202
- pid_row = tl.program_id(axis=1)
203
- pid_col = tl.program_id(axis=2)
204
-
205
- # Get position of current sparsity block consisting of its batch, row, and column index
206
- spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
207
- spa_bat_o_msk = (spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s + s_lut_o_c * s_lut_o_c_s)
208
- spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
209
-
210
- spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
211
- spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s + s_lut_o_c * s_lut_o_c_s)
212
- spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
213
-
214
- spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
215
- spa_col_o_msk = (spa_col_o_idx < s_lut_o_r * s_lut_o_r_s + s_lut_o_c * s_lut_o_c_s)
216
- spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
217
-
218
- # Setup buffer
219
- buf = tl.zeros(shape=(TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), dtype=tl.float32)
220
-
221
- # Slide over triton block sized segments of input tensors
222
- for i_seg_tri in range(0, tl.cdiv(s_l_x_c * sparsity_block_size, TRITON_BLOCK_SIZE)):
223
- # Convert to segment index of sparsity layout
224
- i_seg_spa = (i_seg_tri * TRITON_BLOCK_SIZE) // sparsity_block_size
225
- # Calculate the triton segment index within a block
226
- i_seg_tri_mod = i_seg_tri % (sparsity_block_size // TRITON_BLOCK_SIZE)
227
-
228
- # Get reverse sparsity indices for input tensors x and y
229
- # These are either -1 if the block is empty or equal to the index of the block in the sparse tensor
230
-
231
- # Get reverse sparsity indices for x
232
- rev_idx_spa_x_idx = (spa_bat_o * s_l_x_b_s +
233
- spa_row_o * s_l_x_r_s +
234
- i_seg_spa * s_l_x_c_s)
235
- rev_idx_spa_x_msk = (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s + s_l_x_r * s_l_x_r_s + s_l_x_c * s_l_x_c_s)
236
- rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
237
-
238
- # Get reverse sparsity indices for y
239
- rev_idx_spa_y_idx = (spa_bat_o * s_l_y_b_s + i_seg_spa * s_l_y_r_s + spa_col_o * s_l_y_c_s)
240
- rev_idx_spa_y_msk = (rev_idx_spa_y_idx < s_l_y_b * s_l_y_b_s + s_l_y_r * s_l_y_r_s + s_l_y_c * s_l_y_c_s)
241
- rev_idx_spa_y = tl.load(r_lut_y + rev_idx_spa_y_idx, mask=rev_idx_spa_y_msk).to(tl.int32)
242
-
243
- # If both blocks are present commence calculation
244
- if rev_idx_spa_x >= 0 and rev_idx_spa_y >= 0:
245
- blk_x_idx = ((rev_idx_spa_x * x_b_s) +
246
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
247
- ((i_seg_tri_mod * TRITON_BLOCK_SIZE +
248
- tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
249
- blk_x_msk = (blk_x_idx < x_b * x_b_s + x_r * x_r_s + x_c * x_c_s)
250
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
251
-
252
- blk_y_idx = ((rev_idx_spa_y * y_b_s) +
253
- ((i_seg_tri_mod * TRITON_BLOCK_SIZE +
254
- tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
255
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
256
- blk_y_msk = (blk_y_idx < y_b * y_b_s + y_r * y_r_s + y_c * y_c_s)
257
- blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
258
-
259
- # Perform matrix multiplication
260
- buf += tl.dot(blk_x, blk_y)
261
-
262
- # Store output
263
- blk_o_idx = ((pid_blk * o_b_s) +
264
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
265
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
266
- blk_o_msk = (blk_o_idx < o_b * o_b_s + o_r * o_r_s + o_c * o_c_s)
267
- tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
268
-
269
-
270
- # --- Softmax ---
271
-
272
- class BlocksparseSoftmax(BaseBlocksparse):
273
- # TODO At the moment uses standard softmax instead of blocksparse improvements
274
-
275
- def __init__(self, sparsity_block_size: int, device: torch.device, triton_block_size: int = None) -> None:
276
- super().__init__(sparsity_block_size, device, triton_block_size=triton_block_size)
277
-
278
- self.blksprs_to_dense = BlocksparseToDense(sparsity_block_size, device)
279
- self.blksprs_to_sparse = BlocksparseToSparse(sparsity_block_size, device)
280
-
281
- def forward(self, x: Tensor, sparsity_layout: Tensor, fill_value:float=float("-inf")) -> Tensor:
282
- self.validate_tensors(x)
283
-
284
- x_dense = self.blksprs_to_dense(x, sparsity_layout, fill_value=fill_value)
285
- x_softmax = torch.softmax(x_dense, dim=-1)
286
- x_sparse = self.blksprs_to_sparse(x_softmax, sparsity_layout)
287
-
288
- return x_sparse
289
-
290
-
291
- # --- Transpose ---
292
-
293
- class BlocksparseTranspose(BaseBlocksparse):
294
-
295
- def __init__(self, sparsity_block_size: int, device: torch.device, triton_block_size: int = None) -> None:
296
- super().__init__(sparsity_block_size, device, triton_block_size=triton_block_size)
297
-
298
- def forward(self, x: Tensor, sparsity_layout: Tensor, shuffle_blocks: bool = True) -> (Tensor, Tensor):
299
- self.validate_tensors(x)
300
-
301
- x_t = x.transpose(1, 2).contiguous()
302
- sparsity_layout_t = sparsity_layout.transpose(-1, -2).contiguous()
303
-
304
- if shuffle_blocks:
305
- sparsity_layout_t_flat = sparsity_layout.reshape(-1)
306
- shuffle_layout = ((torch.cumsum(sparsity_layout_t_flat, dim=-1) - 1) *
307
- (sparsity_layout_t_flat == 1) -
308
- (1 * (sparsity_layout_t_flat == 0)))
309
- shuffle_layout = (shuffle_layout.reshape(sparsity_layout.size()).transpose(-1, -2).contiguous()
310
- .reshape(-1).to(torch.int))
311
- shuffle_layout = shuffle_layout[shuffle_layout >= 0]
312
- x_t = x_t[shuffle_layout, :, :]
313
-
314
- return x_t, sparsity_layout_t
315
-
316
-
317
- # --- To Dense ---
318
-
319
- class BlocksparseToDense(BaseBlocksparse):
320
-
321
- def __init__(self, sparsity_block_size: int, device: torch.device, triton_block_size: int = None) -> None:
322
- super().__init__(sparsity_block_size, device, triton_block_size=triton_block_size)
323
-
324
- def forward(self, x: Tensor, sparsity_layout: Tensor, fill_value: int = 0) -> Tensor:
325
- self.validate_tensors(x)
326
-
327
- sparsity_layout_flat = sparsity_layout.reshape(-1)
328
- sparsity_reverse_lut = ((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
329
- (sparsity_layout_flat == 1) -
330
- (1 * (sparsity_layout_flat == 0)))
331
-
332
- return _BlocksparseToDense.apply(x,
333
- sparsity_layout, sparsity_reverse_lut,
334
- self.sparsity_block_size, fill_value,
335
- self.triton_block_size, self.device)
336
-
337
-
338
- class _BlocksparseToDense(torch.autograd.Function):
339
-
340
- @staticmethod
341
- def forward(ctx, x: Tensor,
342
- sparsity_layout: Tensor, sparsity_reverse_lut: Tensor,
343
- sparsity_block_size: int, fill_value: int,
344
- triton_block_size: int, device: torch.device) -> Tensor:
345
- output = torch.full(size=(sparsity_layout.size(0), sparsity_layout.size(1) * sparsity_block_size,
346
- sparsity_layout.size(2) * sparsity_block_size), fill_value=fill_value,
347
- dtype=x.dtype, device=device)
348
-
349
- x_b, x_r, x_c = x.shape
350
- x_b_s, x_r_s, x_c_s = x.stride()
351
- s_l_b, s_l_r, s_l_c = sparsity_layout.size()
352
- s_l_b_s, s_l_r_s, s_l_c_s = sparsity_layout.stride()
353
- o_b, o_r, o_c = output.size()
354
- o_b_s, o_r_s, o_c_s = output.stride()
355
-
356
- if triton_block_size is None:
357
- triton_block_size = BaseBlocksparse.get_triton_block_size(sparsity_block_size)
358
-
359
- triton_grid = lambda meta: [o_b,
360
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
361
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
362
-
363
- (_BlocksparseToDense.kernel_blocksparse_to_dense[triton_grid]
364
- (x,
365
- x_b, x_b_s, x_r, x_r_s, x_c, x_c_s,
366
- s_l_b, s_l_b_s, s_l_r, s_l_r_s, s_l_c, s_l_c_s,
367
- sparsity_reverse_lut,
368
- output,
369
- o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,
370
- sparsity_block_size,
371
- triton_block_size))
372
-
373
- ctx.sparsity_layout = sparsity_layout
374
- ctx.sparsity_block_size = sparsity_block_size
375
- ctx.triton_block_size = triton_block_size
376
- ctx.device = device
377
-
378
- return output
379
-
380
- @staticmethod
381
- def backward(ctx, grad_output):
382
- sparsity_layout = ctx.sparsity_layout
383
- sparsity_block_size = ctx.sparsity_block_size
384
- triton_block_size = ctx.triton_block_size
385
- device = ctx.device
386
-
387
- return BlocksparseToSparse(sparsity_block_size, device, triton_block_size)(grad_output,
388
- sparsity_layout), None, None, None, None, None, None
389
-
390
- @staticmethod
391
- @triton.jit
392
- def kernel_blocksparse_to_dense(x,
393
- x_b, x_b_s, x_r, x_r_s, x_c, x_c_s,
394
- s_l_b, s_l_b_s, s_l_r, s_l_r_s, s_l_c, s_l_c_s,
395
- sparsity_reverse_lut,
396
- o,
397
- o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,
398
- sparsity_block_size,
399
- TRITON_BLOCK_SIZE: tl.constexpr):
400
- # Get triton block indices
401
- pid_bat = tl.program_id(axis=0)
402
- pid_row = tl.program_id(axis=1)
403
- pid_col = tl.program_id(axis=2)
404
-
405
- # Get sparsity index of current block
406
- spa_row = (pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size
407
- spa_col = (pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size
408
-
409
- # Get reverse sparsity index for current block
410
- rev_idx_spa_idx = (pid_bat * s_l_b_s + spa_row * s_l_r_s + spa_col * s_l_c_s)
411
- rev_idx_spa_msk = (rev_idx_spa_idx < s_l_b * s_l_b_s + s_l_r * s_l_r_s + s_l_c * s_l_c_s)
412
- rev_idx_spa = tl.load(sparsity_reverse_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
413
-
414
- # If block is present commence operations
415
- if rev_idx_spa >= 0:
416
- blk_idx = (rev_idx_spa * x_b_s +
417
- (((pid_row % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
418
- tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
419
- (((pid_col % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
420
- tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
421
- blk_msk = (blk_idx < x_b * x_b_s + x_r * x_r_s + x_c * x_c_s)
422
- blk = tl.load(x + blk_idx, mask=blk_msk)
423
-
424
- o_idx = (pid_bat * o_b_s +
425
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
426
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
427
- o_msk = (o_idx < o_b * o_b_s + o_r * o_r_s + o_c * o_c_s)
428
- tl.store(o + o_idx, blk, o_msk)
429
-
430
-
431
- # --- To Sparse ---
432
-
433
- class BlocksparseToSparse(BaseBlocksparse):
434
-
435
- def __init__(self, sparsity_block_size: int, device: torch.device, triton_block_size: int = None) -> None:
436
- super().__init__(sparsity_block_size, device, triton_block_size=triton_block_size)
437
-
438
- def forward(self, x: Tensor, sparsity_layout: Tensor) -> Tensor:
439
- self.validate_tensors(x)
440
-
441
- sparsity_lut = torch.nonzero(sparsity_layout)
442
- n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
443
-
444
- return _BlocksparseToSparse.apply(x,
445
- sparsity_layout, sparsity_lut,
446
- self.sparsity_block_size, n_sparse_blocks,
447
- self.triton_block_size, self.device)
448
-
449
-
450
- class _BlocksparseToSparse(torch.autograd.Function):
451
-
452
- @staticmethod
453
- def forward(ctx, x: Tensor,
454
- sparsity_layout: Tensor, sparsity_lut: Tensor,
455
- sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int, device: torch.device) -> Tensor:
456
- output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size), device=device)
457
-
458
- x_b, x_r, x_c = x.size()
459
- x_b_s, x_r_s, x_c_s = x.stride()
460
- o_b, o_r, o_c = output.size()
461
- o_b_s, o_r_s, o_c_s = output.stride()
462
- s_lut_r, s_lut_c = sparsity_lut.size()
463
- s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
464
-
465
- if triton_block_size is None:
466
- triton_block_size = BaseBlocksparse.get_triton_block_size(sparsity_block_size)
467
-
468
- triton_grid = lambda meta: [o_b,
469
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
470
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
471
-
472
- (_BlocksparseToSparse.kernel_blocksparse_to_sparse[triton_grid]
473
- (x, x_b, x_b_s, x_r, x_r_s, x_c, x_c_s,
474
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c,
475
- s_lut_c_s,
476
- output, o_b_s, o_r_s, o_c_s,
477
- sparsity_block_size,
478
- triton_block_size))
479
-
480
- ctx.sparsity_layout = sparsity_layout
481
- ctx.sparsity_block_size = sparsity_block_size
482
- ctx.triton_block_size = triton_block_size
483
- ctx.device = device
484
-
485
- return output
486
-
487
- @staticmethod
488
- def backward(ctx, grad_output):
489
- sparsity_layout = ctx.sparsity_layout
490
- sparsity_block_size = ctx.sparsity_block_size
491
- triton_block_size = ctx.triton_block_size
492
- device = ctx.device
493
-
494
- # return _BlocksparseToDense.apply(grad_output,
495
- # sparsity_layout, sparsity_lut,
496
- # sparsity_block_size, 0,
497
- # triton_block_size, device), None, None, None, None, None, None
498
- return BlocksparseToDense(sparsity_block_size, device, triton_block_size)(grad_output,
499
- sparsity_layout), None, None, None, None, None, None
500
-
501
- @staticmethod
502
- @triton.jit
503
- def kernel_blocksparse_to_sparse(x,
504
- x_b, x_b_s, x_r, x_r_s, x_c: tl.constexpr, x_c_s,
505
- s_lut, s_lut_r, s_lut_r_s, s_lut_c, s_lut_c_s,
506
- o,
507
- o_b_s, o_r_s, o_c_s,
508
- sparsity_block_size,
509
- TRITON_BLOCK_SIZE: tl.constexpr):
510
- # Get triton block indices
511
- pid_blk = tl.program_id(axis=0)
512
- pid_row = tl.program_id(axis=1)
513
- pid_col = tl.program_id(axis=2)
514
-
515
- # Get sparsity index of current output block consisting of its batch, row, and column index
516
- spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
517
- spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s + s_lut_c * s_lut_c_s)
518
- spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
519
-
520
- spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
521
- spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s + s_lut_c * s_lut_c_s)
522
- spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
523
-
524
- spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
525
- spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s + s_lut_c * s_lut_c_s)
526
- spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
527
-
528
- # Load block from dense tensor
529
- blk_d_idx = (spa_bat * x_b_s +
530
- ((spa_row * sparsity_block_size + pid_row * TRITON_BLOCK_SIZE +
531
- tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
532
- ((spa_col * sparsity_block_size + pid_col * TRITON_BLOCK_SIZE +
533
- tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
534
- blk_d_msk = (blk_d_idx < x_b * x_b_s + x_r * x_r_s + x_c * x_c_s)
535
- blk_d = tl.load(x + blk_d_idx, mask=blk_d_msk)
536
-
537
- # Store block in sparse tensor
538
- blk_o_idx = ((pid_blk * o_b_s) +
539
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
540
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
541
- blk_o_msk = (blk_o_idx < (pid_blk + 1) * o_b_s)
542
- tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
543
-
544
-
545
- class BlocksparseTools:
546
-
547
- @staticmethod
548
- def do_shape_blocksparse(x: Tensor):
549
- if x.dim() == 3:
550
- return x
551
-
552
- return x.reshape(-1, x.size(-2), x.size(-1))
553
-
554
- @staticmethod
555
- def undo_shape_blocksparse(x: Tensor, shape: Size):
556
- if x.dim() == 3:
557
- return x
558
-
559
- return x.reshape(shape)
560
-
561
- # Methods used for verification
562
-
563
- @staticmethod
564
- def slow_to_dense(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int):
565
- output = torch.zeros(size=(sparsity_layout.size(0), sparsity_layout.size(1) * sparsity_block_size,
566
- sparsity_layout.size(2) * sparsity_block_size), device=x.device)
567
- indices_sparse_blocks = sparsity_layout.nonzero(as_tuple=True)
568
-
569
- for idx, (b, r, c) in enumerate(zip(*indices_sparse_blocks)):
570
- t_r = r * sparsity_block_size
571
- t_c = c * sparsity_block_size
572
- to_insert = x[idx]
573
- output[b, t_r:t_r + sparsity_block_size, t_c:t_c + sparsity_block_size] = to_insert
574
-
575
- return output
576
-
577
- @staticmethod
578
- def slow_to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int):
579
- indices_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
580
- output = torch.zeros(size=(indices_sparse_blocks, sparsity_block_size, sparsity_block_size), device=x.device)
581
- indices_sparse_blocks = sparsity_layout.nonzero(as_tuple=True)
582
-
583
- for idx, (b, r, c) in enumerate(zip(*indices_sparse_blocks)):
584
- t_r = r * sparsity_block_size
585
- t_c = c * sparsity_block_size
586
- to_insert = x[b, t_r:t_r + sparsity_block_size, t_c:t_c + sparsity_block_size]
587
- output[idx] = to_insert
588
-
589
- return output
@@ -1,26 +0,0 @@
1
- Metadata-Version: 2.1
2
- Name: blksprs
3
- Version: 0.2b4
4
- Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
- Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
- Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
7
- Project-URL: Bugtracker, https://github.com/FelixSchoen/blksprs/issues
8
- Requires-Python: >=3.12
9
- Description-Content-Type: text/markdown
10
- Requires-Dist: torch
11
- Provides-Extra: deploy
12
- Requires-Dist: build; extra == "deploy"
13
- Requires-Dist: twine; extra == "deploy"
14
- Requires-Dist: pdoc3; extra == "deploy"
15
- Provides-Extra: test
16
- Requires-Dist: pytest; extra == "test"
17
- Requires-Dist: pytest-xdist; extra == "test"
18
- Requires-Dist: pytest-cov; extra == "test"
19
- Requires-Dist: coverage; extra == "test"
20
- Requires-Dist: matplotlib; extra == "test"
21
-
22
- # blksprs
23
-
24
- ## Overview
25
-
26
- A lightweight library for operations on blocksparse matrices in PyTorch.
@@ -1,6 +0,0 @@
1
- blksprs/ops/blocksparse.py,sha256=4vATdQicjMgEmULct-955vyJ4rRoIqk572tIGu5RjPU,27630
2
- blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
3
- blksprs-0.2b4.dist-info/METADATA,sha256=PGuf_WUjS7KT7dvkPoiApvSasWGKVZtH-EF_XX_Ffos,876
4
- blksprs-0.2b4.dist-info/WHEEL,sha256=UvcQYKBHoFqaQd6LKyqHw9fxEolWLQnlzP0h_LgJAfI,91
5
- blksprs-0.2b4.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
6
- blksprs-0.2b4.dist-info/RECORD,,