blksprs 1.8.1__tar.gz → 1.8.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {blksprs-1.8.1 → blksprs-1.8.3}/PKG-INFO +21 -13
  2. {blksprs-1.8.1 → blksprs-1.8.3}/README.md +20 -12
  3. blksprs-1.8.3/blksprs/__init__.py +40 -0
  4. {blksprs-1.8.1 → blksprs-1.8.3}/blksprs/layouting/distribution_layout.py +3 -2
  5. {blksprs-1.8.1 → blksprs-1.8.3}/blksprs/layouting/sparsity_layout.py +3 -2
  6. {blksprs-1.8.1 → blksprs-1.8.3}/blksprs/ops/conversion.py +35 -25
  7. {blksprs-1.8.1 → blksprs-1.8.3}/blksprs/ops/distribution.py +19 -18
  8. {blksprs-1.8.1/blksprs → blksprs-1.8.3/blksprs/ops}/experimental/distribution_mdi.py +22 -21
  9. {blksprs-1.8.1 → blksprs-1.8.3}/blksprs/ops/matmul.py +14 -13
  10. {blksprs-1.8.1/blksprs → blksprs-1.8.3/blksprs/ops}/misc/broadcast_ops.py +5 -4
  11. {blksprs-1.8.1/blksprs → blksprs-1.8.3/blksprs/ops}/misc/exp.py +5 -4
  12. {blksprs-1.8.1/blksprs → blksprs-1.8.3/blksprs/ops}/misc/row_wise.py +19 -18
  13. {blksprs-1.8.1/blksprs/misc → blksprs-1.8.3/blksprs/ops}/partitioning.py +13 -12
  14. {blksprs-1.8.1 → blksprs-1.8.3}/blksprs/ops/repeat.py +13 -12
  15. {blksprs-1.8.1 → blksprs-1.8.3}/blksprs/ops/softmax.py +8 -7
  16. {blksprs-1.8.1 → blksprs-1.8.3}/blksprs/ops/transpose.py +7 -6
  17. blksprs-1.8.3/blksprs/utils/blksprs_tensor.py +8 -0
  18. blksprs-1.8.3/blksprs/utils/processing.py +41 -0
  19. {blksprs-1.8.1 → blksprs-1.8.3}/blksprs/utils/tools.py +1 -6
  20. {blksprs-1.8.1 → blksprs-1.8.3}/blksprs/utils/validation.py +4 -0
  21. {blksprs-1.8.1 → blksprs-1.8.3}/blksprs.egg-info/PKG-INFO +21 -13
  22. {blksprs-1.8.1 → blksprs-1.8.3}/blksprs.egg-info/SOURCES.txt +7 -5
  23. {blksprs-1.8.1 → blksprs-1.8.3}/pyproject.toml +1 -1
  24. blksprs-1.8.1/blksprs/__init__.py +0 -27
  25. {blksprs-1.8.1 → blksprs-1.8.3}/blksprs/utils/benchmarking.py +0 -0
  26. {blksprs-1.8.1 → blksprs-1.8.3}/blksprs.egg-info/dependency_links.txt +0 -0
  27. {blksprs-1.8.1 → blksprs-1.8.3}/blksprs.egg-info/requires.txt +0 -0
  28. {blksprs-1.8.1 → blksprs-1.8.3}/blksprs.egg-info/top_level.txt +0 -0
  29. {blksprs-1.8.1 → blksprs-1.8.3}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: blksprs
3
- Version: 1.8.1
3
+ Version: 1.8.3
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -22,6 +22,14 @@ Requires-Dist: build; extra == "build"
22
22
  [![GitHub Release](https://img.shields.io/github/v/release/FelixSchoen/blksprs?include_prereleases&label=Latest%20Release)](https://github.com/FelixSchoen/blksprs/releases)
23
23
  [![Python Version](https://img.shields.io/badge/Python%20Version-3.11-blue)](https://www.python.org/downloads/release/python-3119/)
24
24
 
25
+ ## Important Notice
26
+
27
+ 🚨 **Non-Final API** 🚨
28
+
29
+ Although it already supports a wide variety of functions, this library is still under active development and the API is
30
+ subject to change. For feature requests or bug reports, please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
31
+ We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
32
+
25
33
  ## Overview
26
34
 
27
35
  A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
@@ -51,14 +59,14 @@ These include, e.g.,
51
59
  Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to
52
60
  match.
53
61
 
54
- Further helpful operations (included in the ``bs.misc`` module) that do **not** support gradient calculation include:
62
+ Further helpful operations (included in the ``bs.ops.misc`` module) that do **not** support gradient calculation include:
55
63
 
56
64
  - Row-wise sum, max, addition, and subtraction
57
65
  - Broadcast addition and subtraction between slices
58
66
 
59
67
  Furthermore, the library provides a set of utility functions for the creation of sparsity layouts based on existing
60
- dense tensors and for the scatter operation (module ``bs.layout``), as well as utility functions to ensure correct input
61
- dimensionality (module ``bs.util``).
68
+ dense tensors and for the scatter operation (module ``bs.layouting``), as well as utility functions to apply linear layers,
69
+ ensure correct input dimensionality, and validate input (module ``bs.utils``).
62
70
 
63
71
  ## Installation
64
72
 
@@ -111,14 +119,14 @@ def test_readme():
111
119
  y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()
112
120
 
113
121
  # Convert tensors to three-dimensional (dense) tensors since Triton can only handle tensors of exactly three dimensions
114
- x_dense, x_shape_original = bs.util.do_shape_blocksparse(x)
115
- y_dense, y_shape_original = bs.util.do_shape_blocksparse(y)
122
+ x_dense, x_shape_original = bs.utils.do_shape_blocksparse(x)
123
+ y_dense, y_shape_original = bs.utils.do_shape_blocksparse(y)
116
124
 
117
125
  # Create sparsity layouts from existing tensors
118
- sparsity_layout_x = bs.layout.build_sparsity_layout(x_dense, sparsity_block_size,
119
- triton_block_size=triton_block_size)
120
- sparsity_layout_y = bs.layout.build_sparsity_layout(y_dense, sparsity_block_size,
121
- triton_block_size=triton_block_size)
126
+ sparsity_layout_x = bs.layouting.build_sparsity_layout(x_dense, sparsity_block_size,
127
+ triton_block_size=triton_block_size)
128
+ sparsity_layout_y = bs.layouting.build_sparsity_layout(y_dense, sparsity_block_size,
129
+ triton_block_size=triton_block_size)
122
130
 
123
131
  # Create random sparsity layout for output tensor
124
132
  sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)
@@ -150,12 +158,12 @@ def test_readme():
150
158
  assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2) # Note that small numerical differences are expected
151
159
 
152
160
  # Assert that the output has the correct sparsity layout
153
- actual_sparsity_layout_o = bs.layout.build_sparsity_layout(o_dense, sparsity_block_size,
154
- triton_block_size=triton_block_size)
161
+ actual_sparsity_layout_o = bs.layouting.build_sparsity_layout(o_dense, sparsity_block_size,
162
+ triton_block_size=triton_block_size)
155
163
  assert torch.allclose(actual_sparsity_layout_o.to(torch.int), sparsity_layout_o)
156
164
 
157
165
  # Convert output tensor back to original shape
158
- o = bs.util.undo_shape_blocksparse(o_dense, x_shape_original)
166
+ o = bs.utils.undo_shape_blocksparse(o_dense, x_shape_original)
159
167
 
160
168
  # Other available functions
161
169
  bs.transpose(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
@@ -3,6 +3,14 @@
3
3
  [![GitHub Release](https://img.shields.io/github/v/release/FelixSchoen/blksprs?include_prereleases&label=Latest%20Release)](https://github.com/FelixSchoen/blksprs/releases)
4
4
  [![Python Version](https://img.shields.io/badge/Python%20Version-3.11-blue)](https://www.python.org/downloads/release/python-3119/)
5
5
 
6
+ ## Important Notice
7
+
8
+ 🚨 **Non-Final API** 🚨
9
+
10
+ Although it already supports a wide variety of functions, this library is still under active development and the API is
11
+ subject to change. For feature requests or bug reports, please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
12
+ We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
13
+
6
14
  ## Overview
7
15
 
8
16
  A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
@@ -32,14 +40,14 @@ These include, e.g.,
32
40
  Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to
33
41
  match.
34
42
 
35
- Further helpful operations (included in the ``bs.misc`` module) that do **not** support gradient calculation include:
43
+ Further helpful operations (included in the ``bs.ops.misc`` module) that do **not** support gradient calculation include:
36
44
 
37
45
  - Row-wise sum, max, addition, and subtraction
38
46
  - Broadcast addition and subtraction between slices
39
47
 
40
48
  Furthermore, the library provides a set of utility functions for the creation of sparsity layouts based on existing
41
- dense tensors and for the scatter operation (module ``bs.layout``), as well as utility functions to ensure correct input
42
- dimensionality (module ``bs.util``).
49
+ dense tensors and for the scatter operation (module ``bs.layouting``), as well as utility functions to apply linear layers,
50
+ ensure correct input dimensionality, and validate input (module ``bs.utils``).
43
51
 
44
52
  ## Installation
45
53
 
@@ -92,14 +100,14 @@ def test_readme():
92
100
  y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()
93
101
 
94
102
  # Convert tensors to three-dimensional (dense) tensors since Triton can only handle tensors of exactly three dimensions
95
- x_dense, x_shape_original = bs.util.do_shape_blocksparse(x)
96
- y_dense, y_shape_original = bs.util.do_shape_blocksparse(y)
103
+ x_dense, x_shape_original = bs.utils.do_shape_blocksparse(x)
104
+ y_dense, y_shape_original = bs.utils.do_shape_blocksparse(y)
97
105
 
98
106
  # Create sparsity layouts from existing tensors
99
- sparsity_layout_x = bs.layout.build_sparsity_layout(x_dense, sparsity_block_size,
100
- triton_block_size=triton_block_size)
101
- sparsity_layout_y = bs.layout.build_sparsity_layout(y_dense, sparsity_block_size,
102
- triton_block_size=triton_block_size)
107
+ sparsity_layout_x = bs.layouting.build_sparsity_layout(x_dense, sparsity_block_size,
108
+ triton_block_size=triton_block_size)
109
+ sparsity_layout_y = bs.layouting.build_sparsity_layout(y_dense, sparsity_block_size,
110
+ triton_block_size=triton_block_size)
103
111
 
104
112
  # Create random sparsity layout for output tensor
105
113
  sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)
@@ -131,12 +139,12 @@ def test_readme():
131
139
  assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2) # Note that small numerical differences are expected
132
140
 
133
141
  # Assert that the output has the correct sparsity layout
134
- actual_sparsity_layout_o = bs.layout.build_sparsity_layout(o_dense, sparsity_block_size,
135
- triton_block_size=triton_block_size)
142
+ actual_sparsity_layout_o = bs.layouting.build_sparsity_layout(o_dense, sparsity_block_size,
143
+ triton_block_size=triton_block_size)
136
144
  assert torch.allclose(actual_sparsity_layout_o.to(torch.int), sparsity_layout_o)
137
145
 
138
146
  # Convert output tensor back to original shape
139
- o = bs.util.undo_shape_blocksparse(o_dense, x_shape_original)
147
+ o = bs.utils.undo_shape_blocksparse(o_dense, x_shape_original)
140
148
 
141
149
  # Other available functions
142
150
  bs.transpose(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
@@ -0,0 +1,40 @@
1
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
2
+
3
+ class ops:
4
+ from blksprs.ops.conversion import to_dense, to_sparse, from_blksprs, to_blksprs, adapt_layout
5
+ from blksprs.ops.distribution import gather, scatter, scatter_reduce
6
+ from blksprs.ops.matmul import matmul
7
+ from blksprs.ops.softmax import softmax
8
+ from blksprs.ops.transpose import transpose
9
+ from blksprs.ops.repeat import repeat, repeat_interleave
10
+ from blksprs.ops.partitioning import split, merge
11
+
12
+ class misc:
13
+ from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_add, row_wise_sub
14
+ from blksprs.ops.misc.broadcast_ops import broadcast_add, broadcast_sub
15
+ from blksprs.ops.misc.exp import exp
16
+
17
+ class experimental:
18
+ from blksprs.ops.experimental.distribution_mdi import gather_mdi, scatter_reduce_mdi
19
+
20
+
21
+ class layouting:
22
+ from blksprs.layouting.distribution_layout import build_distribution_layout
23
+ from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption, \
24
+ build_sparsity_layout_matmul, build_sparsity_layout_matmul_fast
25
+
26
+ class experimental:
27
+ from blksprs.ops.experimental.distribution_mdi import build_distribution_layout_mdi
28
+
29
+
30
+ class utils:
31
+ from blksprs.utils.processing import apply_torch_linear
32
+ from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse
33
+ from blksprs.utils.validation import disable_validation
34
+
35
+ class validation:
36
+ from blksprs.utils.validation import disable_validation
37
+ from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_dtype_float, \
38
+ validate_dtype_int, validate_device, validate_sparsity, validate_sparsity_dense, \
39
+ validate_sparsity_block_size, \
40
+ validate_triton_block_size
@@ -3,18 +3,19 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
6
7
  from blksprs.utils.tools import get_triton_block_size, stride
7
8
  from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
8
9
  validate_contiguous
9
10
 
10
11
 
11
- def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
12
+ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: Tensor,
12
13
  size_target: torch.Size,
13
14
  sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
14
15
  """Builds the sparsity layout of either the source of a gather or the target of a scatter operation.
15
16
 
16
17
  Args:
17
- indices (Tensor): The block-sparse indices tensor in compressed form used for the gather or scatter operation.
18
+ indices (BlksprsTensor): The block-sparse indices tensor in compressed form used for the gather or scatter operation.
18
19
  sparsity_layout_indices (Tensor): The sparsity layout of the indices block-sparse tensor.
19
20
  size_target (torch.Size): The size of the block-sparse target tensor in regular form.
20
21
  sparsity_block_size (int): The size of the sparsity blocks.
@@ -5,6 +5,7 @@ import triton
5
5
  from torch import Tensor
6
6
  from triton import language as tl
7
7
 
8
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
8
9
  from blksprs.utils.tools import get_triton_block_size, stride
9
10
  from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
10
11
  validate_contiguous, validate_sparsity, validate_sparsity_block_size
@@ -82,14 +83,14 @@ def kernel_sparsity_layout(x,
82
83
  tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
83
84
 
84
85
 
85
- def build_sparsity_layout_adaption(x: Tensor, sparsity_layout_from: Tensor,
86
+ def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tensor,
86
87
  sparsity_block_size_from: int, sparsity_block_size_to: int,
87
88
  triton_block_size: int = None) -> Tensor:
88
89
  """Builds the sparsity layout of a block-sparse tensor in compressed form if a different sparsity block size were
89
90
  used.
90
91
 
91
92
  Args:
92
- x (Tensor): A block-sparse tensor in compressed form.
93
+ x (BlksprsTensor): A block-sparse tensor in compressed form.
93
94
  sparsity_layout_from (Tensor): The sparsity layout of the input block-sparse tensor.
94
95
  sparsity_block_size_from (int): The size of the sparsity blocks of the input tensor.
95
96
  sparsity_block_size_to (int): The desired size of the sparsity blocks for the resulting layout.
@@ -6,23 +6,27 @@ from torch import Tensor
6
6
  from triton import language as tl
7
7
 
8
8
  from blksprs.layouting.sparsity_layout import build_sparsity_layout_adaption
9
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
9
10
  from blksprs.utils.tools import get_triton_block_size, stride
10
11
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
11
12
  validate_sparsity, validate_sparsity_block_size, validate_triton_block_size, validate_sparsity_dense
12
13
 
13
14
 
14
- def from_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
15
+ def from_blksprs(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
15
16
  triton_block_size: int = None) -> Tensor:
17
+ """Wrapper for ``to_dense``.
18
+
19
+ """
16
20
  return to_dense(x, sparsity_layout, sparsity_block_size, fill_value, triton_block_size)
17
21
 
18
22
 
19
- def to_dense(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
23
+ def to_dense(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
20
24
  triton_block_size: int = None) -> Tensor:
21
25
  """Converts a block-sparse tensor in compressed form to a block-sparse tensor in regular form based on the given
22
26
  sparsity layout.
23
27
 
24
28
  Args:
25
- x (Tensor): A block-sparse tensor in compressed form.
29
+ x (BlksprsTensor): A block-sparse tensor in compressed form.
26
30
  sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
27
31
  sparsity_block_size (int): The size of the sparsity blocks.
28
32
  fill_value (float): The value to fill the resulting dense tensor with where the block-sparse tensor is not
@@ -50,12 +54,12 @@ def to_dense(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_
50
54
  validate_contiguous(sparsity_reverse_lut)
51
55
 
52
56
  if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
53
- return x
57
+ return BlksprsTensor(x)
54
58
 
55
- return _BlocksparseToDense.apply(x,
56
- sparsity_layout, sparsity_reverse_lut,
57
- sparsity_block_size, fill_value,
58
- triton_block_size)
59
+ return BlksprsTensor(_BlocksparseToDense.apply(x,
60
+ sparsity_layout, sparsity_reverse_lut,
61
+ sparsity_block_size, fill_value,
62
+ triton_block_size))
59
63
 
60
64
 
61
65
  class _BlocksparseToDense(torch.autograd.Function):
@@ -150,11 +154,15 @@ class _BlocksparseToDense(torch.autograd.Function):
150
154
 
151
155
 
152
156
  def to_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
153
- triton_block_size: int = None) -> Tensor:
157
+ triton_block_size: int = None) -> BlksprsTensor:
158
+ """Wrapper for ``to_sparse``.
159
+
160
+ """
154
161
  return to_sparse(x, sparsity_layout, sparsity_block_size, triton_block_size)
155
162
 
156
163
 
157
- def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
164
+ def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
165
+ triton_block_size: int = None) -> BlksprsTensor:
158
166
  """Converts a block-sparse tensor in regular form to a block-sparse tensor in compressed form based on the given
159
167
  sparsity layout.
160
168
 
@@ -165,7 +173,7 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, trit
165
173
  triton_block_size (int): The block size to use for the triton kernel (default ``None``).
166
174
 
167
175
  Returns:
168
- Tensor: The block-sparse tensor converted to compressed form.
176
+ BlksprsTensor: The block-sparse tensor converted to compressed form.
169
177
 
170
178
  """
171
179
  x = x.contiguous()
@@ -183,12 +191,12 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, trit
183
191
  validate_contiguous(sparsity_layout, sparsity_lut)
184
192
 
185
193
  if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
186
- return x
194
+ return BlksprsTensor(x)
187
195
 
188
- return _BlocksparseToSparse.apply(x,
189
- sparsity_layout, sparsity_lut,
190
- sparsity_block_size, n_sparse_blocks,
191
- triton_block_size)
196
+ return BlksprsTensor(_BlocksparseToSparse.apply(x,
197
+ sparsity_layout, sparsity_lut,
198
+ sparsity_block_size, n_sparse_blocks,
199
+ triton_block_size))
192
200
 
193
201
 
194
202
  class _BlocksparseToSparse(torch.autograd.Function):
@@ -280,13 +288,14 @@ class _BlocksparseToSparse(torch.autograd.Function):
280
288
  tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
281
289
 
282
290
 
283
- def adapt_layout(x: Tensor, sparsity_layout_from: Tensor, sparsity_block_size_from: int, sparsity_block_size_to: int,
284
- preprocess_data: dict = None, triton_block_size: int = None) -> Tensor:
291
+ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_size_from: int,
292
+ sparsity_block_size_to: int,
293
+ preprocess_data: dict = None, triton_block_size: int = None) -> BlksprsTensor:
285
294
  """Adapts the sparsity layout of a block-sparse tensor, resulting in a new block-sparse tensor in compressed form
286
295
  conforming to the new sparsity layout (and sparsity block size) definition.
287
296
 
288
297
  Args:
289
- x (Tensor): A block-sparse tensor in compressed form.
298
+ x (BlksprsTensor): A block-sparse tensor in compressed form.
290
299
  sparsity_layout_from (Tensor): The sparsity layout of the input block-sparse tensor.
291
300
  sparsity_block_size_from (int): The size of the sparsity blocks of the input sparsity layout.
292
301
  sparsity_block_size_to (int): The size of the sparsity blocks of the output sparsity layout.
@@ -294,7 +303,7 @@ def adapt_layout(x: Tensor, sparsity_layout_from: Tensor, sparsity_block_size_fr
294
303
  triton_block_size (int): The block size to use for the triton kernel (default ``None``).
295
304
 
296
305
  Returns:
297
- Tensor: The block-sparse tensor in compressed form with the adapted sparsity layout and sparsity block size.
306
+ BlksprsTensor: The block-sparse tensor in compressed form with the adapted sparsity layout and sparsity block size.
298
307
 
299
308
  """
300
309
  x = x.contiguous()
@@ -339,12 +348,13 @@ def adapt_layout(x: Tensor, sparsity_layout_from: Tensor, sparsity_block_size_fr
339
348
  validate_contiguous(sparsity_layout_to, sparsity_reverse_lut_from, sparsity_lut_to)
340
349
 
341
350
  if (sparsity_block_size_from == sparsity_block_size_to) and torch.equal(sparsity_layout_from, sparsity_layout_to):
342
- return x
351
+ return BlksprsTensor(x)
343
352
 
344
- return _BlocksparseAdaptLayout.apply(x,
345
- sparsity_layout_from, sparsity_reverse_lut_from, sparsity_block_size_from,
346
- sparsity_layout_to, sparsity_lut_to, sparsity_block_size_to,
347
- n_sparse_blocks_to, min_sparsity_block_size, triton_block_size)
353
+ return BlksprsTensor(_BlocksparseAdaptLayout.apply(x,
354
+ sparsity_layout_from, sparsity_reverse_lut_from,
355
+ sparsity_block_size_from,
356
+ sparsity_layout_to, sparsity_lut_to, sparsity_block_size_to,
357
+ n_sparse_blocks_to, min_sparsity_block_size, triton_block_size))
348
358
 
349
359
 
350
360
  class _BlocksparseAdaptLayout(torch.autograd.Function):
@@ -3,25 +3,26 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
6
7
  from blksprs.utils.tools import get_triton_block_size, stride
7
8
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
8
9
  validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
9
10
 
10
11
 
11
- def gather(src: Tensor, sparsity_layout_src: Tensor, idx: Tensor, sparsity_layout_idx: Tensor,
12
- sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
12
+ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor, idx: BlksprsTensor, sparsity_layout_idx: Tensor,
13
+ sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
13
14
  """Applies a gather operation on a block-sparse tensor in compressed form.
14
15
 
15
16
  Args:
16
- src (Tensor): The source block-sparse tensor in compressed form to gather from.
17
+ src (BlksprsTensor): The source block-sparse tensor in compressed form to gather from.
17
18
  sparsity_layout_src (Tensor): The sparsity layout of the source block-sparse tensor.
18
- idx (Tensor): The block-sparse indices tensor in compressed form specifying how to gather from the source tensor.
19
+ idx (BlksprsTensor): The block-sparse indices tensor in compressed form specifying how to gather from the source tensor.
19
20
  sparsity_layout_idx (Tensor): The sparsity layout of the indices block-sparse tensor.
20
21
  sparsity_block_size (int): The size of the sparsity blocks.
21
22
  triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
22
23
 
23
24
  Returns:
24
- Tensor: The result of the gather operation as a block-sparse tensor in compressed form.
25
+ BlksprsTensor: The result of the gather operation as a block-sparse tensor in compressed form.
25
26
 
26
27
  """
27
28
  src = src.contiguous()
@@ -45,9 +46,9 @@ def gather(src: Tensor, sparsity_layout_src: Tensor, idx: Tensor, sparsity_layou
45
46
  validate_contiguous(sparsity_layout_src, sparsity_reverse_lut_x,
46
47
  sparsity_layout_idx, sparsity_lut_i)
47
48
 
48
- return _BlocksparseGather.apply(src, sparsity_layout_src, sparsity_reverse_lut_x,
49
+ return BlksprsTensor(_BlocksparseGather.apply(src, sparsity_layout_src, sparsity_reverse_lut_x,
49
50
  idx, sparsity_layout_idx, sparsity_lut_i,
50
- sparsity_block_size, triton_block_size)
51
+ sparsity_block_size, triton_block_size))
51
52
 
52
53
 
53
54
  class _BlocksparseGather(torch.autograd.Function):
@@ -168,10 +169,10 @@ class _BlocksparseGather(torch.autograd.Function):
168
169
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
169
170
 
170
171
 
171
- def scatter(src: Tensor, sparsity_layout_src: Tensor,
172
- idx: Tensor,
172
+ def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
173
+ idx: BlksprsTensor,
173
174
  sparsity_layout_tgt: Tensor,
174
- sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
175
+ sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
175
176
  """Wrapper for ``scatter_reduce`` with ``reduce_op="none"``.
176
177
 
177
178
  """
@@ -182,17 +183,17 @@ def scatter(src: Tensor, sparsity_layout_src: Tensor,
182
183
  reduce_op="none", triton_block_size=triton_block_size)
183
184
 
184
185
 
185
- def scatter_reduce(src: Tensor, sparsity_layout_src: Tensor,
186
- idx: Tensor,
186
+ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
187
+ idx: BlksprsTensor,
187
188
  sparsity_layout_tgt: Tensor,
188
189
  sparsity_block_size: int,
189
- reduce_op: str = "sum", triton_block_size: int = None) -> Tensor:
190
+ reduce_op: str = "sum", triton_block_size: int = None) -> BlksprsTensor:
190
191
  """Applies a scatter operation on a block-sparse tensor in compressed form.
191
192
 
192
193
  Args:
193
- src (Tensor): The source block-sparse tensor in compressed form to scatter from.
194
+ src (BlksprsTensor): The source block-sparse tensor in compressed form to scatter from.
194
195
  sparsity_layout_src (Tensor): The sparsity layout of the source block-sparse tensor.
195
- idx (Tensor): The block-sparse indices tensor in compressed form specifying how to scatter to the target tensor.
196
+ idx (BlksprsTensor): The block-sparse indices tensor in compressed form specifying how to scatter to the target tensor.
196
197
  sparsity_layout_tgt (Tensor): The sparsity layout of the target block-sparse tensor.
197
198
  sparsity_block_size (int): The size of the sparsity blocks.
198
199
  reduce_op (str, optional): The reduction operation to apply during the scatter operation (default ``"sum"``).
@@ -200,7 +201,7 @@ def scatter_reduce(src: Tensor, sparsity_layout_src: Tensor,
200
201
  triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
201
202
 
202
203
  Returns:
203
- Tensor: The result of the scatter operation as a block-sparse tensor in compressed form.
204
+ BlksprsTensor: The result of the scatter operation as a block-sparse tensor in compressed form.
204
205
 
205
206
  """
206
207
  src = src.contiguous()
@@ -229,11 +230,11 @@ def scatter_reduce(src: Tensor, sparsity_layout_src: Tensor,
229
230
  validate_contiguous(sparsity_layout_src, sparsity_lut_x,
230
231
  sparsity_layout_tgt, sparsity_reverse_lut_o)
231
232
 
232
- return _BlocksparseScatterReduce.apply(src, sparsity_layout_src, sparsity_lut_x,
233
+ return BlksprsTensor(_BlocksparseScatterReduce.apply(src, sparsity_layout_src, sparsity_lut_x,
233
234
  idx,
234
235
  sparsity_layout_tgt, sparsity_reverse_lut_o,
235
236
  sparsity_block_size, n_sparse_blocks,
236
- reduce_op, triton_block_size)
237
+ reduce_op, triton_block_size))
237
238
 
238
239
 
239
240
  class _BlocksparseScatterReduce(torch.autograd.Function):
@@ -3,17 +3,18 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
6
7
  from blksprs.utils.tools import get_triton_block_size, stride
7
8
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
8
9
  validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
9
10
 
10
11
 
11
- def gather_mdi(src: Tensor, sparsity_layout_src: Tensor,
12
- idx_bat: Tensor,
13
- idx_row: Tensor,
14
- idx_col: Tensor,
12
+ def gather_mdi(src: BlksprsTensor, sparsity_layout_src: Tensor,
13
+ idx_bat: BlksprsTensor,
14
+ idx_row: BlksprsTensor,
15
+ idx_col: BlksprsTensor,
15
16
  sparsity_layout_idx: Tensor,
16
- sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
17
+ sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
17
18
  src = src.contiguous()
18
19
  idx_bat = idx_bat.contiguous()
19
20
  idx_col = idx_col.contiguous()
@@ -37,9 +38,9 @@ def gather_mdi(src: Tensor, sparsity_layout_src: Tensor,
37
38
  validate_contiguous(sparsity_layout_src, sparsity_reverse_lut_x,
38
39
  sparsity_layout_idx, sparsity_lut_i)
39
40
 
40
- return _BlocksparseGatherMDI.apply(src, sparsity_layout_src, sparsity_reverse_lut_x,
41
- idx_bat, idx_col, sparsity_layout_idx, sparsity_lut_i,
42
- sparsity_block_size, triton_block_size)
41
+ return BlksprsTensor(_BlocksparseGatherMDI.apply(src, sparsity_layout_src, sparsity_reverse_lut_x,
42
+ idx_bat, idx_col, sparsity_layout_idx, sparsity_lut_i,
43
+ sparsity_block_size, triton_block_size))
43
44
 
44
45
 
45
46
  class _BlocksparseGatherMDI(torch.autograd.Function):
@@ -167,13 +168,13 @@ class _BlocksparseGatherMDI(torch.autograd.Function):
167
168
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
168
169
 
169
170
 
170
- def scatter_reduce_mdi(src: Tensor, sparsity_layout_src: Tensor,
171
- idx_bat: Tensor,
172
- idx_row: Tensor,
173
- idx_col: Tensor,
171
+ def scatter_reduce_mdi(src: BlksprsTensor, sparsity_layout_src: Tensor,
172
+ idx_bat: BlksprsTensor,
173
+ idx_row: BlksprsTensor,
174
+ idx_col: BlksprsTensor,
174
175
  sparsity_layout_tgt: Tensor,
175
176
  sparsity_block_size: int,
176
- reduce_op: str = "sum", triton_block_size: int = None) -> Tensor:
177
+ reduce_op: str = "sum", triton_block_size: int = None) -> BlksprsTensor:
177
178
  src = src.contiguous()
178
179
  idx_bat = idx_bat.contiguous()
179
180
  idx_col = idx_col.contiguous()
@@ -203,12 +204,12 @@ def scatter_reduce_mdi(src: Tensor, sparsity_layout_src: Tensor,
203
204
  validate_contiguous(sparsity_layout_src, sparsity_lut_x,
204
205
  sparsity_layout_tgt, sparsity_reverse_lut_o)
205
206
 
206
- return _BlocksparseScatterReduceMDI.apply(src, sparsity_layout_src, sparsity_lut_x,
207
- idx_bat,
208
- idx_col,
209
- sparsity_layout_tgt, sparsity_reverse_lut_o,
210
- sparsity_block_size, n_sparse_blocks,
211
- reduce_op, triton_block_size)
207
+ return BlksprsTensor(_BlocksparseScatterReduceMDI.apply(src, sparsity_layout_src, sparsity_lut_x,
208
+ idx_bat,
209
+ idx_col,
210
+ sparsity_layout_tgt, sparsity_reverse_lut_o,
211
+ sparsity_block_size, n_sparse_blocks,
212
+ reduce_op, triton_block_size))
212
213
 
213
214
 
214
215
  class _BlocksparseScatterReduceMDI(torch.autograd.Function):
@@ -353,8 +354,8 @@ class _BlocksparseScatterReduceMDI(torch.autograd.Function):
353
354
  tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
354
355
 
355
356
 
356
- def build_distribution_layout_mdi(idx_bat: Tensor, idx_row: Tensor, idx_col: Tensor, sparsity_layout_idx: Tensor,
357
- size_target: torch.Size,
357
+ def build_distribution_layout_mdi(idx_bat: BlksprsTensor, idx_row: BlksprsTensor, idx_col: BlksprsTensor,
358
+ sparsity_layout_idx: Tensor, size_target: torch.Size,
358
359
  sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
359
360
  validate_dimensions(idx_bat, idx_col)
360
361
  validate_contiguous(idx_bat, idx_col)
@@ -4,22 +4,23 @@ from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
6
  from blksprs.ops.transpose import transpose
7
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
7
8
  from blksprs.utils.tools import get_triton_block_size, stride
8
9
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
9
10
  validate_sparsity, validate_sparsity_block_size, validate_triton_block_size, validate_dtype_float
10
11
 
11
12
 
12
- def matmul(x: Tensor, sparsity_layout_x: Tensor,
13
- y: Tensor, sparsity_layout_y: Tensor,
13
+ def matmul(x: BlksprsTensor, sparsity_layout_x: Tensor,
14
+ y: BlksprsTensor, sparsity_layout_y: Tensor,
14
15
  sparsity_layout_output: Tensor,
15
- sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
16
+ sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
16
17
  """Performs matrix multiplication between two block-sparse tensors.
17
18
 
18
19
  The sparsity layout of the output tensor is used to only calculate blocks that will be present in the output.
19
20
 
20
21
  Args:
21
- x (Tensor): A block-sparse tensor in compressed form.
22
- y (Tensor): A block-sparse tensor in compressed form.
22
+ x (BlksprsTensor): A block-sparse tensor in compressed form.
23
+ y (BlksprsTensor): A block-sparse tensor in compressed form.
23
24
  sparsity_layout_x (Tensor): The sparsity layout of the first block-sparse tensor.
24
25
  sparsity_layout_y (Tensor): The sparsity layout of the second block-sparse tensor.
25
26
  sparsity_layout_output (Tensor): The sparsity layout of the output tensor.
@@ -27,7 +28,7 @@ def matmul(x: Tensor, sparsity_layout_x: Tensor,
27
28
  triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
28
29
 
29
30
  Returns:
30
- Tensor: The result of the matrix multiplication as a block-sparse tensor in compressed form.
31
+ BlksprsTensor: The result of the matrix multiplication as a block-sparse tensor in compressed form.
31
32
 
32
33
  """
33
34
  x = x.contiguous()
@@ -61,13 +62,13 @@ def matmul(x: Tensor, sparsity_layout_x: Tensor,
61
62
  sparsity_layout_y, sparsity_reverse_lut_y,
62
63
  sparsity_layout_output, sparsity_lut_o)
63
64
 
64
- return _BlocksparseMatmulSSS.apply(x, y,
65
- sparsity_layout_x, sparsity_reverse_lut_x,
66
- sparsity_layout_y, sparsity_reverse_lut_y,
67
- sparsity_layout_output, sparsity_lut_o,
68
- sparsity_block_size,
69
- n_sparse_blocks,
70
- triton_block_size)
65
+ return BlksprsTensor(_BlocksparseMatmulSSS.apply(x, y,
66
+ sparsity_layout_x, sparsity_reverse_lut_x,
67
+ sparsity_layout_y, sparsity_reverse_lut_y,
68
+ sparsity_layout_output, sparsity_lut_o,
69
+ sparsity_block_size,
70
+ n_sparse_blocks,
71
+ triton_block_size))
71
72
 
72
73
 
73
74
  class _BlocksparseMatmulSSS(torch.autograd.Function):