blksprs 1.4.2__tar.gz → 1.6.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. {blksprs-1.4.2 → blksprs-1.6.1}/PKG-INFO +3 -2
  2. {blksprs-1.4.2 → blksprs-1.6.1}/README.md +2 -1
  3. {blksprs-1.4.2 → blksprs-1.6.1}/blksprs/__init__.py +6 -2
  4. blksprs-1.6.1/blksprs/experimental/distribution_mdi.py +438 -0
  5. {blksprs-1.4.2 → blksprs-1.6.1}/blksprs/layouting/distribution_layout.py +4 -17
  6. {blksprs-1.4.2 → blksprs-1.6.1}/blksprs/layouting/sparsity_layout.py +36 -0
  7. {blksprs-1.4.2 → blksprs-1.6.1}/blksprs/misc/repeat_interleave.py +1 -1
  8. blksprs-1.6.1/blksprs/ops/partitioning.py +244 -0
  9. {blksprs-1.4.2 → blksprs-1.6.1}/blksprs/ops/transpose.py +6 -7
  10. {blksprs-1.4.2 → blksprs-1.6.1}/blksprs/utils/validation.py +2 -0
  11. {blksprs-1.4.2 → blksprs-1.6.1}/blksprs.egg-info/PKG-INFO +3 -2
  12. {blksprs-1.4.2 → blksprs-1.6.1}/blksprs.egg-info/SOURCES.txt +2 -0
  13. {blksprs-1.4.2 → blksprs-1.6.1}/pyproject.toml +1 -1
  14. {blksprs-1.4.2 → blksprs-1.6.1}/blksprs/misc/broadcast_ops.py +0 -0
  15. {blksprs-1.4.2 → blksprs-1.6.1}/blksprs/misc/row_wise.py +0 -0
  16. {blksprs-1.4.2 → blksprs-1.6.1}/blksprs/ops/conversion.py +0 -0
  17. {blksprs-1.4.2 → blksprs-1.6.1}/blksprs/ops/distribution.py +0 -0
  18. {blksprs-1.4.2 → blksprs-1.6.1}/blksprs/ops/exp.py +0 -0
  19. {blksprs-1.4.2 → blksprs-1.6.1}/blksprs/ops/matmul.py +0 -0
  20. {blksprs-1.4.2 → blksprs-1.6.1}/blksprs/ops/softmax.py +0 -0
  21. {blksprs-1.4.2 → blksprs-1.6.1}/blksprs/utils/benchmarking.py +0 -0
  22. {blksprs-1.4.2 → blksprs-1.6.1}/blksprs/utils/tools.py +0 -0
  23. {blksprs-1.4.2 → blksprs-1.6.1}/blksprs.egg-info/dependency_links.txt +0 -0
  24. {blksprs-1.4.2 → blksprs-1.6.1}/blksprs.egg-info/requires.txt +0 -0
  25. {blksprs-1.4.2 → blksprs-1.6.1}/blksprs.egg-info/top_level.txt +0 -0
  26. {blksprs-1.4.2 → blksprs-1.6.1}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: blksprs
3
- Version: 1.4.2
3
+ Version: 1.6.1
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -31,9 +31,10 @@ Currently supported operations (includes gradient calculation):
31
31
  - Sparse matrix multiplication (_supports any combination of sparse and dense matrices due to support
32
32
  for `sparse = sparse @ sparse` matmul_)
33
33
  - Softmax
34
- - Transposition
34
+ - Transpose
35
35
  - Gather
36
36
  - Scatter (_supports either no reduction or summation, gradients are only available for summation_)
37
+ - Splitting and merging of matrices along the last dimension
37
38
  - Conversion to and from sparse form
38
39
  - Conversion to different sparsity layouts and different sparsity block sizes
39
40
 
@@ -12,9 +12,10 @@ Currently supported operations (includes gradient calculation):
12
12
  - Sparse matrix multiplication (_supports any combination of sparse and dense matrices due to support
13
13
  for `sparse = sparse @ sparse` matmul_)
14
14
  - Softmax
15
- - Transposition
15
+ - Transpose
16
16
  - Gather
17
17
  - Scatter (_supports either no reduction or summation, gradients are only available for summation_)
18
+ - Splitting and merging of matrices along the last dimension
18
19
  - Conversion to and from sparse form
19
20
  - Conversion to different sparsity layouts and different sparsity block sizes
20
21
 
@@ -4,10 +4,11 @@ from blksprs.ops.exp import exp
4
4
  from blksprs.ops.matmul import matmul
5
5
  from blksprs.ops.softmax import softmax
6
6
  from blksprs.ops.transpose import transpose
7
+ from blksprs.ops.partitioning import split, merge
7
8
 
8
9
  class layout:
9
10
  from blksprs.layouting.distribution_layout import build_distribution_layout
10
- from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption
11
+ from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption, build_sparsity_layout_matmul
11
12
 
12
13
  class misc:
13
14
  from blksprs.misc.broadcast_ops import broadcast_add, broadcast_sub
@@ -15,4 +16,7 @@ class misc:
15
16
  from blksprs.misc.row_wise import row_wise_sum, row_wise_max, row_wise_add, row_wise_sub
16
17
 
17
18
  class util:
18
- from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse, disable_validation
19
+ from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse, disable_validation
20
+
21
+ class experimental:
22
+ from blksprs.experimental.distribution_mdi import gather_mdi, scatter_reduce_mdi
@@ -0,0 +1,438 @@
1
+ import torch
2
+ import triton
3
+ from torch import Tensor
4
+ from triton import language as tl
5
+
6
+ from blksprs.utils.tools import get_triton_block_size
7
+ from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
8
+ validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
9
+
10
+
11
+ def gather_mdi(src: Tensor, sparsity_layout_src: Tensor,
12
+ idx_bat: Tensor,
13
+ idx_row: Tensor,
14
+ idx_col: Tensor,
15
+ sparsity_layout_idx: Tensor,
16
+ sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
17
+ src = src.contiguous()
18
+ idx_bat = idx_bat.contiguous()
19
+ idx_col = idx_col.contiguous()
20
+
21
+ validate_dimensions(src, idx_bat, idx_col)
22
+ validate_contiguous(src, idx_bat, idx_col)
23
+ validate_dtype_int(idx_bat, idx_col)
24
+ validate_device(src, idx_bat, idx_col)
25
+ validate_sparsity(sparsity_block_size, (src, sparsity_layout_src),
26
+ (idx_bat, sparsity_layout_idx), (idx_col, sparsity_layout_idx))
27
+ validate_sparsity_block_size(sparsity_block_size, src, idx_bat, idx_col)
28
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
29
+
30
+ sparsity_layout_x_flat = sparsity_layout_src.reshape(-1)
31
+ sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
32
+ (sparsity_layout_x_flat == 1) -
33
+ (1 * (sparsity_layout_x_flat == 0)))
34
+
35
+ sparsity_lut_i = torch.nonzero(sparsity_layout_idx).contiguous()
36
+
37
+ validate_contiguous(sparsity_layout_src, sparsity_reverse_lut_x,
38
+ sparsity_layout_idx, sparsity_lut_i)
39
+
40
+ return _BlocksparseGatherMDI.apply(src, sparsity_layout_src, sparsity_reverse_lut_x,
41
+ idx_bat, idx_col, sparsity_layout_idx, sparsity_lut_i,
42
+ sparsity_block_size, triton_block_size)
43
+
44
+
45
+ class _BlocksparseGatherMDI(torch.autograd.Function):
46
+
47
+ @staticmethod
48
+ def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
49
+ idx_bat: Tensor, idx_col: Tensor, sparsity_layout_i: Tensor, sparsity_lut_i: Tensor,
50
+ sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
51
+ output = torch.empty_like(idx_col, dtype=x.dtype)
52
+
53
+ x_b, x_r, x_c = x.size()
54
+ x_b_s, x_r_s, x_c_s = x.stride()
55
+ s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
56
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout_x.stride()
57
+ i_b, i_r, i_c = idx_col.size()
58
+ i_b_s, i_r_s, i_c_s = idx_col.stride()
59
+ s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
60
+ s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
61
+ o_b, o_r, o_c = output.size()
62
+ o_b_s, o_r_s, o_c_s = output.stride()
63
+
64
+ if triton_block_size is None:
65
+ triton_block_size = get_triton_block_size(sparsity_block_size)
66
+
67
+ triton_grid = lambda meta: [o_b,
68
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
69
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
70
+
71
+ (_BlocksparseGatherMDI.kernel_blocksparse_gather_mdi[triton_grid]
72
+ (x,
73
+ x_b, x_b_s, x_r_s, x_c_s,
74
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
75
+ sparsity_reverse_lut_x,
76
+ idx_bat,
77
+ idx_col,
78
+ i_b, i_b_s, i_r_s, i_c_s,
79
+ output,
80
+ o_b, o_b_s, o_r_s, o_c_s,
81
+ sparsity_lut_i, s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
82
+ sparsity_block_size,
83
+ triton_block_size))
84
+
85
+ ctx.save_for_backward(sparsity_layout_x, idx_bat, idx_col, sparsity_layout_i)
86
+ ctx.sparsity_block_size = sparsity_block_size
87
+ ctx.triton_block_size = triton_block_size
88
+
89
+ return output
90
+
91
+ @staticmethod
92
+ def backward(ctx, grad_output):
93
+ sparsity_layout_x, idx_bat, idx_col, sparsity_layout_i = ctx.saved_tensors
94
+ sparsity_block_size = ctx.sparsity_block_size
95
+ triton_block_size = ctx.triton_block_size
96
+
97
+ return scatter_reduce_mdi(grad_output, sparsity_layout_i,
98
+ idx_bat,
99
+ None,
100
+ idx_col,
101
+ sparsity_layout_x,
102
+ sparsity_block_size,
103
+ reduce_op="sum",
104
+ triton_block_size=triton_block_size), None, None, None, None, None, None, None, None
105
+
106
+ @staticmethod
107
+ @triton.jit
108
+ def kernel_blocksparse_gather_mdi(x,
109
+ x_b, x_b_s, x_r_s, x_c_s,
110
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
111
+ r_lut_x,
112
+ idx_bat,
113
+ idx_col,
114
+ i_b, i_b_s, i_r_s, i_c_s,
115
+ o,
116
+ o_b, o_b_s, o_r_s, o_c_s,
117
+ s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
118
+ sparsity_block_size,
119
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
120
+ # Get triton block indices
121
+ pid_blk = tl.program_id(axis=0)
122
+ pid_row = tl.program_id(axis=1)
123
+ pid_col = tl.program_id(axis=2)
124
+
125
+ # Load batch index values
126
+ blk_idx_bat_idx = ((pid_blk * i_b_s) +
127
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
128
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
129
+ blk_idx_bat_msk = (blk_idx_bat_idx < i_b * i_b_s)
130
+ blk_idx_bat = tl.load(idx_bat + blk_idx_bat_idx, mask=blk_idx_bat_msk).to(tl.int32)
131
+
132
+ # Get position of current sparsity block row
133
+ spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
134
+ spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
135
+ spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
136
+
137
+ # Load column index values
138
+ blk_idx_col_idx = ((pid_blk * i_b_s) +
139
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
140
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
141
+ blk_idx_col_msk = (blk_idx_col_idx < i_b * i_b_s)
142
+ blk_idx_col = tl.load(idx_col + blk_idx_col_idx, mask=blk_idx_col_msk).to(tl.int32)
143
+
144
+ # Get positions of sparsity blocks
145
+ pos_spa_blk_x = blk_idx_col // sparsity_block_size
146
+ pos_spa_col_x = blk_idx_col % sparsity_block_size
147
+
148
+ # Load reverse sparsity indices for x
149
+ rev_idx_spa_x_idx = ((blk_idx_bat * s_l_x_b_s) +
150
+ (spa_row_o * s_l_x_r_s) +
151
+ (pos_spa_blk_x * s_l_x_c_s))
152
+ rev_idx_spa_x_msk = (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
153
+ rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
154
+
155
+ # Load x values
156
+ blk_x_idx = ((rev_idx_spa_x * x_b_s) +
157
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
158
+ (pos_spa_col_x * x_c_s))
159
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
160
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
161
+
162
+ # Store output
163
+ blk_o_idx = ((pid_blk * o_b_s) +
164
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
165
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
166
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
167
+ tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
168
+
169
+
170
+ def scatter_reduce_mdi(src: Tensor, sparsity_layout_src: Tensor,
171
+ idx_bat: Tensor,
172
+ idx_row: Tensor,
173
+ idx_col: Tensor,
174
+ sparsity_layout_tgt: Tensor,
175
+ sparsity_block_size: int,
176
+ reduce_op: str = "sum", triton_block_size: int = None) -> Tensor:
177
+ src = src.contiguous()
178
+ idx_bat = idx_bat.contiguous()
179
+ idx_col = idx_col.contiguous()
180
+
181
+ validate_dimensions(src, idx_bat, idx_col)
182
+ validate_contiguous(src, idx_bat, idx_col)
183
+ validate_dtype_int(idx_bat, idx_col)
184
+ validate_device(src, idx_bat, idx_col)
185
+ validate_sparsity(sparsity_block_size, (src, sparsity_layout_src),
186
+ (idx_bat, sparsity_layout_src),
187
+ (idx_col, sparsity_layout_src))
188
+ validate_sparsity_block_size(sparsity_block_size, src, idx_bat, idx_col)
189
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
190
+
191
+ if reduce_op not in ["none", "sum"]:
192
+ raise ValueError(f"Reduction operation '{reduce_op}' is not supported")
193
+
194
+ sparsity_lut_x = torch.nonzero(sparsity_layout_src).contiguous()
195
+
196
+ sparsity_layout_o_flat = sparsity_layout_tgt.reshape(-1)
197
+ sparsity_reverse_lut_o = ((torch.cumsum(sparsity_layout_o_flat, dim=-1) - 1) *
198
+ (sparsity_layout_o_flat == 1) -
199
+ (1 * (sparsity_layout_o_flat == 0)))
200
+
201
+ n_sparse_blocks = torch.sum(sparsity_layout_tgt.to(torch.int)).item()
202
+
203
+ validate_contiguous(sparsity_layout_src, sparsity_lut_x,
204
+ sparsity_layout_tgt, sparsity_reverse_lut_o)
205
+
206
+ return _BlocksparseScatterReduceMDI.apply(src, sparsity_layout_src, sparsity_lut_x,
207
+ idx_bat,
208
+ idx_col,
209
+ sparsity_layout_tgt, sparsity_reverse_lut_o,
210
+ sparsity_block_size, n_sparse_blocks,
211
+ reduce_op, triton_block_size)
212
+
213
+
214
+ class _BlocksparseScatterReduceMDI(torch.autograd.Function):
215
+
216
+ @staticmethod
217
+ def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_lut_x: Tensor,
218
+ idx_bat: Tensor,
219
+ idx_col: Tensor,
220
+ sparsity_layout_o: Tensor, sparsity_reverse_lut_o: Tensor,
221
+ sparsity_block_size: int, n_sparse_blocks: int,
222
+ reduce_op: str, triton_block_size: int) -> Tensor:
223
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
224
+ dtype=x.dtype, device=x.device)
225
+
226
+ x_b, x_r, x_c = x.size()
227
+ x_b_s, x_r_s, x_c_s = x.stride()
228
+ s_lut_x_r, s_lut_x_c = sparsity_lut_x.size()
229
+ s_lut_x_r_s, s_lut_x_c_s = sparsity_lut_x.stride()
230
+ i_b, i_r, i_c = idx_col.size()
231
+ i_b_s, i_r_s, i_c_s = idx_col.stride()
232
+ o_b, o_r, o_c = output.size()
233
+ o_b_s, o_r_s, o_c_s = output.stride()
234
+ s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
235
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_o.stride()
236
+
237
+ if triton_block_size is None:
238
+ triton_block_size = get_triton_block_size(sparsity_block_size)
239
+
240
+ triton_grid = lambda meta: [x_b,
241
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
242
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
243
+
244
+ reduce_op_ind = 0
245
+ if reduce_op == "sum":
246
+ reduce_op_ind = 1
247
+
248
+ (_BlocksparseScatterReduceMDI.kernel_blocksparse_scatter_mdi[triton_grid]
249
+ (x,
250
+ x_b, x_b_s, x_r_s, x_c_s,
251
+ sparsity_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
252
+ idx_bat,
253
+ idx_col,
254
+ i_b, i_b_s, i_r_s, i_c_s,
255
+ output,
256
+ o_b, o_b_s, o_r_s, o_c_s,
257
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
258
+ sparsity_reverse_lut_o,
259
+ reduce_op_ind,
260
+ sparsity_block_size,
261
+ triton_block_size))
262
+
263
+ ctx.save_for_backward(sparsity_layout_x, idx_bat, idx_col, sparsity_layout_o)
264
+ ctx.sparsity_block_size = sparsity_block_size
265
+ ctx.reduce_op = reduce_op
266
+ ctx.triton_block_size = triton_block_size
267
+
268
+ return output
269
+
270
+ @staticmethod
271
+ def backward(ctx, grad_output):
272
+ sparsity_layout_x, idx_bat, idx_col, sparsity_layout_o = ctx.saved_tensors
273
+ sparsity_block_size = ctx.sparsity_block_size
274
+ reduce_op = ctx.reduce_op
275
+ triton_block_size = ctx.triton_block_size
276
+
277
+ if reduce_op == "sum":
278
+ return gather_mdi(grad_output, sparsity_layout_o,
279
+ idx_bat,
280
+ None,
281
+ idx_col,
282
+ sparsity_layout_x, sparsity_block_size,
283
+ triton_block_size=triton_block_size), None, None, None, None, None, None, None, None, None, None
284
+ else:
285
+ raise ValueError(f"Reduction operation '{reduce_op}' does not support backward pass")
286
+
287
+ @staticmethod
288
+ @triton.jit
289
+ def kernel_blocksparse_scatter_mdi(x,
290
+ x_b, x_b_s, x_r_s, x_c_s,
291
+ s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
292
+ idx_bat,
293
+ idx_col,
294
+ i_b, i_b_s, i_r_s, i_c_s,
295
+ o,
296
+ o_b, o_b_s, o_r_s, o_c_s,
297
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
298
+ r_lut_o,
299
+ reduce_op_ind,
300
+ sparsity_block_size,
301
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
302
+ # Get triton block indices
303
+ pid_blk = tl.program_id(axis=0)
304
+ pid_row = tl.program_id(axis=1)
305
+ pid_col = tl.program_id(axis=2)
306
+
307
+ # Load batch index values
308
+ blk_idx_bat_idx = ((pid_blk * i_b_s) +
309
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
310
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
311
+ blk_idx_bat_msk = (blk_idx_bat_idx < i_b * i_b_s)
312
+ blk_idx_bat = tl.load(idx_bat + blk_idx_bat_idx, mask=blk_idx_bat_msk).to(tl.int32)
313
+
314
+ # Get position of current sparsity block row
315
+ spa_row_x_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
316
+ spa_row_x_msk = (spa_row_x_idx < s_lut_x_r * s_lut_x_r_s)
317
+ spa_row_x = tl.load(s_lut_x + spa_row_x_idx, mask=spa_row_x_msk)
318
+
319
+ # Load x values
320
+ blk_x_idx = ((pid_blk * x_b_s) +
321
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
322
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
323
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
324
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
325
+
326
+ # Load column index values
327
+ blk_idx_col_idx = ((pid_blk * i_b_s) +
328
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
329
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
330
+ blk_idx_col_msk = (blk_idx_col_idx < i_b * i_b_s)
331
+ blk_idx_col = tl.load(idx_col + blk_idx_col_idx, mask=blk_idx_col_msk).to(tl.int32)
332
+
333
+ # Get positions of sparsity blocks
334
+ pos_spa_blk_o = blk_idx_col // sparsity_block_size
335
+ pos_spa_col_o = blk_idx_col % sparsity_block_size
336
+
337
+ # Load reverse sparsity indices for o
338
+ rev_idx_spa_o_idx = ((blk_idx_bat * s_l_o_b_s) +
339
+ (spa_row_x * s_l_o_r_s) +
340
+ (pos_spa_blk_o * s_l_o_c_s))
341
+ rev_idx_spa_o_msk = (rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
342
+ rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
343
+
344
+ # Store output
345
+ blk_o_idx = ((rev_idx_spa_o * o_b_s) +
346
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
347
+ (pos_spa_col_o * o_c_s))
348
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
349
+
350
+ if reduce_op_ind == 0:
351
+ tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
352
+ elif reduce_op_ind == 1:
353
+ tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
354
+
355
+
356
+ def build_distribution_layout_mdi(idx_bat: Tensor, idx_row: Tensor, idx_col: Tensor, sparsity_layout_idx: Tensor,
357
+ size_target: torch.Size,
358
+ sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
359
+ validate_dimensions(idx_bat, idx_col)
360
+ validate_contiguous(idx_bat, idx_col)
361
+ validate_device(idx_bat, idx_col)
362
+
363
+ sparsity_lut_i = torch.nonzero(sparsity_layout_idx).contiguous()
364
+
365
+ output = torch.zeros(size_target[0], size_target[1] // sparsity_block_size, size_target[2] // sparsity_block_size,
366
+ dtype=torch.bool, device=idx_col.device)
367
+
368
+ i_b, i_r, i_c = idx_col.size()
369
+ i_b_s, i_r_s, i_c_s = idx_col.stride()
370
+ s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
371
+ s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
372
+ o_b, o_r, o_c = output.size()
373
+ o_b_s, o_r_s, o_c_s = output.stride()
374
+
375
+ if triton_block_size is None:
376
+ triton_block_size = get_triton_block_size(sparsity_block_size)
377
+
378
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
379
+
380
+ triton_grid = lambda meta: [i_b,
381
+ triton.cdiv(i_r, meta["TRITON_BLOCK_SIZE"]),
382
+ triton.cdiv(i_c, meta["TRITON_BLOCK_SIZE"])]
383
+
384
+ (kernel_distribution_layout_mdi[triton_grid]
385
+ (idx_bat,
386
+ idx_col,
387
+ i_b, i_b_s, i_r_s, i_c_s,
388
+ sparsity_lut_i,
389
+ s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
390
+ output,
391
+ o_b, o_b_s, o_r_s, o_c_s,
392
+ sparsity_block_size,
393
+ triton_block_size))
394
+
395
+ return output
396
+
397
+
398
+ @triton.jit
399
+ def kernel_distribution_layout_mdi(idx_bat,
400
+ idx_col,
401
+ i_b, i_b_s, i_r_s, i_c_s,
402
+ s_lut_i,
403
+ s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
404
+ o,
405
+ o_b, o_b_s, o_r_s, o_c_s,
406
+ sparsity_block_size,
407
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
408
+ # Get triton block indices
409
+ pid_blk = tl.program_id(axis=0)
410
+ pid_row = tl.program_id(axis=1)
411
+ pid_col = tl.program_id(axis=2)
412
+
413
+ # Load batch index values
414
+ blk_idx_bat_idx = ((pid_blk * i_b_s) +
415
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
416
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
417
+ blk_idx_bat_msk = (blk_idx_bat_idx < i_b * i_b_s)
418
+ blk_idx_bat = tl.load(idx_bat + blk_idx_bat_idx, mask=blk_idx_bat_msk).to(tl.int32)
419
+
420
+ # Get position of current sparsity block row
421
+ spa_row_i_idx = (pid_blk * s_lut_i_r_s + 1 * s_lut_i_c_s)
422
+ spa_row_i_msk = (spa_row_i_idx < s_lut_i_r * s_lut_i_r_s)
423
+ spa_row_i = tl.load(s_lut_i + spa_row_i_idx, mask=spa_row_i_msk)
424
+
425
+ blk_i_idx = (pid_blk * i_b_s +
426
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
427
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
428
+ blk_i_msk = (blk_i_idx < i_b * i_b_s)
429
+ blk_i = tl.load(idx_col + blk_i_idx, mask=blk_i_msk)
430
+
431
+ blk_i = blk_i // sparsity_block_size
432
+ blk_v = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), 1, dtype=tl.int32)
433
+
434
+ blk_o_idx = ((blk_idx_bat * o_b_s) +
435
+ (spa_row_i * o_r_s) +
436
+ (blk_i * o_c_s))
437
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
438
+ tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
@@ -35,8 +35,6 @@ def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
35
35
 
36
36
  i_b, i_r, i_c = indices.size()
37
37
  i_b_s, i_r_s, i_c_s = indices.stride()
38
- s_l_i_b, s_l_i_r, s_l_i_c = sparsity_layout_indices.size()
39
- s_l_i_b_s, s_l_i_r_s, s_l_i_c_s = sparsity_layout_indices.stride()
40
38
  s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
41
39
  s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
42
40
  o_b, o_r, o_c = output.size()
@@ -54,12 +52,10 @@ def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
54
52
  (kernel_distribution_layout[triton_grid]
55
53
  (indices,
56
54
  i_b, i_b_s, i_r_s, i_c_s,
57
- sparsity_layout_indices,
58
- s_l_i_b, s_l_i_b_s, s_l_i_r, s_l_i_r_s, s_l_i_c, s_l_i_c_s,
59
55
  sparsity_lut_i,
60
- s_lut_i_r, s_lut_i_r_s, s_lut_i_c, s_lut_i_c_s,
56
+ s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
61
57
  output,
62
- o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,
58
+ o_b, o_b_s, o_r_s, o_c_s,
63
59
  sparsity_block_size,
64
60
  triton_block_size))
65
61
 
@@ -69,12 +65,10 @@ def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
69
65
  @triton.jit
70
66
  def kernel_distribution_layout(i,
71
67
  i_b, i_b_s, i_r_s, i_c_s,
72
- s_l_i,
73
- s_l_i_b, s_l_i_b_s, s_l_i_r, s_l_i_r_s, s_l_i_c, s_l_i_c_s,
74
68
  s_lut_i,
75
- s_lut_i_r, s_lut_i_r_s, s_lut_i_c, s_lut_i_c_s,
69
+ s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
76
70
  o,
77
- o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,
71
+ o_b, o_b_s, o_r_s, o_c_s,
78
72
  sparsity_block_size,
79
73
  TRITON_BLOCK_SIZE: tl.constexpr) -> None:
80
74
  # Get triton block indices
@@ -105,10 +99,3 @@ def kernel_distribution_layout(i,
105
99
  (blk_i * o_c_s))
106
100
  blk_o_msk = (blk_o_idx < o_b * o_b_s)
107
101
  tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
108
-
109
- # if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
110
- # blk_o_idx = (pid_bat * o_b_s +
111
- # (((pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_r_s +
112
- # ((pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_c_s))
113
- # blk_o_msk = (blk_o_idx < o_b * o_b_s)
114
- # tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
@@ -188,3 +188,39 @@ def kernel_sparsity_layout_adaption(x,
188
188
  // sparsity_block_size_to) * o_c_s))
189
189
  blk_o_msk = (blk_o_idx < o_b * o_b_s)
190
190
  tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
191
+
192
+
193
+ def build_sparsity_layout_matmul(sparsity_layout_x: Tensor, sparsity_layout_y: Tensor):
194
+ """Builds the precise sparsity layout of the result of a matrix multiplication between the two input tensors.
195
+
196
+ Args:
197
+ sparsity_layout_x (Tensor): The sparsity layout of the first block-sparse tensor.
198
+ sparsity_layout_y (Tensor): The sparsity layout of the second block-sparse tensor.
199
+
200
+ Returns:
201
+ Tensor: The precise sparsity layout of the result of a matrix multiplication between the two input tensors.
202
+
203
+ """
204
+ return torch.matmul(sparsity_layout_x.to(torch.float), sparsity_layout_y.to(torch.float)).to(torch.bool)
205
+
206
+
207
+ def build_sparsity_layout_matmul_fast(sparsity_layout_x: Tensor, sparsity_layout_y: Tensor):
208
+ """Builds the approximate sparsity layout of the result of a matrix multiplication between the two input tensors.
209
+
210
+ Note:
211
+ This function is faster than the ``build_sparsity_layout_matmul`` function due to the fact that it only checks
212
+ whether at least one of the blocks in either of the vectors participating in the matmul is non-sparse. The
213
+ resulting sparsity layout may thus overestimate the actual sparsity of the result.
214
+
215
+ Args:
216
+ sparsity_layout_x (Tensor): The sparsity layout of the first block-sparse tensor.
217
+ sparsity_layout_y (Tensor): The sparsity layout of the second block-sparse tensor.
218
+
219
+ Returns:
220
+ Tensor: The approximate sparsity layout of the result of a matrix multiplication between the two input tensors.
221
+
222
+ """
223
+ sparsity_layout_x_slice = torch.max(sparsity_layout_x, dim=-1).values.unsqueeze(-1)
224
+ sparsity_layout_y_slice = torch.max(sparsity_layout_y, dim=-2).values.unsqueeze(1)
225
+
226
+ return torch.logical_or(sparsity_layout_x_slice, sparsity_layout_y_slice)
@@ -35,7 +35,7 @@ def repeat_interleave(x: Tensor, sparsity_layout: Tensor, repeats: int,
35
35
  validate_sparsity_block_size(sparsity_block_size, x)
36
36
  validate_triton_block_size(triton_block_size, sparsity_block_size)
37
37
 
38
- sparsity_layout_output = torch.repeat_interleave(sparsity_layout, 3, dim=0).contiguous()
38
+ sparsity_layout_output = torch.repeat_interleave(sparsity_layout, repeats, dim=0).contiguous()
39
39
 
40
40
  sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
41
41
 
@@ -0,0 +1,244 @@
1
+ import torch
2
+ import triton
3
+ from sympy.utilities.iterables import partitions
4
+ from torch import Tensor
5
+ from triton import language as tl
6
+
7
+ from blksprs.utils.tools import get_triton_block_size
8
+
9
+ from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
10
+ validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
11
+
12
+
13
+ def split(x: Tensor, sparsity_layout: Tensor, partitions: int,
14
+ sparsity_block_size: int, triton_block_size: int = None) -> (Tensor, Tensor):
15
+ """Splits a block-sparse tensor in compressed form along the last dimension into partitions.
16
+
17
+ Args:
18
+ x (Tensor): A block-sparse tensor in compressed form.
19
+ sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
20
+ partitions (int): The number of partitions to split the block-sparse tensor into.
21
+ sparsity_block_size (int): The size of the sparsity blocks.
22
+ triton_block_size (int): The block size to use for the triton kernel (default ``None``).
23
+
24
+ Returns:
25
+ Tensor: The block-sparse tensor split into partitions in compressed form.
26
+ Tensor: The sparsity layout of the output tensor.
27
+
28
+ """
29
+ x = x.contiguous()
30
+
31
+ validate_dimensions(x)
32
+ validate_contiguous(x)
33
+ validate_device(x)
34
+ validate_sparsity(sparsity_block_size, (x, sparsity_layout))
35
+ validate_sparsity_block_size(sparsity_block_size, x)
36
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
37
+
38
+ sparsity_layout_output = (sparsity_layout
39
+ .reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
40
+ sparsity_layout.size(2) // partitions)
41
+ .permute(0, 2, 1, 3)
42
+ .reshape(sparsity_layout.size(0) * partitions, sparsity_layout.size(1),
43
+ sparsity_layout.size(2) // partitions).contiguous())
44
+
45
+ sparsity_lut = torch.nonzero(sparsity_layout_output).contiguous()
46
+
47
+ sparsity_layout_flat = sparsity_layout.reshape(-1)
48
+ sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
49
+ (sparsity_layout_flat == 1) -
50
+ (1 * (sparsity_layout_flat == 0)))
51
+ .reshape(sparsity_layout.size())
52
+ .reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
53
+ sparsity_layout.size(2) // partitions)
54
+ .permute(0, 2, 1, 3).reshape(-1).contiguous())
55
+
56
+ n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
57
+
58
+ validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
59
+
60
+ return _BlocksparseSplit.apply(x, sparsity_layout_output, sparsity_lut, sparsity_reverse_lut, partitions,
61
+ sparsity_block_size, n_sparse_blocks, triton_block_size), sparsity_layout_output
62
+
63
+
64
+ class _BlocksparseSplit(torch.autograd.Function):
65
+
66
+ @staticmethod
67
+ def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
68
+ num_partitions: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
69
+ ctx.num_partitions = num_partitions
70
+
71
+ return forward_reorder(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
72
+ n_sparse_blocks, triton_block_size)
73
+
74
+ @staticmethod
75
+ def backward(ctx, grad_output):
76
+ sparsity_layout = ctx.saved_tensors[0]
77
+ num_partitions = ctx.num_partitions
78
+ sparsity_block_size = ctx.sparsity_block_size
79
+ triton_block_size = ctx.triton_block_size
80
+
81
+ return merge(grad_output, sparsity_layout, num_partitions,
82
+ sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None
83
+
84
+
85
+ def merge(x: Tensor, sparsity_layout: Tensor, partitions: int,
86
+ sparsity_block_size: int, triton_block_size: int = None) -> (Tensor, Tensor):
87
+ """Merges the specified partitions of a block-sparse tensor in compressed form along the last dimension.
88
+
89
+ Args:
90
+ x (Tensor): A block-sparse tensor in compressed form.
91
+ sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
92
+ partitions (int): The number of partitions to be merged.
93
+ sparsity_block_size (int): The size of the sparsity blocks.
94
+ triton_block_size (int): The block size to use for the triton kernel (default ``None``).
95
+
96
+ Returns:
97
+ Tensor: The merged block-sparse tensor in compressed form.
98
+ Tensor: The sparsity layout of the output tensor.
99
+
100
+ """
101
+ x = x.contiguous()
102
+
103
+ validate_dimensions(x)
104
+ validate_contiguous(x)
105
+ validate_device(x)
106
+ validate_sparsity(sparsity_block_size, (x, sparsity_layout))
107
+ validate_sparsity_block_size(sparsity_block_size, x)
108
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
109
+
110
+ sparsity_layout_output = (sparsity_layout.reshape(sparsity_layout.size(0) // partitions, partitions,
111
+ sparsity_layout.size(1), sparsity_layout.size(2))
112
+ .permute(0, 2, 1, 3)
113
+ .reshape(sparsity_layout.size(0) // partitions,
114
+ sparsity_layout.size(1), sparsity_layout.size(2) * partitions).contiguous())
115
+
116
+ sparsity_lut = torch.nonzero(sparsity_layout_output).contiguous()
117
+
118
+ sparsity_layout_flat = sparsity_layout.reshape(-1)
119
+ sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
120
+ (sparsity_layout_flat == 1) -
121
+ (1 * (sparsity_layout_flat == 0)))
122
+ .reshape(sparsity_layout.size(0) // partitions, partitions,
123
+ sparsity_layout.size(1), sparsity_layout.size(2))
124
+ .permute(0, 2, 1, 3)
125
+ .reshape(sparsity_layout.size(0) // partitions,
126
+ sparsity_layout.size(1), sparsity_layout.size(2) * partitions)
127
+ .reshape(-1).contiguous())
128
+
129
+ n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
130
+
131
+ validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
132
+
133
+ return _BlocksparseMerge.apply(x, sparsity_layout_output, sparsity_lut, sparsity_reverse_lut, partitions,
134
+ sparsity_block_size, n_sparse_blocks, triton_block_size), sparsity_layout_output
135
+
136
+
137
+ class _BlocksparseMerge(torch.autograd.Function):
138
+
139
+ @staticmethod
140
+ def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
141
+ num_partitions: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
142
+ ctx.num_partitions = num_partitions
143
+
144
+ return forward_reorder(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
145
+ n_sparse_blocks, triton_block_size)
146
+
147
+ @staticmethod
148
+ def backward(ctx, grad_output):
149
+ sparsity_layout = ctx.saved_tensors[0]
150
+ num_partitions = ctx.num_partitions
151
+ sparsity_block_size = ctx.sparsity_block_size
152
+ triton_block_size = ctx.triton_block_size
153
+
154
+ return split(grad_output, sparsity_layout, num_partitions,
155
+ sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None
156
+
157
+
158
+ def forward_reorder(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
159
+ sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
160
+ output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
161
+ dtype=x.dtype, device=x.device)
162
+
163
+ x_b, x_r, x_c = x.size()
164
+ x_b_s, x_r_s, x_c_s = x.stride()
165
+ s_l_b, s_l_r, s_l_c = sparsity_layout_o.size()
166
+ s_l_b_s, s_l_r_s, s_l_c_s = sparsity_layout_o.stride()
167
+ s_lut_r, s_lut_c = sparsity_lut.shape
168
+ s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
169
+ o_b, o_r, o_c = output.size()
170
+ o_b_s, o_r_s, o_c_s = output.stride()
171
+
172
+ if triton_block_size is None:
173
+ triton_block_size = get_triton_block_size(sparsity_block_size)
174
+
175
+ triton_grid = lambda meta: [o_b,
176
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
177
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
178
+
179
+ (kernel_blocksparse_reorder[triton_grid]
180
+ (x,
181
+ x_b, x_b_s, x_r_s, x_c_s,
182
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
183
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
184
+ sparsity_reverse_lut,
185
+ output,
186
+ o_b, o_b_s,
187
+ triton_block_size))
188
+
189
+ # Save for backward pass
190
+ ctx.save_for_backward(sparsity_layout_o)
191
+ ctx.sparsity_block_size = sparsity_block_size
192
+ ctx.triton_block_size = triton_block_size
193
+
194
+ return output
195
+
196
+
197
+ @triton.jit
198
+ def kernel_blocksparse_reorder(x,
199
+ x_b, x_b_s, x_r_s, x_c_s,
200
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
201
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
202
+ r_lut,
203
+ o,
204
+ o_b, o_b_s,
205
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
206
+ # Get triton block indices
207
+ pid_blk = tl.program_id(axis=0)
208
+ pid_row = tl.program_id(axis=1)
209
+ pid_col = tl.program_id(axis=2)
210
+
211
+ # Get sparsity index of current output block consisting of its batch, row, and column index
212
+ spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
213
+ spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
214
+ spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
215
+
216
+ spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
217
+ spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
218
+ spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
219
+
220
+ spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
221
+ spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
222
+ spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
223
+
224
+ # Get reverse sparsity index
225
+ rev_idx_spa_idx = (spa_bat * s_l_b_s +
226
+ spa_row * s_l_r_s +
227
+ spa_col * s_l_c_s)
228
+ rev_idx_spa_msk = (rev_idx_spa_idx < s_l_b * s_l_b_s)
229
+ rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
230
+
231
+ if rev_idx_spa == -1:
232
+ assert False, "Invalid sparsity block"
233
+
234
+ blk_x_idx = (rev_idx_spa * x_b_s +
235
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
236
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
237
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
238
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
239
+
240
+ blk_o_idx = (pid_blk * o_b_s +
241
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
242
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
243
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
244
+ tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
@@ -56,16 +56,16 @@ def transpose(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, trit
56
56
  class _BlocksparseTranspose(torch.autograd.Function):
57
57
 
58
58
  @staticmethod
59
- def forward(ctx, x: Tensor,
60
- sparsity_layout: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor, sparsity_block_size: int,
61
- n_sparse_blocks: int, triton_block_size: int) -> (Tensor, Tensor):
59
+ def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
60
+ sparsity_block_size: int,
61
+ n_sparse_blocks: int, triton_block_size: int) -> Tensor:
62
62
  output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
63
63
  dtype=x.dtype, device=x.device)
64
64
 
65
65
  x_b, x_r, x_c = x.size()
66
66
  x_b_s, x_r_s, x_c_s = x.stride()
67
- s_l_b, s_l_r, s_l_c = sparsity_layout.size()
68
- s_l_b_s, s_l_r_s, s_l_c_s = sparsity_layout.stride()
67
+ s_l_b, s_l_r, s_l_c = sparsity_layout_o.size()
68
+ s_l_b_s, s_l_r_s, s_l_c_s = sparsity_layout_o.stride()
69
69
  s_lut_r, s_lut_c = sparsity_lut.shape
70
70
  s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
71
71
  o_b, o_r, o_c = output.size()
@@ -89,8 +89,7 @@ class _BlocksparseTranspose(torch.autograd.Function):
89
89
  triton_block_size))
90
90
 
91
91
  # Save for backward pass
92
- ctx.save_for_backward(sparsity_layout)
93
- ctx.sparsity_layout = sparsity_layout
92
+ ctx.save_for_backward(sparsity_layout_o)
94
93
  ctx.sparsity_block_size = sparsity_block_size
95
94
  ctx.triton_block_size = triton_block_size
96
95
 
@@ -63,6 +63,8 @@ def validate_sparsity(sparsity_block_size: int, *tensor_sparsity_layout_tuples:
63
63
  for (tensor, sparsity_layout) in tensor_sparsity_layout_tuples:
64
64
  _validate_sparsity_layout_values(sparsity_layout)
65
65
 
66
+ if not sparsity_layout.dim() == 3:
67
+ raise ValueError("Sparsity layout must have exactly 3 dimensions")
66
68
  if not (tensor.size(-1) == tensor.size(-2) == sparsity_block_size):
67
69
  raise ValueError("Blocks not conforming to sparsity block size")
68
70
  if not tensor.size(0) == torch.sum(sparsity_layout.reshape(-1)):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: blksprs
3
- Version: 1.4.2
3
+ Version: 1.6.1
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -31,9 +31,10 @@ Currently supported operations (includes gradient calculation):
31
31
  - Sparse matrix multiplication (_supports any combination of sparse and dense matrices due to support
32
32
  for `sparse = sparse @ sparse` matmul_)
33
33
  - Softmax
34
- - Transposition
34
+ - Transpose
35
35
  - Gather
36
36
  - Scatter (_supports either no reduction or summation, gradients are only available for summation_)
37
+ - Splitting and merging of matrices along the last dimension
37
38
  - Conversion to and from sparse form
38
39
  - Conversion to different sparsity layouts and different sparsity block sizes
39
40
 
@@ -6,6 +6,7 @@ blksprs.egg-info/SOURCES.txt
6
6
  blksprs.egg-info/dependency_links.txt
7
7
  blksprs.egg-info/requires.txt
8
8
  blksprs.egg-info/top_level.txt
9
+ blksprs/experimental/distribution_mdi.py
9
10
  blksprs/layouting/distribution_layout.py
10
11
  blksprs/layouting/sparsity_layout.py
11
12
  blksprs/misc/broadcast_ops.py
@@ -15,6 +16,7 @@ blksprs/ops/conversion.py
15
16
  blksprs/ops/distribution.py
16
17
  blksprs/ops/exp.py
17
18
  blksprs/ops/matmul.py
19
+ blksprs/ops/partitioning.py
18
20
  blksprs/ops/softmax.py
19
21
  blksprs/ops/transpose.py
20
22
  blksprs/utils/benchmarking.py
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "blksprs"
3
- version = "1.4.2"
3
+ version = "1.6.1"
4
4
  authors = [{ name = "Felix Schön", email = "schoen@kr.tuwien.ac.at" }]
5
5
  description = "A lightweight library for operations on blocksparse matrices in PyTorch."
6
6
  readme = "README.md"
File without changes
File without changes
File without changes
File without changes
File without changes