blksprs 1.6.1__tar.gz → 1.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. {blksprs-1.6.1 → blksprs-1.8}/PKG-INFO +12 -5
  2. {blksprs-1.6.1 → blksprs-1.8}/README.md +11 -4
  3. {blksprs-1.6.1 → blksprs-1.8}/blksprs/__init__.py +11 -6
  4. {blksprs-1.6.1 → blksprs-1.8}/blksprs/experimental/distribution_mdi.py +14 -14
  5. {blksprs-1.6.1 → blksprs-1.8}/blksprs/layouting/distribution_layout.py +4 -4
  6. {blksprs-1.6.1 → blksprs-1.8}/blksprs/layouting/sparsity_layout.py +6 -6
  7. {blksprs-1.6.1 → blksprs-1.8}/blksprs/misc/broadcast_ops.py +5 -5
  8. {blksprs-1.6.1/blksprs/ops → blksprs-1.8/blksprs/misc}/exp.py +3 -3
  9. {blksprs-1.6.1/blksprs/ops → blksprs-1.8/blksprs/misc}/partitioning.py +9 -98
  10. {blksprs-1.6.1 → blksprs-1.8}/blksprs/misc/row_wise.py +16 -15
  11. {blksprs-1.6.1 → blksprs-1.8}/blksprs/ops/conversion.py +23 -12
  12. {blksprs-1.6.1 → blksprs-1.8}/blksprs/ops/distribution.py +11 -11
  13. {blksprs-1.6.1 → blksprs-1.8}/blksprs/ops/matmul.py +7 -7
  14. blksprs-1.8/blksprs/ops/repeat.py +322 -0
  15. {blksprs-1.6.1 → blksprs-1.8}/blksprs/ops/softmax.py +12 -11
  16. {blksprs-1.6.1 → blksprs-1.8}/blksprs/ops/transpose.py +7 -6
  17. {blksprs-1.6.1 → blksprs-1.8}/blksprs/utils/tools.py +3 -0
  18. {blksprs-1.6.1 → blksprs-1.8}/blksprs/utils/validation.py +20 -1
  19. {blksprs-1.6.1 → blksprs-1.8}/blksprs.egg-info/PKG-INFO +12 -5
  20. {blksprs-1.6.1 → blksprs-1.8}/blksprs.egg-info/SOURCES.txt +3 -3
  21. {blksprs-1.6.1 → blksprs-1.8}/pyproject.toml +1 -1
  22. blksprs-1.6.1/blksprs/misc/repeat_interleave.py +0 -132
  23. {blksprs-1.6.1 → blksprs-1.8}/blksprs/utils/benchmarking.py +0 -0
  24. {blksprs-1.6.1 → blksprs-1.8}/blksprs.egg-info/dependency_links.txt +0 -0
  25. {blksprs-1.6.1 → blksprs-1.8}/blksprs.egg-info/requires.txt +0 -0
  26. {blksprs-1.6.1 → blksprs-1.8}/blksprs.egg-info/top_level.txt +0 -0
  27. {blksprs-1.6.1 → blksprs-1.8}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: blksprs
3
- Version: 1.6.1
3
+ Version: 1.8
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -28,12 +28,13 @@ A lightweight and efficient library for operations on block-sparse matrices in P
28
28
 
29
29
  Currently supported operations (includes gradient calculation):
30
30
 
31
- - Sparse matrix multiplication (_supports any combination of sparse and dense matrices due to support
32
- for `sparse = sparse @ sparse` matmul_)
31
+ - Matrix multiplication
33
32
  - Softmax
34
33
  - Transpose
35
34
  - Gather
36
35
  - Scatter (_supports either no reduction or summation, gradients are only available for summation_)
36
+ - Repeat (_supports target sparsity layout_)
37
+ - Repeat Interleave (_supports target sparsity layout_)
37
38
  - Splitting and merging of matrices along the last dimension
38
39
  - Conversion to and from sparse form
39
40
  - Conversion to different sparsity layouts and different sparsity block sizes
@@ -50,8 +51,14 @@ These include, e.g.,
50
51
  Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to
51
52
  match.
52
53
 
54
+ Further helpful operations (included in the ``bs.misc`` module) that do **not** support gradient calculation include:
55
+
56
+ - Row-wise sum, max, addition, and subtraction
57
+ - Broadcast addition and subtraction between slices
58
+
53
59
  Furthermore, the library provides a set of utility functions for the creation of sparsity layouts based on existing
54
- dense tensors.
60
+ dense tensors and for the scatter operation (module ``bs.layout``), as well as utility functions to ensure correct input
61
+ dimensionality (module ``bs.util``).
55
62
 
56
63
  ## Installation
57
64
 
@@ -64,7 +71,7 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
64
71
 
65
72
  ### Dependencies
66
73
 
67
- - [PyTorch](https://pytorch.org/) (built with v2.4.0)
74
+ - [PyTorch](https://pytorch.org/) (built with v2.5.0)
68
75
  - _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
69
76
 
70
77
  ## Changelog
@@ -9,12 +9,13 @@ A lightweight and efficient library for operations on block-sparse matrices in P
9
9
 
10
10
  Currently supported operations (includes gradient calculation):
11
11
 
12
- - Sparse matrix multiplication (_supports any combination of sparse and dense matrices due to support
13
- for `sparse = sparse @ sparse` matmul_)
12
+ - Matrix multiplication
14
13
  - Softmax
15
14
  - Transpose
16
15
  - Gather
17
16
  - Scatter (_supports either no reduction or summation, gradients are only available for summation_)
17
+ - Repeat (_supports target sparsity layout_)
18
+ - Repeat Interleave (_supports target sparsity layout_)
18
19
  - Splitting and merging of matrices along the last dimension
19
20
  - Conversion to and from sparse form
20
21
  - Conversion to different sparsity layouts and different sparsity block sizes
@@ -31,8 +32,14 @@ These include, e.g.,
31
32
  Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to
32
33
  match.
33
34
 
35
+ Further helpful operations (included in the ``bs.misc`` module) that do **not** support gradient calculation include:
36
+
37
+ - Row-wise sum, max, addition, and subtraction
38
+ - Broadcast addition and subtraction between slices
39
+
34
40
  Furthermore, the library provides a set of utility functions for the creation of sparsity layouts based on existing
35
- dense tensors.
41
+ dense tensors and for the scatter operation (module ``bs.layout``), as well as utility functions to ensure correct input
42
+ dimensionality (module ``bs.util``).
36
43
 
37
44
  ## Installation
38
45
 
@@ -45,7 +52,7 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
45
52
 
46
53
  ### Dependencies
47
54
 
48
- - [PyTorch](https://pytorch.org/) (built with v2.4.0)
55
+ - [PyTorch](https://pytorch.org/) (built with v2.5.0)
49
56
  - _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
50
57
 
51
58
  ## Changelog
@@ -1,22 +1,27 @@
1
- from blksprs.ops.conversion import to_dense, to_sparse
1
+ from blksprs.ops.conversion import to_dense, to_sparse, from_blksprs, to_blksprs
2
2
  from blksprs.ops.distribution import gather, scatter, scatter_reduce
3
- from blksprs.ops.exp import exp
4
3
  from blksprs.ops.matmul import matmul
5
4
  from blksprs.ops.softmax import softmax
6
5
  from blksprs.ops.transpose import transpose
7
- from blksprs.ops.partitioning import split, merge
6
+ from blksprs.ops.repeat import repeat, repeat_interleave
7
+ from blksprs.misc.partitioning import split, merge
8
+
8
9
 
9
10
  class layout:
10
11
  from blksprs.layouting.distribution_layout import build_distribution_layout
11
- from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption, build_sparsity_layout_matmul
12
+ from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption, \
13
+ build_sparsity_layout_matmul
14
+
12
15
 
13
16
  class misc:
14
17
  from blksprs.misc.broadcast_ops import broadcast_add, broadcast_sub
15
- from blksprs.misc.repeat_interleave import repeat_interleave
18
+ from blksprs.misc.exp import exp
16
19
  from blksprs.misc.row_wise import row_wise_sum, row_wise_max, row_wise_add, row_wise_sub
17
20
 
21
+
18
22
  class util:
19
23
  from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse, disable_validation
20
24
 
25
+
21
26
  class experimental:
22
- from blksprs.experimental.distribution_mdi import gather_mdi, scatter_reduce_mdi
27
+ from blksprs.experimental.distribution_mdi import gather_mdi, scatter_reduce_mdi
@@ -3,7 +3,7 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
- from blksprs.utils.tools import get_triton_block_size
6
+ from blksprs.utils.tools import get_triton_block_size, stride
7
7
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
8
8
  validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
9
9
 
@@ -51,15 +51,15 @@ class _BlocksparseGatherMDI(torch.autograd.Function):
51
51
  output = torch.empty_like(idx_col, dtype=x.dtype)
52
52
 
53
53
  x_b, x_r, x_c = x.size()
54
- x_b_s, x_r_s, x_c_s = x.stride()
54
+ x_b_s, x_r_s, x_c_s = stride(x)
55
55
  s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
56
- s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout_x.stride()
56
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
57
57
  i_b, i_r, i_c = idx_col.size()
58
- i_b_s, i_r_s, i_c_s = idx_col.stride()
58
+ i_b_s, i_r_s, i_c_s = stride(idx_col)
59
59
  s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
60
- s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
60
+ s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
61
61
  o_b, o_r, o_c = output.size()
62
- o_b_s, o_r_s, o_c_s = output.stride()
62
+ o_b_s, o_r_s, o_c_s = stride(output)
63
63
 
64
64
  if triton_block_size is None:
65
65
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -224,15 +224,15 @@ class _BlocksparseScatterReduceMDI(torch.autograd.Function):
224
224
  dtype=x.dtype, device=x.device)
225
225
 
226
226
  x_b, x_r, x_c = x.size()
227
- x_b_s, x_r_s, x_c_s = x.stride()
227
+ x_b_s, x_r_s, x_c_s = stride(x)
228
228
  s_lut_x_r, s_lut_x_c = sparsity_lut_x.size()
229
- s_lut_x_r_s, s_lut_x_c_s = sparsity_lut_x.stride()
229
+ s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut_x)
230
230
  i_b, i_r, i_c = idx_col.size()
231
- i_b_s, i_r_s, i_c_s = idx_col.stride()
231
+ i_b_s, i_r_s, i_c_s = stride(idx_col)
232
232
  o_b, o_r, o_c = output.size()
233
- o_b_s, o_r_s, o_c_s = output.stride()
233
+ o_b_s, o_r_s, o_c_s = stride(output)
234
234
  s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
235
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_o.stride()
235
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
236
236
 
237
237
  if triton_block_size is None:
238
238
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -366,11 +366,11 @@ def build_distribution_layout_mdi(idx_bat: Tensor, idx_row: Tensor, idx_col: Ten
366
366
  dtype=torch.bool, device=idx_col.device)
367
367
 
368
368
  i_b, i_r, i_c = idx_col.size()
369
- i_b_s, i_r_s, i_c_s = idx_col.stride()
369
+ i_b_s, i_r_s, i_c_s = stride(idx_col)
370
370
  s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
371
- s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
371
+ s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
372
372
  o_b, o_r, o_c = output.size()
373
- o_b_s, o_r_s, o_c_s = output.stride()
373
+ o_b_s, o_r_s, o_c_s = stride(output)
374
374
 
375
375
  if triton_block_size is None:
376
376
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -3,7 +3,7 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
- from blksprs.utils.tools import get_triton_block_size
6
+ from blksprs.utils.tools import get_triton_block_size, stride
7
7
  from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
8
8
  validate_contiguous
9
9
 
@@ -34,11 +34,11 @@ def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
34
34
  dtype=torch.bool, device=indices.device)
35
35
 
36
36
  i_b, i_r, i_c = indices.size()
37
- i_b_s, i_r_s, i_c_s = indices.stride()
37
+ i_b_s, i_r_s, i_c_s = stride(indices)
38
38
  s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
39
- s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
39
+ s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
40
40
  o_b, o_r, o_c = output.size()
41
- o_b_s, o_r_s, o_c_s = output.stride()
41
+ o_b_s, o_r_s, o_c_s = stride(output)
42
42
 
43
43
  if triton_block_size is None:
44
44
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -5,7 +5,7 @@ import triton
5
5
  from torch import Tensor
6
6
  from triton import language as tl
7
7
 
8
- from blksprs.utils.tools import get_triton_block_size
8
+ from blksprs.utils.tools import get_triton_block_size, stride
9
9
  from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
10
10
  validate_contiguous, validate_sparsity, validate_sparsity_block_size
11
11
 
@@ -30,9 +30,9 @@ def build_sparsity_layout(x: Tensor, sparsity_block_size: int, triton_block_size
30
30
  dtype=torch.bool, device=x.device)
31
31
 
32
32
  x_b, x_r, x_c = x.size()
33
- x_b_s, x_r_s, x_c_s = x.stride()
33
+ x_b_s, x_r_s, x_c_s = stride(x)
34
34
  o_b, o_r, o_c = output.size()
35
- o_b_s, o_r_s, o_c_s = output.stride()
35
+ o_b_s, o_r_s, o_c_s = stride(output)
36
36
 
37
37
  if triton_block_size is None:
38
38
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -120,10 +120,10 @@ def build_sparsity_layout_adaption(x: Tensor, sparsity_layout_from: Tensor,
120
120
  output = torch.zeros(o_b, o_r, o_c, dtype=torch.bool, device=x.device)
121
121
 
122
122
  x_b, x_r, x_c = x.size()
123
- x_b_s, x_r_s, x_c_s = x.stride()
123
+ x_b_s, x_r_s, x_c_s = stride(x)
124
124
  s_lut_r, s_lut_c = sparsity_lut.size()
125
- s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
126
- o_b_s, o_r_s, o_c_s = output.stride()
125
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
126
+ o_b_s, o_r_s, o_c_s = stride(output)
127
127
 
128
128
  if triton_block_size is None:
129
129
  triton_block_size = get_triton_block_size(sparsity_block_size_from)
@@ -3,7 +3,7 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
- from blksprs.utils.tools import get_triton_block_size
6
+ from blksprs.utils.tools import get_triton_block_size, stride
7
7
  from blksprs.utils.validation import validate_contiguous, validate_device, \
8
8
  validate_sparsity_block_size, validate_triton_block_size
9
9
 
@@ -44,13 +44,13 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
44
44
  output = torch.zeros(n_sparse_blocks, sparsity_block_size, sparsity_block_size, dtype=x.dtype, device=x.device)
45
45
 
46
46
  x_b, x_c = x.size()
47
- x_b_s, x_c_s = x.stride()
47
+ x_b_s, x_c_s = stride(x)
48
48
  y_b, y_c = y.size()
49
- y_b_s, y_c_s = y.stride()
49
+ y_b_s, y_c_s = stride(y)
50
50
  o_b, o_r, o_c = output.size()
51
- o_b_s, o_r_s, o_c_s = output.stride()
51
+ o_b_s, o_r_s, o_c_s = stride(output)
52
52
  s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
53
- s_lut_o_r_s, s_lut_o_c_s = sparsity_lut_o.stride()
53
+ s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_o)
54
54
 
55
55
  if triton_block_size is None:
56
56
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -3,7 +3,7 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
- from blksprs.utils.tools import get_triton_block_size
6
+ from blksprs.utils.tools import get_triton_block_size, stride
7
7
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
8
8
  validate_sparsity_block_size, validate_triton_block_size
9
9
 
@@ -43,9 +43,9 @@ class _BlocksparseExp(torch.autograd.Function):
43
43
  output = torch.empty_like(x)
44
44
 
45
45
  x_b, x_r, x_c = x.shape
46
- x_b_s, x_r_s, x_c_s = x.stride()
46
+ x_b_s, x_r_s, x_c_s = stride(x)
47
47
  o_b, o_r, o_c = output.shape
48
- o_b_s, o_r_s, o_c_s = output.stride()
48
+ o_b_s, o_r_s, o_c_s = stride(output)
49
49
 
50
50
  if triton_block_size is None:
51
51
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -1,10 +1,7 @@
1
1
  import torch
2
- import triton
3
- from sympy.utilities.iterables import partitions
4
2
  from torch import Tensor
5
- from triton import language as tl
6
3
 
7
- from blksprs.utils.tools import get_triton_block_size
4
+ from blksprs.ops.repeat import forward_flow
8
5
 
9
6
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
10
7
  validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
@@ -48,12 +45,11 @@ def split(x: Tensor, sparsity_layout: Tensor, partitions: int,
48
45
  sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
49
46
  (sparsity_layout_flat == 1) -
50
47
  (1 * (sparsity_layout_flat == 0)))
51
- .reshape(sparsity_layout.size())
52
48
  .reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
53
49
  sparsity_layout.size(2) // partitions)
54
50
  .permute(0, 2, 1, 3).reshape(-1).contiguous())
55
51
 
56
- n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
52
+ n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
57
53
 
58
54
  validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
59
55
 
@@ -66,10 +62,11 @@ class _BlocksparseSplit(torch.autograd.Function):
66
62
  @staticmethod
67
63
  def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
68
64
  num_partitions: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
65
+ ctx.save_for_backward(sparsity_layout_o)
69
66
  ctx.num_partitions = num_partitions
70
67
 
71
- return forward_reorder(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
72
- n_sparse_blocks, triton_block_size)
68
+ return forward_flow(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
69
+ n_sparse_blocks, triton_block_size)
73
70
 
74
71
  @staticmethod
75
72
  def backward(ctx, grad_output):
@@ -126,7 +123,7 @@ def merge(x: Tensor, sparsity_layout: Tensor, partitions: int,
126
123
  sparsity_layout.size(1), sparsity_layout.size(2) * partitions)
127
124
  .reshape(-1).contiguous())
128
125
 
129
- n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
126
+ n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
130
127
 
131
128
  validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
132
129
 
@@ -139,10 +136,11 @@ class _BlocksparseMerge(torch.autograd.Function):
139
136
  @staticmethod
140
137
  def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
141
138
  num_partitions: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
139
+ ctx.save_for_backward(sparsity_layout_o)
142
140
  ctx.num_partitions = num_partitions
143
141
 
144
- return forward_reorder(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
145
- n_sparse_blocks, triton_block_size)
142
+ return forward_flow(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
143
+ n_sparse_blocks, triton_block_size)
146
144
 
147
145
  @staticmethod
148
146
  def backward(ctx, grad_output):
@@ -155,90 +153,3 @@ class _BlocksparseMerge(torch.autograd.Function):
155
153
  sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None
156
154
 
157
155
 
158
- def forward_reorder(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
159
- sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
160
- output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
161
- dtype=x.dtype, device=x.device)
162
-
163
- x_b, x_r, x_c = x.size()
164
- x_b_s, x_r_s, x_c_s = x.stride()
165
- s_l_b, s_l_r, s_l_c = sparsity_layout_o.size()
166
- s_l_b_s, s_l_r_s, s_l_c_s = sparsity_layout_o.stride()
167
- s_lut_r, s_lut_c = sparsity_lut.shape
168
- s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
169
- o_b, o_r, o_c = output.size()
170
- o_b_s, o_r_s, o_c_s = output.stride()
171
-
172
- if triton_block_size is None:
173
- triton_block_size = get_triton_block_size(sparsity_block_size)
174
-
175
- triton_grid = lambda meta: [o_b,
176
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
177
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
178
-
179
- (kernel_blocksparse_reorder[triton_grid]
180
- (x,
181
- x_b, x_b_s, x_r_s, x_c_s,
182
- s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
183
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
184
- sparsity_reverse_lut,
185
- output,
186
- o_b, o_b_s,
187
- triton_block_size))
188
-
189
- # Save for backward pass
190
- ctx.save_for_backward(sparsity_layout_o)
191
- ctx.sparsity_block_size = sparsity_block_size
192
- ctx.triton_block_size = triton_block_size
193
-
194
- return output
195
-
196
-
197
- @triton.jit
198
- def kernel_blocksparse_reorder(x,
199
- x_b, x_b_s, x_r_s, x_c_s,
200
- s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
201
- s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
202
- r_lut,
203
- o,
204
- o_b, o_b_s,
205
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
206
- # Get triton block indices
207
- pid_blk = tl.program_id(axis=0)
208
- pid_row = tl.program_id(axis=1)
209
- pid_col = tl.program_id(axis=2)
210
-
211
- # Get sparsity index of current output block consisting of its batch, row, and column index
212
- spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
213
- spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
214
- spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
215
-
216
- spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
217
- spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
218
- spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
219
-
220
- spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
221
- spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
222
- spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
223
-
224
- # Get reverse sparsity index
225
- rev_idx_spa_idx = (spa_bat * s_l_b_s +
226
- spa_row * s_l_r_s +
227
- spa_col * s_l_c_s)
228
- rev_idx_spa_msk = (rev_idx_spa_idx < s_l_b * s_l_b_s)
229
- rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
230
-
231
- if rev_idx_spa == -1:
232
- assert False, "Invalid sparsity block"
233
-
234
- blk_x_idx = (rev_idx_spa * x_b_s +
235
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
236
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
237
- blk_x_msk = (blk_x_idx < x_b * x_b_s)
238
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
239
-
240
- blk_o_idx = (pid_blk * o_b_s +
241
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
242
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
243
- blk_o_msk = (blk_o_idx < o_b * o_b_s)
244
- tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
@@ -3,7 +3,7 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
- from blksprs.utils.tools import get_triton_block_size
6
+ from blksprs.utils.tools import get_triton_block_size, stride
7
7
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, validate_sparsity, \
8
8
  validate_sparsity_block_size, validate_triton_block_size
9
9
 
@@ -60,13 +60,13 @@ def row_wise_sum(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
60
60
  device=x.device)
61
61
 
62
62
  x_b, x_r, x_c = x.size()
63
- x_b_s, x_r_s, x_c_s = x.stride()
63
+ x_b_s, x_r_s, x_c_s = stride(x)
64
64
  s_lut_x_r, s_lut_x_c = sparsity_lut.size()
65
- s_lut_x_r_s, s_lut_x_c_s = sparsity_lut.stride()
65
+ s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut)
66
66
  o_b, o_r, o_c = output.size()
67
- o_b_s, o_r_s, o_c_s = output.stride()
67
+ o_b_s, o_r_s, o_c_s = stride(output)
68
68
  s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
69
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_output.stride()
69
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
70
70
 
71
71
  if triton_block_size is None:
72
72
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -183,13 +183,13 @@ def row_wise_max(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
183
183
  device=x.device)
184
184
 
185
185
  x_b, x_r, x_c = x.size()
186
- x_b_s, x_r_s, x_c_s = x.stride()
186
+ x_b_s, x_r_s, x_c_s = stride(x)
187
187
  s_lut_x_r, s_lut_x_c = sparsity_lut.size()
188
- s_lut_x_r_s, s_lut_x_c_s = sparsity_lut.stride()
188
+ s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut)
189
189
  o_b, o_r, o_c = output.size()
190
- o_b_s, o_r_s, o_c_s = output.stride()
190
+ o_b_s, o_r_s, o_c_s = stride(output)
191
191
  s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
192
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_output.stride()
192
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
193
193
 
194
194
  if triton_block_size is None:
195
195
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -290,15 +290,15 @@ def row_wise_add(x: Tensor, sparsity_layout_x: Tensor, y: Tensor,
290
290
  output = torch.empty_like(x)
291
291
 
292
292
  x_b, x_r, x_c = x.size()
293
- x_b_s, x_r_s, x_c_s = x.stride()
293
+ x_b_s, x_r_s, x_c_s = stride(x)
294
294
  s_lut_r, s_lut_c = sparsity_lut.size()
295
- s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
295
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
296
296
  y_b, y_r, y_c = y.size()
297
- y_b_s, y_r_s, y_c_s = y.stride()
297
+ y_b_s, y_r_s, y_c_s = stride(y)
298
298
  s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_rwm.size()
299
- s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = sparsity_layout_rwm.stride()
299
+ s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = stride(sparsity_layout_rwm)
300
300
  o_b, o_r, o_c = output.size()
301
- o_b_s, o_r_s, o_c_s = output.stride()
301
+ o_b_s, o_r_s, o_c_s = stride(output)
302
302
 
303
303
  if triton_block_size is None:
304
304
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -361,7 +361,8 @@ def kernel_blocksparse_row_wise_add(x,
361
361
  rev_idx_spa_s = tl.load(r_lut_y + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
362
362
 
363
363
  if rev_idx_spa_s == -1:
364
- assert False, "Invalid sparsity block"
364
+ tl.device_assert(False)
365
+ return
365
366
 
366
367
  # Load x block
367
368
  blk_x_idx = ((pid_blk * x_b_s) +
@@ -6,9 +6,14 @@ from torch import Tensor
6
6
  from triton import language as tl
7
7
 
8
8
  from blksprs.layouting.sparsity_layout import build_sparsity_layout_adaption
9
- from blksprs.utils.tools import get_triton_block_size
9
+ from blksprs.utils.tools import get_triton_block_size, stride
10
10
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
11
- validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
11
+ validate_sparsity, validate_sparsity_block_size, validate_triton_block_size, validate_sparsity_dense
12
+
13
+
14
+ def from_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
15
+ triton_block_size: int = None) -> Tensor:
16
+ return to_dense(x, sparsity_layout, sparsity_block_size, fill_value, triton_block_size)
12
17
 
13
18
 
14
19
  def to_dense(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
@@ -65,11 +70,11 @@ class _BlocksparseToDense(torch.autograd.Function):
65
70
  dtype=x.dtype, device=x.device)
66
71
 
67
72
  x_b, x_r, x_c = x.shape
68
- x_b_s, x_r_s, x_c_s = x.stride()
73
+ x_b_s, x_r_s, x_c_s = stride(x)
69
74
  s_l_b, s_l_r, s_l_c = sparsity_layout.size()
70
- s_l_b_s, s_l_r_s, s_l_c_s = sparsity_layout.stride()
75
+ s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout)
71
76
  o_b, o_r, o_c = output.size()
72
- o_b_s, o_r_s, o_c_s = output.stride()
77
+ o_b_s, o_r_s, o_c_s = stride(output)
73
78
 
74
79
  if triton_block_size is None:
75
80
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -144,6 +149,11 @@ class _BlocksparseToDense(torch.autograd.Function):
144
149
  tl.store(o + o_idx, blk, o_msk)
145
150
 
146
151
 
152
+ def to_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
153
+ triton_block_size: int = None) -> Tensor:
154
+ return to_sparse(x, sparsity_layout, sparsity_block_size, triton_block_size)
155
+
156
+
147
157
  def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
148
158
  """Converts a block-sparse tensor in regular form to a block-sparse tensor in compressed form based on the given
149
159
  sparsity layout.
@@ -163,6 +173,7 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, trit
163
173
  validate_dimensions(x)
164
174
  validate_contiguous(x)
165
175
  validate_device(x)
176
+ validate_sparsity_dense(sparsity_block_size, (x, sparsity_layout))
166
177
  validate_sparsity_block_size(sparsity_block_size, x)
167
178
  validate_triton_block_size(triton_block_size, sparsity_block_size)
168
179
 
@@ -190,11 +201,11 @@ class _BlocksparseToSparse(torch.autograd.Function):
190
201
  dtype=x.dtype, device=x.device)
191
202
 
192
203
  x_b, x_r, x_c = x.size()
193
- x_b_s, x_r_s, x_c_s = x.stride()
204
+ x_b_s, x_r_s, x_c_s = stride(x)
194
205
  s_lut_r, s_lut_c = sparsity_lut.size()
195
- s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
206
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
196
207
  o_b, o_r, o_c = output.size()
197
- o_b_s, o_r_s, o_c_s = output.stride()
208
+ o_b_s, o_r_s, o_c_s = stride(output)
198
209
 
199
210
  if triton_block_size is None:
200
211
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -347,13 +358,13 @@ class _BlocksparseAdaptLayout(torch.autograd.Function):
347
358
  dtype=x.dtype, device=x.device)
348
359
 
349
360
  x_b, x_r, x_c = x.size()
350
- x_b_s, x_r_s, x_c_s = x.stride()
361
+ x_b_s, x_r_s, x_c_s = stride(x)
351
362
  s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_from.size()
352
- s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout_from.stride()
363
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_from)
353
364
  o_b, o_r, o_c = output.size()
354
- o_b_s, o_r_s, o_c_s = output.stride()
365
+ o_b_s, o_r_s, o_c_s = stride(output)
355
366
  s_lut_o_r, s_lut_o_c = sparsity_lut_to.size()
356
- s_lut_o_r_s, s_lut_o_c_s = sparsity_lut_to.stride()
367
+ s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_to)
357
368
 
358
369
  if triton_block_size is None:
359
370
  triton_block_size = get_triton_block_size(min_sparsity_block_size)
@@ -3,7 +3,7 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
- from blksprs.utils.tools import get_triton_block_size
6
+ from blksprs.utils.tools import get_triton_block_size, stride
7
7
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
8
8
  validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
9
9
 
@@ -59,15 +59,15 @@ class _BlocksparseGather(torch.autograd.Function):
59
59
  output = torch.empty_like(i, dtype=x.dtype)
60
60
 
61
61
  x_b, x_r, x_c = x.size()
62
- x_b_s, x_r_s, x_c_s = x.stride()
62
+ x_b_s, x_r_s, x_c_s = stride(x)
63
63
  s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
64
- s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout_x.stride()
64
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
65
65
  i_b, i_r, i_c = i.size()
66
- i_b_s, i_r_s, i_c_s = i.stride()
66
+ i_b_s, i_r_s, i_c_s = stride(i)
67
67
  s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
68
- s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
68
+ s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
69
69
  o_b, o_r, o_c = output.size()
70
- o_b_s, o_r_s, o_c_s = output.stride()
70
+ o_b_s, o_r_s, o_c_s = stride(output)
71
71
 
72
72
  if triton_block_size is None:
73
73
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -248,15 +248,15 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
248
248
  dtype=x.dtype, device=x.device)
249
249
 
250
250
  x_b, x_r, x_c = x.size()
251
- x_b_s, x_r_s, x_c_s = x.stride()
251
+ x_b_s, x_r_s, x_c_s = stride(x)
252
252
  s_lut_x_r, s_lut_x_c = sparsity_lut_x.size()
253
- s_lut_x_r_s, s_lut_x_c_s = sparsity_lut_x.stride()
253
+ s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut_x)
254
254
  i_b, i_r, i_c = i.size()
255
- i_b_s, i_r_s, i_c_s = i.stride()
255
+ i_b_s, i_r_s, i_c_s = stride(i)
256
256
  o_b, o_r, o_c = output.size()
257
- o_b_s, o_r_s, o_c_s = output.stride()
257
+ o_b_s, o_r_s, o_c_s = stride(output)
258
258
  s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
259
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_o.stride()
259
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
260
260
 
261
261
  if triton_block_size is None:
262
262
  triton_block_size = get_triton_block_size(sparsity_block_size)