blksprs 1.8.1__py3-none-any.whl → 1.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/__init__.py CHANGED
@@ -1,27 +1,40 @@
1
- from blksprs.ops.conversion import to_dense, to_sparse, from_blksprs, to_blksprs
2
- from blksprs.ops.distribution import gather, scatter, scatter_reduce
3
- from blksprs.ops.matmul import matmul
4
- from blksprs.ops.softmax import softmax
5
- from blksprs.ops.transpose import transpose
6
- from blksprs.ops.repeat import repeat, repeat_interleave
7
- from blksprs.misc.partitioning import split, merge
1
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
8
2
 
3
+ class ops:
4
+ from blksprs.ops.conversion import to_dense, to_sparse, from_blksprs, to_blksprs, adapt_layout
5
+ from blksprs.ops.distribution import gather, scatter, scatter_reduce
6
+ from blksprs.ops.matmul import matmul
7
+ from blksprs.ops.softmax import softmax
8
+ from blksprs.ops.transpose import transpose
9
+ from blksprs.ops.repeat import repeat, repeat_interleave
10
+ from blksprs.ops.partitioning import split, merge
9
11
 
10
- class layout:
12
+ class misc:
13
+ from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_add, row_wise_sub
14
+ from blksprs.ops.misc.broadcast_ops import broadcast_add, broadcast_sub
15
+ from blksprs.ops.misc.exp import exp
16
+
17
+ class experimental:
18
+ from blksprs.ops.experimental.distribution_mdi import gather_mdi, scatter_reduce_mdi
19
+
20
+
21
+ class layouting:
11
22
  from blksprs.layouting.distribution_layout import build_distribution_layout
12
23
  from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption, \
13
24
  build_sparsity_layout_matmul, build_sparsity_layout_matmul_fast
14
25
 
26
+ class experimental:
27
+ from blksprs.ops.experimental.distribution_mdi import build_distribution_layout_mdi
15
28
 
16
- class misc:
17
- from blksprs.misc.broadcast_ops import broadcast_add, broadcast_sub
18
- from blksprs.misc.exp import exp
19
- from blksprs.misc.row_wise import row_wise_sum, row_wise_max, row_wise_add, row_wise_sub
20
-
21
-
22
- class util:
23
- from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse, disable_validation
24
29
 
30
+ class utils:
31
+ from blksprs.utils.processing import apply_torch_linear
32
+ from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse
33
+ from blksprs.utils.validation import disable_validation
25
34
 
26
- class experimental:
27
- from blksprs.experimental.distribution_mdi import gather_mdi, scatter_reduce_mdi
35
+ class validation:
36
+ from blksprs.utils.validation import disable_validation
37
+ from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_dtype_float, \
38
+ validate_dtype_int, validate_device, validate_sparsity, validate_sparsity_dense, \
39
+ validate_sparsity_block_size, \
40
+ validate_triton_block_size
@@ -3,18 +3,19 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
6
7
  from blksprs.utils.tools import get_triton_block_size, stride
7
8
  from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
8
9
  validate_contiguous
9
10
 
10
11
 
11
- def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
12
+ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: Tensor,
12
13
  size_target: torch.Size,
13
14
  sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
14
15
  """Builds the sparsity layout of either the source of a gather or the target of a scatter operation.
15
16
 
16
17
  Args:
17
- indices (Tensor): The block-sparse indices tensor in compressed form used for the gather or scatter operation.
18
+ indices (BlksprsTensor): The block-sparse indices tensor in compressed form used for the gather or scatter operation.
18
19
  sparsity_layout_indices (Tensor): The sparsity layout of the indices block-sparse tensor.
19
20
  size_target (torch.Size): The size of the block-sparse target tensor in regular form.
20
21
  sparsity_block_size (int): The size of the sparsity blocks.
@@ -5,6 +5,7 @@ import triton
5
5
  from torch import Tensor
6
6
  from triton import language as tl
7
7
 
8
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
8
9
  from blksprs.utils.tools import get_triton_block_size, stride
9
10
  from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
10
11
  validate_contiguous, validate_sparsity, validate_sparsity_block_size
@@ -82,14 +83,14 @@ def kernel_sparsity_layout(x,
82
83
  tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
83
84
 
84
85
 
85
- def build_sparsity_layout_adaption(x: Tensor, sparsity_layout_from: Tensor,
86
+ def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tensor,
86
87
  sparsity_block_size_from: int, sparsity_block_size_to: int,
87
88
  triton_block_size: int = None) -> Tensor:
88
89
  """Builds the sparsity layout of a block-sparse tensor in compressed form if a different sparsity block size were
89
90
  used.
90
91
 
91
92
  Args:
92
- x (Tensor): A block-sparse tensor in compressed form.
93
+ x (BlksprsTensor): A block-sparse tensor in compressed form.
93
94
  sparsity_layout_from (Tensor): The sparsity layout of the input block-sparse tensor.
94
95
  sparsity_block_size_from (int): The size of the sparsity blocks of the input tensor.
95
96
  sparsity_block_size_to (int): The desired size of the sparsity blocks for the resulting layout.
blksprs/ops/conversion.py CHANGED
@@ -6,23 +6,27 @@ from torch import Tensor
6
6
  from triton import language as tl
7
7
 
8
8
  from blksprs.layouting.sparsity_layout import build_sparsity_layout_adaption
9
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
9
10
  from blksprs.utils.tools import get_triton_block_size, stride
10
11
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
11
12
  validate_sparsity, validate_sparsity_block_size, validate_triton_block_size, validate_sparsity_dense
12
13
 
13
14
 
14
- def from_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
15
+ def from_blksprs(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
15
16
  triton_block_size: int = None) -> Tensor:
17
+ """Wrapper for ``to_dense``.
18
+
19
+ """
16
20
  return to_dense(x, sparsity_layout, sparsity_block_size, fill_value, triton_block_size)
17
21
 
18
22
 
19
- def to_dense(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
23
+ def to_dense(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
20
24
  triton_block_size: int = None) -> Tensor:
21
25
  """Converts a block-sparse tensor in compressed form to a block-sparse tensor in regular form based on the given
22
26
  sparsity layout.
23
27
 
24
28
  Args:
25
- x (Tensor): A block-sparse tensor in compressed form.
29
+ x (BlksprsTensor): A block-sparse tensor in compressed form.
26
30
  sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
27
31
  sparsity_block_size (int): The size of the sparsity blocks.
28
32
  fill_value (float): The value to fill the resulting dense tensor with where the block-sparse tensor is not
@@ -50,12 +54,12 @@ def to_dense(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_
50
54
  validate_contiguous(sparsity_reverse_lut)
51
55
 
52
56
  if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
53
- return x
57
+ return BlksprsTensor(x)
54
58
 
55
- return _BlocksparseToDense.apply(x,
56
- sparsity_layout, sparsity_reverse_lut,
57
- sparsity_block_size, fill_value,
58
- triton_block_size)
59
+ return BlksprsTensor(_BlocksparseToDense.apply(x,
60
+ sparsity_layout, sparsity_reverse_lut,
61
+ sparsity_block_size, fill_value,
62
+ triton_block_size))
59
63
 
60
64
 
61
65
  class _BlocksparseToDense(torch.autograd.Function):
@@ -150,11 +154,15 @@ class _BlocksparseToDense(torch.autograd.Function):
150
154
 
151
155
 
152
156
  def to_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
153
- triton_block_size: int = None) -> Tensor:
157
+ triton_block_size: int = None) -> BlksprsTensor:
158
+ """Wrapper for ``to_sparse``.
159
+
160
+ """
154
161
  return to_sparse(x, sparsity_layout, sparsity_block_size, triton_block_size)
155
162
 
156
163
 
157
- def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
164
+ def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
165
+ triton_block_size: int = None) -> BlksprsTensor:
158
166
  """Converts a block-sparse tensor in regular form to a block-sparse tensor in compressed form based on the given
159
167
  sparsity layout.
160
168
 
@@ -165,7 +173,7 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, trit
165
173
  triton_block_size (int): The block size to use for the triton kernel (default ``None``).
166
174
 
167
175
  Returns:
168
- Tensor: The block-sparse tensor converted to compressed form.
176
+ BlksprsTensor: The block-sparse tensor converted to compressed form.
169
177
 
170
178
  """
171
179
  x = x.contiguous()
@@ -183,12 +191,12 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, trit
183
191
  validate_contiguous(sparsity_layout, sparsity_lut)
184
192
 
185
193
  if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
186
- return x
194
+ return BlksprsTensor(x)
187
195
 
188
- return _BlocksparseToSparse.apply(x,
189
- sparsity_layout, sparsity_lut,
190
- sparsity_block_size, n_sparse_blocks,
191
- triton_block_size)
196
+ return BlksprsTensor(_BlocksparseToSparse.apply(x,
197
+ sparsity_layout, sparsity_lut,
198
+ sparsity_block_size, n_sparse_blocks,
199
+ triton_block_size))
192
200
 
193
201
 
194
202
  class _BlocksparseToSparse(torch.autograd.Function):
@@ -280,13 +288,14 @@ class _BlocksparseToSparse(torch.autograd.Function):
280
288
  tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
281
289
 
282
290
 
283
- def adapt_layout(x: Tensor, sparsity_layout_from: Tensor, sparsity_block_size_from: int, sparsity_block_size_to: int,
284
- preprocess_data: dict = None, triton_block_size: int = None) -> Tensor:
291
+ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_size_from: int,
292
+ sparsity_block_size_to: int,
293
+ preprocess_data: dict = None, triton_block_size: int = None) -> BlksprsTensor:
285
294
  """Adapts the sparsity layout of a block-sparse tensor, resulting in a new block-sparse tensor in compressed form
286
295
  conforming to the new sparsity layout (and sparsity block size) definition.
287
296
 
288
297
  Args:
289
- x (Tensor): A block-sparse tensor in compressed form.
298
+ x (BlksprsTensor): A block-sparse tensor in compressed form.
290
299
  sparsity_layout_from (Tensor): The sparsity layout of the input block-sparse tensor.
291
300
  sparsity_block_size_from (int): The size of the sparsity blocks of the input sparsity layout.
292
301
  sparsity_block_size_to (int): The size of the sparsity blocks of the output sparsity layout.
@@ -294,7 +303,7 @@ def adapt_layout(x: Tensor, sparsity_layout_from: Tensor, sparsity_block_size_fr
294
303
  triton_block_size (int): The block size to use for the triton kernel (default ``None``).
295
304
 
296
305
  Returns:
297
- Tensor: The block-sparse tensor in compressed form with the adapted sparsity layout and sparsity block size.
306
+ BlksprsTensor: The block-sparse tensor in compressed form with the adapted sparsity layout and sparsity block size.
298
307
 
299
308
  """
300
309
  x = x.contiguous()
@@ -339,12 +348,13 @@ def adapt_layout(x: Tensor, sparsity_layout_from: Tensor, sparsity_block_size_fr
339
348
  validate_contiguous(sparsity_layout_to, sparsity_reverse_lut_from, sparsity_lut_to)
340
349
 
341
350
  if (sparsity_block_size_from == sparsity_block_size_to) and torch.equal(sparsity_layout_from, sparsity_layout_to):
342
- return x
351
+ return BlksprsTensor(x)
343
352
 
344
- return _BlocksparseAdaptLayout.apply(x,
345
- sparsity_layout_from, sparsity_reverse_lut_from, sparsity_block_size_from,
346
- sparsity_layout_to, sparsity_lut_to, sparsity_block_size_to,
347
- n_sparse_blocks_to, min_sparsity_block_size, triton_block_size)
353
+ return BlksprsTensor(_BlocksparseAdaptLayout.apply(x,
354
+ sparsity_layout_from, sparsity_reverse_lut_from,
355
+ sparsity_block_size_from,
356
+ sparsity_layout_to, sparsity_lut_to, sparsity_block_size_to,
357
+ n_sparse_blocks_to, min_sparsity_block_size, triton_block_size))
348
358
 
349
359
 
350
360
  class _BlocksparseAdaptLayout(torch.autograd.Function):
@@ -3,25 +3,26 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
6
7
  from blksprs.utils.tools import get_triton_block_size, stride
7
8
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
8
9
  validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
9
10
 
10
11
 
11
- def gather(src: Tensor, sparsity_layout_src: Tensor, idx: Tensor, sparsity_layout_idx: Tensor,
12
- sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
12
+ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor, idx: BlksprsTensor, sparsity_layout_idx: Tensor,
13
+ sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
13
14
  """Applies a gather operation on a block-sparse tensor in compressed form.
14
15
 
15
16
  Args:
16
- src (Tensor): The source block-sparse tensor in compressed form to gather from.
17
+ src (BlksprsTensor): The source block-sparse tensor in compressed form to gather from.
17
18
  sparsity_layout_src (Tensor): The sparsity layout of the source block-sparse tensor.
18
- idx (Tensor): The block-sparse indices tensor in compressed form specifying how to gather from the source tensor.
19
+ idx (BlksprsTensor): The block-sparse indices tensor in compressed form specifying how to gather from the source tensor.
19
20
  sparsity_layout_idx (Tensor): The sparsity layout of the indices block-sparse tensor.
20
21
  sparsity_block_size (int): The size of the sparsity blocks.
21
22
  triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
22
23
 
23
24
  Returns:
24
- Tensor: The result of the gather operation as a block-sparse tensor in compressed form.
25
+ BlksprsTensor: The result of the gather operation as a block-sparse tensor in compressed form.
25
26
 
26
27
  """
27
28
  src = src.contiguous()
@@ -45,9 +46,9 @@ def gather(src: Tensor, sparsity_layout_src: Tensor, idx: Tensor, sparsity_layou
45
46
  validate_contiguous(sparsity_layout_src, sparsity_reverse_lut_x,
46
47
  sparsity_layout_idx, sparsity_lut_i)
47
48
 
48
- return _BlocksparseGather.apply(src, sparsity_layout_src, sparsity_reverse_lut_x,
49
+ return BlksprsTensor(_BlocksparseGather.apply(src, sparsity_layout_src, sparsity_reverse_lut_x,
49
50
  idx, sparsity_layout_idx, sparsity_lut_i,
50
- sparsity_block_size, triton_block_size)
51
+ sparsity_block_size, triton_block_size))
51
52
 
52
53
 
53
54
  class _BlocksparseGather(torch.autograd.Function):
@@ -168,10 +169,10 @@ class _BlocksparseGather(torch.autograd.Function):
168
169
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
169
170
 
170
171
 
171
- def scatter(src: Tensor, sparsity_layout_src: Tensor,
172
- idx: Tensor,
172
+ def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
173
+ idx: BlksprsTensor,
173
174
  sparsity_layout_tgt: Tensor,
174
- sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
175
+ sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
175
176
  """Wrapper for ``scatter_reduce`` with ``reduce_op="none"``.
176
177
 
177
178
  """
@@ -182,17 +183,17 @@ def scatter(src: Tensor, sparsity_layout_src: Tensor,
182
183
  reduce_op="none", triton_block_size=triton_block_size)
183
184
 
184
185
 
185
- def scatter_reduce(src: Tensor, sparsity_layout_src: Tensor,
186
- idx: Tensor,
186
+ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
187
+ idx: BlksprsTensor,
187
188
  sparsity_layout_tgt: Tensor,
188
189
  sparsity_block_size: int,
189
- reduce_op: str = "sum", triton_block_size: int = None) -> Tensor:
190
+ reduce_op: str = "sum", triton_block_size: int = None) -> BlksprsTensor:
190
191
  """Applies a scatter operation on a block-sparse tensor in compressed form.
191
192
 
192
193
  Args:
193
- src (Tensor): The source block-sparse tensor in compressed form to scatter from.
194
+ src (BlksprsTensor): The source block-sparse tensor in compressed form to scatter from.
194
195
  sparsity_layout_src (Tensor): The sparsity layout of the source block-sparse tensor.
195
- idx (Tensor): The block-sparse indices tensor in compressed form specifying how to scatter to the target tensor.
196
+ idx (BlksprsTensor): The block-sparse indices tensor in compressed form specifying how to scatter to the target tensor.
196
197
  sparsity_layout_tgt (Tensor): The sparsity layout of the target block-sparse tensor.
197
198
  sparsity_block_size (int): The size of the sparsity blocks.
198
199
  reduce_op (str, optional): The reduction operation to apply during the scatter operation (default ``"sum"``).
@@ -200,7 +201,7 @@ def scatter_reduce(src: Tensor, sparsity_layout_src: Tensor,
200
201
  triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
201
202
 
202
203
  Returns:
203
- Tensor: The result of the scatter operation as a block-sparse tensor in compressed form.
204
+ BlksprsTensor: The result of the scatter operation as a block-sparse tensor in compressed form.
204
205
 
205
206
  """
206
207
  src = src.contiguous()
@@ -229,11 +230,11 @@ def scatter_reduce(src: Tensor, sparsity_layout_src: Tensor,
229
230
  validate_contiguous(sparsity_layout_src, sparsity_lut_x,
230
231
  sparsity_layout_tgt, sparsity_reverse_lut_o)
231
232
 
232
- return _BlocksparseScatterReduce.apply(src, sparsity_layout_src, sparsity_lut_x,
233
+ return BlksprsTensor(_BlocksparseScatterReduce.apply(src, sparsity_layout_src, sparsity_lut_x,
233
234
  idx,
234
235
  sparsity_layout_tgt, sparsity_reverse_lut_o,
235
236
  sparsity_block_size, n_sparse_blocks,
236
- reduce_op, triton_block_size)
237
+ reduce_op, triton_block_size))
237
238
 
238
239
 
239
240
  class _BlocksparseScatterReduce(torch.autograd.Function):
@@ -3,17 +3,18 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
6
7
  from blksprs.utils.tools import get_triton_block_size, stride
7
8
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
8
9
  validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
9
10
 
10
11
 
11
- def gather_mdi(src: Tensor, sparsity_layout_src: Tensor,
12
- idx_bat: Tensor,
13
- idx_row: Tensor,
14
- idx_col: Tensor,
12
+ def gather_mdi(src: BlksprsTensor, sparsity_layout_src: Tensor,
13
+ idx_bat: BlksprsTensor,
14
+ idx_row: BlksprsTensor,
15
+ idx_col: BlksprsTensor,
15
16
  sparsity_layout_idx: Tensor,
16
- sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
17
+ sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
17
18
  src = src.contiguous()
18
19
  idx_bat = idx_bat.contiguous()
19
20
  idx_col = idx_col.contiguous()
@@ -37,9 +38,9 @@ def gather_mdi(src: Tensor, sparsity_layout_src: Tensor,
37
38
  validate_contiguous(sparsity_layout_src, sparsity_reverse_lut_x,
38
39
  sparsity_layout_idx, sparsity_lut_i)
39
40
 
40
- return _BlocksparseGatherMDI.apply(src, sparsity_layout_src, sparsity_reverse_lut_x,
41
- idx_bat, idx_col, sparsity_layout_idx, sparsity_lut_i,
42
- sparsity_block_size, triton_block_size)
41
+ return BlksprsTensor(_BlocksparseGatherMDI.apply(src, sparsity_layout_src, sparsity_reverse_lut_x,
42
+ idx_bat, idx_col, sparsity_layout_idx, sparsity_lut_i,
43
+ sparsity_block_size, triton_block_size))
43
44
 
44
45
 
45
46
  class _BlocksparseGatherMDI(torch.autograd.Function):
@@ -167,13 +168,13 @@ class _BlocksparseGatherMDI(torch.autograd.Function):
167
168
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
168
169
 
169
170
 
170
- def scatter_reduce_mdi(src: Tensor, sparsity_layout_src: Tensor,
171
- idx_bat: Tensor,
172
- idx_row: Tensor,
173
- idx_col: Tensor,
171
+ def scatter_reduce_mdi(src: BlksprsTensor, sparsity_layout_src: Tensor,
172
+ idx_bat: BlksprsTensor,
173
+ idx_row: BlksprsTensor,
174
+ idx_col: BlksprsTensor,
174
175
  sparsity_layout_tgt: Tensor,
175
176
  sparsity_block_size: int,
176
- reduce_op: str = "sum", triton_block_size: int = None) -> Tensor:
177
+ reduce_op: str = "sum", triton_block_size: int = None) -> BlksprsTensor:
177
178
  src = src.contiguous()
178
179
  idx_bat = idx_bat.contiguous()
179
180
  idx_col = idx_col.contiguous()
@@ -203,12 +204,12 @@ def scatter_reduce_mdi(src: Tensor, sparsity_layout_src: Tensor,
203
204
  validate_contiguous(sparsity_layout_src, sparsity_lut_x,
204
205
  sparsity_layout_tgt, sparsity_reverse_lut_o)
205
206
 
206
- return _BlocksparseScatterReduceMDI.apply(src, sparsity_layout_src, sparsity_lut_x,
207
- idx_bat,
208
- idx_col,
209
- sparsity_layout_tgt, sparsity_reverse_lut_o,
210
- sparsity_block_size, n_sparse_blocks,
211
- reduce_op, triton_block_size)
207
+ return BlksprsTensor(_BlocksparseScatterReduceMDI.apply(src, sparsity_layout_src, sparsity_lut_x,
208
+ idx_bat,
209
+ idx_col,
210
+ sparsity_layout_tgt, sparsity_reverse_lut_o,
211
+ sparsity_block_size, n_sparse_blocks,
212
+ reduce_op, triton_block_size))
212
213
 
213
214
 
214
215
  class _BlocksparseScatterReduceMDI(torch.autograd.Function):
@@ -353,8 +354,8 @@ class _BlocksparseScatterReduceMDI(torch.autograd.Function):
353
354
  tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
354
355
 
355
356
 
356
- def build_distribution_layout_mdi(idx_bat: Tensor, idx_row: Tensor, idx_col: Tensor, sparsity_layout_idx: Tensor,
357
- size_target: torch.Size,
357
+ def build_distribution_layout_mdi(idx_bat: BlksprsTensor, idx_row: BlksprsTensor, idx_col: BlksprsTensor,
358
+ sparsity_layout_idx: Tensor, size_target: torch.Size,
358
359
  sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
359
360
  validate_dimensions(idx_bat, idx_col)
360
361
  validate_contiguous(idx_bat, idx_col)
blksprs/ops/matmul.py CHANGED
@@ -4,22 +4,23 @@ from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
6
  from blksprs.ops.transpose import transpose
7
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
7
8
  from blksprs.utils.tools import get_triton_block_size, stride
8
9
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
9
10
  validate_sparsity, validate_sparsity_block_size, validate_triton_block_size, validate_dtype_float
10
11
 
11
12
 
12
- def matmul(x: Tensor, sparsity_layout_x: Tensor,
13
- y: Tensor, sparsity_layout_y: Tensor,
13
+ def matmul(x: BlksprsTensor, sparsity_layout_x: Tensor,
14
+ y: BlksprsTensor, sparsity_layout_y: Tensor,
14
15
  sparsity_layout_output: Tensor,
15
- sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
16
+ sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
16
17
  """Performs matrix multiplication between two block-sparse tensors.
17
18
 
18
19
  The sparsity layout of the output tensor is used to only calculate blocks that will be present in the output.
19
20
 
20
21
  Args:
21
- x (Tensor): A block-sparse tensor in compressed form.
22
- y (Tensor): A block-sparse tensor in compressed form.
22
+ x (BlksprsTensor): A block-sparse tensor in compressed form.
23
+ y (BlksprsTensor): A block-sparse tensor in compressed form.
23
24
  sparsity_layout_x (Tensor): The sparsity layout of the first block-sparse tensor.
24
25
  sparsity_layout_y (Tensor): The sparsity layout of the second block-sparse tensor.
25
26
  sparsity_layout_output (Tensor): The sparsity layout of the output tensor.
@@ -27,7 +28,7 @@ def matmul(x: Tensor, sparsity_layout_x: Tensor,
27
28
  triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
28
29
 
29
30
  Returns:
30
- Tensor: The result of the matrix multiplication as a block-sparse tensor in compressed form.
31
+ BlksprsTensor: The result of the matrix multiplication as a block-sparse tensor in compressed form.
31
32
 
32
33
  """
33
34
  x = x.contiguous()
@@ -61,13 +62,13 @@ def matmul(x: Tensor, sparsity_layout_x: Tensor,
61
62
  sparsity_layout_y, sparsity_reverse_lut_y,
62
63
  sparsity_layout_output, sparsity_lut_o)
63
64
 
64
- return _BlocksparseMatmulSSS.apply(x, y,
65
- sparsity_layout_x, sparsity_reverse_lut_x,
66
- sparsity_layout_y, sparsity_reverse_lut_y,
67
- sparsity_layout_output, sparsity_lut_o,
68
- sparsity_block_size,
69
- n_sparse_blocks,
70
- triton_block_size)
65
+ return BlksprsTensor(_BlocksparseMatmulSSS.apply(x, y,
66
+ sparsity_layout_x, sparsity_reverse_lut_x,
67
+ sparsity_layout_y, sparsity_reverse_lut_y,
68
+ sparsity_layout_output, sparsity_lut_o,
69
+ sparsity_block_size,
70
+ n_sparse_blocks,
71
+ triton_block_size))
71
72
 
72
73
 
73
74
  class _BlocksparseMatmulSSS(torch.autograd.Function):
@@ -3,13 +3,14 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
6
7
  from blksprs.utils.tools import get_triton_block_size, stride
7
8
  from blksprs.utils.validation import validate_contiguous, validate_device, \
8
9
  validate_sparsity_block_size, validate_triton_block_size
9
10
 
10
11
 
11
12
  def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
12
- sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
13
+ sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
13
14
  """Performs a broadcast and subsequent addition of two dense tensors x and y. Returns a block-sparse tensor in
14
15
  compressed form.
15
16
 
@@ -21,7 +22,7 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
21
22
  triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
22
23
 
23
24
  Returns:
24
- Tensor: The result of the operation as a block-sparse tensor in compressed form. Each element o(i, j) of the
25
+ BlksprsTensor: The result of the operation as a block-sparse tensor in compressed form. Each element o(i, j) of the
25
26
  output tensor corresponds to x(i) + y(j).
26
27
 
27
28
  """
@@ -70,11 +71,11 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
70
71
  sparsity_block_size,
71
72
  triton_block_size))
72
73
 
73
- return output
74
+ return BlksprsTensor(output)
74
75
 
75
76
 
76
77
  def broadcast_sub(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
77
- sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
78
+ sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
78
79
  """Wrapper for ``broadcast_add`` with negated y.
79
80
 
80
81
  """
@@ -3,12 +3,13 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
6
7
  from blksprs.utils.tools import get_triton_block_size, stride
7
8
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
8
9
  validate_sparsity_block_size, validate_triton_block_size
9
10
 
10
11
 
11
- def exp(x: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
12
+ def exp(x: BlksprsTensor, sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
12
13
  """Applies the element-wise exponential function to a block-sparse tensor.
13
14
 
14
15
  Note:
@@ -16,12 +17,12 @@ def exp(x: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> T
16
17
  Consider this when converting back to tensors in regular form.
17
18
 
18
19
  Args:
19
- x (Tensor): A block-sparse tensor in compressed form.
20
+ x (BlksprsTensor): A block-sparse tensor in compressed form.
20
21
  sparsity_block_size (int): The size of the sparsity blocks.
21
22
  triton_block_size (int): The block size to use for the triton kernel (default ``None``).
22
23
 
23
24
  Returns:
24
- Tensor: The exponential function applied to all elements of the input tensor as a block-sparse tensor in
25
+ BlksprsTensor: The exponential function applied to all elements of the input tensor as a block-sparse tensor in
25
26
  compressed form.
26
27
 
27
28
  """
@@ -33,7 +34,7 @@ def exp(x: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> T
33
34
  validate_sparsity_block_size(sparsity_block_size, x)
34
35
  validate_triton_block_size(triton_block_size, sparsity_block_size)
35
36
 
36
- return _BlocksparseExp.apply(x, sparsity_block_size, triton_block_size)
37
+ return BlksprsTensor(_BlocksparseExp.apply(x, sparsity_block_size, triton_block_size))
37
38
 
38
39
 
39
40
  class _BlocksparseExp(torch.autograd.Function):
@@ -3,13 +3,14 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
6
7
  from blksprs.utils.tools import get_triton_block_size, stride
7
8
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, validate_sparsity, \
8
9
  validate_sparsity_block_size, validate_triton_block_size
9
10
 
10
11
 
11
- def row_wise_sum(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
12
- flag_slice_only: bool = False, triton_block_size: int = None) -> tuple[Tensor, Tensor]:
12
+ def row_wise_sum(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
13
+ flag_slice_only: bool = False, triton_block_size: int = None) -> (BlksprsTensor, Tensor):
13
14
  """Computes the row-wise sum of a block-sparse tensor.
14
15
 
15
16
  Returns a block-sparse tensor in compressed form with only one block per row, where the first entry contains the sum
@@ -19,7 +20,7 @@ def row_wise_sum(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
19
20
  If ``flag_slice_only`` is set the output will be of shape ``[x.size(0), x.size(1), 1]``.
20
21
 
21
22
  Args:
22
- x (Tensor): A block-sparse tensor in compressed form.
23
+ x (BlksprsTensor): A block-sparse tensor in compressed form.
23
24
  sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
24
25
  sparsity_block_size (int): The size of the sparsity blocks.
25
26
  flag_slice_only (bool, optional): If set the output will be of shape ``[x.size(0), x.size(1), 1]``
@@ -27,7 +28,7 @@ def row_wise_sum(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
27
28
  triton_block_size (int): The block size to use for the triton kernel (default ``None``).
28
29
 
29
30
  Returns:
30
- tuple[Tensor, Tensor]: A tuple containing a block-sparse tensor in compressed form containing the row-wise sum
31
+ tuple[BlksprsTensor, Tensor]: A tuple containing a block-sparse tensor in compressed form containing the row-wise sum
31
32
  of the input and the sparsity layout of the output tensor.
32
33
 
33
34
  """
@@ -85,7 +86,7 @@ def row_wise_sum(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
85
86
  sparsity_reverse_lut_output,
86
87
  triton_block_size))
87
88
 
88
- return (output, sparsity_layout_output)
89
+ return BlksprsTensor(output), sparsity_layout_output
89
90
 
90
91
 
91
92
  @triton.jit
@@ -131,8 +132,8 @@ def kernel_blocksparse_row_wise_sum(x,
131
132
  tl.atomic_add(o + o_idx, buf, o_msk)
132
133
 
133
134
 
134
- def row_wise_max(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
135
- flag_slice_only: bool = False, triton_block_size: int = None) -> tuple[Tensor, Tensor]:
135
+ def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
136
+ flag_slice_only: bool = False, triton_block_size: int = None) -> (BlksprsTensor, Tensor):
136
137
  """Computes the row-wise max of a block-sparse tensor.
137
138
 
138
139
  Returns a block-sparse tensor in compressed form with only one block per row, where the first entry contains the
@@ -142,7 +143,7 @@ def row_wise_max(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
142
143
  If ``flag_slice_only`` is set the output will be of shape ``[x.size(0), x.size(1), 1]``.
143
144
 
144
145
  Args:
145
- x (Tensor): A block-sparse tensor in compressed form.
146
+ x (BlksprsTensor): A block-sparse tensor in compressed form.
146
147
  sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
147
148
  sparsity_block_size (int): The size of the sparsity blocks.
148
149
  flag_slice_only (bool, optional): If set the output will be of shape ``[x.size(0), x.size(1), 1]``
@@ -150,7 +151,7 @@ def row_wise_max(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
150
151
  triton_block_size (int): The block size to use for the triton kernel (default ``None``).
151
152
 
152
153
  Returns:
153
- tuple[Tensor, Tensor]: A tuple containing a block-sparse tensor in compressed form containing the row-wise max
154
+ tuple[BlksprsTensor, Tensor]: A tuple containing a block-sparse tensor in compressed form containing the row-wise max
154
155
  of the input and the sparsity layout of the output tensor.
155
156
 
156
157
  """
@@ -208,7 +209,7 @@ def row_wise_max(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
208
209
  sparsity_reverse_lut_output,
209
210
  triton_block_size))
210
211
 
211
- return output, sparsity_layout_output
212
+ return BlksprsTensor(output), sparsity_layout_output
212
213
 
213
214
 
214
215
  @triton.jit
@@ -254,19 +255,19 @@ def kernel_blocksparse_row_wise_max(x,
254
255
  tl.atomic_max(o + o_idx, buf, o_msk)
255
256
 
256
257
 
257
- def row_wise_add(x: Tensor, sparsity_layout_x: Tensor, y: Tensor,
258
- sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
258
+ def row_wise_add(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
259
+ sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
259
260
  """For each row in ``y`` adds the value to each value in the corresponding row of the block-sparse tensor ``x``.
260
261
 
261
262
  Args:
262
- x (Tensor): A block-sparse tensor in compressed form.
263
+ x (BlksprsTensor): A block-sparse tensor in compressed form.
263
264
  sparsity_layout_x (Tensor): The sparsity layout of the block-sparse tensor.
264
- y (Tensor): A block-sparse tensor in compressed form with only one value per row and a single column of sparse blocks.
265
+ y (BlksprsTensor): A block-sparse tensor in compressed form with only one value per row and a single column of sparse blocks.
265
266
  sparsity_block_size (int): The size of the sparsity blocks.
266
267
  triton_block_size (int): The block size to use for the triton kernel (default ``None``).
267
268
 
268
269
  Returns:
269
- Tensor: The values of ``x`` with the first value of ``y`` in each row added to them as a block-sparse tensor in
270
+ BlksprsTensor: The values of ``x`` with the first value of ``y`` in each row added to them as a block-sparse tensor in
270
271
  compressed form.
271
272
 
272
273
  """
@@ -319,11 +320,11 @@ def row_wise_add(x: Tensor, sparsity_layout_x: Tensor, y: Tensor,
319
320
  triton_block_size
320
321
  ))
321
322
 
322
- return output
323
+ return BlksprsTensor(output)
323
324
 
324
325
 
325
- def row_wise_sub(x: Tensor, sparsity_layout_x: Tensor, y: Tensor,
326
- sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
326
+ def row_wise_sub(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
327
+ sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
327
328
  """Wrapper for ``row_wise_add`` with negated y.
328
329
 
329
330
  """
@@ -2,24 +2,25 @@ import torch
2
2
  from torch import Tensor
3
3
 
4
4
  from blksprs.ops.repeat import forward_flow
5
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
5
6
 
6
7
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
7
8
  validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
8
9
 
9
10
 
10
- def split(x: Tensor, sparsity_layout: Tensor, partitions: int,
11
- sparsity_block_size: int, triton_block_size: int = None) -> (Tensor, Tensor):
11
+ def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
12
+ sparsity_block_size: int, triton_block_size: int = None) -> (BlksprsTensor, Tensor):
12
13
  """Splits a block-sparse tensor in compressed form along the last dimension into partitions.
13
14
 
14
15
  Args:
15
- x (Tensor): A block-sparse tensor in compressed form.
16
+ x (BlksprsTensor): A block-sparse tensor in compressed form.
16
17
  sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
17
18
  partitions (int): The number of partitions to split the block-sparse tensor into.
18
19
  sparsity_block_size (int): The size of the sparsity blocks.
19
20
  triton_block_size (int): The block size to use for the triton kernel (default ``None``).
20
21
 
21
22
  Returns:
22
- Tensor: The block-sparse tensor split into partitions in compressed form.
23
+ BlksprsTensor: The block-sparse tensor split into partitions in compressed form.
23
24
  Tensor: The sparsity layout of the output tensor.
24
25
 
25
26
  """
@@ -53,8 +54,8 @@ def split(x: Tensor, sparsity_layout: Tensor, partitions: int,
53
54
 
54
55
  validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
55
56
 
56
- return _BlocksparseSplit.apply(x, sparsity_layout_output, sparsity_lut, sparsity_reverse_lut, partitions,
57
- sparsity_block_size, n_sparse_blocks, triton_block_size), sparsity_layout_output
57
+ return BlksprsTensor(_BlocksparseSplit.apply(x, sparsity_layout_output, sparsity_lut, sparsity_reverse_lut, partitions,
58
+ sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_output
58
59
 
59
60
 
60
61
  class _BlocksparseSplit(torch.autograd.Function):
@@ -79,19 +80,19 @@ class _BlocksparseSplit(torch.autograd.Function):
79
80
  sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None
80
81
 
81
82
 
82
- def merge(x: Tensor, sparsity_layout: Tensor, partitions: int,
83
- sparsity_block_size: int, triton_block_size: int = None) -> (Tensor, Tensor):
83
+ def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
84
+ sparsity_block_size: int, triton_block_size: int = None) -> (BlksprsTensor, Tensor):
84
85
  """Merges the specified partitions of a block-sparse tensor in compressed form along the last dimension.
85
86
 
86
87
  Args:
87
- x (Tensor): A block-sparse tensor in compressed form.
88
+ x (BlksprsTensor): A block-sparse tensor in compressed form.
88
89
  sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
89
90
  partitions (int): The number of partitions to be merged.
90
91
  sparsity_block_size (int): The size of the sparsity blocks.
91
92
  triton_block_size (int): The block size to use for the triton kernel (default ``None``).
92
93
 
93
94
  Returns:
94
- Tensor: The merged block-sparse tensor in compressed form.
95
+ BlksprsTensor: The merged block-sparse tensor in compressed form.
95
96
  Tensor: The sparsity layout of the output tensor.
96
97
 
97
98
  """
@@ -127,8 +128,8 @@ def merge(x: Tensor, sparsity_layout: Tensor, partitions: int,
127
128
 
128
129
  validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
129
130
 
130
- return _BlocksparseMerge.apply(x, sparsity_layout_output, sparsity_lut, sparsity_reverse_lut, partitions,
131
- sparsity_block_size, n_sparse_blocks, triton_block_size), sparsity_layout_output
131
+ return BlksprsTensor(_BlocksparseMerge.apply(x, sparsity_layout_output, sparsity_lut, sparsity_reverse_lut, partitions,
132
+ sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_output
132
133
 
133
134
 
134
135
  class _BlocksparseMerge(torch.autograd.Function):
blksprs/ops/repeat.py CHANGED
@@ -3,14 +3,15 @@ import triton
3
3
  from triton import language as tl
4
4
  from torch import Tensor
5
5
 
6
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
6
7
  from blksprs.utils.tools import get_triton_block_size, stride
7
8
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
8
9
  validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
9
10
 
10
11
 
11
- def repeat(x: Tensor, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
12
+ def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
12
13
  sparsity_block_size: int, sparsity_layout_output: Tensor = None, triton_block_size: int = None) -> (
13
- Tensor, Tensor):
14
+ BlksprsTensor, Tensor):
14
15
  """Repeats a block-spare tensor in compressed form according to the given repeats.
15
16
 
16
17
  Repeats is a 3-tuple of integers, where each integer represents the number of times the tensor should be repeated in
@@ -22,7 +23,7 @@ def repeat(x: Tensor, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
22
23
  them to be sparse.
23
24
 
24
25
  Args:
25
- x (Tensor): A block-sparse tensor in compressed form.
26
+ x (BlksprsTensor): A block-sparse tensor in compressed form.
26
27
  sparsity_layout_x (Tensor): The sparsity layout of the block-sparse tensor.
27
28
  repeats (tuple[int, int, int]): The number of times the tensor should be repeated in the first, second and
28
29
  third dimension respectively.
@@ -31,7 +32,7 @@ def repeat(x: Tensor, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
31
32
  triton_block_size (int): The block size to use for the triton kernel (default ``None``).
32
33
 
33
34
  Returns:
34
- Tensor: A block-sparse tensor in compressed form containing the repeated values.
35
+ BlksprsTensor: A block-sparse tensor in compressed form containing the repeated values.
35
36
  Tensor: The sparsity layout of the resulting output tensor.
36
37
 
37
38
  """
@@ -63,14 +64,14 @@ def repeat(x: Tensor, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
63
64
 
64
65
  validate_contiguous(sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
65
66
 
66
- return _BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
67
- sparsity_block_size, n_sparse_blocks, triton_block_size), sparsity_layout_o
67
+ return BlksprsTensor(_BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
68
+ sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_o
68
69
 
69
70
 
70
- def repeat_interleave(x: Tensor, sparsity_layout_x: Tensor, repeats: int,
71
+ def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
71
72
  sparsity_block_size: int, sparsity_layout_output: Tensor = None,
72
73
  triton_block_size: int = None) -> (
73
- Tensor, Tensor):
74
+ BlksprsTensor, Tensor):
74
75
  """Repeats and interleaves the block-sparse tensor in compressed form.
75
76
 
76
77
  Repeats each matrix contained in the tensors by ``repeats`` amount and places them consecutively in the output
@@ -81,7 +82,7 @@ def repeat_interleave(x: Tensor, sparsity_layout_x: Tensor, repeats: int,
81
82
  non-sparse blocks will be filled.
82
83
 
83
84
  Args:
84
- x (Tensor): A block-sparse tensor in compressed form.
85
+ x (BlksprsTensor): A block-sparse tensor in compressed form.
85
86
  sparsity_layout_x (Tensor): The sparsity layout of the block-sparse tensor.
86
87
  repeats (int): The number of times to repeat the matrices.
87
88
  sparsity_block_size (int): The size of the sparsity blocks.
@@ -89,7 +90,7 @@ def repeat_interleave(x: Tensor, sparsity_layout_x: Tensor, repeats: int,
89
90
  triton_block_size (int): The block size to use for the triton kernel (default ``None``).
90
91
 
91
92
  Returns:
92
- Tensor: A block-sparse tensor in compressed form containing the repeated and interleaved matrices.
93
+ BlksprsTensor: A block-sparse tensor in compressed form containing the repeated and interleaved matrices.
93
94
  Tensor: The sparsity layout of the resulting output tensor.
94
95
 
95
96
  """
@@ -121,8 +122,8 @@ def repeat_interleave(x: Tensor, sparsity_layout_x: Tensor, repeats: int,
121
122
 
122
123
  validate_contiguous(sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
123
124
 
124
- return _BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
125
- sparsity_block_size, n_sparse_blocks, triton_block_size), sparsity_layout_o
125
+ return BlksprsTensor(_BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
126
+ sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_o
126
127
 
127
128
 
128
129
  class _BlocksparseRepeat(torch.autograd.Function):
blksprs/ops/softmax.py CHANGED
@@ -3,27 +3,28 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
- from blksprs.misc.exp import exp
7
- from blksprs.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
6
+ from blksprs.ops.misc.exp import exp
7
+ from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
8
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
8
9
  from blksprs.utils.tools import get_triton_block_size, stride
9
10
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
10
11
  validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
11
12
 
12
13
 
13
- def softmax(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
14
+ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
14
15
  """Computes the softmax of a block-sparse tensor in compressed form.
15
16
 
16
17
  Note:
17
18
  Sparse blocks are not considered for the calculation of the softmax, i.e., all values are assumed to be ``-inf``.
18
19
 
19
20
  Args:
20
- x (Tensor): A block-sparse tensor in compressed form.
21
+ x (BlksprsTensor): A block-sparse tensor in compressed form.
21
22
  sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
22
23
  sparsity_block_size (int): The size of the sparsity blocks.
23
24
  triton_block_size (int): The block size to use for the triton kernel (default ``None``).
24
25
 
25
26
  Returns:
26
- Tensor: The result of the softmax operation as a block-sparse tensor in compressed form.
27
+ BlksprsTensor: The result of the softmax operation as a block-sparse tensor in compressed form.
27
28
 
28
29
  """
29
30
  x = x.contiguous()
@@ -45,10 +46,10 @@ def softmax(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton
45
46
 
46
47
  validate_contiguous(sparsity_layout, sparsity_lut, sparsity_reverse_lut_rws)
47
48
 
48
- return _BlocksparseSoftmax.apply(x, sparsity_layout,
49
+ return BlksprsTensor(_BlocksparseSoftmax.apply(x, sparsity_layout,
49
50
  sparsity_lut,
50
51
  sparsity_reverse_lut_rws,
51
- sparsity_block_size, triton_block_size)
52
+ sparsity_block_size, triton_block_size))
52
53
 
53
54
 
54
55
  class _BlocksparseSoftmax(torch.autograd.Function):
blksprs/ops/transpose.py CHANGED
@@ -3,26 +3,27 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
6
7
  from blksprs.utils.tools import get_triton_block_size, stride
7
8
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
8
9
  validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
9
10
 
10
11
 
11
- def transpose(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> (
12
- Tensor, Tensor):
12
+ def transpose(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> (
13
+ BlksprsTensor, Tensor):
13
14
  """Transposes a block-sparse tensor in compressed form.
14
15
 
15
16
  Note:
16
17
  Returns the transposed tensor and the sparsity layout of the transposed tensor.
17
18
 
18
19
  Args:
19
- x (Tensor): A block-sparse tensor in compressed form.
20
+ x (BlksprsTensor): A block-sparse tensor in compressed form.
20
21
  sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
21
22
  sparsity_block_size (int): The size of the sparsity blocks.
22
23
  triton_block_size (int): The block size to use for the triton kernel (default ``None``).
23
24
 
24
25
  Returns:
25
- Tensor: The transposed block-sparse tensor in compressed form.
26
+ BlksprsTensor: The transposed block-sparse tensor in compressed form.
26
27
  Tensor: The sparsity layout of the transposed tensor.
27
28
 
28
29
  """
@@ -49,8 +50,8 @@ def transpose(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, trit
49
50
 
50
51
  validate_contiguous(sparsity_layout_t, sparsity_lut, sparsity_reverse_lut)
51
52
 
52
- return _BlocksparseTranspose.apply(x, sparsity_layout_t, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
53
- n_sparse_blocks, triton_block_size), sparsity_layout_t
53
+ return BlksprsTensor(_BlocksparseTranspose.apply(x, sparsity_layout_t, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
54
+ n_sparse_blocks, triton_block_size)), sparsity_layout_t
54
55
 
55
56
 
56
57
  class _BlocksparseTranspose(torch.autograd.Function):
@@ -0,0 +1,8 @@
1
+ from torch import Tensor
2
+
3
+
4
+ class BlksprsTensor(Tensor):
5
+ """A wrapper class representing a block-sparse tensor in compressed form.
6
+ """
7
+
8
+ pass
@@ -0,0 +1,41 @@
1
+ import torch
2
+ from torch import Tensor, nn
3
+ from triton.language import dtype
4
+
5
+ from blksprs.layouting.sparsity_layout import build_sparsity_layout_matmul_fast
6
+ from blksprs.ops.conversion import to_sparse
7
+ from blksprs.ops.matmul import matmul
8
+ from blksprs.ops.repeat import repeat
9
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
10
+
11
+
12
+ def apply_torch_linear(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
13
+ linear: nn.Linear) -> (BlksprsTensor, Tensor):
14
+ # Extract weight and bias
15
+ w = linear.weight
16
+ b = linear.bias
17
+
18
+ # Convert w to block-sparse representation
19
+ sparsity_layout_w_t = torch.ones(size=(sparsity_layout.size(0), w.size(1) // sparsity_block_size,
20
+ w.size(0) // sparsity_block_size), dtype=torch.bool, device=x.device)
21
+ w_t_bs = to_sparse(w.transpose(-1, -2).unsqueeze(0).repeat(sparsity_layout.size(0), 1, 1),
22
+ sparsity_layout_w_t, sparsity_block_size)
23
+
24
+ # Apply weights
25
+ sparsity_layout_xw = build_sparsity_layout_matmul_fast(sparsity_layout, sparsity_layout_w_t)
26
+ xw = matmul(x, sparsity_layout, w_t_bs, sparsity_layout_w_t, sparsity_layout_xw, sparsity_block_size)
27
+ interim = xw
28
+
29
+ # Apply bias
30
+ if b is not None:
31
+ b_slice = b.unsqueeze(0).unsqueeze(0).repeat(1, sparsity_block_size, 1)
32
+ sparsity_layout_b_slice = torch.ones(size=(1, b_slice.size(1) // sparsity_block_size,
33
+ b_slice.size(2) // sparsity_block_size), dtype=torch.bool,
34
+ device=x.device)
35
+ b_slice_bs = to_sparse(b_slice, sparsity_layout_b_slice, sparsity_block_size)
36
+ b_bs, sparsity_layout_b = repeat(b_slice_bs, sparsity_layout_b_slice,
37
+ (sparsity_layout.size(0), sparsity_layout_xw.size(1), 1), sparsity_block_size,
38
+ sparsity_layout_output=sparsity_layout_xw)
39
+ interim = interim + b_bs
40
+
41
+ return interim, sparsity_layout_xw
blksprs/utils/tools.py CHANGED
@@ -1,7 +1,5 @@
1
1
  from torch import Tensor, Size
2
2
 
3
- from blksprs.utils.validation import _set_skip_validation
4
-
5
3
 
6
4
  def do_shape_blocksparse(x: Tensor):
7
5
  if x.dim() == 3:
@@ -21,8 +19,5 @@ def get_triton_block_size(sparsity_block_size: int, limit: int = 128):
21
19
  return min(sparsity_block_size, limit)
22
20
 
23
21
 
24
- def disable_validation():
25
- _set_skip_validation(True)
26
-
27
22
  def stride(x: Tensor):
28
- return x.view(x.shape).stride()
23
+ return x.view(x.shape).stride()
@@ -124,3 +124,7 @@ def _check_skip_validation():
124
124
  def _set_skip_validation(skip_validation: bool):
125
125
  global VALIDATION
126
126
  VALIDATION = not skip_validation
127
+
128
+
129
+ def disable_validation():
130
+ _set_skip_validation(True)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: blksprs
3
- Version: 1.8.1
3
+ Version: 1.8.3
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -22,6 +22,14 @@ Requires-Dist: matplotlib; extra == "test"
22
22
  [![GitHub Release](https://img.shields.io/github/v/release/FelixSchoen/blksprs?include_prereleases&label=Latest%20Release)](https://github.com/FelixSchoen/blksprs/releases)
23
23
  [![Python Version](https://img.shields.io/badge/Python%20Version-3.11-blue)](https://www.python.org/downloads/release/python-3119/)
24
24
 
25
+ ## Important Notice
26
+
27
+ 🚨 **Non-Final API** 🚨
28
+
29
+ Although it already supports a wide variety of functions, this library is still under active development and the API is
30
+ subject to change. For feature requests or bug reports, please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
31
+ We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
32
+
25
33
  ## Overview
26
34
 
27
35
  A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
@@ -51,14 +59,14 @@ These include, e.g.,
51
59
  Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to
52
60
  match.
53
61
 
54
- Further helpful operations (included in the ``bs.misc`` module) that do **not** support gradient calculation include:
62
+ Further helpful operations (included in the ``bs.ops.misc`` module) that do **not** support gradient calculation include:
55
63
 
56
64
  - Row-wise sum, max, addition, and subtraction
57
65
  - Broadcast addition and subtraction between slices
58
66
 
59
67
  Furthermore, the library provides a set of utility functions for the creation of sparsity layouts based on existing
60
- dense tensors and for the scatter operation (module ``bs.layout``), as well as utility functions to ensure correct input
61
- dimensionality (module ``bs.util``).
68
+ dense tensors and for the scatter operation (module ``bs.layouting``), as well as utility functions to apply linear layers,
69
+ ensure correct input dimensionality, and validate input (module ``bs.utils``).
62
70
 
63
71
  ## Installation
64
72
 
@@ -111,14 +119,14 @@ def test_readme():
111
119
  y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()
112
120
 
113
121
  # Convert tensors to three-dimensional (dense) tensors since Triton can only handle tensors of exactly three dimensions
114
- x_dense, x_shape_original = bs.util.do_shape_blocksparse(x)
115
- y_dense, y_shape_original = bs.util.do_shape_blocksparse(y)
122
+ x_dense, x_shape_original = bs.utils.do_shape_blocksparse(x)
123
+ y_dense, y_shape_original = bs.utils.do_shape_blocksparse(y)
116
124
 
117
125
  # Create sparsity layouts from existing tensors
118
- sparsity_layout_x = bs.layout.build_sparsity_layout(x_dense, sparsity_block_size,
119
- triton_block_size=triton_block_size)
120
- sparsity_layout_y = bs.layout.build_sparsity_layout(y_dense, sparsity_block_size,
121
- triton_block_size=triton_block_size)
126
+ sparsity_layout_x = bs.layouting.build_sparsity_layout(x_dense, sparsity_block_size,
127
+ triton_block_size=triton_block_size)
128
+ sparsity_layout_y = bs.layouting.build_sparsity_layout(y_dense, sparsity_block_size,
129
+ triton_block_size=triton_block_size)
122
130
 
123
131
  # Create random sparsity layout for output tensor
124
132
  sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)
@@ -150,12 +158,12 @@ def test_readme():
150
158
  assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2) # Note that small numerical differences are expected
151
159
 
152
160
  # Assert that the output has the correct sparsity layout
153
- actual_sparsity_layout_o = bs.layout.build_sparsity_layout(o_dense, sparsity_block_size,
154
- triton_block_size=triton_block_size)
161
+ actual_sparsity_layout_o = bs.layouting.build_sparsity_layout(o_dense, sparsity_block_size,
162
+ triton_block_size=triton_block_size)
155
163
  assert torch.allclose(actual_sparsity_layout_o.to(torch.int), sparsity_layout_o)
156
164
 
157
165
  # Convert output tensor back to original shape
158
- o = bs.util.undo_shape_blocksparse(o_dense, x_shape_original)
166
+ o = bs.utils.undo_shape_blocksparse(o_dense, x_shape_original)
159
167
 
160
168
  # Other available functions
161
169
  bs.transpose(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
@@ -0,0 +1,23 @@
1
+ blksprs/__init__.py,sha256=YMrERuEf1hTv5vVdOvPEzh9rESn4uqOB7WHB12Qs5lU,1836
2
+ blksprs/layouting/distribution_layout.py,sha256=wmj1SwWyY_fhbvMmh6AXrR77LoSp6xLwUWCCyO9i5lk,4239
3
+ blksprs/layouting/sparsity_layout.py,sha256=-sScIn4hhG35j9BXytrojEzp8jnFkMargJjtivPV1fc,9755
4
+ blksprs/ops/conversion.py,sha256=ol-iV45wDzp9G1dJEkY53EdrvnmHzcl7QQmPJ-xqQTs,22410
5
+ blksprs/ops/distribution.py,sha256=fXZV6UegCVpIwzh-A825OSYClHWu5k0UMYdO2UGDUpM,17067
6
+ blksprs/ops/matmul.py,sha256=yh2ZnO0ZltT1AgadiFP0vX28YJ4n74xO-I_5vFUmOmA,11452
7
+ blksprs/ops/partitioning.py,sha256=K0ExR2a3W62d_9xxCJzsdJDLgtbxTI6P8loOOBdhPzE,7674
8
+ blksprs/ops/repeat.py,sha256=IvSIRbuyFn0b57LObymLgup0LqlWQ3ndIw-QuiYQcaU,14564
9
+ blksprs/ops/softmax.py,sha256=CDQT2KnwkJ4hGIgT0EUp6P92uiYpCdJQ9zxcdgSAAJA,12102
10
+ blksprs/ops/transpose.py,sha256=jxzFFffrj4S_9tiCrwwUMdz6EA98o1dziWXjlqb64a4,6859
11
+ blksprs/ops/experimental/distribution_mdi.py,sha256=HaRUu6LTWATzjuHWgddIUE-0fgY-O87STpJO4JY7k_8,20357
12
+ blksprs/ops/misc/broadcast_ops.py,sha256=cPtRJa3pkZfY1QG51CJ-zDn4SK-CRpX5LEXoKGGMvRU,5418
13
+ blksprs/ops/misc/exp.py,sha256=FnSFosBfJHuiEbD0MD-i4axLghRn4a0f8KvHXrKBB6M,3802
14
+ blksprs/ops/misc/row_wise.py,sha256=SvJuNww-_QoVKTyTjMvjmzHlBuUlTKamkuq_rKzwAqs,17081
15
+ blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
16
+ blksprs/utils/blksprs_tensor.py,sha256=VjplBgDhnf9sxf-1R5feA0xp5FDCDdaeZmCeoIRdCnc,151
17
+ blksprs/utils/processing.py,sha256=hYsFxEbQKcbqU4WtZWusPnWMHg8ZAZF1SKZJYjez9aU,2060
18
+ blksprs/utils/tools.py,sha256=r7Y4C37vfSWUyQTGwa8NyRqgovmsq9hMufkenqYHOxo,539
19
+ blksprs/utils/validation.py,sha256=IZxH2HZpePmv7lRqLsSwV_6FwsdnTXv9q4j98vCMSsQ,4195
20
+ blksprs-1.8.3.dist-info/METADATA,sha256=DZkJ_HeetF1V6-_F6GeG0uXT-QmttMFOq4ao8fiSMgQ,8458
21
+ blksprs-1.8.3.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
22
+ blksprs-1.8.3.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
23
+ blksprs-1.8.3.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.2.0)
2
+ Generator: setuptools (75.3.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,21 +0,0 @@
1
- blksprs/__init__.py,sha256=np0msosWMaZNVVfuFGt8rE6HZURyIald391dKAs1dSQ,1093
2
- blksprs/experimental/distribution_mdi.py,sha256=jE_SbB0SMGPcCoeM0699cceWAikBkBpGM_8Oo5A7Ets,20083
3
- blksprs/layouting/distribution_layout.py,sha256=nCR3fCF6eNpi97DI6MMLF2hq_9Lwyo6_aUCIZiJfiX4,4170
4
- blksprs/layouting/sparsity_layout.py,sha256=JNwbJ4L-418zCiCbt-vLfZ0xU7ReP0zr_tLHs_dytRA,9686
5
- blksprs/misc/broadcast_ops.py,sha256=0RLnLMYV7GAPI2YL8RotcxjIUSBZKGxdVcsGaJFeL_I,5327
6
- blksprs/misc/exp.py,sha256=cdF0s93Q9iucIXuEE3howsB0N6D60xgvem7C-a-yiGI,3704
7
- blksprs/misc/partitioning.py,sha256=nBRZzfi3XYAhDLEBzYflQkvGa3MIZ-qNeIlrZ16k44g,7533
8
- blksprs/misc/row_wise.py,sha256=0vDJA8uCocmebSIPIbFeND5_PQIE10pUj3DBOQXlTvE,16888
9
- blksprs/ops/conversion.py,sha256=9xVdCrj38m1cMh43LQs-GrXZ5pNRjhQyKx6paaw3C6A,21898
10
- blksprs/ops/distribution.py,sha256=V3TK5SlNT_JdGHNaDNl-U4U5vwAYsgkAOg4eTmYxbuA,16877
11
- blksprs/ops/matmul.py,sha256=uqVe6Dz2aaCbCglM1uS2eRHVKh7PQcuecaIBWFubPEw,11256
12
- blksprs/ops/repeat.py,sha256=OSsa2rj6BHL3Kedfu3wr0D82mn4HmbJ1l7XEmT-6ehg,14423
13
- blksprs/ops/softmax.py,sha256=5nAgeT68nucgOugjtCy1aBIMa7Kyk1KNN-j8fgmeVuk,11996
14
- blksprs/ops/transpose.py,sha256=67pDdCEb7r-Xifupl82fBKAYsxKcCUDy--cPPfduRvU,6761
15
- blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
16
- blksprs/utils/tools.py,sha256=S3836Zuc-BMigv-5mLTjRznCzuaF6oYW-Ir9zzUnr3o,655
17
- blksprs/utils/validation.py,sha256=WzihRPibXYzss3PMkhDt5_d3Q3NHA_d1TzTz3CoGPGg,4136
18
- blksprs-1.8.1.dist-info/METADATA,sha256=UDXUjS8PHyD4Zm-gWF4maXzY1k2SjKHMQllu-uOwLIA,8009
19
- blksprs-1.8.1.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
20
- blksprs-1.8.1.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
21
- blksprs-1.8.1.dist-info/RECORD,,