blksprs 1.5__tar.gz → 1.6.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. {blksprs-1.5 → blksprs-1.6.1}/PKG-INFO +3 -2
  2. {blksprs-1.5 → blksprs-1.6.1}/README.md +2 -1
  3. {blksprs-1.5 → blksprs-1.6.1}/blksprs/__init__.py +3 -2
  4. {blksprs-1.5 → blksprs-1.6.1}/blksprs/layouting/sparsity_layout.py +36 -0
  5. {blksprs-1.5 → blksprs-1.6.1}/blksprs/misc/repeat_interleave.py +1 -1
  6. blksprs-1.6.1/blksprs/ops/partitioning.py +244 -0
  7. {blksprs-1.5 → blksprs-1.6.1}/blksprs/ops/transpose.py +6 -7
  8. {blksprs-1.5 → blksprs-1.6.1}/blksprs/utils/validation.py +2 -0
  9. {blksprs-1.5 → blksprs-1.6.1}/blksprs.egg-info/PKG-INFO +3 -2
  10. {blksprs-1.5 → blksprs-1.6.1}/blksprs.egg-info/SOURCES.txt +1 -0
  11. {blksprs-1.5 → blksprs-1.6.1}/pyproject.toml +1 -1
  12. {blksprs-1.5 → blksprs-1.6.1}/blksprs/experimental/distribution_mdi.py +0 -0
  13. {blksprs-1.5 → blksprs-1.6.1}/blksprs/layouting/distribution_layout.py +0 -0
  14. {blksprs-1.5 → blksprs-1.6.1}/blksprs/misc/broadcast_ops.py +0 -0
  15. {blksprs-1.5 → blksprs-1.6.1}/blksprs/misc/row_wise.py +0 -0
  16. {blksprs-1.5 → blksprs-1.6.1}/blksprs/ops/conversion.py +0 -0
  17. {blksprs-1.5 → blksprs-1.6.1}/blksprs/ops/distribution.py +0 -0
  18. {blksprs-1.5 → blksprs-1.6.1}/blksprs/ops/exp.py +0 -0
  19. {blksprs-1.5 → blksprs-1.6.1}/blksprs/ops/matmul.py +0 -0
  20. {blksprs-1.5 → blksprs-1.6.1}/blksprs/ops/softmax.py +0 -0
  21. {blksprs-1.5 → blksprs-1.6.1}/blksprs/utils/benchmarking.py +0 -0
  22. {blksprs-1.5 → blksprs-1.6.1}/blksprs/utils/tools.py +0 -0
  23. {blksprs-1.5 → blksprs-1.6.1}/blksprs.egg-info/dependency_links.txt +0 -0
  24. {blksprs-1.5 → blksprs-1.6.1}/blksprs.egg-info/requires.txt +0 -0
  25. {blksprs-1.5 → blksprs-1.6.1}/blksprs.egg-info/top_level.txt +0 -0
  26. {blksprs-1.5 → blksprs-1.6.1}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: blksprs
3
- Version: 1.5
3
+ Version: 1.6.1
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -31,9 +31,10 @@ Currently supported operations (includes gradient calculation):
31
31
  - Sparse matrix multiplication (_supports any combination of sparse and dense matrices due to support
32
32
  for `sparse = sparse @ sparse` matmul_)
33
33
  - Softmax
34
- - Transposition
34
+ - Transpose
35
35
  - Gather
36
36
  - Scatter (_supports either no reduction or summation, gradients are only available for summation_)
37
+ - Splitting and merging of matrices along the last dimension
37
38
  - Conversion to and from sparse form
38
39
  - Conversion to different sparsity layouts and different sparsity block sizes
39
40
 
@@ -12,9 +12,10 @@ Currently supported operations (includes gradient calculation):
12
12
  - Sparse matrix multiplication (_supports any combination of sparse and dense matrices due to support
13
13
  for `sparse = sparse @ sparse` matmul_)
14
14
  - Softmax
15
- - Transposition
15
+ - Transpose
16
16
  - Gather
17
17
  - Scatter (_supports either no reduction or summation, gradients are only available for summation_)
18
+ - Splitting and merging of matrices along the last dimension
18
19
  - Conversion to and from sparse form
19
20
  - Conversion to different sparsity layouts and different sparsity block sizes
20
21
 
@@ -4,10 +4,11 @@ from blksprs.ops.exp import exp
4
4
  from blksprs.ops.matmul import matmul
5
5
  from blksprs.ops.softmax import softmax
6
6
  from blksprs.ops.transpose import transpose
7
+ from blksprs.ops.partitioning import split, merge
7
8
 
8
9
  class layout:
9
10
  from blksprs.layouting.distribution_layout import build_distribution_layout
10
- from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption
11
+ from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption, build_sparsity_layout_matmul
11
12
 
12
13
  class misc:
13
14
  from blksprs.misc.broadcast_ops import broadcast_add, broadcast_sub
@@ -18,4 +19,4 @@ class util:
18
19
  from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse, disable_validation
19
20
 
20
21
  class experimental:
21
- from blksprs.experimental.distribution_mdi import gather_mdi
22
+ from blksprs.experimental.distribution_mdi import gather_mdi, scatter_reduce_mdi
@@ -188,3 +188,39 @@ def kernel_sparsity_layout_adaption(x,
188
188
  // sparsity_block_size_to) * o_c_s))
189
189
  blk_o_msk = (blk_o_idx < o_b * o_b_s)
190
190
  tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
191
+
192
+
193
+ def build_sparsity_layout_matmul(sparsity_layout_x: Tensor, sparsity_layout_y: Tensor):
194
+ """Builds the precise sparsity layout of the result of a matrix multiplication between the two input tensors.
195
+
196
+ Args:
197
+ sparsity_layout_x (Tensor): The sparsity layout of the first block-sparse tensor.
198
+ sparsity_layout_y (Tensor): The sparsity layout of the second block-sparse tensor.
199
+
200
+ Returns:
201
+ Tensor: The precise sparsity layout of the result of a matrix multiplication between the two input tensors.
202
+
203
+ """
204
+ return torch.matmul(sparsity_layout_x.to(torch.float), sparsity_layout_y.to(torch.float)).to(torch.bool)
205
+
206
+
207
+ def build_sparsity_layout_matmul_fast(sparsity_layout_x: Tensor, sparsity_layout_y: Tensor):
208
+ """Builds the approximate sparsity layout of the result of a matrix multiplication between the two input tensors.
209
+
210
+ Note:
211
+ This function is faster than the ``build_sparsity_layout_matmul`` function due to the fact that it only checks
212
+ whether at least one of the blocks in either of the vectors participating in the matmul is non-sparse. The
213
+ resulting sparsity layout may thus overestimate the actual sparsity of the result.
214
+
215
+ Args:
216
+ sparsity_layout_x (Tensor): The sparsity layout of the first block-sparse tensor.
217
+ sparsity_layout_y (Tensor): The sparsity layout of the second block-sparse tensor.
218
+
219
+ Returns:
220
+ Tensor: The approximate sparsity layout of the result of a matrix multiplication between the two input tensors.
221
+
222
+ """
223
+ sparsity_layout_x_slice = torch.max(sparsity_layout_x, dim=-1).values.unsqueeze(-1)
224
+ sparsity_layout_y_slice = torch.max(sparsity_layout_y, dim=-2).values.unsqueeze(1)
225
+
226
+ return torch.logical_or(sparsity_layout_x_slice, sparsity_layout_y_slice)
@@ -35,7 +35,7 @@ def repeat_interleave(x: Tensor, sparsity_layout: Tensor, repeats: int,
35
35
  validate_sparsity_block_size(sparsity_block_size, x)
36
36
  validate_triton_block_size(triton_block_size, sparsity_block_size)
37
37
 
38
- sparsity_layout_output = torch.repeat_interleave(sparsity_layout, 3, dim=0).contiguous()
38
+ sparsity_layout_output = torch.repeat_interleave(sparsity_layout, repeats, dim=0).contiguous()
39
39
 
40
40
  sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
41
41
 
@@ -0,0 +1,244 @@
1
+ import torch
2
+ import triton
3
+ from sympy.utilities.iterables import partitions
4
+ from torch import Tensor
5
+ from triton import language as tl
6
+
7
+ from blksprs.utils.tools import get_triton_block_size
8
+
9
+ from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
10
+ validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
11
+
12
+
13
+ def split(x: Tensor, sparsity_layout: Tensor, partitions: int,
14
+ sparsity_block_size: int, triton_block_size: int = None) -> (Tensor, Tensor):
15
+ """Splits a block-sparse tensor in compressed form along the last dimension into partitions.
16
+
17
+ Args:
18
+ x (Tensor): A block-sparse tensor in compressed form.
19
+ sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
20
+ partitions (int): The number of partitions to split the block-sparse tensor into.
21
+ sparsity_block_size (int): The size of the sparsity blocks.
22
+ triton_block_size (int): The block size to use for the triton kernel (default ``None``).
23
+
24
+ Returns:
25
+ Tensor: The block-sparse tensor split into partitions in compressed form.
26
+ Tensor: The sparsity layout of the output tensor.
27
+
28
+ """
29
+ x = x.contiguous()
30
+
31
+ validate_dimensions(x)
32
+ validate_contiguous(x)
33
+ validate_device(x)
34
+ validate_sparsity(sparsity_block_size, (x, sparsity_layout))
35
+ validate_sparsity_block_size(sparsity_block_size, x)
36
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
37
+
38
+ sparsity_layout_output = (sparsity_layout
39
+ .reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
40
+ sparsity_layout.size(2) // partitions)
41
+ .permute(0, 2, 1, 3)
42
+ .reshape(sparsity_layout.size(0) * partitions, sparsity_layout.size(1),
43
+ sparsity_layout.size(2) // partitions).contiguous())
44
+
45
+ sparsity_lut = torch.nonzero(sparsity_layout_output).contiguous()
46
+
47
+ sparsity_layout_flat = sparsity_layout.reshape(-1)
48
+ sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
49
+ (sparsity_layout_flat == 1) -
50
+ (1 * (sparsity_layout_flat == 0)))
51
+ .reshape(sparsity_layout.size())
52
+ .reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
53
+ sparsity_layout.size(2) // partitions)
54
+ .permute(0, 2, 1, 3).reshape(-1).contiguous())
55
+
56
+ n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
57
+
58
+ validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
59
+
60
+ return _BlocksparseSplit.apply(x, sparsity_layout_output, sparsity_lut, sparsity_reverse_lut, partitions,
61
+ sparsity_block_size, n_sparse_blocks, triton_block_size), sparsity_layout_output
62
+
63
+
64
+ class _BlocksparseSplit(torch.autograd.Function):
65
+
66
+ @staticmethod
67
+ def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
68
+ num_partitions: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
69
+ ctx.num_partitions = num_partitions
70
+
71
+ return forward_reorder(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
72
+ n_sparse_blocks, triton_block_size)
73
+
74
+ @staticmethod
75
+ def backward(ctx, grad_output):
76
+ sparsity_layout = ctx.saved_tensors[0]
77
+ num_partitions = ctx.num_partitions
78
+ sparsity_block_size = ctx.sparsity_block_size
79
+ triton_block_size = ctx.triton_block_size
80
+
81
+ return merge(grad_output, sparsity_layout, num_partitions,
82
+ sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None
83
+
84
+
85
+ def merge(x: Tensor, sparsity_layout: Tensor, partitions: int,
86
+ sparsity_block_size: int, triton_block_size: int = None) -> (Tensor, Tensor):
87
+ """Merges the specified partitions of a block-sparse tensor in compressed form along the last dimension.
88
+
89
+ Args:
90
+ x (Tensor): A block-sparse tensor in compressed form.
91
+ sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
92
+ partitions (int): The number of partitions to be merged.
93
+ sparsity_block_size (int): The size of the sparsity blocks.
94
+ triton_block_size (int): The block size to use for the triton kernel (default ``None``).
95
+
96
+ Returns:
97
+ Tensor: The merged block-sparse tensor in compressed form.
98
+ Tensor: The sparsity layout of the output tensor.
99
+
100
+ """
101
+ x = x.contiguous()
102
+
103
+ validate_dimensions(x)
104
+ validate_contiguous(x)
105
+ validate_device(x)
106
+ validate_sparsity(sparsity_block_size, (x, sparsity_layout))
107
+ validate_sparsity_block_size(sparsity_block_size, x)
108
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
109
+
110
+ sparsity_layout_output = (sparsity_layout.reshape(sparsity_layout.size(0) // partitions, partitions,
111
+ sparsity_layout.size(1), sparsity_layout.size(2))
112
+ .permute(0, 2, 1, 3)
113
+ .reshape(sparsity_layout.size(0) // partitions,
114
+ sparsity_layout.size(1), sparsity_layout.size(2) * partitions).contiguous())
115
+
116
+ sparsity_lut = torch.nonzero(sparsity_layout_output).contiguous()
117
+
118
+ sparsity_layout_flat = sparsity_layout.reshape(-1)
119
+ sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
120
+ (sparsity_layout_flat == 1) -
121
+ (1 * (sparsity_layout_flat == 0)))
122
+ .reshape(sparsity_layout.size(0) // partitions, partitions,
123
+ sparsity_layout.size(1), sparsity_layout.size(2))
124
+ .permute(0, 2, 1, 3)
125
+ .reshape(sparsity_layout.size(0) // partitions,
126
+ sparsity_layout.size(1), sparsity_layout.size(2) * partitions)
127
+ .reshape(-1).contiguous())
128
+
129
+ n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
130
+
131
+ validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
132
+
133
+ return _BlocksparseMerge.apply(x, sparsity_layout_output, sparsity_lut, sparsity_reverse_lut, partitions,
134
+ sparsity_block_size, n_sparse_blocks, triton_block_size), sparsity_layout_output
135
+
136
+
137
+ class _BlocksparseMerge(torch.autograd.Function):
138
+
139
+ @staticmethod
140
+ def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
141
+ num_partitions: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
142
+ ctx.num_partitions = num_partitions
143
+
144
+ return forward_reorder(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
145
+ n_sparse_blocks, triton_block_size)
146
+
147
+ @staticmethod
148
+ def backward(ctx, grad_output):
149
+ sparsity_layout = ctx.saved_tensors[0]
150
+ num_partitions = ctx.num_partitions
151
+ sparsity_block_size = ctx.sparsity_block_size
152
+ triton_block_size = ctx.triton_block_size
153
+
154
+ return split(grad_output, sparsity_layout, num_partitions,
155
+ sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None
156
+
157
+
158
+ def forward_reorder(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
159
+ sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
160
+ output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
161
+ dtype=x.dtype, device=x.device)
162
+
163
+ x_b, x_r, x_c = x.size()
164
+ x_b_s, x_r_s, x_c_s = x.stride()
165
+ s_l_b, s_l_r, s_l_c = sparsity_layout_o.size()
166
+ s_l_b_s, s_l_r_s, s_l_c_s = sparsity_layout_o.stride()
167
+ s_lut_r, s_lut_c = sparsity_lut.shape
168
+ s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
169
+ o_b, o_r, o_c = output.size()
170
+ o_b_s, o_r_s, o_c_s = output.stride()
171
+
172
+ if triton_block_size is None:
173
+ triton_block_size = get_triton_block_size(sparsity_block_size)
174
+
175
+ triton_grid = lambda meta: [o_b,
176
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
177
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
178
+
179
+ (kernel_blocksparse_reorder[triton_grid]
180
+ (x,
181
+ x_b, x_b_s, x_r_s, x_c_s,
182
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
183
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
184
+ sparsity_reverse_lut,
185
+ output,
186
+ o_b, o_b_s,
187
+ triton_block_size))
188
+
189
+ # Save for backward pass
190
+ ctx.save_for_backward(sparsity_layout_o)
191
+ ctx.sparsity_block_size = sparsity_block_size
192
+ ctx.triton_block_size = triton_block_size
193
+
194
+ return output
195
+
196
+
197
+ @triton.jit
198
+ def kernel_blocksparse_reorder(x,
199
+ x_b, x_b_s, x_r_s, x_c_s,
200
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
201
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
202
+ r_lut,
203
+ o,
204
+ o_b, o_b_s,
205
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
206
+ # Get triton block indices
207
+ pid_blk = tl.program_id(axis=0)
208
+ pid_row = tl.program_id(axis=1)
209
+ pid_col = tl.program_id(axis=2)
210
+
211
+ # Get sparsity index of current output block consisting of its batch, row, and column index
212
+ spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
213
+ spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
214
+ spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
215
+
216
+ spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
217
+ spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
218
+ spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
219
+
220
+ spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
221
+ spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
222
+ spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
223
+
224
+ # Get reverse sparsity index
225
+ rev_idx_spa_idx = (spa_bat * s_l_b_s +
226
+ spa_row * s_l_r_s +
227
+ spa_col * s_l_c_s)
228
+ rev_idx_spa_msk = (rev_idx_spa_idx < s_l_b * s_l_b_s)
229
+ rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
230
+
231
+ if rev_idx_spa == -1:
232
+ assert False, "Invalid sparsity block"
233
+
234
+ blk_x_idx = (rev_idx_spa * x_b_s +
235
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
236
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
237
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
238
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
239
+
240
+ blk_o_idx = (pid_blk * o_b_s +
241
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
242
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
243
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
244
+ tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
@@ -56,16 +56,16 @@ def transpose(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, trit
56
56
  class _BlocksparseTranspose(torch.autograd.Function):
57
57
 
58
58
  @staticmethod
59
- def forward(ctx, x: Tensor,
60
- sparsity_layout: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor, sparsity_block_size: int,
61
- n_sparse_blocks: int, triton_block_size: int) -> (Tensor, Tensor):
59
+ def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
60
+ sparsity_block_size: int,
61
+ n_sparse_blocks: int, triton_block_size: int) -> Tensor:
62
62
  output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
63
63
  dtype=x.dtype, device=x.device)
64
64
 
65
65
  x_b, x_r, x_c = x.size()
66
66
  x_b_s, x_r_s, x_c_s = x.stride()
67
- s_l_b, s_l_r, s_l_c = sparsity_layout.size()
68
- s_l_b_s, s_l_r_s, s_l_c_s = sparsity_layout.stride()
67
+ s_l_b, s_l_r, s_l_c = sparsity_layout_o.size()
68
+ s_l_b_s, s_l_r_s, s_l_c_s = sparsity_layout_o.stride()
69
69
  s_lut_r, s_lut_c = sparsity_lut.shape
70
70
  s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
71
71
  o_b, o_r, o_c = output.size()
@@ -89,8 +89,7 @@ class _BlocksparseTranspose(torch.autograd.Function):
89
89
  triton_block_size))
90
90
 
91
91
  # Save for backward pass
92
- ctx.save_for_backward(sparsity_layout)
93
- ctx.sparsity_layout = sparsity_layout
92
+ ctx.save_for_backward(sparsity_layout_o)
94
93
  ctx.sparsity_block_size = sparsity_block_size
95
94
  ctx.triton_block_size = triton_block_size
96
95
 
@@ -63,6 +63,8 @@ def validate_sparsity(sparsity_block_size: int, *tensor_sparsity_layout_tuples:
63
63
  for (tensor, sparsity_layout) in tensor_sparsity_layout_tuples:
64
64
  _validate_sparsity_layout_values(sparsity_layout)
65
65
 
66
+ if not sparsity_layout.dim() == 3:
67
+ raise ValueError("Sparsity layout must have exactly 3 dimensions")
66
68
  if not (tensor.size(-1) == tensor.size(-2) == sparsity_block_size):
67
69
  raise ValueError("Blocks not conforming to sparsity block size")
68
70
  if not tensor.size(0) == torch.sum(sparsity_layout.reshape(-1)):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: blksprs
3
- Version: 1.5
3
+ Version: 1.6.1
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -31,9 +31,10 @@ Currently supported operations (includes gradient calculation):
31
31
  - Sparse matrix multiplication (_supports any combination of sparse and dense matrices due to support
32
32
  for `sparse = sparse @ sparse` matmul_)
33
33
  - Softmax
34
- - Transposition
34
+ - Transpose
35
35
  - Gather
36
36
  - Scatter (_supports either no reduction or summation, gradients are only available for summation_)
37
+ - Splitting and merging of matrices along the last dimension
37
38
  - Conversion to and from sparse form
38
39
  - Conversion to different sparsity layouts and different sparsity block sizes
39
40
 
@@ -16,6 +16,7 @@ blksprs/ops/conversion.py
16
16
  blksprs/ops/distribution.py
17
17
  blksprs/ops/exp.py
18
18
  blksprs/ops/matmul.py
19
+ blksprs/ops/partitioning.py
19
20
  blksprs/ops/softmax.py
20
21
  blksprs/ops/transpose.py
21
22
  blksprs/utils/benchmarking.py
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "blksprs"
3
- version = "1.5"
3
+ version = "1.6.1"
4
4
  authors = [{ name = "Felix Schön", email = "schoen@kr.tuwien.ac.at" }]
5
5
  description = "A lightweight library for operations on blocksparse matrices in PyTorch."
6
6
  readme = "README.md"
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes