blksprs 1.5__py3-none-any.whl → 1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/__init__.py CHANGED
@@ -4,10 +4,11 @@ from blksprs.ops.exp import exp
4
4
  from blksprs.ops.matmul import matmul
5
5
  from blksprs.ops.softmax import softmax
6
6
  from blksprs.ops.transpose import transpose
7
+ from blksprs.ops.partitioning import split, merge
7
8
 
8
9
  class layout:
9
10
  from blksprs.layouting.distribution_layout import build_distribution_layout
10
- from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption
11
+ from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption, build_sparsity_layout_matmul
11
12
 
12
13
  class misc:
13
14
  from blksprs.misc.broadcast_ops import broadcast_add, broadcast_sub
@@ -18,4 +19,4 @@ class util:
18
19
  from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse, disable_validation
19
20
 
20
21
  class experimental:
21
- from blksprs.experimental.distribution_mdi import gather_mdi
22
+ from blksprs.experimental.distribution_mdi import gather_mdi, scatter_reduce_mdi
@@ -3,7 +3,7 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
- from blksprs.utils.tools import get_triton_block_size
6
+ from blksprs.utils.tools import get_triton_block_size, stride
7
7
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
8
8
  validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
9
9
 
@@ -51,15 +51,15 @@ class _BlocksparseGatherMDI(torch.autograd.Function):
51
51
  output = torch.empty_like(idx_col, dtype=x.dtype)
52
52
 
53
53
  x_b, x_r, x_c = x.size()
54
- x_b_s, x_r_s, x_c_s = x.stride()
54
+ x_b_s, x_r_s, x_c_s = stride(x)
55
55
  s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
56
- s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout_x.stride()
56
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
57
57
  i_b, i_r, i_c = idx_col.size()
58
- i_b_s, i_r_s, i_c_s = idx_col.stride()
58
+ i_b_s, i_r_s, i_c_s = stride(idx_col)
59
59
  s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
60
- s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
60
+ s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
61
61
  o_b, o_r, o_c = output.size()
62
- o_b_s, o_r_s, o_c_s = output.stride()
62
+ o_b_s, o_r_s, o_c_s = stride(output)
63
63
 
64
64
  if triton_block_size is None:
65
65
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -224,15 +224,15 @@ class _BlocksparseScatterReduceMDI(torch.autograd.Function):
224
224
  dtype=x.dtype, device=x.device)
225
225
 
226
226
  x_b, x_r, x_c = x.size()
227
- x_b_s, x_r_s, x_c_s = x.stride()
227
+ x_b_s, x_r_s, x_c_s = stride(x)
228
228
  s_lut_x_r, s_lut_x_c = sparsity_lut_x.size()
229
- s_lut_x_r_s, s_lut_x_c_s = sparsity_lut_x.stride()
229
+ s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut_x)
230
230
  i_b, i_r, i_c = idx_col.size()
231
- i_b_s, i_r_s, i_c_s = idx_col.stride()
231
+ i_b_s, i_r_s, i_c_s = stride(idx_col)
232
232
  o_b, o_r, o_c = output.size()
233
- o_b_s, o_r_s, o_c_s = output.stride()
233
+ o_b_s, o_r_s, o_c_s = stride(output)
234
234
  s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
235
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_o.stride()
235
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
236
236
 
237
237
  if triton_block_size is None:
238
238
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -366,11 +366,11 @@ def build_distribution_layout_mdi(idx_bat: Tensor, idx_row: Tensor, idx_col: Ten
366
366
  dtype=torch.bool, device=idx_col.device)
367
367
 
368
368
  i_b, i_r, i_c = idx_col.size()
369
- i_b_s, i_r_s, i_c_s = idx_col.stride()
369
+ i_b_s, i_r_s, i_c_s = stride(idx_col)
370
370
  s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
371
- s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
371
+ s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
372
372
  o_b, o_r, o_c = output.size()
373
- o_b_s, o_r_s, o_c_s = output.stride()
373
+ o_b_s, o_r_s, o_c_s = stride(output)
374
374
 
375
375
  if triton_block_size is None:
376
376
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -3,7 +3,7 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
- from blksprs.utils.tools import get_triton_block_size
6
+ from blksprs.utils.tools import get_triton_block_size, stride
7
7
  from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
8
8
  validate_contiguous
9
9
 
@@ -34,11 +34,11 @@ def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
34
34
  dtype=torch.bool, device=indices.device)
35
35
 
36
36
  i_b, i_r, i_c = indices.size()
37
- i_b_s, i_r_s, i_c_s = indices.stride()
37
+ i_b_s, i_r_s, i_c_s = stride(indices)
38
38
  s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
39
- s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
39
+ s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
40
40
  o_b, o_r, o_c = output.size()
41
- o_b_s, o_r_s, o_c_s = output.stride()
41
+ o_b_s, o_r_s, o_c_s = stride(output)
42
42
 
43
43
  if triton_block_size is None:
44
44
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -5,7 +5,7 @@ import triton
5
5
  from torch import Tensor
6
6
  from triton import language as tl
7
7
 
8
- from blksprs.utils.tools import get_triton_block_size
8
+ from blksprs.utils.tools import get_triton_block_size, stride
9
9
  from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
10
10
  validate_contiguous, validate_sparsity, validate_sparsity_block_size
11
11
 
@@ -30,9 +30,9 @@ def build_sparsity_layout(x: Tensor, sparsity_block_size: int, triton_block_size
30
30
  dtype=torch.bool, device=x.device)
31
31
 
32
32
  x_b, x_r, x_c = x.size()
33
- x_b_s, x_r_s, x_c_s = x.stride()
33
+ x_b_s, x_r_s, x_c_s = stride(x)
34
34
  o_b, o_r, o_c = output.size()
35
- o_b_s, o_r_s, o_c_s = output.stride()
35
+ o_b_s, o_r_s, o_c_s = stride(output)
36
36
 
37
37
  if triton_block_size is None:
38
38
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -120,10 +120,10 @@ def build_sparsity_layout_adaption(x: Tensor, sparsity_layout_from: Tensor,
120
120
  output = torch.zeros(o_b, o_r, o_c, dtype=torch.bool, device=x.device)
121
121
 
122
122
  x_b, x_r, x_c = x.size()
123
- x_b_s, x_r_s, x_c_s = x.stride()
123
+ x_b_s, x_r_s, x_c_s = stride(x)
124
124
  s_lut_r, s_lut_c = sparsity_lut.size()
125
- s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
126
- o_b_s, o_r_s, o_c_s = output.stride()
125
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
126
+ o_b_s, o_r_s, o_c_s = stride(output)
127
127
 
128
128
  if triton_block_size is None:
129
129
  triton_block_size = get_triton_block_size(sparsity_block_size_from)
@@ -188,3 +188,39 @@ def kernel_sparsity_layout_adaption(x,
188
188
  // sparsity_block_size_to) * o_c_s))
189
189
  blk_o_msk = (blk_o_idx < o_b * o_b_s)
190
190
  tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
191
+
192
+
193
+ def build_sparsity_layout_matmul(sparsity_layout_x: Tensor, sparsity_layout_y: Tensor):
194
+ """Builds the precise sparsity layout of the result of a matrix multiplication between the two input tensors.
195
+
196
+ Args:
197
+ sparsity_layout_x (Tensor): The sparsity layout of the first block-sparse tensor.
198
+ sparsity_layout_y (Tensor): The sparsity layout of the second block-sparse tensor.
199
+
200
+ Returns:
201
+ Tensor: The precise sparsity layout of the result of a matrix multiplication between the two input tensors.
202
+
203
+ """
204
+ return torch.matmul(sparsity_layout_x.to(torch.float), sparsity_layout_y.to(torch.float)).to(torch.bool)
205
+
206
+
207
+ def build_sparsity_layout_matmul_fast(sparsity_layout_x: Tensor, sparsity_layout_y: Tensor):
208
+ """Builds the approximate sparsity layout of the result of a matrix multiplication between the two input tensors.
209
+
210
+ Note:
211
+ This function is faster than the ``build_sparsity_layout_matmul`` function due to the fact that it only checks
212
+ whether at least one of the blocks in either of the vectors participating in the matmul is non-sparse. The
213
+ resulting sparsity layout may thus overestimate the actual sparsity of the result.
214
+
215
+ Args:
216
+ sparsity_layout_x (Tensor): The sparsity layout of the first block-sparse tensor.
217
+ sparsity_layout_y (Tensor): The sparsity layout of the second block-sparse tensor.
218
+
219
+ Returns:
220
+ Tensor: The approximate sparsity layout of the result of a matrix multiplication between the two input tensors.
221
+
222
+ """
223
+ sparsity_layout_x_slice = torch.max(sparsity_layout_x, dim=-1).values.unsqueeze(-1)
224
+ sparsity_layout_y_slice = torch.max(sparsity_layout_y, dim=-2).values.unsqueeze(1)
225
+
226
+ return torch.logical_or(sparsity_layout_x_slice, sparsity_layout_y_slice)
@@ -3,7 +3,7 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
- from blksprs.utils.tools import get_triton_block_size
6
+ from blksprs.utils.tools import get_triton_block_size, stride
7
7
  from blksprs.utils.validation import validate_contiguous, validate_device, \
8
8
  validate_sparsity_block_size, validate_triton_block_size
9
9
 
@@ -44,13 +44,13 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
44
44
  output = torch.zeros(n_sparse_blocks, sparsity_block_size, sparsity_block_size, dtype=x.dtype, device=x.device)
45
45
 
46
46
  x_b, x_c = x.size()
47
- x_b_s, x_c_s = x.stride()
47
+ x_b_s, x_c_s = stride(x)
48
48
  y_b, y_c = y.size()
49
- y_b_s, y_c_s = y.stride()
49
+ y_b_s, y_c_s = stride(y)
50
50
  o_b, o_r, o_c = output.size()
51
- o_b_s, o_r_s, o_c_s = output.stride()
51
+ o_b_s, o_r_s, o_c_s = stride(output)
52
52
  s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
53
- s_lut_o_r_s, s_lut_o_c_s = sparsity_lut_o.stride()
53
+ s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_o)
54
54
 
55
55
  if triton_block_size is None:
56
56
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -3,7 +3,7 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
- from blksprs.utils.tools import get_triton_block_size
6
+ from blksprs.utils.tools import get_triton_block_size, stride
7
7
  from blksprs.utils.validation import validate_contiguous, validate_device, \
8
8
  validate_sparsity_block_size, validate_triton_block_size, validate_dimensions
9
9
 
@@ -35,7 +35,7 @@ def repeat_interleave(x: Tensor, sparsity_layout: Tensor, repeats: int,
35
35
  validate_sparsity_block_size(sparsity_block_size, x)
36
36
  validate_triton_block_size(triton_block_size, sparsity_block_size)
37
37
 
38
- sparsity_layout_output = torch.repeat_interleave(sparsity_layout, 3, dim=0).contiguous()
38
+ sparsity_layout_output = torch.repeat_interleave(sparsity_layout, repeats, dim=0).contiguous()
39
39
 
40
40
  sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
41
41
 
@@ -52,13 +52,13 @@ def repeat_interleave(x: Tensor, sparsity_layout: Tensor, repeats: int,
52
52
  dtype=x.dtype, device=x.device)
53
53
 
54
54
  x_b, x_r, x_c = x.size()
55
- x_b_s, x_r_s, x_c_s = x.stride()
55
+ x_b_s, x_r_s, x_c_s = stride(x)
56
56
  s_lut_r, s_lut_c = sparsity_lut.size()
57
- s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
57
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
58
58
  o_b, o_r, o_c = output.size()
59
- o_b_s, o_r_s, o_c_s = output.stride()
59
+ o_b_s, o_r_s, o_c_s = stride(output)
60
60
  s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
61
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_output.stride()
61
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
62
62
 
63
63
  if triton_block_size is None:
64
64
  triton_block_size = get_triton_block_size(sparsity_block_size)
blksprs/misc/row_wise.py CHANGED
@@ -3,7 +3,7 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
- from blksprs.utils.tools import get_triton_block_size
6
+ from blksprs.utils.tools import get_triton_block_size, stride
7
7
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, validate_sparsity, \
8
8
  validate_sparsity_block_size, validate_triton_block_size
9
9
 
@@ -60,13 +60,13 @@ def row_wise_sum(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
60
60
  device=x.device)
61
61
 
62
62
  x_b, x_r, x_c = x.size()
63
- x_b_s, x_r_s, x_c_s = x.stride()
63
+ x_b_s, x_r_s, x_c_s = stride(x)
64
64
  s_lut_x_r, s_lut_x_c = sparsity_lut.size()
65
- s_lut_x_r_s, s_lut_x_c_s = sparsity_lut.stride()
65
+ s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut)
66
66
  o_b, o_r, o_c = output.size()
67
- o_b_s, o_r_s, o_c_s = output.stride()
67
+ o_b_s, o_r_s, o_c_s = stride(output)
68
68
  s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
69
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_output.stride()
69
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
70
70
 
71
71
  if triton_block_size is None:
72
72
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -183,13 +183,13 @@ def row_wise_max(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
183
183
  device=x.device)
184
184
 
185
185
  x_b, x_r, x_c = x.size()
186
- x_b_s, x_r_s, x_c_s = x.stride()
186
+ x_b_s, x_r_s, x_c_s = stride(x)
187
187
  s_lut_x_r, s_lut_x_c = sparsity_lut.size()
188
- s_lut_x_r_s, s_lut_x_c_s = sparsity_lut.stride()
188
+ s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut)
189
189
  o_b, o_r, o_c = output.size()
190
- o_b_s, o_r_s, o_c_s = output.stride()
190
+ o_b_s, o_r_s, o_c_s = stride(output)
191
191
  s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
192
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_output.stride()
192
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
193
193
 
194
194
  if triton_block_size is None:
195
195
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -290,15 +290,15 @@ def row_wise_add(x: Tensor, sparsity_layout_x: Tensor, y: Tensor,
290
290
  output = torch.empty_like(x)
291
291
 
292
292
  x_b, x_r, x_c = x.size()
293
- x_b_s, x_r_s, x_c_s = x.stride()
293
+ x_b_s, x_r_s, x_c_s = stride(x)
294
294
  s_lut_r, s_lut_c = sparsity_lut.size()
295
- s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
295
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
296
296
  y_b, y_r, y_c = y.size()
297
- y_b_s, y_r_s, y_c_s = y.stride()
297
+ y_b_s, y_r_s, y_c_s = stride(y)
298
298
  s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_rwm.size()
299
- s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = sparsity_layout_rwm.stride()
299
+ s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = stride(sparsity_layout_rwm)
300
300
  o_b, o_r, o_c = output.size()
301
- o_b_s, o_r_s, o_c_s = output.stride()
301
+ o_b_s, o_r_s, o_c_s = stride(output)
302
302
 
303
303
  if triton_block_size is None:
304
304
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -361,7 +361,8 @@ def kernel_blocksparse_row_wise_add(x,
361
361
  rev_idx_spa_s = tl.load(r_lut_y + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
362
362
 
363
363
  if rev_idx_spa_s == -1:
364
- assert False, "Invalid sparsity block"
364
+ tl.device_assert(False)
365
+ return
365
366
 
366
367
  # Load x block
367
368
  blk_x_idx = ((pid_blk * x_b_s) +
blksprs/ops/conversion.py CHANGED
@@ -6,7 +6,7 @@ from torch import Tensor
6
6
  from triton import language as tl
7
7
 
8
8
  from blksprs.layouting.sparsity_layout import build_sparsity_layout_adaption
9
- from blksprs.utils.tools import get_triton_block_size
9
+ from blksprs.utils.tools import get_triton_block_size, stride
10
10
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
11
11
  validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
12
12
 
@@ -65,11 +65,11 @@ class _BlocksparseToDense(torch.autograd.Function):
65
65
  dtype=x.dtype, device=x.device)
66
66
 
67
67
  x_b, x_r, x_c = x.shape
68
- x_b_s, x_r_s, x_c_s = x.stride()
68
+ x_b_s, x_r_s, x_c_s = stride(x)
69
69
  s_l_b, s_l_r, s_l_c = sparsity_layout.size()
70
- s_l_b_s, s_l_r_s, s_l_c_s = sparsity_layout.stride()
70
+ s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout)
71
71
  o_b, o_r, o_c = output.size()
72
- o_b_s, o_r_s, o_c_s = output.stride()
72
+ o_b_s, o_r_s, o_c_s = stride(output)
73
73
 
74
74
  if triton_block_size is None:
75
75
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -190,11 +190,11 @@ class _BlocksparseToSparse(torch.autograd.Function):
190
190
  dtype=x.dtype, device=x.device)
191
191
 
192
192
  x_b, x_r, x_c = x.size()
193
- x_b_s, x_r_s, x_c_s = x.stride()
193
+ x_b_s, x_r_s, x_c_s = stride(x)
194
194
  s_lut_r, s_lut_c = sparsity_lut.size()
195
- s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
195
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
196
196
  o_b, o_r, o_c = output.size()
197
- o_b_s, o_r_s, o_c_s = output.stride()
197
+ o_b_s, o_r_s, o_c_s = stride(output)
198
198
 
199
199
  if triton_block_size is None:
200
200
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -347,13 +347,13 @@ class _BlocksparseAdaptLayout(torch.autograd.Function):
347
347
  dtype=x.dtype, device=x.device)
348
348
 
349
349
  x_b, x_r, x_c = x.size()
350
- x_b_s, x_r_s, x_c_s = x.stride()
350
+ x_b_s, x_r_s, x_c_s = stride(x)
351
351
  s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_from.size()
352
- s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout_from.stride()
352
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_from)
353
353
  o_b, o_r, o_c = output.size()
354
- o_b_s, o_r_s, o_c_s = output.stride()
354
+ o_b_s, o_r_s, o_c_s = stride(output)
355
355
  s_lut_o_r, s_lut_o_c = sparsity_lut_to.size()
356
- s_lut_o_r_s, s_lut_o_c_s = sparsity_lut_to.stride()
356
+ s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_to)
357
357
 
358
358
  if triton_block_size is None:
359
359
  triton_block_size = get_triton_block_size(min_sparsity_block_size)
@@ -3,7 +3,7 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
- from blksprs.utils.tools import get_triton_block_size
6
+ from blksprs.utils.tools import get_triton_block_size, stride
7
7
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
8
8
  validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
9
9
 
@@ -59,15 +59,15 @@ class _BlocksparseGather(torch.autograd.Function):
59
59
  output = torch.empty_like(i, dtype=x.dtype)
60
60
 
61
61
  x_b, x_r, x_c = x.size()
62
- x_b_s, x_r_s, x_c_s = x.stride()
62
+ x_b_s, x_r_s, x_c_s = stride(x)
63
63
  s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
64
- s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout_x.stride()
64
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
65
65
  i_b, i_r, i_c = i.size()
66
- i_b_s, i_r_s, i_c_s = i.stride()
66
+ i_b_s, i_r_s, i_c_s = stride(i)
67
67
  s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
68
- s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
68
+ s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
69
69
  o_b, o_r, o_c = output.size()
70
- o_b_s, o_r_s, o_c_s = output.stride()
70
+ o_b_s, o_r_s, o_c_s = stride(output)
71
71
 
72
72
  if triton_block_size is None:
73
73
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -248,15 +248,15 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
248
248
  dtype=x.dtype, device=x.device)
249
249
 
250
250
  x_b, x_r, x_c = x.size()
251
- x_b_s, x_r_s, x_c_s = x.stride()
251
+ x_b_s, x_r_s, x_c_s = stride(x)
252
252
  s_lut_x_r, s_lut_x_c = sparsity_lut_x.size()
253
- s_lut_x_r_s, s_lut_x_c_s = sparsity_lut_x.stride()
253
+ s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut_x)
254
254
  i_b, i_r, i_c = i.size()
255
- i_b_s, i_r_s, i_c_s = i.stride()
255
+ i_b_s, i_r_s, i_c_s = stride(i)
256
256
  o_b, o_r, o_c = output.size()
257
- o_b_s, o_r_s, o_c_s = output.stride()
257
+ o_b_s, o_r_s, o_c_s = stride(output)
258
258
  s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
259
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_o.stride()
259
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
260
260
 
261
261
  if triton_block_size is None:
262
262
  triton_block_size = get_triton_block_size(sparsity_block_size)
blksprs/ops/exp.py CHANGED
@@ -3,7 +3,7 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
- from blksprs.utils.tools import get_triton_block_size
6
+ from blksprs.utils.tools import get_triton_block_size, stride
7
7
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
8
8
  validate_sparsity_block_size, validate_triton_block_size
9
9
 
@@ -43,9 +43,9 @@ class _BlocksparseExp(torch.autograd.Function):
43
43
  output = torch.empty_like(x)
44
44
 
45
45
  x_b, x_r, x_c = x.shape
46
- x_b_s, x_r_s, x_c_s = x.stride()
46
+ x_b_s, x_r_s, x_c_s = stride(x)
47
47
  o_b, o_r, o_c = output.shape
48
- o_b_s, o_r_s, o_c_s = output.stride()
48
+ o_b_s, o_r_s, o_c_s = stride(output)
49
49
 
50
50
  if triton_block_size is None:
51
51
  triton_block_size = get_triton_block_size(sparsity_block_size)
blksprs/ops/matmul.py CHANGED
@@ -4,7 +4,7 @@ from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
6
  from blksprs.ops.transpose import transpose
7
- from blksprs.utils.tools import get_triton_block_size
7
+ from blksprs.utils.tools import get_triton_block_size, stride
8
8
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
9
9
  validate_sparsity, validate_sparsity_block_size, validate_triton_block_size, validate_dtype_float
10
10
 
@@ -82,17 +82,17 @@ class _BlocksparseMatmulSSS(torch.autograd.Function):
82
82
  dtype=x.dtype, device=x.device)
83
83
 
84
84
  x_b, x_r, x_c = x.size()
85
- x_b_s, x_r_s, x_c_s = x.stride()
85
+ x_b_s, x_r_s, x_c_s = stride(x)
86
86
  s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
87
- s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout_x.stride()
87
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
88
88
  y_b, y_r, y_c = y.size()
89
- y_b_s, y_r_s, y_c_s = y.stride()
89
+ y_b_s, y_r_s, y_c_s = stride(y)
90
90
  s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_y.size()
91
- s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = sparsity_layout_y.stride()
91
+ s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = stride(sparsity_layout_y)
92
92
  o_b, o_r, o_c = output.size()
93
- o_b_s, o_r_s, o_c_s = output.stride()
93
+ o_b_s, o_r_s, o_c_s = stride(output)
94
94
  s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
95
- s_lut_o_r_s, s_lut_o_c_s = sparsity_lut_o.stride()
95
+ s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_o)
96
96
 
97
97
  if triton_block_size is None:
98
98
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -0,0 +1,155 @@
1
+ import torch
2
+ from torch import Tensor
3
+
4
+ from blksprs.ops.repeat import forward_flow
5
+
6
+ from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
7
+ validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
8
+
9
+
10
+ def split(x: Tensor, sparsity_layout: Tensor, partitions: int,
11
+ sparsity_block_size: int, triton_block_size: int = None) -> (Tensor, Tensor):
12
+ """Splits a block-sparse tensor in compressed form along the last dimension into partitions.
13
+
14
+ Args:
15
+ x (Tensor): A block-sparse tensor in compressed form.
16
+ sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
17
+ partitions (int): The number of partitions to split the block-sparse tensor into.
18
+ sparsity_block_size (int): The size of the sparsity blocks.
19
+ triton_block_size (int): The block size to use for the triton kernel (default ``None``).
20
+
21
+ Returns:
22
+ Tensor: The block-sparse tensor split into partitions in compressed form.
23
+ Tensor: The sparsity layout of the output tensor.
24
+
25
+ """
26
+ x = x.contiguous()
27
+
28
+ validate_dimensions(x)
29
+ validate_contiguous(x)
30
+ validate_device(x)
31
+ validate_sparsity(sparsity_block_size, (x, sparsity_layout))
32
+ validate_sparsity_block_size(sparsity_block_size, x)
33
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
34
+
35
+ sparsity_layout_output = (sparsity_layout
36
+ .reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
37
+ sparsity_layout.size(2) // partitions)
38
+ .permute(0, 2, 1, 3)
39
+ .reshape(sparsity_layout.size(0) * partitions, sparsity_layout.size(1),
40
+ sparsity_layout.size(2) // partitions).contiguous())
41
+
42
+ sparsity_lut = torch.nonzero(sparsity_layout_output).contiguous()
43
+
44
+ sparsity_layout_flat = sparsity_layout.reshape(-1)
45
+ sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
46
+ (sparsity_layout_flat == 1) -
47
+ (1 * (sparsity_layout_flat == 0)))
48
+ .reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
49
+ sparsity_layout.size(2) // partitions)
50
+ .permute(0, 2, 1, 3).reshape(-1).contiguous())
51
+
52
+ n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
53
+
54
+ validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
55
+
56
+ return _BlocksparseSplit.apply(x, sparsity_layout_output, sparsity_lut, sparsity_reverse_lut, partitions,
57
+ sparsity_block_size, n_sparse_blocks, triton_block_size), sparsity_layout_output
58
+
59
+
60
+ class _BlocksparseSplit(torch.autograd.Function):
61
+
62
+ @staticmethod
63
+ def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
64
+ num_partitions: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
65
+ ctx.save_for_backward(sparsity_layout_o)
66
+ ctx.num_partitions = num_partitions
67
+
68
+ return forward_flow(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
69
+ n_sparse_blocks, triton_block_size)
70
+
71
+ @staticmethod
72
+ def backward(ctx, grad_output):
73
+ sparsity_layout = ctx.saved_tensors[0]
74
+ num_partitions = ctx.num_partitions
75
+ sparsity_block_size = ctx.sparsity_block_size
76
+ triton_block_size = ctx.triton_block_size
77
+
78
+ return merge(grad_output, sparsity_layout, num_partitions,
79
+ sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None
80
+
81
+
82
+ def merge(x: Tensor, sparsity_layout: Tensor, partitions: int,
83
+ sparsity_block_size: int, triton_block_size: int = None) -> (Tensor, Tensor):
84
+ """Merges the specified partitions of a block-sparse tensor in compressed form along the last dimension.
85
+
86
+ Args:
87
+ x (Tensor): A block-sparse tensor in compressed form.
88
+ sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
89
+ partitions (int): The number of partitions to be merged.
90
+ sparsity_block_size (int): The size of the sparsity blocks.
91
+ triton_block_size (int): The block size to use for the triton kernel (default ``None``).
92
+
93
+ Returns:
94
+ Tensor: The merged block-sparse tensor in compressed form.
95
+ Tensor: The sparsity layout of the output tensor.
96
+
97
+ """
98
+ x = x.contiguous()
99
+
100
+ validate_dimensions(x)
101
+ validate_contiguous(x)
102
+ validate_device(x)
103
+ validate_sparsity(sparsity_block_size, (x, sparsity_layout))
104
+ validate_sparsity_block_size(sparsity_block_size, x)
105
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
106
+
107
+ sparsity_layout_output = (sparsity_layout.reshape(sparsity_layout.size(0) // partitions, partitions,
108
+ sparsity_layout.size(1), sparsity_layout.size(2))
109
+ .permute(0, 2, 1, 3)
110
+ .reshape(sparsity_layout.size(0) // partitions,
111
+ sparsity_layout.size(1), sparsity_layout.size(2) * partitions).contiguous())
112
+
113
+ sparsity_lut = torch.nonzero(sparsity_layout_output).contiguous()
114
+
115
+ sparsity_layout_flat = sparsity_layout.reshape(-1)
116
+ sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
117
+ (sparsity_layout_flat == 1) -
118
+ (1 * (sparsity_layout_flat == 0)))
119
+ .reshape(sparsity_layout.size(0) // partitions, partitions,
120
+ sparsity_layout.size(1), sparsity_layout.size(2))
121
+ .permute(0, 2, 1, 3)
122
+ .reshape(sparsity_layout.size(0) // partitions,
123
+ sparsity_layout.size(1), sparsity_layout.size(2) * partitions)
124
+ .reshape(-1).contiguous())
125
+
126
+ n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
127
+
128
+ validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
129
+
130
+ return _BlocksparseMerge.apply(x, sparsity_layout_output, sparsity_lut, sparsity_reverse_lut, partitions,
131
+ sparsity_block_size, n_sparse_blocks, triton_block_size), sparsity_layout_output
132
+
133
+
134
+ class _BlocksparseMerge(torch.autograd.Function):
135
+
136
+ @staticmethod
137
+ def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
138
+ num_partitions: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
139
+ ctx.save_for_backward(sparsity_layout_o)
140
+ ctx.num_partitions = num_partitions
141
+
142
+ return forward_flow(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
143
+ n_sparse_blocks, triton_block_size)
144
+
145
+ @staticmethod
146
+ def backward(ctx, grad_output):
147
+ sparsity_layout = ctx.saved_tensors[0]
148
+ num_partitions = ctx.num_partitions
149
+ sparsity_block_size = ctx.sparsity_block_size
150
+ triton_block_size = ctx.triton_block_size
151
+
152
+ return split(grad_output, sparsity_layout, num_partitions,
153
+ sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None
154
+
155
+
blksprs/ops/repeat.py ADDED
@@ -0,0 +1,241 @@
1
+ import torch
2
+ import triton
3
+ from triton import language as tl
4
+ from torch import Tensor
5
+
6
+ from blksprs.utils.tools import get_triton_block_size, stride
7
+ from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
8
+ validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
9
+
10
+
11
+ def repeat(x: Tensor, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
12
+ sparsity_block_size: int, sparsity_layout_output: Tensor = None, triton_block_size: int = None) -> (
13
+ Tensor, Tensor):
14
+ x = x.contiguous()
15
+
16
+ validate_dimensions(x)
17
+ validate_contiguous(x)
18
+ validate_device(x)
19
+ validate_sparsity(sparsity_block_size, (x, sparsity_layout_x))
20
+ validate_sparsity_block_size(sparsity_block_size, x)
21
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
22
+
23
+ sparsity_layout_o = sparsity_layout_x.repeat(repeats[0], repeats[1], repeats[2])
24
+
25
+ if sparsity_layout_output is not None:
26
+ sparsity_layout_o = torch.logical_and(sparsity_layout_o, sparsity_layout_output)
27
+
28
+ sparsity_lut = torch.nonzero(sparsity_layout_o).contiguous()
29
+
30
+ sparsity_layout_flat = sparsity_layout_x.reshape(-1)
31
+ sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
32
+ (sparsity_layout_flat == 1) -
33
+ (1 * (sparsity_layout_flat == 0)))
34
+ .reshape(sparsity_layout_x.size())
35
+ .repeat(repeats[0], repeats[1], repeats[2])
36
+ .reshape(-1).contiguous())
37
+
38
+ n_sparse_blocks = torch.sum(sparsity_layout_o.to(torch.int)).item()
39
+
40
+ validate_contiguous(sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
41
+
42
+ return _BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
43
+ sparsity_block_size, n_sparse_blocks, triton_block_size), sparsity_layout_o
44
+
45
+
46
+ class _BlocksparseRepeat(torch.autograd.Function):
47
+
48
+ @staticmethod
49
+ def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor,
50
+ sparsity_reverse_lut: Tensor,
51
+ sparsity_block_size: int, n_sparse_blocks: int,
52
+ triton_block_size: int) -> Tensor:
53
+ ctx.save_for_backward(sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
54
+ ctx.x_size = x.size()
55
+ ctx.x_stride = stride(x)
56
+
57
+ return forward_flow(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
58
+ n_sparse_blocks, triton_block_size)
59
+
60
+ @staticmethod
61
+ def backward(ctx, grad_output):
62
+ sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut = ctx.saved_tensors
63
+ x_size = ctx.x_size
64
+ x_stride = ctx.x_stride
65
+ sparsity_block_size = ctx.sparsity_block_size
66
+ triton_block_size = ctx.triton_block_size
67
+
68
+ n_sparse_blocks = torch.sum(sparsity_layout_x.to(torch.int)).item()
69
+
70
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
71
+ dtype=grad_output.dtype, device=grad_output.device)
72
+
73
+ x_b, x_r, x_c = grad_output.size()
74
+ x_b_s, x_r_s, x_c_s = stride(grad_output)
75
+ s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_o.size()
76
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_o)
77
+ s_lut_r, s_lut_c = sparsity_lut.size()
78
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
79
+ o_b, o_r, o_c = x_size
80
+ o_b_s, o_r_s, o_c_s = x_stride
81
+
82
+ if triton_block_size is None:
83
+ triton_block_size = get_triton_block_size(sparsity_block_size)
84
+
85
+ triton_grid = lambda meta: [x_b,
86
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
87
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
88
+
89
+ (kernel_blocksparse_flow_push[triton_grid]
90
+ (grad_output,
91
+ x_b, x_b_s, x_r_s, x_c_s,
92
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
93
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
94
+ sparsity_reverse_lut,
95
+ output,
96
+ o_b, o_b_s, o_r_s, o_c_s,
97
+ triton_block_size))
98
+
99
+ return output, None, None, None, None, None, None, None
100
+
101
+
102
+ @triton.jit
103
+ def kernel_blocksparse_flow_pull(x,
104
+ x_b, x_b_s, x_r_s, x_c_s,
105
+ o,
106
+ o_b, o_b_s, o_r_s, o_c_s,
107
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
108
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
109
+ r_lut,
110
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
111
+ # Get triton block indices
112
+ pid_blk = tl.program_id(axis=0)
113
+ pid_row = tl.program_id(axis=1)
114
+ pid_col = tl.program_id(axis=2)
115
+
116
+ # Get sparsity index of current output block consisting of its batch, row, and column index
117
+ spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
118
+ spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
119
+ spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
120
+
121
+ spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
122
+ spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
123
+ spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
124
+
125
+ spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
126
+ spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
127
+ spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
128
+
129
+ # Get reverse sparsity index
130
+ rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
131
+ spa_row * s_l_o_r_s +
132
+ spa_col * s_l_o_c_s)
133
+ rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
134
+ rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
135
+
136
+ if rev_idx_spa == -1:
137
+ tl.device_assert(False)
138
+ return
139
+
140
+ blk_x_idx = (rev_idx_spa * x_b_s +
141
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
142
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
143
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
144
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
145
+
146
+ blk_o_idx = (pid_blk * o_b_s +
147
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
148
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
149
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
150
+ tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
151
+
152
+
153
+ @triton.jit
154
+ def kernel_blocksparse_flow_push(x,
155
+ x_b, x_b_s, x_r_s, x_c_s,
156
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
157
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
158
+ r_lut,
159
+ o,
160
+ o_b, o_b_s, o_r_s, o_c_s,
161
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
162
+ # Get triton block indices
163
+ pid_blk = tl.program_id(axis=0)
164
+ pid_row = tl.program_id(axis=1)
165
+ pid_col = tl.program_id(axis=2)
166
+
167
+ # Get sparsity index of current input block consisting of its batch, row, and column index
168
+ spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
169
+ spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
170
+ spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
171
+
172
+ spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
173
+ spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
174
+ spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
175
+
176
+ spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
177
+ spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
178
+ spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
179
+
180
+ # Get reverse sparsity index
181
+ rev_idx_spa_idx = (spa_bat * s_l_x_b_s +
182
+ spa_row * s_l_x_r_s +
183
+ spa_col * s_l_x_c_s)
184
+ rev_idx_spa_msk = (rev_idx_spa_idx < s_l_x_b * s_l_x_b_s)
185
+ rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
186
+
187
+ if rev_idx_spa == -1:
188
+ tl.device_assert(False)
189
+ return
190
+
191
+ blk_x_idx = (pid_blk * x_b_s +
192
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
193
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
194
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
195
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
196
+
197
+ blk_o_idx = (rev_idx_spa * o_b_s +
198
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
199
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
200
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
201
+ tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
202
+
203
+
204
+ def forward_flow(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
205
+ sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
206
+ output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
207
+ dtype=x.dtype, device=x.device)
208
+ output = torch.zeros_like(output)
209
+
210
+ x_b, x_r, x_c = x.size()
211
+ x_b_s, x_r_s, x_c_s = stride(x)
212
+ o_b, o_r, o_c = output.size()
213
+ o_b_s, o_r_s, o_c_s = stride(output)
214
+ s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
215
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
216
+ s_lut_r, s_lut_c = sparsity_lut.size()
217
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
218
+ asdf = torch.tensor(sparsity_lut).stride()
219
+
220
+ if triton_block_size is None:
221
+ triton_block_size = get_triton_block_size(sparsity_block_size)
222
+
223
+ triton_grid = lambda meta: [o_b,
224
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
225
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
226
+
227
+ (kernel_blocksparse_flow_pull[triton_grid]
228
+ (x,
229
+ x_b, x_b_s, x_r_s, x_c_s,
230
+ output,
231
+ o_b, o_b_s, o_r_s, o_c_s,
232
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
233
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
234
+ sparsity_reverse_lut,
235
+ triton_block_size))
236
+
237
+ # Save for backward pass
238
+ ctx.sparsity_block_size = sparsity_block_size
239
+ ctx.triton_block_size = triton_block_size
240
+
241
+ return output
blksprs/ops/softmax.py CHANGED
@@ -5,7 +5,7 @@ from triton import language as tl
5
5
 
6
6
  from blksprs.ops.exp import exp
7
7
  from blksprs.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
8
- from blksprs.utils.tools import get_triton_block_size
8
+ from blksprs.utils.tools import get_triton_block_size, stride
9
9
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
10
10
  validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
11
11
 
@@ -61,9 +61,9 @@ class _BlocksparseSoftmax(torch.autograd.Function):
61
61
  output = torch.empty_like(x)
62
62
 
63
63
  x_b, x_r, x_c = x.size()
64
- x_b_s, x_r_s, x_c_s = x.stride()
64
+ x_b_s, x_r_s, x_c_s = stride(x)
65
65
  s_lut_r, s_lut_c = sparsity_lut.size()
66
- s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
66
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
67
67
  o_b, o_r, o_c = output.size()
68
68
 
69
69
  x_row_wise_max, sparsity_layout_rwm = row_wise_max(x, sparsity_layout, sparsity_block_size,
@@ -76,9 +76,9 @@ class _BlocksparseSoftmax(torch.autograd.Function):
76
76
  triton_block_size=triton_block_size)
77
77
 
78
78
  s_b, s_r, s_c = x_exp_row_wise_sum.shape
79
- s_b_s, s_r_s, s_c_s = x_exp_row_wise_sum.stride()
79
+ s_b_s, s_r_s, s_c_s = stride(x_exp_row_wise_sum)
80
80
  s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_rws.shape
81
- s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = sparsity_layout_rws.stride()
81
+ s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_rws)
82
82
 
83
83
  if triton_block_size is None:
84
84
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -119,13 +119,13 @@ class _BlocksparseSoftmax(torch.autograd.Function):
119
119
  (1 * (sparsity_layout_s_flat == 0)))
120
120
 
121
121
  o_b, o_r, o_c = o.size()
122
- o_b_s, o_r_s, o_c_s = o.stride()
122
+ o_b_s, o_r_s, o_c_s = stride(o)
123
123
  s_lut_r, s_lut_c = sparsity_lut.size()
124
- s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
124
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
125
125
  s_b, s_r, s_c = s.size()
126
- s_b_s, s_r_s, s_c_s = s.stride()
126
+ s_b_s, s_r_s, s_c_s = stride(s)
127
127
  s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
128
- s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = sparsity_layout_s.stride()
128
+ s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_s)
129
129
 
130
130
  grad_x = torch.empty_like(o, dtype=torch.float)
131
131
 
@@ -181,7 +181,8 @@ class _BlocksparseSoftmax(torch.autograd.Function):
181
181
  rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
182
182
 
183
183
  if rev_idx_spa_s == -1:
184
- assert False, "Invalid sparsity block"
184
+ tl.device_assert(False)
185
+ return
185
186
 
186
187
  # Load x block
187
188
  blk_x_idx = ((pid_blk * x_b_s) +
blksprs/ops/transpose.py CHANGED
@@ -3,7 +3,7 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
- from blksprs.utils.tools import get_triton_block_size
6
+ from blksprs.utils.tools import get_triton_block_size, stride
7
7
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
8
8
  validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
9
9
 
@@ -56,20 +56,20 @@ def transpose(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, trit
56
56
  class _BlocksparseTranspose(torch.autograd.Function):
57
57
 
58
58
  @staticmethod
59
- def forward(ctx, x: Tensor,
60
- sparsity_layout: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor, sparsity_block_size: int,
61
- n_sparse_blocks: int, triton_block_size: int) -> (Tensor, Tensor):
59
+ def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
60
+ sparsity_block_size: int,
61
+ n_sparse_blocks: int, triton_block_size: int) -> Tensor:
62
62
  output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
63
63
  dtype=x.dtype, device=x.device)
64
64
 
65
65
  x_b, x_r, x_c = x.size()
66
- x_b_s, x_r_s, x_c_s = x.stride()
67
- s_l_b, s_l_r, s_l_c = sparsity_layout.size()
68
- s_l_b_s, s_l_r_s, s_l_c_s = sparsity_layout.stride()
66
+ x_b_s, x_r_s, x_c_s = stride(x)
67
+ s_l_b, s_l_r, s_l_c = sparsity_layout_o.size()
68
+ s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout_o)
69
69
  s_lut_r, s_lut_c = sparsity_lut.shape
70
- s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
70
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
71
71
  o_b, o_r, o_c = output.size()
72
- o_b_s, o_r_s, o_c_s = output.stride()
72
+ o_b_s, o_r_s, o_c_s = stride(output)
73
73
 
74
74
  if triton_block_size is None:
75
75
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -89,8 +89,7 @@ class _BlocksparseTranspose(torch.autograd.Function):
89
89
  triton_block_size))
90
90
 
91
91
  # Save for backward pass
92
- ctx.save_for_backward(sparsity_layout)
93
- ctx.sparsity_layout = sparsity_layout
92
+ ctx.save_for_backward(sparsity_layout_o)
94
93
  ctx.sparsity_block_size = sparsity_block_size
95
94
  ctx.triton_block_size = triton_block_size
96
95
 
@@ -141,7 +140,8 @@ class _BlocksparseTranspose(torch.autograd.Function):
141
140
  rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
142
141
 
143
142
  if rev_idx_spa == -1:
144
- assert False, "Invalid sparsity block"
143
+ tl.device_assert(False)
144
+ return
145
145
 
146
146
  blk_x_idx = (rev_idx_spa * x_b_s +
147
147
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
blksprs/utils/tools.py CHANGED
@@ -23,3 +23,6 @@ def get_triton_block_size(sparsity_block_size: int, limit: int = 128):
23
23
 
24
24
  def disable_validation():
25
25
  _set_skip_validation(True)
26
+
27
+ def stride(x: Tensor):
28
+ return x.view(x.shape).stride()
@@ -63,6 +63,8 @@ def validate_sparsity(sparsity_block_size: int, *tensor_sparsity_layout_tuples:
63
63
  for (tensor, sparsity_layout) in tensor_sparsity_layout_tuples:
64
64
  _validate_sparsity_layout_values(sparsity_layout)
65
65
 
66
+ if not sparsity_layout.dim() == 3:
67
+ raise ValueError("Sparsity layout must have exactly 3 dimensions")
66
68
  if not (tensor.size(-1) == tensor.size(-2) == sparsity_block_size):
67
69
  raise ValueError("Blocks not conforming to sparsity block size")
68
70
  if not tensor.size(0) == torch.sum(sparsity_layout.reshape(-1)):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: blksprs
3
- Version: 1.5
3
+ Version: 1.7
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -31,9 +31,11 @@ Currently supported operations (includes gradient calculation):
31
31
  - Sparse matrix multiplication (_supports any combination of sparse and dense matrices due to support
32
32
  for `sparse = sparse @ sparse` matmul_)
33
33
  - Softmax
34
- - Transposition
34
+ - Transpose
35
35
  - Gather
36
36
  - Scatter (_supports either no reduction or summation, gradients are only available for summation_)
37
+ - Repeat (_supports target sparsity layout_)
38
+ - Splitting and merging of matrices along the last dimension
37
39
  - Conversion to and from sparse form
38
40
  - Conversion to different sparsity layouts and different sparsity block sizes
39
41
 
@@ -63,7 +65,7 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
63
65
 
64
66
  ### Dependencies
65
67
 
66
- - [PyTorch](https://pytorch.org/) (built with v2.4.0)
68
+ - [PyTorch](https://pytorch.org/) (built with v2.5.0)
67
69
  - _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
68
70
 
69
71
  ## Changelog
@@ -0,0 +1,22 @@
1
+ blksprs/__init__.py,sha256=FpvHMo1W6XvuiA1PMDp2_EJz-Xwc15cHz7WeIYXQJC4,1019
2
+ blksprs/experimental/distribution_mdi.py,sha256=jE_SbB0SMGPcCoeM0699cceWAikBkBpGM_8Oo5A7Ets,20083
3
+ blksprs/layouting/distribution_layout.py,sha256=nCR3fCF6eNpi97DI6MMLF2hq_9Lwyo6_aUCIZiJfiX4,4170
4
+ blksprs/layouting/sparsity_layout.py,sha256=JNwbJ4L-418zCiCbt-vLfZ0xU7ReP0zr_tLHs_dytRA,9686
5
+ blksprs/misc/broadcast_ops.py,sha256=0RLnLMYV7GAPI2YL8RotcxjIUSBZKGxdVcsGaJFeL_I,5327
6
+ blksprs/misc/repeat_interleave.py,sha256=P5gfsZXuemLiAijUZfFkBFgMjlU9rlPEzai1xeGOFnw,5678
7
+ blksprs/misc/row_wise.py,sha256=0vDJA8uCocmebSIPIbFeND5_PQIE10pUj3DBOQXlTvE,16888
8
+ blksprs/ops/conversion.py,sha256=iyKIlkWGrK6q55KNRM8N6rY1k4b9k8QUkUl158yZUDA,21330
9
+ blksprs/ops/distribution.py,sha256=V3TK5SlNT_JdGHNaDNl-U4U5vwAYsgkAOg4eTmYxbuA,16877
10
+ blksprs/ops/exp.py,sha256=cdF0s93Q9iucIXuEE3howsB0N6D60xgvem7C-a-yiGI,3704
11
+ blksprs/ops/matmul.py,sha256=uqVe6Dz2aaCbCglM1uS2eRHVKh7PQcuecaIBWFubPEw,11256
12
+ blksprs/ops/partitioning.py,sha256=nBRZzfi3XYAhDLEBzYflQkvGa3MIZ-qNeIlrZ16k44g,7533
13
+ blksprs/ops/repeat.py,sha256=6Wa6GG9Cx6rJXuFpvmOe5hHwYd3l9UYMosKEDsbh9XI,10408
14
+ blksprs/ops/softmax.py,sha256=2dMLbkHNH18jSJmkgOJvZOKwWHhuUogAVCWv2Bwc3oQ,11995
15
+ blksprs/ops/transpose.py,sha256=67pDdCEb7r-Xifupl82fBKAYsxKcCUDy--cPPfduRvU,6761
16
+ blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
17
+ blksprs/utils/tools.py,sha256=S3836Zuc-BMigv-5mLTjRznCzuaF6oYW-Ir9zzUnr3o,655
18
+ blksprs/utils/validation.py,sha256=h2oki3xC5qLWZR4-W5QIna-wVSXvRehQEH-ynrOciVE,3467
19
+ blksprs-1.7.dist-info/METADATA,sha256=raZ3ycSMUEAW71bwm-807d_dse44qKdSkWMhH4GI2Qg,7709
20
+ blksprs-1.7.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
21
+ blksprs-1.7.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
22
+ blksprs-1.7.dist-info/RECORD,,
@@ -1,20 +0,0 @@
1
- blksprs/__init__.py,sha256=OY9ofdbzBGsvY6hx0oLCrSszlJFdMns9x7gKE0asFI0,919
2
- blksprs/experimental/distribution_mdi.py,sha256=shu-3Nt7nkaLIb4O2kSajC8Lh7IWFXO9rsjzP14ASYA,20088
3
- blksprs/layouting/distribution_layout.py,sha256=Zv-b2t5VOvW6-ejdX42kUV7X1yYsvDCY_PXFE_wKwi0,4165
4
- blksprs/layouting/sparsity_layout.py,sha256=vZL8r5LkMwILYYqTYPZcN_NYFJuVFIB6mmBkdtRyXmI,7893
5
- blksprs/misc/broadcast_ops.py,sha256=ahm7_lI12bJ6VTKRuSkwEeaEYWRY-BeMIOhtei35zpQ,5323
6
- blksprs/misc/repeat_interleave.py,sha256=KJeapmxbpA7zGFfa5hUhCGrk4aFmhOhlMw-hbTh9PLI,5668
7
- blksprs/misc/row_wise.py,sha256=1UtjLplrGx1FkxhzQ2hjSBBY11ToLQs0JiLaXKRAkL4,16893
8
- blksprs/ops/conversion.py,sha256=vuiNwrwyuGI6H4PKrS_UHI7OKWJwNZd2i3LSjf6RetU,21332
9
- blksprs/ops/distribution.py,sha256=KhtHRVcv4_woyNlldAjIWF-7021-KX-xyIcN6rE-UgE,16879
10
- blksprs/ops/exp.py,sha256=CVWVq_emO2CnS_xk6Unx67P7EI7IL26dwtsmBJZOLzQ,3698
11
- blksprs/ops/matmul.py,sha256=743XeD5M4iUv28sYf7q6mVXDd4jZpV04JAx8bF7hWkw,11254
12
- blksprs/ops/softmax.py,sha256=cs1utM6UCzHhdJpf-ZysBr6CwbjI-5aQG0ahYY37Zy0,11991
13
- blksprs/ops/transpose.py,sha256=Ru4YKyg796WT6OnDSTCYG45tMmdgvju3hMFzkwsJnO8,6801
14
- blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
15
- blksprs/utils/tools.py,sha256=JAuwsLISr_hcvxIgUVvKz5ZPf9M5ycquplsBU5dVfDc,596
16
- blksprs/utils/validation.py,sha256=rP6yr-C2ghXfJEERry_pfvVJ0g0VyqV4sL4HkBRlJg8,3345
17
- blksprs-1.5.dist-info/METADATA,sha256=dql0_6s1Vfdnx6sLFusayZWSeU9uxvfAjBDdLPk43so,7607
18
- blksprs-1.5.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
19
- blksprs-1.5.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
20
- blksprs-1.5.dist-info/RECORD,,
File without changes