blksprs 1.8__py3-none-any.whl → 1.8.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/__init__.py CHANGED
@@ -10,7 +10,7 @@ from blksprs.misc.partitioning import split, merge
10
10
  class layout:
11
11
  from blksprs.layouting.distribution_layout import build_distribution_layout
12
12
  from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption, \
13
- build_sparsity_layout_matmul
13
+ build_sparsity_layout_matmul, build_sparsity_layout_matmul_fast
14
14
 
15
15
 
16
16
  class misc:
@@ -3,17 +3,18 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
6
7
  from blksprs.utils.tools import get_triton_block_size, stride
7
8
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
8
9
  validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
9
10
 
10
11
 
11
- def gather_mdi(src: Tensor, sparsity_layout_src: Tensor,
12
- idx_bat: Tensor,
13
- idx_row: Tensor,
14
- idx_col: Tensor,
12
+ def gather_mdi(src: BlksprsTensor, sparsity_layout_src: Tensor,
13
+ idx_bat: BlksprsTensor,
14
+ idx_row: BlksprsTensor,
15
+ idx_col: BlksprsTensor,
15
16
  sparsity_layout_idx: Tensor,
16
- sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
17
+ sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
17
18
  src = src.contiguous()
18
19
  idx_bat = idx_bat.contiguous()
19
20
  idx_col = idx_col.contiguous()
@@ -37,9 +38,9 @@ def gather_mdi(src: Tensor, sparsity_layout_src: Tensor,
37
38
  validate_contiguous(sparsity_layout_src, sparsity_reverse_lut_x,
38
39
  sparsity_layout_idx, sparsity_lut_i)
39
40
 
40
- return _BlocksparseGatherMDI.apply(src, sparsity_layout_src, sparsity_reverse_lut_x,
41
- idx_bat, idx_col, sparsity_layout_idx, sparsity_lut_i,
42
- sparsity_block_size, triton_block_size)
41
+ return BlksprsTensor(_BlocksparseGatherMDI.apply(src, sparsity_layout_src, sparsity_reverse_lut_x,
42
+ idx_bat, idx_col, sparsity_layout_idx, sparsity_lut_i,
43
+ sparsity_block_size, triton_block_size))
43
44
 
44
45
 
45
46
  class _BlocksparseGatherMDI(torch.autograd.Function):
@@ -167,13 +168,13 @@ class _BlocksparseGatherMDI(torch.autograd.Function):
167
168
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
168
169
 
169
170
 
170
- def scatter_reduce_mdi(src: Tensor, sparsity_layout_src: Tensor,
171
- idx_bat: Tensor,
172
- idx_row: Tensor,
173
- idx_col: Tensor,
171
+ def scatter_reduce_mdi(src: BlksprsTensor, sparsity_layout_src: Tensor,
172
+ idx_bat: BlksprsTensor,
173
+ idx_row: BlksprsTensor,
174
+ idx_col: BlksprsTensor,
174
175
  sparsity_layout_tgt: Tensor,
175
176
  sparsity_block_size: int,
176
- reduce_op: str = "sum", triton_block_size: int = None) -> Tensor:
177
+ reduce_op: str = "sum", triton_block_size: int = None) -> BlksprsTensor:
177
178
  src = src.contiguous()
178
179
  idx_bat = idx_bat.contiguous()
179
180
  idx_col = idx_col.contiguous()
@@ -203,12 +204,12 @@ def scatter_reduce_mdi(src: Tensor, sparsity_layout_src: Tensor,
203
204
  validate_contiguous(sparsity_layout_src, sparsity_lut_x,
204
205
  sparsity_layout_tgt, sparsity_reverse_lut_o)
205
206
 
206
- return _BlocksparseScatterReduceMDI.apply(src, sparsity_layout_src, sparsity_lut_x,
207
- idx_bat,
208
- idx_col,
209
- sparsity_layout_tgt, sparsity_reverse_lut_o,
210
- sparsity_block_size, n_sparse_blocks,
211
- reduce_op, triton_block_size)
207
+ return BlksprsTensor(_BlocksparseScatterReduceMDI.apply(src, sparsity_layout_src, sparsity_lut_x,
208
+ idx_bat,
209
+ idx_col,
210
+ sparsity_layout_tgt, sparsity_reverse_lut_o,
211
+ sparsity_block_size, n_sparse_blocks,
212
+ reduce_op, triton_block_size))
212
213
 
213
214
 
214
215
  class _BlocksparseScatterReduceMDI(torch.autograd.Function):
@@ -353,8 +354,8 @@ class _BlocksparseScatterReduceMDI(torch.autograd.Function):
353
354
  tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
354
355
 
355
356
 
356
- def build_distribution_layout_mdi(idx_bat: Tensor, idx_row: Tensor, idx_col: Tensor, sparsity_layout_idx: Tensor,
357
- size_target: torch.Size,
357
+ def build_distribution_layout_mdi(idx_bat: BlksprsTensor, idx_row: BlksprsTensor, idx_col: BlksprsTensor,
358
+ sparsity_layout_idx: Tensor, size_target: torch.Size,
358
359
  sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
359
360
  validate_dimensions(idx_bat, idx_col)
360
361
  validate_contiguous(idx_bat, idx_col)
@@ -3,18 +3,19 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
6
7
  from blksprs.utils.tools import get_triton_block_size, stride
7
8
  from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
8
9
  validate_contiguous
9
10
 
10
11
 
11
- def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
12
+ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: Tensor,
12
13
  size_target: torch.Size,
13
14
  sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
14
15
  """Builds the sparsity layout of either the source of a gather or the target of a scatter operation.
15
16
 
16
17
  Args:
17
- indices (Tensor): The block-sparse indices tensor in compressed form used for the gather or scatter operation.
18
+ indices (BlksprsTensor): The block-sparse indices tensor in compressed form used for the gather or scatter operation.
18
19
  sparsity_layout_indices (Tensor): The sparsity layout of the indices block-sparse tensor.
19
20
  size_target (torch.Size): The size of the block-sparse target tensor in regular form.
20
21
  sparsity_block_size (int): The size of the sparsity blocks.
@@ -5,6 +5,7 @@ import triton
5
5
  from torch import Tensor
6
6
  from triton import language as tl
7
7
 
8
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
8
9
  from blksprs.utils.tools import get_triton_block_size, stride
9
10
  from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
10
11
  validate_contiguous, validate_sparsity, validate_sparsity_block_size
@@ -82,14 +83,14 @@ def kernel_sparsity_layout(x,
82
83
  tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
83
84
 
84
85
 
85
- def build_sparsity_layout_adaption(x: Tensor, sparsity_layout_from: Tensor,
86
+ def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tensor,
86
87
  sparsity_block_size_from: int, sparsity_block_size_to: int,
87
88
  triton_block_size: int = None) -> Tensor:
88
89
  """Builds the sparsity layout of a block-sparse tensor in compressed form if a different sparsity block size were
89
90
  used.
90
91
 
91
92
  Args:
92
- x (Tensor): A block-sparse tensor in compressed form.
93
+ x (BlksprsTensor): A block-sparse tensor in compressed form.
93
94
  sparsity_layout_from (Tensor): The sparsity layout of the input block-sparse tensor.
94
95
  sparsity_block_size_from (int): The size of the sparsity blocks of the input tensor.
95
96
  sparsity_block_size_to (int): The desired size of the sparsity blocks for the resulting layout.
@@ -3,13 +3,14 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
6
7
  from blksprs.utils.tools import get_triton_block_size, stride
7
8
  from blksprs.utils.validation import validate_contiguous, validate_device, \
8
9
  validate_sparsity_block_size, validate_triton_block_size
9
10
 
10
11
 
11
12
  def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
12
- sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
13
+ sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
13
14
  """Performs a broadcast and subsequent addition of two dense tensors x and y. Returns a block-sparse tensor in
14
15
  compressed form.
15
16
 
@@ -21,7 +22,7 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
21
22
  triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
22
23
 
23
24
  Returns:
24
- Tensor: The result of the operation as a block-sparse tensor in compressed form. Each element o(i, j) of the
25
+ BlksprsTensor: The result of the operation as a block-sparse tensor in compressed form. Each element o(i, j) of the
25
26
  output tensor corresponds to x(i) + y(j).
26
27
 
27
28
  """
@@ -70,11 +71,11 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
70
71
  sparsity_block_size,
71
72
  triton_block_size))
72
73
 
73
- return output
74
+ return BlksprsTensor(output)
74
75
 
75
76
 
76
77
  def broadcast_sub(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
77
- sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
78
+ sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
78
79
  """Wrapper for ``broadcast_add`` with negated y.
79
80
 
80
81
  """
blksprs/misc/exp.py CHANGED
@@ -3,12 +3,13 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
6
7
  from blksprs.utils.tools import get_triton_block_size, stride
7
8
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
8
9
  validate_sparsity_block_size, validate_triton_block_size
9
10
 
10
11
 
11
- def exp(x: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
12
+ def exp(x: BlksprsTensor, sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
12
13
  """Applies the element-wise exponential function to a block-sparse tensor.
13
14
 
14
15
  Note:
@@ -16,12 +17,12 @@ def exp(x: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> T
16
17
  Consider this when converting back to tensors in regular form.
17
18
 
18
19
  Args:
19
- x (Tensor): A block-sparse tensor in compressed form.
20
+ x (BlksprsTensor): A block-sparse tensor in compressed form.
20
21
  sparsity_block_size (int): The size of the sparsity blocks.
21
22
  triton_block_size (int): The block size to use for the triton kernel (default ``None``).
22
23
 
23
24
  Returns:
24
- Tensor: The exponential function applied to all elements of the input tensor as a block-sparse tensor in
25
+ BlksprsTensor: The exponential function applied to all elements of the input tensor as a block-sparse tensor in
25
26
  compressed form.
26
27
 
27
28
  """
@@ -33,7 +34,7 @@ def exp(x: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> T
33
34
  validate_sparsity_block_size(sparsity_block_size, x)
34
35
  validate_triton_block_size(triton_block_size, sparsity_block_size)
35
36
 
36
- return _BlocksparseExp.apply(x, sparsity_block_size, triton_block_size)
37
+ return BlksprsTensor(_BlocksparseExp.apply(x, sparsity_block_size, triton_block_size))
37
38
 
38
39
 
39
40
  class _BlocksparseExp(torch.autograd.Function):
@@ -2,24 +2,25 @@ import torch
2
2
  from torch import Tensor
3
3
 
4
4
  from blksprs.ops.repeat import forward_flow
5
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
5
6
 
6
7
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
7
8
  validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
8
9
 
9
10
 
10
- def split(x: Tensor, sparsity_layout: Tensor, partitions: int,
11
- sparsity_block_size: int, triton_block_size: int = None) -> (Tensor, Tensor):
11
+ def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
12
+ sparsity_block_size: int, triton_block_size: int = None) -> (BlksprsTensor, Tensor):
12
13
  """Splits a block-sparse tensor in compressed form along the last dimension into partitions.
13
14
 
14
15
  Args:
15
- x (Tensor): A block-sparse tensor in compressed form.
16
+ x (BlksprsTensor): A block-sparse tensor in compressed form.
16
17
  sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
17
18
  partitions (int): The number of partitions to split the block-sparse tensor into.
18
19
  sparsity_block_size (int): The size of the sparsity blocks.
19
20
  triton_block_size (int): The block size to use for the triton kernel (default ``None``).
20
21
 
21
22
  Returns:
22
- Tensor: The block-sparse tensor split into partitions in compressed form.
23
+ BlksprsTensor: The block-sparse tensor split into partitions in compressed form.
23
24
  Tensor: The sparsity layout of the output tensor.
24
25
 
25
26
  """
@@ -53,8 +54,8 @@ def split(x: Tensor, sparsity_layout: Tensor, partitions: int,
53
54
 
54
55
  validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
55
56
 
56
- return _BlocksparseSplit.apply(x, sparsity_layout_output, sparsity_lut, sparsity_reverse_lut, partitions,
57
- sparsity_block_size, n_sparse_blocks, triton_block_size), sparsity_layout_output
57
+ return BlksprsTensor(_BlocksparseSplit.apply(x, sparsity_layout_output, sparsity_lut, sparsity_reverse_lut, partitions,
58
+ sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_output
58
59
 
59
60
 
60
61
  class _BlocksparseSplit(torch.autograd.Function):
@@ -79,19 +80,19 @@ class _BlocksparseSplit(torch.autograd.Function):
79
80
  sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None
80
81
 
81
82
 
82
- def merge(x: Tensor, sparsity_layout: Tensor, partitions: int,
83
- sparsity_block_size: int, triton_block_size: int = None) -> (Tensor, Tensor):
83
+ def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
84
+ sparsity_block_size: int, triton_block_size: int = None) -> (BlksprsTensor, Tensor):
84
85
  """Merges the specified partitions of a block-sparse tensor in compressed form along the last dimension.
85
86
 
86
87
  Args:
87
- x (Tensor): A block-sparse tensor in compressed form.
88
+ x (BlksprsTensor): A block-sparse tensor in compressed form.
88
89
  sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
89
90
  partitions (int): The number of partitions to be merged.
90
91
  sparsity_block_size (int): The size of the sparsity blocks.
91
92
  triton_block_size (int): The block size to use for the triton kernel (default ``None``).
92
93
 
93
94
  Returns:
94
- Tensor: The merged block-sparse tensor in compressed form.
95
+ BlksprsTensor: The merged block-sparse tensor in compressed form.
95
96
  Tensor: The sparsity layout of the output tensor.
96
97
 
97
98
  """
@@ -127,8 +128,8 @@ def merge(x: Tensor, sparsity_layout: Tensor, partitions: int,
127
128
 
128
129
  validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
129
130
 
130
- return _BlocksparseMerge.apply(x, sparsity_layout_output, sparsity_lut, sparsity_reverse_lut, partitions,
131
- sparsity_block_size, n_sparse_blocks, triton_block_size), sparsity_layout_output
131
+ return BlksprsTensor(_BlocksparseMerge.apply(x, sparsity_layout_output, sparsity_lut, sparsity_reverse_lut, partitions,
132
+ sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_output
132
133
 
133
134
 
134
135
  class _BlocksparseMerge(torch.autograd.Function):
blksprs/misc/row_wise.py CHANGED
@@ -3,13 +3,14 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
6
7
  from blksprs.utils.tools import get_triton_block_size, stride
7
8
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, validate_sparsity, \
8
9
  validate_sparsity_block_size, validate_triton_block_size
9
10
 
10
11
 
11
- def row_wise_sum(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
12
- flag_slice_only: bool = False, triton_block_size: int = None) -> tuple[Tensor, Tensor]:
12
+ def row_wise_sum(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
13
+ flag_slice_only: bool = False, triton_block_size: int = None) -> (BlksprsTensor, Tensor):
13
14
  """Computes the row-wise sum of a block-sparse tensor.
14
15
 
15
16
  Returns a block-sparse tensor in compressed form with only one block per row, where the first entry contains the sum
@@ -19,7 +20,7 @@ def row_wise_sum(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
19
20
  If ``flag_slice_only`` is set the output will be of shape ``[x.size(0), x.size(1), 1]``.
20
21
 
21
22
  Args:
22
- x (Tensor): A block-sparse tensor in compressed form.
23
+ x (BlksprsTensor): A block-sparse tensor in compressed form.
23
24
  sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
24
25
  sparsity_block_size (int): The size of the sparsity blocks.
25
26
  flag_slice_only (bool, optional): If set the output will be of shape ``[x.size(0), x.size(1), 1]``
@@ -27,7 +28,7 @@ def row_wise_sum(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
27
28
  triton_block_size (int): The block size to use for the triton kernel (default ``None``).
28
29
 
29
30
  Returns:
30
- tuple[Tensor, Tensor]: A tuple containing a block-sparse tensor in compressed form containing the row-wise sum
31
+ tuple[BlksprsTensor, Tensor]: A tuple containing a block-sparse tensor in compressed form containing the row-wise sum
31
32
  of the input and the sparsity layout of the output tensor.
32
33
 
33
34
  """
@@ -85,7 +86,7 @@ def row_wise_sum(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
85
86
  sparsity_reverse_lut_output,
86
87
  triton_block_size))
87
88
 
88
- return (output, sparsity_layout_output)
89
+ return BlksprsTensor(output), sparsity_layout_output
89
90
 
90
91
 
91
92
  @triton.jit
@@ -131,8 +132,8 @@ def kernel_blocksparse_row_wise_sum(x,
131
132
  tl.atomic_add(o + o_idx, buf, o_msk)
132
133
 
133
134
 
134
- def row_wise_max(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
135
- flag_slice_only: bool = False, triton_block_size: int = None) -> tuple[Tensor, Tensor]:
135
+ def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
136
+ flag_slice_only: bool = False, triton_block_size: int = None) -> (BlksprsTensor, Tensor):
136
137
  """Computes the row-wise max of a block-sparse tensor.
137
138
 
138
139
  Returns a block-sparse tensor in compressed form with only one block per row, where the first entry contains the
@@ -142,7 +143,7 @@ def row_wise_max(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
142
143
  If ``flag_slice_only`` is set the output will be of shape ``[x.size(0), x.size(1), 1]``.
143
144
 
144
145
  Args:
145
- x (Tensor): A block-sparse tensor in compressed form.
146
+ x (BlksprsTensor): A block-sparse tensor in compressed form.
146
147
  sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
147
148
  sparsity_block_size (int): The size of the sparsity blocks.
148
149
  flag_slice_only (bool, optional): If set the output will be of shape ``[x.size(0), x.size(1), 1]``
@@ -150,7 +151,7 @@ def row_wise_max(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
150
151
  triton_block_size (int): The block size to use for the triton kernel (default ``None``).
151
152
 
152
153
  Returns:
153
- tuple[Tensor, Tensor]: A tuple containing a block-sparse tensor in compressed form containing the row-wise max
154
+ tuple[BlksprsTensor, Tensor]: A tuple containing a block-sparse tensor in compressed form containing the row-wise max
154
155
  of the input and the sparsity layout of the output tensor.
155
156
 
156
157
  """
@@ -208,7 +209,7 @@ def row_wise_max(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
208
209
  sparsity_reverse_lut_output,
209
210
  triton_block_size))
210
211
 
211
- return output, sparsity_layout_output
212
+ return BlksprsTensor(output), sparsity_layout_output
212
213
 
213
214
 
214
215
  @triton.jit
@@ -254,19 +255,19 @@ def kernel_blocksparse_row_wise_max(x,
254
255
  tl.atomic_max(o + o_idx, buf, o_msk)
255
256
 
256
257
 
257
- def row_wise_add(x: Tensor, sparsity_layout_x: Tensor, y: Tensor,
258
- sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
258
+ def row_wise_add(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
259
+ sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
259
260
  """For each row in ``y`` adds the value to each value in the corresponding row of the block-sparse tensor ``x``.
260
261
 
261
262
  Args:
262
- x (Tensor): A block-sparse tensor in compressed form.
263
+ x (BlksprsTensor): A block-sparse tensor in compressed form.
263
264
  sparsity_layout_x (Tensor): The sparsity layout of the block-sparse tensor.
264
- y (Tensor): A block-sparse tensor in compressed form with only one value per row and a single column of sparse blocks.
265
+ y (BlksprsTensor): A block-sparse tensor in compressed form with only one value per row and a single column of sparse blocks.
265
266
  sparsity_block_size (int): The size of the sparsity blocks.
266
267
  triton_block_size (int): The block size to use for the triton kernel (default ``None``).
267
268
 
268
269
  Returns:
269
- Tensor: The values of ``x`` with the first value of ``y`` in each row added to them as a block-sparse tensor in
270
+ BlksprsTensor: The values of ``x`` with the first value of ``y`` in each row added to them as a block-sparse tensor in
270
271
  compressed form.
271
272
 
272
273
  """
@@ -319,11 +320,11 @@ def row_wise_add(x: Tensor, sparsity_layout_x: Tensor, y: Tensor,
319
320
  triton_block_size
320
321
  ))
321
322
 
322
- return output
323
+ return BlksprsTensor(output)
323
324
 
324
325
 
325
- def row_wise_sub(x: Tensor, sparsity_layout_x: Tensor, y: Tensor,
326
- sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
326
+ def row_wise_sub(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
327
+ sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
327
328
  """Wrapper for ``row_wise_add`` with negated y.
328
329
 
329
330
  """
blksprs/ops/conversion.py CHANGED
@@ -6,23 +6,27 @@ from torch import Tensor
6
6
  from triton import language as tl
7
7
 
8
8
  from blksprs.layouting.sparsity_layout import build_sparsity_layout_adaption
9
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
9
10
  from blksprs.utils.tools import get_triton_block_size, stride
10
11
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
11
12
  validate_sparsity, validate_sparsity_block_size, validate_triton_block_size, validate_sparsity_dense
12
13
 
13
14
 
14
- def from_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
15
+ def from_blksprs(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
15
16
  triton_block_size: int = None) -> Tensor:
17
+ """Wrapper for ``to_dense``.
18
+
19
+ """
16
20
  return to_dense(x, sparsity_layout, sparsity_block_size, fill_value, triton_block_size)
17
21
 
18
22
 
19
- def to_dense(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
23
+ def to_dense(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
20
24
  triton_block_size: int = None) -> Tensor:
21
25
  """Converts a block-sparse tensor in compressed form to a block-sparse tensor in regular form based on the given
22
26
  sparsity layout.
23
27
 
24
28
  Args:
25
- x (Tensor): A block-sparse tensor in compressed form.
29
+ x (BlksprsTensor): A block-sparse tensor in compressed form.
26
30
  sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
27
31
  sparsity_block_size (int): The size of the sparsity blocks.
28
32
  fill_value (float): The value to fill the resulting dense tensor with where the block-sparse tensor is not
@@ -50,12 +54,12 @@ def to_dense(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_
50
54
  validate_contiguous(sparsity_reverse_lut)
51
55
 
52
56
  if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
53
- return x
57
+ return BlksprsTensor(x)
54
58
 
55
- return _BlocksparseToDense.apply(x,
56
- sparsity_layout, sparsity_reverse_lut,
57
- sparsity_block_size, fill_value,
58
- triton_block_size)
59
+ return BlksprsTensor(_BlocksparseToDense.apply(x,
60
+ sparsity_layout, sparsity_reverse_lut,
61
+ sparsity_block_size, fill_value,
62
+ triton_block_size))
59
63
 
60
64
 
61
65
  class _BlocksparseToDense(torch.autograd.Function):
@@ -150,11 +154,15 @@ class _BlocksparseToDense(torch.autograd.Function):
150
154
 
151
155
 
152
156
  def to_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
153
- triton_block_size: int = None) -> Tensor:
157
+ triton_block_size: int = None) -> BlksprsTensor:
158
+ """Wrapper for ``to_sparse``.
159
+
160
+ """
154
161
  return to_sparse(x, sparsity_layout, sparsity_block_size, triton_block_size)
155
162
 
156
163
 
157
- def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
164
+ def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
165
+ triton_block_size: int = None) -> BlksprsTensor:
158
166
  """Converts a block-sparse tensor in regular form to a block-sparse tensor in compressed form based on the given
159
167
  sparsity layout.
160
168
 
@@ -165,7 +173,7 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, trit
165
173
  triton_block_size (int): The block size to use for the triton kernel (default ``None``).
166
174
 
167
175
  Returns:
168
- Tensor: The block-sparse tensor converted to compressed form.
176
+ BlksprsTensor: The block-sparse tensor converted to compressed form.
169
177
 
170
178
  """
171
179
  x = x.contiguous()
@@ -183,12 +191,12 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, trit
183
191
  validate_contiguous(sparsity_layout, sparsity_lut)
184
192
 
185
193
  if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
186
- return x
194
+ return BlksprsTensor(x)
187
195
 
188
- return _BlocksparseToSparse.apply(x,
189
- sparsity_layout, sparsity_lut,
190
- sparsity_block_size, n_sparse_blocks,
191
- triton_block_size)
196
+ return BlksprsTensor(_BlocksparseToSparse.apply(x,
197
+ sparsity_layout, sparsity_lut,
198
+ sparsity_block_size, n_sparse_blocks,
199
+ triton_block_size))
192
200
 
193
201
 
194
202
  class _BlocksparseToSparse(torch.autograd.Function):
@@ -280,13 +288,14 @@ class _BlocksparseToSparse(torch.autograd.Function):
280
288
  tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
281
289
 
282
290
 
283
- def adapt_layout(x: Tensor, sparsity_layout_from: Tensor, sparsity_block_size_from: int, sparsity_block_size_to: int,
284
- preprocess_data: dict = None, triton_block_size: int = None) -> Tensor:
291
+ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_size_from: int,
292
+ sparsity_block_size_to: int,
293
+ preprocess_data: dict = None, triton_block_size: int = None) -> BlksprsTensor:
285
294
  """Adapts the sparsity layout of a block-sparse tensor, resulting in a new block-sparse tensor in compressed form
286
295
  conforming to the new sparsity layout (and sparsity block size) definition.
287
296
 
288
297
  Args:
289
- x (Tensor): A block-sparse tensor in compressed form.
298
+ x (BlksprsTensor): A block-sparse tensor in compressed form.
290
299
  sparsity_layout_from (Tensor): The sparsity layout of the input block-sparse tensor.
291
300
  sparsity_block_size_from (int): The size of the sparsity blocks of the input sparsity layout.
292
301
  sparsity_block_size_to (int): The size of the sparsity blocks of the output sparsity layout.
@@ -294,7 +303,7 @@ def adapt_layout(x: Tensor, sparsity_layout_from: Tensor, sparsity_block_size_fr
294
303
  triton_block_size (int): The block size to use for the triton kernel (default ``None``).
295
304
 
296
305
  Returns:
297
- Tensor: The block-sparse tensor in compressed form with the adapted sparsity layout and sparsity block size.
306
+ BlksprsTensor: The block-sparse tensor in compressed form with the adapted sparsity layout and sparsity block size.
298
307
 
299
308
  """
300
309
  x = x.contiguous()
@@ -339,12 +348,13 @@ def adapt_layout(x: Tensor, sparsity_layout_from: Tensor, sparsity_block_size_fr
339
348
  validate_contiguous(sparsity_layout_to, sparsity_reverse_lut_from, sparsity_lut_to)
340
349
 
341
350
  if (sparsity_block_size_from == sparsity_block_size_to) and torch.equal(sparsity_layout_from, sparsity_layout_to):
342
- return x
351
+ return BlksprsTensor(x)
343
352
 
344
- return _BlocksparseAdaptLayout.apply(x,
345
- sparsity_layout_from, sparsity_reverse_lut_from, sparsity_block_size_from,
346
- sparsity_layout_to, sparsity_lut_to, sparsity_block_size_to,
347
- n_sparse_blocks_to, min_sparsity_block_size, triton_block_size)
353
+ return BlksprsTensor(_BlocksparseAdaptLayout.apply(x,
354
+ sparsity_layout_from, sparsity_reverse_lut_from,
355
+ sparsity_block_size_from,
356
+ sparsity_layout_to, sparsity_lut_to, sparsity_block_size_to,
357
+ n_sparse_blocks_to, min_sparsity_block_size, triton_block_size))
348
358
 
349
359
 
350
360
  class _BlocksparseAdaptLayout(torch.autograd.Function):
@@ -3,25 +3,26 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
6
7
  from blksprs.utils.tools import get_triton_block_size, stride
7
8
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
8
9
  validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
9
10
 
10
11
 
11
- def gather(src: Tensor, sparsity_layout_src: Tensor, idx: Tensor, sparsity_layout_idx: Tensor,
12
- sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
12
+ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor, idx: BlksprsTensor, sparsity_layout_idx: Tensor,
13
+ sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
13
14
  """Applies a gather operation on a block-sparse tensor in compressed form.
14
15
 
15
16
  Args:
16
- src (Tensor): The source block-sparse tensor in compressed form to gather from.
17
+ src (BlksprsTensor): The source block-sparse tensor in compressed form to gather from.
17
18
  sparsity_layout_src (Tensor): The sparsity layout of the source block-sparse tensor.
18
- idx (Tensor): The block-sparse indices tensor in compressed form specifying how to gather from the source tensor.
19
+ idx (BlksprsTensor): The block-sparse indices tensor in compressed form specifying how to gather from the source tensor.
19
20
  sparsity_layout_idx (Tensor): The sparsity layout of the indices block-sparse tensor.
20
21
  sparsity_block_size (int): The size of the sparsity blocks.
21
22
  triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
22
23
 
23
24
  Returns:
24
- Tensor: The result of the gather operation as a block-sparse tensor in compressed form.
25
+ BlksprsTensor: The result of the gather operation as a block-sparse tensor in compressed form.
25
26
 
26
27
  """
27
28
  src = src.contiguous()
@@ -45,9 +46,9 @@ def gather(src: Tensor, sparsity_layout_src: Tensor, idx: Tensor, sparsity_layou
45
46
  validate_contiguous(sparsity_layout_src, sparsity_reverse_lut_x,
46
47
  sparsity_layout_idx, sparsity_lut_i)
47
48
 
48
- return _BlocksparseGather.apply(src, sparsity_layout_src, sparsity_reverse_lut_x,
49
+ return BlksprsTensor(_BlocksparseGather.apply(src, sparsity_layout_src, sparsity_reverse_lut_x,
49
50
  idx, sparsity_layout_idx, sparsity_lut_i,
50
- sparsity_block_size, triton_block_size)
51
+ sparsity_block_size, triton_block_size))
51
52
 
52
53
 
53
54
  class _BlocksparseGather(torch.autograd.Function):
@@ -168,10 +169,10 @@ class _BlocksparseGather(torch.autograd.Function):
168
169
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
169
170
 
170
171
 
171
- def scatter(src: Tensor, sparsity_layout_src: Tensor,
172
- idx: Tensor,
172
+ def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
173
+ idx: BlksprsTensor,
173
174
  sparsity_layout_tgt: Tensor,
174
- sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
175
+ sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
175
176
  """Wrapper for ``scatter_reduce`` with ``reduce_op="none"``.
176
177
 
177
178
  """
@@ -182,17 +183,17 @@ def scatter(src: Tensor, sparsity_layout_src: Tensor,
182
183
  reduce_op="none", triton_block_size=triton_block_size)
183
184
 
184
185
 
185
- def scatter_reduce(src: Tensor, sparsity_layout_src: Tensor,
186
- idx: Tensor,
186
+ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
187
+ idx: BlksprsTensor,
187
188
  sparsity_layout_tgt: Tensor,
188
189
  sparsity_block_size: int,
189
- reduce_op: str = "sum", triton_block_size: int = None) -> Tensor:
190
+ reduce_op: str = "sum", triton_block_size: int = None) -> BlksprsTensor:
190
191
  """Applies a scatter operation on a block-sparse tensor in compressed form.
191
192
 
192
193
  Args:
193
- src (Tensor): The source block-sparse tensor in compressed form to scatter from.
194
+ src (BlksprsTensor): The source block-sparse tensor in compressed form to scatter from.
194
195
  sparsity_layout_src (Tensor): The sparsity layout of the source block-sparse tensor.
195
- idx (Tensor): The block-sparse indices tensor in compressed form specifying how to scatter to the target tensor.
196
+ idx (BlksprsTensor): The block-sparse indices tensor in compressed form specifying how to scatter to the target tensor.
196
197
  sparsity_layout_tgt (Tensor): The sparsity layout of the target block-sparse tensor.
197
198
  sparsity_block_size (int): The size of the sparsity blocks.
198
199
  reduce_op (str, optional): The reduction operation to apply during the scatter operation (default ``"sum"``).
@@ -200,7 +201,7 @@ def scatter_reduce(src: Tensor, sparsity_layout_src: Tensor,
200
201
  triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
201
202
 
202
203
  Returns:
203
- Tensor: The result of the scatter operation as a block-sparse tensor in compressed form.
204
+ BlksprsTensor: The result of the scatter operation as a block-sparse tensor in compressed form.
204
205
 
205
206
  """
206
207
  src = src.contiguous()
@@ -229,11 +230,11 @@ def scatter_reduce(src: Tensor, sparsity_layout_src: Tensor,
229
230
  validate_contiguous(sparsity_layout_src, sparsity_lut_x,
230
231
  sparsity_layout_tgt, sparsity_reverse_lut_o)
231
232
 
232
- return _BlocksparseScatterReduce.apply(src, sparsity_layout_src, sparsity_lut_x,
233
+ return BlksprsTensor(_BlocksparseScatterReduce.apply(src, sparsity_layout_src, sparsity_lut_x,
233
234
  idx,
234
235
  sparsity_layout_tgt, sparsity_reverse_lut_o,
235
236
  sparsity_block_size, n_sparse_blocks,
236
- reduce_op, triton_block_size)
237
+ reduce_op, triton_block_size))
237
238
 
238
239
 
239
240
  class _BlocksparseScatterReduce(torch.autograd.Function):
blksprs/ops/matmul.py CHANGED
@@ -4,22 +4,23 @@ from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
6
  from blksprs.ops.transpose import transpose
7
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
7
8
  from blksprs.utils.tools import get_triton_block_size, stride
8
9
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
9
10
  validate_sparsity, validate_sparsity_block_size, validate_triton_block_size, validate_dtype_float
10
11
 
11
12
 
12
- def matmul(x: Tensor, sparsity_layout_x: Tensor,
13
- y: Tensor, sparsity_layout_y: Tensor,
13
+ def matmul(x: BlksprsTensor, sparsity_layout_x: Tensor,
14
+ y: BlksprsTensor, sparsity_layout_y: Tensor,
14
15
  sparsity_layout_output: Tensor,
15
- sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
16
+ sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
16
17
  """Performs matrix multiplication between two block-sparse tensors.
17
18
 
18
19
  The sparsity layout of the output tensor is used to only calculate blocks that will be present in the output.
19
20
 
20
21
  Args:
21
- x (Tensor): A block-sparse tensor in compressed form.
22
- y (Tensor): A block-sparse tensor in compressed form.
22
+ x (BlksprsTensor): A block-sparse tensor in compressed form.
23
+ y (BlksprsTensor): A block-sparse tensor in compressed form.
23
24
  sparsity_layout_x (Tensor): The sparsity layout of the first block-sparse tensor.
24
25
  sparsity_layout_y (Tensor): The sparsity layout of the second block-sparse tensor.
25
26
  sparsity_layout_output (Tensor): The sparsity layout of the output tensor.
@@ -27,7 +28,7 @@ def matmul(x: Tensor, sparsity_layout_x: Tensor,
27
28
  triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
28
29
 
29
30
  Returns:
30
- Tensor: The result of the matrix multiplication as a block-sparse tensor in compressed form.
31
+ BlksprsTensor: The result of the matrix multiplication as a block-sparse tensor in compressed form.
31
32
 
32
33
  """
33
34
  x = x.contiguous()
@@ -61,13 +62,13 @@ def matmul(x: Tensor, sparsity_layout_x: Tensor,
61
62
  sparsity_layout_y, sparsity_reverse_lut_y,
62
63
  sparsity_layout_output, sparsity_lut_o)
63
64
 
64
- return _BlocksparseMatmulSSS.apply(x, y,
65
- sparsity_layout_x, sparsity_reverse_lut_x,
66
- sparsity_layout_y, sparsity_reverse_lut_y,
67
- sparsity_layout_output, sparsity_lut_o,
68
- sparsity_block_size,
69
- n_sparse_blocks,
70
- triton_block_size)
65
+ return BlksprsTensor(_BlocksparseMatmulSSS.apply(x, y,
66
+ sparsity_layout_x, sparsity_reverse_lut_x,
67
+ sparsity_layout_y, sparsity_reverse_lut_y,
68
+ sparsity_layout_output, sparsity_lut_o,
69
+ sparsity_block_size,
70
+ n_sparse_blocks,
71
+ triton_block_size))
71
72
 
72
73
 
73
74
  class _BlocksparseMatmulSSS(torch.autograd.Function):
blksprs/ops/repeat.py CHANGED
@@ -3,14 +3,15 @@ import triton
3
3
  from triton import language as tl
4
4
  from torch import Tensor
5
5
 
6
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
6
7
  from blksprs.utils.tools import get_triton_block_size, stride
7
8
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
8
9
  validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
9
10
 
10
11
 
11
- def repeat(x: Tensor, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
12
+ def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
12
13
  sparsity_block_size: int, sparsity_layout_output: Tensor = None, triton_block_size: int = None) -> (
13
- Tensor, Tensor):
14
+ BlksprsTensor, Tensor):
14
15
  """Repeats a block-spare tensor in compressed form according to the given repeats.
15
16
 
16
17
  Repeats is a 3-tuple of integers, where each integer represents the number of times the tensor should be repeated in
@@ -22,7 +23,7 @@ def repeat(x: Tensor, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
22
23
  them to be sparse.
23
24
 
24
25
  Args:
25
- x (Tensor): A block-sparse tensor in compressed form.
26
+ x (BlksprsTensor): A block-sparse tensor in compressed form.
26
27
  sparsity_layout_x (Tensor): The sparsity layout of the block-sparse tensor.
27
28
  repeats (tuple[int, int, int]): The number of times the tensor should be repeated in the first, second and
28
29
  third dimension respectively.
@@ -31,7 +32,7 @@ def repeat(x: Tensor, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
31
32
  triton_block_size (int): The block size to use for the triton kernel (default ``None``).
32
33
 
33
34
  Returns:
34
- Tensor: A block-sparse tensor in compressed form containing the repeated values.
35
+ BlksprsTensor: A block-sparse tensor in compressed form containing the repeated values.
35
36
  Tensor: The sparsity layout of the resulting output tensor.
36
37
 
37
38
  """
@@ -63,14 +64,14 @@ def repeat(x: Tensor, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
63
64
 
64
65
  validate_contiguous(sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
65
66
 
66
- return _BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
67
- sparsity_block_size, n_sparse_blocks, triton_block_size), sparsity_layout_o
67
+ return BlksprsTensor(_BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
68
+ sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_o
68
69
 
69
70
 
70
- def repeat_interleave(x: Tensor, sparsity_layout_x: Tensor, repeats: int,
71
+ def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
71
72
  sparsity_block_size: int, sparsity_layout_output: Tensor = None,
72
73
  triton_block_size: int = None) -> (
73
- Tensor, Tensor):
74
+ BlksprsTensor, Tensor):
74
75
  """Repeats and interleaves the block-sparse tensor in compressed form.
75
76
 
76
77
  Repeats each matrix contained in the tensors by ``repeats`` amount and places them consecutively in the output
@@ -81,7 +82,7 @@ def repeat_interleave(x: Tensor, sparsity_layout_x: Tensor, repeats: int,
81
82
  non-sparse blocks will be filled.
82
83
 
83
84
  Args:
84
- x (Tensor): A block-sparse tensor in compressed form.
85
+ x (BlksprsTensor): A block-sparse tensor in compressed form.
85
86
  sparsity_layout_x (Tensor): The sparsity layout of the block-sparse tensor.
86
87
  repeats (int): The number of times to repeat the matrices.
87
88
  sparsity_block_size (int): The size of the sparsity blocks.
@@ -89,7 +90,7 @@ def repeat_interleave(x: Tensor, sparsity_layout_x: Tensor, repeats: int,
89
90
  triton_block_size (int): The block size to use for the triton kernel (default ``None``).
90
91
 
91
92
  Returns:
92
- Tensor: A block-sparse tensor in compressed form containing the repeated and interleaved matrices.
93
+ BlksprsTensor: A block-sparse tensor in compressed form containing the repeated and interleaved matrices.
93
94
  Tensor: The sparsity layout of the resulting output tensor.
94
95
 
95
96
  """
@@ -121,8 +122,8 @@ def repeat_interleave(x: Tensor, sparsity_layout_x: Tensor, repeats: int,
121
122
 
122
123
  validate_contiguous(sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
123
124
 
124
- return _BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
125
- sparsity_block_size, n_sparse_blocks, triton_block_size), sparsity_layout_o
125
+ return BlksprsTensor(_BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
126
+ sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_o
126
127
 
127
128
 
128
129
  class _BlocksparseRepeat(torch.autograd.Function):
blksprs/ops/softmax.py CHANGED
@@ -5,25 +5,26 @@ from triton import language as tl
5
5
 
6
6
  from blksprs.misc.exp import exp
7
7
  from blksprs.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
8
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
8
9
  from blksprs.utils.tools import get_triton_block_size, stride
9
10
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
10
11
  validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
11
12
 
12
13
 
13
- def softmax(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
14
+ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
14
15
  """Computes the softmax of a block-sparse tensor in compressed form.
15
16
 
16
17
  Note:
17
18
  Sparse blocks are not considered for the calculation of the softmax, i.e., all values are assumed to be ``-inf``.
18
19
 
19
20
  Args:
20
- x (Tensor): A block-sparse tensor in compressed form.
21
+ x (BlksprsTensor): A block-sparse tensor in compressed form.
21
22
  sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
22
23
  sparsity_block_size (int): The size of the sparsity blocks.
23
24
  triton_block_size (int): The block size to use for the triton kernel (default ``None``).
24
25
 
25
26
  Returns:
26
- Tensor: The result of the softmax operation as a block-sparse tensor in compressed form.
27
+ BlksprsTensor: The result of the softmax operation as a block-sparse tensor in compressed form.
27
28
 
28
29
  """
29
30
  x = x.contiguous()
@@ -45,10 +46,10 @@ def softmax(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton
45
46
 
46
47
  validate_contiguous(sparsity_layout, sparsity_lut, sparsity_reverse_lut_rws)
47
48
 
48
- return _BlocksparseSoftmax.apply(x, sparsity_layout,
49
+ return BlksprsTensor(_BlocksparseSoftmax.apply(x, sparsity_layout,
49
50
  sparsity_lut,
50
51
  sparsity_reverse_lut_rws,
51
- sparsity_block_size, triton_block_size)
52
+ sparsity_block_size, triton_block_size))
52
53
 
53
54
 
54
55
  class _BlocksparseSoftmax(torch.autograd.Function):
blksprs/ops/transpose.py CHANGED
@@ -3,26 +3,27 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
6
7
  from blksprs.utils.tools import get_triton_block_size, stride
7
8
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
8
9
  validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
9
10
 
10
11
 
11
- def transpose(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> (
12
- Tensor, Tensor):
12
+ def transpose(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> (
13
+ BlksprsTensor, Tensor):
13
14
  """Transposes a block-sparse tensor in compressed form.
14
15
 
15
16
  Note:
16
17
  Returns the transposed tensor and the sparsity layout of the transposed tensor.
17
18
 
18
19
  Args:
19
- x (Tensor): A block-sparse tensor in compressed form.
20
+ x (BlksprsTensor): A block-sparse tensor in compressed form.
20
21
  sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
21
22
  sparsity_block_size (int): The size of the sparsity blocks.
22
23
  triton_block_size (int): The block size to use for the triton kernel (default ``None``).
23
24
 
24
25
  Returns:
25
- Tensor: The transposed block-sparse tensor in compressed form.
26
+ BlksprsTensor: The transposed block-sparse tensor in compressed form.
26
27
  Tensor: The sparsity layout of the transposed tensor.
27
28
 
28
29
  """
@@ -49,8 +50,8 @@ def transpose(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, trit
49
50
 
50
51
  validate_contiguous(sparsity_layout_t, sparsity_lut, sparsity_reverse_lut)
51
52
 
52
- return _BlocksparseTranspose.apply(x, sparsity_layout_t, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
53
- n_sparse_blocks, triton_block_size), sparsity_layout_t
53
+ return BlksprsTensor(_BlocksparseTranspose.apply(x, sparsity_layout_t, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
54
+ n_sparse_blocks, triton_block_size)), sparsity_layout_t
54
55
 
55
56
 
56
57
  class _BlocksparseTranspose(torch.autograd.Function):
@@ -0,0 +1,8 @@
1
+ from torch import Tensor
2
+
3
+
4
+ class BlksprsTensor(Tensor):
5
+ """A wrapper class representing a block-sparse tensor in compressed form.
6
+ """
7
+
8
+ pass
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: blksprs
3
- Version: 1.8
3
+ Version: 1.8.2
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -0,0 +1,22 @@
1
+ blksprs/__init__.py,sha256=np0msosWMaZNVVfuFGt8rE6HZURyIald391dKAs1dSQ,1093
2
+ blksprs/experimental/distribution_mdi.py,sha256=HaRUu6LTWATzjuHWgddIUE-0fgY-O87STpJO4JY7k_8,20357
3
+ blksprs/layouting/distribution_layout.py,sha256=wmj1SwWyY_fhbvMmh6AXrR77LoSp6xLwUWCCyO9i5lk,4239
4
+ blksprs/layouting/sparsity_layout.py,sha256=-sScIn4hhG35j9BXytrojEzp8jnFkMargJjtivPV1fc,9755
5
+ blksprs/misc/broadcast_ops.py,sha256=cPtRJa3pkZfY1QG51CJ-zDn4SK-CRpX5LEXoKGGMvRU,5418
6
+ blksprs/misc/exp.py,sha256=FnSFosBfJHuiEbD0MD-i4axLghRn4a0f8KvHXrKBB6M,3802
7
+ blksprs/misc/partitioning.py,sha256=K0ExR2a3W62d_9xxCJzsdJDLgtbxTI6P8loOOBdhPzE,7674
8
+ blksprs/misc/row_wise.py,sha256=SvJuNww-_QoVKTyTjMvjmzHlBuUlTKamkuq_rKzwAqs,17081
9
+ blksprs/ops/conversion.py,sha256=ol-iV45wDzp9G1dJEkY53EdrvnmHzcl7QQmPJ-xqQTs,22410
10
+ blksprs/ops/distribution.py,sha256=fXZV6UegCVpIwzh-A825OSYClHWu5k0UMYdO2UGDUpM,17067
11
+ blksprs/ops/matmul.py,sha256=yh2ZnO0ZltT1AgadiFP0vX28YJ4n74xO-I_5vFUmOmA,11452
12
+ blksprs/ops/repeat.py,sha256=IvSIRbuyFn0b57LObymLgup0LqlWQ3ndIw-QuiYQcaU,14564
13
+ blksprs/ops/softmax.py,sha256=D9wITz3KB24QXGGjgn_RLQ0Iiq_SjX0bTbUyv9479uU,12094
14
+ blksprs/ops/transpose.py,sha256=jxzFFffrj4S_9tiCrwwUMdz6EA98o1dziWXjlqb64a4,6859
15
+ blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
16
+ blksprs/utils/blksprs_tensor.py,sha256=VjplBgDhnf9sxf-1R5feA0xp5FDCDdaeZmCeoIRdCnc,151
17
+ blksprs/utils/tools.py,sha256=S3836Zuc-BMigv-5mLTjRznCzuaF6oYW-Ir9zzUnr3o,655
18
+ blksprs/utils/validation.py,sha256=WzihRPibXYzss3PMkhDt5_d3Q3NHA_d1TzTz3CoGPGg,4136
19
+ blksprs-1.8.2.dist-info/METADATA,sha256=Zoc860mYmFss7v5ChNoi9407v1qDo_ecc6JUWCvaesg,8009
20
+ blksprs-1.8.2.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
21
+ blksprs-1.8.2.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
22
+ blksprs-1.8.2.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.2.0)
2
+ Generator: setuptools (75.3.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,21 +0,0 @@
1
- blksprs/__init__.py,sha256=qDqoB-X5vo5_3PlrN54sp59XR5hg6EanIsADS67QnH0,1058
2
- blksprs/experimental/distribution_mdi.py,sha256=jE_SbB0SMGPcCoeM0699cceWAikBkBpGM_8Oo5A7Ets,20083
3
- blksprs/layouting/distribution_layout.py,sha256=nCR3fCF6eNpi97DI6MMLF2hq_9Lwyo6_aUCIZiJfiX4,4170
4
- blksprs/layouting/sparsity_layout.py,sha256=JNwbJ4L-418zCiCbt-vLfZ0xU7ReP0zr_tLHs_dytRA,9686
5
- blksprs/misc/broadcast_ops.py,sha256=0RLnLMYV7GAPI2YL8RotcxjIUSBZKGxdVcsGaJFeL_I,5327
6
- blksprs/misc/exp.py,sha256=cdF0s93Q9iucIXuEE3howsB0N6D60xgvem7C-a-yiGI,3704
7
- blksprs/misc/partitioning.py,sha256=nBRZzfi3XYAhDLEBzYflQkvGa3MIZ-qNeIlrZ16k44g,7533
8
- blksprs/misc/row_wise.py,sha256=0vDJA8uCocmebSIPIbFeND5_PQIE10pUj3DBOQXlTvE,16888
9
- blksprs/ops/conversion.py,sha256=9xVdCrj38m1cMh43LQs-GrXZ5pNRjhQyKx6paaw3C6A,21898
10
- blksprs/ops/distribution.py,sha256=V3TK5SlNT_JdGHNaDNl-U4U5vwAYsgkAOg4eTmYxbuA,16877
11
- blksprs/ops/matmul.py,sha256=uqVe6Dz2aaCbCglM1uS2eRHVKh7PQcuecaIBWFubPEw,11256
12
- blksprs/ops/repeat.py,sha256=OSsa2rj6BHL3Kedfu3wr0D82mn4HmbJ1l7XEmT-6ehg,14423
13
- blksprs/ops/softmax.py,sha256=5nAgeT68nucgOugjtCy1aBIMa7Kyk1KNN-j8fgmeVuk,11996
14
- blksprs/ops/transpose.py,sha256=67pDdCEb7r-Xifupl82fBKAYsxKcCUDy--cPPfduRvU,6761
15
- blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
16
- blksprs/utils/tools.py,sha256=S3836Zuc-BMigv-5mLTjRznCzuaF6oYW-Ir9zzUnr3o,655
17
- blksprs/utils/validation.py,sha256=WzihRPibXYzss3PMkhDt5_d3Q3NHA_d1TzTz3CoGPGg,4136
18
- blksprs-1.8.dist-info/METADATA,sha256=koey4w8ynY84Z0dM5u9y_P831rtR0w-Z-dBcje4O6ko,8007
19
- blksprs-1.8.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
20
- blksprs-1.8.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
21
- blksprs-1.8.dist-info/RECORD,,