blksprs 1.6.1__tar.gz → 1.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. {blksprs-1.6.1 → blksprs-1.7}/PKG-INFO +3 -2
  2. {blksprs-1.6.1 → blksprs-1.7}/README.md +2 -1
  3. {blksprs-1.6.1 → blksprs-1.7}/blksprs/experimental/distribution_mdi.py +14 -14
  4. {blksprs-1.6.1 → blksprs-1.7}/blksprs/layouting/distribution_layout.py +4 -4
  5. {blksprs-1.6.1 → blksprs-1.7}/blksprs/layouting/sparsity_layout.py +6 -6
  6. {blksprs-1.6.1 → blksprs-1.7}/blksprs/misc/broadcast_ops.py +5 -5
  7. {blksprs-1.6.1 → blksprs-1.7}/blksprs/misc/repeat_interleave.py +5 -5
  8. {blksprs-1.6.1 → blksprs-1.7}/blksprs/misc/row_wise.py +16 -15
  9. {blksprs-1.6.1 → blksprs-1.7}/blksprs/ops/conversion.py +11 -11
  10. {blksprs-1.6.1 → blksprs-1.7}/blksprs/ops/distribution.py +11 -11
  11. {blksprs-1.6.1 → blksprs-1.7}/blksprs/ops/exp.py +3 -3
  12. {blksprs-1.6.1 → blksprs-1.7}/blksprs/ops/matmul.py +7 -7
  13. {blksprs-1.6.1 → blksprs-1.7}/blksprs/ops/partitioning.py +9 -98
  14. blksprs-1.7/blksprs/ops/repeat.py +241 -0
  15. {blksprs-1.6.1 → blksprs-1.7}/blksprs/ops/softmax.py +11 -10
  16. {blksprs-1.6.1 → blksprs-1.7}/blksprs/ops/transpose.py +7 -6
  17. {blksprs-1.6.1 → blksprs-1.7}/blksprs/utils/tools.py +3 -0
  18. {blksprs-1.6.1 → blksprs-1.7}/blksprs.egg-info/PKG-INFO +3 -2
  19. {blksprs-1.6.1 → blksprs-1.7}/blksprs.egg-info/SOURCES.txt +1 -0
  20. {blksprs-1.6.1 → blksprs-1.7}/pyproject.toml +1 -1
  21. {blksprs-1.6.1 → blksprs-1.7}/blksprs/__init__.py +0 -0
  22. {blksprs-1.6.1 → blksprs-1.7}/blksprs/utils/benchmarking.py +0 -0
  23. {blksprs-1.6.1 → blksprs-1.7}/blksprs/utils/validation.py +0 -0
  24. {blksprs-1.6.1 → blksprs-1.7}/blksprs.egg-info/dependency_links.txt +0 -0
  25. {blksprs-1.6.1 → blksprs-1.7}/blksprs.egg-info/requires.txt +0 -0
  26. {blksprs-1.6.1 → blksprs-1.7}/blksprs.egg-info/top_level.txt +0 -0
  27. {blksprs-1.6.1 → blksprs-1.7}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: blksprs
3
- Version: 1.6.1
3
+ Version: 1.7
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -34,6 +34,7 @@ Currently supported operations (includes gradient calculation):
34
34
  - Transpose
35
35
  - Gather
36
36
  - Scatter (_supports either no reduction or summation, gradients are only available for summation_)
37
+ - Repeat (_supports target sparsity layout_)
37
38
  - Splitting and merging of matrices along the last dimension
38
39
  - Conversion to and from sparse form
39
40
  - Conversion to different sparsity layouts and different sparsity block sizes
@@ -64,7 +65,7 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
64
65
 
65
66
  ### Dependencies
66
67
 
67
- - [PyTorch](https://pytorch.org/) (built with v2.4.0)
68
+ - [PyTorch](https://pytorch.org/) (built with v2.5.0)
68
69
  - _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
69
70
 
70
71
  ## Changelog
@@ -15,6 +15,7 @@ Currently supported operations (includes gradient calculation):
15
15
  - Transpose
16
16
  - Gather
17
17
  - Scatter (_supports either no reduction or summation, gradients are only available for summation_)
18
+ - Repeat (_supports target sparsity layout_)
18
19
  - Splitting and merging of matrices along the last dimension
19
20
  - Conversion to and from sparse form
20
21
  - Conversion to different sparsity layouts and different sparsity block sizes
@@ -45,7 +46,7 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
45
46
 
46
47
  ### Dependencies
47
48
 
48
- - [PyTorch](https://pytorch.org/) (built with v2.4.0)
49
+ - [PyTorch](https://pytorch.org/) (built with v2.5.0)
49
50
  - _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
50
51
 
51
52
  ## Changelog
@@ -3,7 +3,7 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
- from blksprs.utils.tools import get_triton_block_size
6
+ from blksprs.utils.tools import get_triton_block_size, stride
7
7
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
8
8
  validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
9
9
 
@@ -51,15 +51,15 @@ class _BlocksparseGatherMDI(torch.autograd.Function):
51
51
  output = torch.empty_like(idx_col, dtype=x.dtype)
52
52
 
53
53
  x_b, x_r, x_c = x.size()
54
- x_b_s, x_r_s, x_c_s = x.stride()
54
+ x_b_s, x_r_s, x_c_s = stride(x)
55
55
  s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
56
- s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout_x.stride()
56
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
57
57
  i_b, i_r, i_c = idx_col.size()
58
- i_b_s, i_r_s, i_c_s = idx_col.stride()
58
+ i_b_s, i_r_s, i_c_s = stride(idx_col)
59
59
  s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
60
- s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
60
+ s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
61
61
  o_b, o_r, o_c = output.size()
62
- o_b_s, o_r_s, o_c_s = output.stride()
62
+ o_b_s, o_r_s, o_c_s = stride(output)
63
63
 
64
64
  if triton_block_size is None:
65
65
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -224,15 +224,15 @@ class _BlocksparseScatterReduceMDI(torch.autograd.Function):
224
224
  dtype=x.dtype, device=x.device)
225
225
 
226
226
  x_b, x_r, x_c = x.size()
227
- x_b_s, x_r_s, x_c_s = x.stride()
227
+ x_b_s, x_r_s, x_c_s = stride(x)
228
228
  s_lut_x_r, s_lut_x_c = sparsity_lut_x.size()
229
- s_lut_x_r_s, s_lut_x_c_s = sparsity_lut_x.stride()
229
+ s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut_x)
230
230
  i_b, i_r, i_c = idx_col.size()
231
- i_b_s, i_r_s, i_c_s = idx_col.stride()
231
+ i_b_s, i_r_s, i_c_s = stride(idx_col)
232
232
  o_b, o_r, o_c = output.size()
233
- o_b_s, o_r_s, o_c_s = output.stride()
233
+ o_b_s, o_r_s, o_c_s = stride(output)
234
234
  s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
235
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_o.stride()
235
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
236
236
 
237
237
  if triton_block_size is None:
238
238
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -366,11 +366,11 @@ def build_distribution_layout_mdi(idx_bat: Tensor, idx_row: Tensor, idx_col: Ten
366
366
  dtype=torch.bool, device=idx_col.device)
367
367
 
368
368
  i_b, i_r, i_c = idx_col.size()
369
- i_b_s, i_r_s, i_c_s = idx_col.stride()
369
+ i_b_s, i_r_s, i_c_s = stride(idx_col)
370
370
  s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
371
- s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
371
+ s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
372
372
  o_b, o_r, o_c = output.size()
373
- o_b_s, o_r_s, o_c_s = output.stride()
373
+ o_b_s, o_r_s, o_c_s = stride(output)
374
374
 
375
375
  if triton_block_size is None:
376
376
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -3,7 +3,7 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
- from blksprs.utils.tools import get_triton_block_size
6
+ from blksprs.utils.tools import get_triton_block_size, stride
7
7
  from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
8
8
  validate_contiguous
9
9
 
@@ -34,11 +34,11 @@ def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
34
34
  dtype=torch.bool, device=indices.device)
35
35
 
36
36
  i_b, i_r, i_c = indices.size()
37
- i_b_s, i_r_s, i_c_s = indices.stride()
37
+ i_b_s, i_r_s, i_c_s = stride(indices)
38
38
  s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
39
- s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
39
+ s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
40
40
  o_b, o_r, o_c = output.size()
41
- o_b_s, o_r_s, o_c_s = output.stride()
41
+ o_b_s, o_r_s, o_c_s = stride(output)
42
42
 
43
43
  if triton_block_size is None:
44
44
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -5,7 +5,7 @@ import triton
5
5
  from torch import Tensor
6
6
  from triton import language as tl
7
7
 
8
- from blksprs.utils.tools import get_triton_block_size
8
+ from blksprs.utils.tools import get_triton_block_size, stride
9
9
  from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
10
10
  validate_contiguous, validate_sparsity, validate_sparsity_block_size
11
11
 
@@ -30,9 +30,9 @@ def build_sparsity_layout(x: Tensor, sparsity_block_size: int, triton_block_size
30
30
  dtype=torch.bool, device=x.device)
31
31
 
32
32
  x_b, x_r, x_c = x.size()
33
- x_b_s, x_r_s, x_c_s = x.stride()
33
+ x_b_s, x_r_s, x_c_s = stride(x)
34
34
  o_b, o_r, o_c = output.size()
35
- o_b_s, o_r_s, o_c_s = output.stride()
35
+ o_b_s, o_r_s, o_c_s = stride(output)
36
36
 
37
37
  if triton_block_size is None:
38
38
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -120,10 +120,10 @@ def build_sparsity_layout_adaption(x: Tensor, sparsity_layout_from: Tensor,
120
120
  output = torch.zeros(o_b, o_r, o_c, dtype=torch.bool, device=x.device)
121
121
 
122
122
  x_b, x_r, x_c = x.size()
123
- x_b_s, x_r_s, x_c_s = x.stride()
123
+ x_b_s, x_r_s, x_c_s = stride(x)
124
124
  s_lut_r, s_lut_c = sparsity_lut.size()
125
- s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
126
- o_b_s, o_r_s, o_c_s = output.stride()
125
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
126
+ o_b_s, o_r_s, o_c_s = stride(output)
127
127
 
128
128
  if triton_block_size is None:
129
129
  triton_block_size = get_triton_block_size(sparsity_block_size_from)
@@ -3,7 +3,7 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
- from blksprs.utils.tools import get_triton_block_size
6
+ from blksprs.utils.tools import get_triton_block_size, stride
7
7
  from blksprs.utils.validation import validate_contiguous, validate_device, \
8
8
  validate_sparsity_block_size, validate_triton_block_size
9
9
 
@@ -44,13 +44,13 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
44
44
  output = torch.zeros(n_sparse_blocks, sparsity_block_size, sparsity_block_size, dtype=x.dtype, device=x.device)
45
45
 
46
46
  x_b, x_c = x.size()
47
- x_b_s, x_c_s = x.stride()
47
+ x_b_s, x_c_s = stride(x)
48
48
  y_b, y_c = y.size()
49
- y_b_s, y_c_s = y.stride()
49
+ y_b_s, y_c_s = stride(y)
50
50
  o_b, o_r, o_c = output.size()
51
- o_b_s, o_r_s, o_c_s = output.stride()
51
+ o_b_s, o_r_s, o_c_s = stride(output)
52
52
  s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
53
- s_lut_o_r_s, s_lut_o_c_s = sparsity_lut_o.stride()
53
+ s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_o)
54
54
 
55
55
  if triton_block_size is None:
56
56
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -3,7 +3,7 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
- from blksprs.utils.tools import get_triton_block_size
6
+ from blksprs.utils.tools import get_triton_block_size, stride
7
7
  from blksprs.utils.validation import validate_contiguous, validate_device, \
8
8
  validate_sparsity_block_size, validate_triton_block_size, validate_dimensions
9
9
 
@@ -52,13 +52,13 @@ def repeat_interleave(x: Tensor, sparsity_layout: Tensor, repeats: int,
52
52
  dtype=x.dtype, device=x.device)
53
53
 
54
54
  x_b, x_r, x_c = x.size()
55
- x_b_s, x_r_s, x_c_s = x.stride()
55
+ x_b_s, x_r_s, x_c_s = stride(x)
56
56
  s_lut_r, s_lut_c = sparsity_lut.size()
57
- s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
57
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
58
58
  o_b, o_r, o_c = output.size()
59
- o_b_s, o_r_s, o_c_s = output.stride()
59
+ o_b_s, o_r_s, o_c_s = stride(output)
60
60
  s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
61
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_output.stride()
61
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
62
62
 
63
63
  if triton_block_size is None:
64
64
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -3,7 +3,7 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
- from blksprs.utils.tools import get_triton_block_size
6
+ from blksprs.utils.tools import get_triton_block_size, stride
7
7
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, validate_sparsity, \
8
8
  validate_sparsity_block_size, validate_triton_block_size
9
9
 
@@ -60,13 +60,13 @@ def row_wise_sum(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
60
60
  device=x.device)
61
61
 
62
62
  x_b, x_r, x_c = x.size()
63
- x_b_s, x_r_s, x_c_s = x.stride()
63
+ x_b_s, x_r_s, x_c_s = stride(x)
64
64
  s_lut_x_r, s_lut_x_c = sparsity_lut.size()
65
- s_lut_x_r_s, s_lut_x_c_s = sparsity_lut.stride()
65
+ s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut)
66
66
  o_b, o_r, o_c = output.size()
67
- o_b_s, o_r_s, o_c_s = output.stride()
67
+ o_b_s, o_r_s, o_c_s = stride(output)
68
68
  s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
69
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_output.stride()
69
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
70
70
 
71
71
  if triton_block_size is None:
72
72
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -183,13 +183,13 @@ def row_wise_max(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
183
183
  device=x.device)
184
184
 
185
185
  x_b, x_r, x_c = x.size()
186
- x_b_s, x_r_s, x_c_s = x.stride()
186
+ x_b_s, x_r_s, x_c_s = stride(x)
187
187
  s_lut_x_r, s_lut_x_c = sparsity_lut.size()
188
- s_lut_x_r_s, s_lut_x_c_s = sparsity_lut.stride()
188
+ s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut)
189
189
  o_b, o_r, o_c = output.size()
190
- o_b_s, o_r_s, o_c_s = output.stride()
190
+ o_b_s, o_r_s, o_c_s = stride(output)
191
191
  s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
192
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_output.stride()
192
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
193
193
 
194
194
  if triton_block_size is None:
195
195
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -290,15 +290,15 @@ def row_wise_add(x: Tensor, sparsity_layout_x: Tensor, y: Tensor,
290
290
  output = torch.empty_like(x)
291
291
 
292
292
  x_b, x_r, x_c = x.size()
293
- x_b_s, x_r_s, x_c_s = x.stride()
293
+ x_b_s, x_r_s, x_c_s = stride(x)
294
294
  s_lut_r, s_lut_c = sparsity_lut.size()
295
- s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
295
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
296
296
  y_b, y_r, y_c = y.size()
297
- y_b_s, y_r_s, y_c_s = y.stride()
297
+ y_b_s, y_r_s, y_c_s = stride(y)
298
298
  s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_rwm.size()
299
- s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = sparsity_layout_rwm.stride()
299
+ s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = stride(sparsity_layout_rwm)
300
300
  o_b, o_r, o_c = output.size()
301
- o_b_s, o_r_s, o_c_s = output.stride()
301
+ o_b_s, o_r_s, o_c_s = stride(output)
302
302
 
303
303
  if triton_block_size is None:
304
304
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -361,7 +361,8 @@ def kernel_blocksparse_row_wise_add(x,
361
361
  rev_idx_spa_s = tl.load(r_lut_y + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
362
362
 
363
363
  if rev_idx_spa_s == -1:
364
- assert False, "Invalid sparsity block"
364
+ tl.device_assert(False)
365
+ return
365
366
 
366
367
  # Load x block
367
368
  blk_x_idx = ((pid_blk * x_b_s) +
@@ -6,7 +6,7 @@ from torch import Tensor
6
6
  from triton import language as tl
7
7
 
8
8
  from blksprs.layouting.sparsity_layout import build_sparsity_layout_adaption
9
- from blksprs.utils.tools import get_triton_block_size
9
+ from blksprs.utils.tools import get_triton_block_size, stride
10
10
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
11
11
  validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
12
12
 
@@ -65,11 +65,11 @@ class _BlocksparseToDense(torch.autograd.Function):
65
65
  dtype=x.dtype, device=x.device)
66
66
 
67
67
  x_b, x_r, x_c = x.shape
68
- x_b_s, x_r_s, x_c_s = x.stride()
68
+ x_b_s, x_r_s, x_c_s = stride(x)
69
69
  s_l_b, s_l_r, s_l_c = sparsity_layout.size()
70
- s_l_b_s, s_l_r_s, s_l_c_s = sparsity_layout.stride()
70
+ s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout)
71
71
  o_b, o_r, o_c = output.size()
72
- o_b_s, o_r_s, o_c_s = output.stride()
72
+ o_b_s, o_r_s, o_c_s = stride(output)
73
73
 
74
74
  if triton_block_size is None:
75
75
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -190,11 +190,11 @@ class _BlocksparseToSparse(torch.autograd.Function):
190
190
  dtype=x.dtype, device=x.device)
191
191
 
192
192
  x_b, x_r, x_c = x.size()
193
- x_b_s, x_r_s, x_c_s = x.stride()
193
+ x_b_s, x_r_s, x_c_s = stride(x)
194
194
  s_lut_r, s_lut_c = sparsity_lut.size()
195
- s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
195
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
196
196
  o_b, o_r, o_c = output.size()
197
- o_b_s, o_r_s, o_c_s = output.stride()
197
+ o_b_s, o_r_s, o_c_s = stride(output)
198
198
 
199
199
  if triton_block_size is None:
200
200
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -347,13 +347,13 @@ class _BlocksparseAdaptLayout(torch.autograd.Function):
347
347
  dtype=x.dtype, device=x.device)
348
348
 
349
349
  x_b, x_r, x_c = x.size()
350
- x_b_s, x_r_s, x_c_s = x.stride()
350
+ x_b_s, x_r_s, x_c_s = stride(x)
351
351
  s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_from.size()
352
- s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout_from.stride()
352
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_from)
353
353
  o_b, o_r, o_c = output.size()
354
- o_b_s, o_r_s, o_c_s = output.stride()
354
+ o_b_s, o_r_s, o_c_s = stride(output)
355
355
  s_lut_o_r, s_lut_o_c = sparsity_lut_to.size()
356
- s_lut_o_r_s, s_lut_o_c_s = sparsity_lut_to.stride()
356
+ s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_to)
357
357
 
358
358
  if triton_block_size is None:
359
359
  triton_block_size = get_triton_block_size(min_sparsity_block_size)
@@ -3,7 +3,7 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
- from blksprs.utils.tools import get_triton_block_size
6
+ from blksprs.utils.tools import get_triton_block_size, stride
7
7
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
8
8
  validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
9
9
 
@@ -59,15 +59,15 @@ class _BlocksparseGather(torch.autograd.Function):
59
59
  output = torch.empty_like(i, dtype=x.dtype)
60
60
 
61
61
  x_b, x_r, x_c = x.size()
62
- x_b_s, x_r_s, x_c_s = x.stride()
62
+ x_b_s, x_r_s, x_c_s = stride(x)
63
63
  s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
64
- s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout_x.stride()
64
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
65
65
  i_b, i_r, i_c = i.size()
66
- i_b_s, i_r_s, i_c_s = i.stride()
66
+ i_b_s, i_r_s, i_c_s = stride(i)
67
67
  s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
68
- s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
68
+ s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
69
69
  o_b, o_r, o_c = output.size()
70
- o_b_s, o_r_s, o_c_s = output.stride()
70
+ o_b_s, o_r_s, o_c_s = stride(output)
71
71
 
72
72
  if triton_block_size is None:
73
73
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -248,15 +248,15 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
248
248
  dtype=x.dtype, device=x.device)
249
249
 
250
250
  x_b, x_r, x_c = x.size()
251
- x_b_s, x_r_s, x_c_s = x.stride()
251
+ x_b_s, x_r_s, x_c_s = stride(x)
252
252
  s_lut_x_r, s_lut_x_c = sparsity_lut_x.size()
253
- s_lut_x_r_s, s_lut_x_c_s = sparsity_lut_x.stride()
253
+ s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut_x)
254
254
  i_b, i_r, i_c = i.size()
255
- i_b_s, i_r_s, i_c_s = i.stride()
255
+ i_b_s, i_r_s, i_c_s = stride(i)
256
256
  o_b, o_r, o_c = output.size()
257
- o_b_s, o_r_s, o_c_s = output.stride()
257
+ o_b_s, o_r_s, o_c_s = stride(output)
258
258
  s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
259
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_o.stride()
259
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
260
260
 
261
261
  if triton_block_size is None:
262
262
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -3,7 +3,7 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
- from blksprs.utils.tools import get_triton_block_size
6
+ from blksprs.utils.tools import get_triton_block_size, stride
7
7
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
8
8
  validate_sparsity_block_size, validate_triton_block_size
9
9
 
@@ -43,9 +43,9 @@ class _BlocksparseExp(torch.autograd.Function):
43
43
  output = torch.empty_like(x)
44
44
 
45
45
  x_b, x_r, x_c = x.shape
46
- x_b_s, x_r_s, x_c_s = x.stride()
46
+ x_b_s, x_r_s, x_c_s = stride(x)
47
47
  o_b, o_r, o_c = output.shape
48
- o_b_s, o_r_s, o_c_s = output.stride()
48
+ o_b_s, o_r_s, o_c_s = stride(output)
49
49
 
50
50
  if triton_block_size is None:
51
51
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -4,7 +4,7 @@ from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
6
  from blksprs.ops.transpose import transpose
7
- from blksprs.utils.tools import get_triton_block_size
7
+ from blksprs.utils.tools import get_triton_block_size, stride
8
8
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
9
9
  validate_sparsity, validate_sparsity_block_size, validate_triton_block_size, validate_dtype_float
10
10
 
@@ -82,17 +82,17 @@ class _BlocksparseMatmulSSS(torch.autograd.Function):
82
82
  dtype=x.dtype, device=x.device)
83
83
 
84
84
  x_b, x_r, x_c = x.size()
85
- x_b_s, x_r_s, x_c_s = x.stride()
85
+ x_b_s, x_r_s, x_c_s = stride(x)
86
86
  s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
87
- s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout_x.stride()
87
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
88
88
  y_b, y_r, y_c = y.size()
89
- y_b_s, y_r_s, y_c_s = y.stride()
89
+ y_b_s, y_r_s, y_c_s = stride(y)
90
90
  s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_y.size()
91
- s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = sparsity_layout_y.stride()
91
+ s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = stride(sparsity_layout_y)
92
92
  o_b, o_r, o_c = output.size()
93
- o_b_s, o_r_s, o_c_s = output.stride()
93
+ o_b_s, o_r_s, o_c_s = stride(output)
94
94
  s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
95
- s_lut_o_r_s, s_lut_o_c_s = sparsity_lut_o.stride()
95
+ s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_o)
96
96
 
97
97
  if triton_block_size is None:
98
98
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -1,10 +1,7 @@
1
1
  import torch
2
- import triton
3
- from sympy.utilities.iterables import partitions
4
2
  from torch import Tensor
5
- from triton import language as tl
6
3
 
7
- from blksprs.utils.tools import get_triton_block_size
4
+ from blksprs.ops.repeat import forward_flow
8
5
 
9
6
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
10
7
  validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
@@ -48,12 +45,11 @@ def split(x: Tensor, sparsity_layout: Tensor, partitions: int,
48
45
  sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
49
46
  (sparsity_layout_flat == 1) -
50
47
  (1 * (sparsity_layout_flat == 0)))
51
- .reshape(sparsity_layout.size())
52
48
  .reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
53
49
  sparsity_layout.size(2) // partitions)
54
50
  .permute(0, 2, 1, 3).reshape(-1).contiguous())
55
51
 
56
- n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
52
+ n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
57
53
 
58
54
  validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
59
55
 
@@ -66,10 +62,11 @@ class _BlocksparseSplit(torch.autograd.Function):
66
62
  @staticmethod
67
63
  def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
68
64
  num_partitions: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
65
+ ctx.save_for_backward(sparsity_layout_o)
69
66
  ctx.num_partitions = num_partitions
70
67
 
71
- return forward_reorder(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
72
- n_sparse_blocks, triton_block_size)
68
+ return forward_flow(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
69
+ n_sparse_blocks, triton_block_size)
73
70
 
74
71
  @staticmethod
75
72
  def backward(ctx, grad_output):
@@ -126,7 +123,7 @@ def merge(x: Tensor, sparsity_layout: Tensor, partitions: int,
126
123
  sparsity_layout.size(1), sparsity_layout.size(2) * partitions)
127
124
  .reshape(-1).contiguous())
128
125
 
129
- n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
126
+ n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
130
127
 
131
128
  validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
132
129
 
@@ -139,10 +136,11 @@ class _BlocksparseMerge(torch.autograd.Function):
139
136
  @staticmethod
140
137
  def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
141
138
  num_partitions: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
139
+ ctx.save_for_backward(sparsity_layout_o)
142
140
  ctx.num_partitions = num_partitions
143
141
 
144
- return forward_reorder(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
145
- n_sparse_blocks, triton_block_size)
142
+ return forward_flow(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
143
+ n_sparse_blocks, triton_block_size)
146
144
 
147
145
  @staticmethod
148
146
  def backward(ctx, grad_output):
@@ -155,90 +153,3 @@ class _BlocksparseMerge(torch.autograd.Function):
155
153
  sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None
156
154
 
157
155
 
158
- def forward_reorder(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
159
- sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
160
- output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
161
- dtype=x.dtype, device=x.device)
162
-
163
- x_b, x_r, x_c = x.size()
164
- x_b_s, x_r_s, x_c_s = x.stride()
165
- s_l_b, s_l_r, s_l_c = sparsity_layout_o.size()
166
- s_l_b_s, s_l_r_s, s_l_c_s = sparsity_layout_o.stride()
167
- s_lut_r, s_lut_c = sparsity_lut.shape
168
- s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
169
- o_b, o_r, o_c = output.size()
170
- o_b_s, o_r_s, o_c_s = output.stride()
171
-
172
- if triton_block_size is None:
173
- triton_block_size = get_triton_block_size(sparsity_block_size)
174
-
175
- triton_grid = lambda meta: [o_b,
176
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
177
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
178
-
179
- (kernel_blocksparse_reorder[triton_grid]
180
- (x,
181
- x_b, x_b_s, x_r_s, x_c_s,
182
- s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
183
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
184
- sparsity_reverse_lut,
185
- output,
186
- o_b, o_b_s,
187
- triton_block_size))
188
-
189
- # Save for backward pass
190
- ctx.save_for_backward(sparsity_layout_o)
191
- ctx.sparsity_block_size = sparsity_block_size
192
- ctx.triton_block_size = triton_block_size
193
-
194
- return output
195
-
196
-
197
- @triton.jit
198
- def kernel_blocksparse_reorder(x,
199
- x_b, x_b_s, x_r_s, x_c_s,
200
- s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
201
- s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
202
- r_lut,
203
- o,
204
- o_b, o_b_s,
205
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
206
- # Get triton block indices
207
- pid_blk = tl.program_id(axis=0)
208
- pid_row = tl.program_id(axis=1)
209
- pid_col = tl.program_id(axis=2)
210
-
211
- # Get sparsity index of current output block consisting of its batch, row, and column index
212
- spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
213
- spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
214
- spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
215
-
216
- spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
217
- spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
218
- spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
219
-
220
- spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
221
- spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
222
- spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
223
-
224
- # Get reverse sparsity index
225
- rev_idx_spa_idx = (spa_bat * s_l_b_s +
226
- spa_row * s_l_r_s +
227
- spa_col * s_l_c_s)
228
- rev_idx_spa_msk = (rev_idx_spa_idx < s_l_b * s_l_b_s)
229
- rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
230
-
231
- if rev_idx_spa == -1:
232
- assert False, "Invalid sparsity block"
233
-
234
- blk_x_idx = (rev_idx_spa * x_b_s +
235
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
236
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
237
- blk_x_msk = (blk_x_idx < x_b * x_b_s)
238
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
239
-
240
- blk_o_idx = (pid_blk * o_b_s +
241
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
242
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
243
- blk_o_msk = (blk_o_idx < o_b * o_b_s)
244
- tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
@@ -0,0 +1,241 @@
1
+ import torch
2
+ import triton
3
+ from triton import language as tl
4
+ from torch import Tensor
5
+
6
+ from blksprs.utils.tools import get_triton_block_size, stride
7
+ from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
8
+ validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
9
+
10
+
11
+ def repeat(x: Tensor, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
12
+ sparsity_block_size: int, sparsity_layout_output: Tensor = None, triton_block_size: int = None) -> (
13
+ Tensor, Tensor):
14
+ x = x.contiguous()
15
+
16
+ validate_dimensions(x)
17
+ validate_contiguous(x)
18
+ validate_device(x)
19
+ validate_sparsity(sparsity_block_size, (x, sparsity_layout_x))
20
+ validate_sparsity_block_size(sparsity_block_size, x)
21
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
22
+
23
+ sparsity_layout_o = sparsity_layout_x.repeat(repeats[0], repeats[1], repeats[2])
24
+
25
+ if sparsity_layout_output is not None:
26
+ sparsity_layout_o = torch.logical_and(sparsity_layout_o, sparsity_layout_output)
27
+
28
+ sparsity_lut = torch.nonzero(sparsity_layout_o).contiguous()
29
+
30
+ sparsity_layout_flat = sparsity_layout_x.reshape(-1)
31
+ sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
32
+ (sparsity_layout_flat == 1) -
33
+ (1 * (sparsity_layout_flat == 0)))
34
+ .reshape(sparsity_layout_x.size())
35
+ .repeat(repeats[0], repeats[1], repeats[2])
36
+ .reshape(-1).contiguous())
37
+
38
+ n_sparse_blocks = torch.sum(sparsity_layout_o.to(torch.int)).item()
39
+
40
+ validate_contiguous(sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
41
+
42
+ return _BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
43
+ sparsity_block_size, n_sparse_blocks, triton_block_size), sparsity_layout_o
44
+
45
+
46
+ class _BlocksparseRepeat(torch.autograd.Function):
47
+
48
+ @staticmethod
49
+ def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor,
50
+ sparsity_reverse_lut: Tensor,
51
+ sparsity_block_size: int, n_sparse_blocks: int,
52
+ triton_block_size: int) -> Tensor:
53
+ ctx.save_for_backward(sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
54
+ ctx.x_size = x.size()
55
+ ctx.x_stride = stride(x)
56
+
57
+ return forward_flow(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
58
+ n_sparse_blocks, triton_block_size)
59
+
60
+ @staticmethod
61
+ def backward(ctx, grad_output):
62
+ sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut = ctx.saved_tensors
63
+ x_size = ctx.x_size
64
+ x_stride = ctx.x_stride
65
+ sparsity_block_size = ctx.sparsity_block_size
66
+ triton_block_size = ctx.triton_block_size
67
+
68
+ n_sparse_blocks = torch.sum(sparsity_layout_x.to(torch.int)).item()
69
+
70
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
71
+ dtype=grad_output.dtype, device=grad_output.device)
72
+
73
+ x_b, x_r, x_c = grad_output.size()
74
+ x_b_s, x_r_s, x_c_s = stride(grad_output)
75
+ s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_o.size()
76
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_o)
77
+ s_lut_r, s_lut_c = sparsity_lut.size()
78
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
79
+ o_b, o_r, o_c = x_size
80
+ o_b_s, o_r_s, o_c_s = x_stride
81
+
82
+ if triton_block_size is None:
83
+ triton_block_size = get_triton_block_size(sparsity_block_size)
84
+
85
+ triton_grid = lambda meta: [x_b,
86
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
87
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
88
+
89
+ (kernel_blocksparse_flow_push[triton_grid]
90
+ (grad_output,
91
+ x_b, x_b_s, x_r_s, x_c_s,
92
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
93
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
94
+ sparsity_reverse_lut,
95
+ output,
96
+ o_b, o_b_s, o_r_s, o_c_s,
97
+ triton_block_size))
98
+
99
+ return output, None, None, None, None, None, None, None
100
+
101
+
102
+ @triton.jit
103
+ def kernel_blocksparse_flow_pull(x,
104
+ x_b, x_b_s, x_r_s, x_c_s,
105
+ o,
106
+ o_b, o_b_s, o_r_s, o_c_s,
107
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
108
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
109
+ r_lut,
110
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
111
+ # Get triton block indices
112
+ pid_blk = tl.program_id(axis=0)
113
+ pid_row = tl.program_id(axis=1)
114
+ pid_col = tl.program_id(axis=2)
115
+
116
+ # Get sparsity index of current output block consisting of its batch, row, and column index
117
+ spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
118
+ spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
119
+ spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
120
+
121
+ spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
122
+ spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
123
+ spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
124
+
125
+ spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
126
+ spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
127
+ spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
128
+
129
+ # Get reverse sparsity index
130
+ rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
131
+ spa_row * s_l_o_r_s +
132
+ spa_col * s_l_o_c_s)
133
+ rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
134
+ rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
135
+
136
+ if rev_idx_spa == -1:
137
+ tl.device_assert(False)
138
+ return
139
+
140
+ blk_x_idx = (rev_idx_spa * x_b_s +
141
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
142
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
143
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
144
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
145
+
146
+ blk_o_idx = (pid_blk * o_b_s +
147
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
148
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
149
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
150
+ tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
151
+
152
+
153
+ @triton.jit
154
+ def kernel_blocksparse_flow_push(x,
155
+ x_b, x_b_s, x_r_s, x_c_s,
156
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
157
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
158
+ r_lut,
159
+ o,
160
+ o_b, o_b_s, o_r_s, o_c_s,
161
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
162
+ # Get triton block indices
163
+ pid_blk = tl.program_id(axis=0)
164
+ pid_row = tl.program_id(axis=1)
165
+ pid_col = tl.program_id(axis=2)
166
+
167
+ # Get sparsity index of current input block consisting of its batch, row, and column index
168
+ spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
169
+ spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
170
+ spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
171
+
172
+ spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
173
+ spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
174
+ spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
175
+
176
+ spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
177
+ spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
178
+ spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
179
+
180
+ # Get reverse sparsity index
181
+ rev_idx_spa_idx = (spa_bat * s_l_x_b_s +
182
+ spa_row * s_l_x_r_s +
183
+ spa_col * s_l_x_c_s)
184
+ rev_idx_spa_msk = (rev_idx_spa_idx < s_l_x_b * s_l_x_b_s)
185
+ rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
186
+
187
+ if rev_idx_spa == -1:
188
+ tl.device_assert(False)
189
+ return
190
+
191
+ blk_x_idx = (pid_blk * x_b_s +
192
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
193
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
194
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
195
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
196
+
197
+ blk_o_idx = (rev_idx_spa * o_b_s +
198
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
199
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
200
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
201
+ tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
202
+
203
+
204
+ def forward_flow(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
205
+ sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
206
+ output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
207
+ dtype=x.dtype, device=x.device)
208
+ output = torch.zeros_like(output)
209
+
210
+ x_b, x_r, x_c = x.size()
211
+ x_b_s, x_r_s, x_c_s = stride(x)
212
+ o_b, o_r, o_c = output.size()
213
+ o_b_s, o_r_s, o_c_s = stride(output)
214
+ s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
215
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
216
+ s_lut_r, s_lut_c = sparsity_lut.size()
217
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
218
+ asdf = torch.tensor(sparsity_lut).stride()
219
+
220
+ if triton_block_size is None:
221
+ triton_block_size = get_triton_block_size(sparsity_block_size)
222
+
223
+ triton_grid = lambda meta: [o_b,
224
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
225
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
226
+
227
+ (kernel_blocksparse_flow_pull[triton_grid]
228
+ (x,
229
+ x_b, x_b_s, x_r_s, x_c_s,
230
+ output,
231
+ o_b, o_b_s, o_r_s, o_c_s,
232
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
233
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
234
+ sparsity_reverse_lut,
235
+ triton_block_size))
236
+
237
+ # Save for backward pass
238
+ ctx.sparsity_block_size = sparsity_block_size
239
+ ctx.triton_block_size = triton_block_size
240
+
241
+ return output
@@ -5,7 +5,7 @@ from triton import language as tl
5
5
 
6
6
  from blksprs.ops.exp import exp
7
7
  from blksprs.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
8
- from blksprs.utils.tools import get_triton_block_size
8
+ from blksprs.utils.tools import get_triton_block_size, stride
9
9
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
10
10
  validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
11
11
 
@@ -61,9 +61,9 @@ class _BlocksparseSoftmax(torch.autograd.Function):
61
61
  output = torch.empty_like(x)
62
62
 
63
63
  x_b, x_r, x_c = x.size()
64
- x_b_s, x_r_s, x_c_s = x.stride()
64
+ x_b_s, x_r_s, x_c_s = stride(x)
65
65
  s_lut_r, s_lut_c = sparsity_lut.size()
66
- s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
66
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
67
67
  o_b, o_r, o_c = output.size()
68
68
 
69
69
  x_row_wise_max, sparsity_layout_rwm = row_wise_max(x, sparsity_layout, sparsity_block_size,
@@ -76,9 +76,9 @@ class _BlocksparseSoftmax(torch.autograd.Function):
76
76
  triton_block_size=triton_block_size)
77
77
 
78
78
  s_b, s_r, s_c = x_exp_row_wise_sum.shape
79
- s_b_s, s_r_s, s_c_s = x_exp_row_wise_sum.stride()
79
+ s_b_s, s_r_s, s_c_s = stride(x_exp_row_wise_sum)
80
80
  s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_rws.shape
81
- s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = sparsity_layout_rws.stride()
81
+ s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_rws)
82
82
 
83
83
  if triton_block_size is None:
84
84
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -119,13 +119,13 @@ class _BlocksparseSoftmax(torch.autograd.Function):
119
119
  (1 * (sparsity_layout_s_flat == 0)))
120
120
 
121
121
  o_b, o_r, o_c = o.size()
122
- o_b_s, o_r_s, o_c_s = o.stride()
122
+ o_b_s, o_r_s, o_c_s = stride(o)
123
123
  s_lut_r, s_lut_c = sparsity_lut.size()
124
- s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
124
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
125
125
  s_b, s_r, s_c = s.size()
126
- s_b_s, s_r_s, s_c_s = s.stride()
126
+ s_b_s, s_r_s, s_c_s = stride(s)
127
127
  s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
128
- s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = sparsity_layout_s.stride()
128
+ s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_s)
129
129
 
130
130
  grad_x = torch.empty_like(o, dtype=torch.float)
131
131
 
@@ -181,7 +181,8 @@ class _BlocksparseSoftmax(torch.autograd.Function):
181
181
  rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
182
182
 
183
183
  if rev_idx_spa_s == -1:
184
- assert False, "Invalid sparsity block"
184
+ tl.device_assert(False)
185
+ return
185
186
 
186
187
  # Load x block
187
188
  blk_x_idx = ((pid_blk * x_b_s) +
@@ -3,7 +3,7 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
- from blksprs.utils.tools import get_triton_block_size
6
+ from blksprs.utils.tools import get_triton_block_size, stride
7
7
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
8
8
  validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
9
9
 
@@ -63,13 +63,13 @@ class _BlocksparseTranspose(torch.autograd.Function):
63
63
  dtype=x.dtype, device=x.device)
64
64
 
65
65
  x_b, x_r, x_c = x.size()
66
- x_b_s, x_r_s, x_c_s = x.stride()
66
+ x_b_s, x_r_s, x_c_s = stride(x)
67
67
  s_l_b, s_l_r, s_l_c = sparsity_layout_o.size()
68
- s_l_b_s, s_l_r_s, s_l_c_s = sparsity_layout_o.stride()
68
+ s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout_o)
69
69
  s_lut_r, s_lut_c = sparsity_lut.shape
70
- s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
70
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
71
71
  o_b, o_r, o_c = output.size()
72
- o_b_s, o_r_s, o_c_s = output.stride()
72
+ o_b_s, o_r_s, o_c_s = stride(output)
73
73
 
74
74
  if triton_block_size is None:
75
75
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -140,7 +140,8 @@ class _BlocksparseTranspose(torch.autograd.Function):
140
140
  rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
141
141
 
142
142
  if rev_idx_spa == -1:
143
- assert False, "Invalid sparsity block"
143
+ tl.device_assert(False)
144
+ return
144
145
 
145
146
  blk_x_idx = (rev_idx_spa * x_b_s +
146
147
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
@@ -23,3 +23,6 @@ def get_triton_block_size(sparsity_block_size: int, limit: int = 128):
23
23
 
24
24
  def disable_validation():
25
25
  _set_skip_validation(True)
26
+
27
+ def stride(x: Tensor):
28
+ return x.view(x.shape).stride()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: blksprs
3
- Version: 1.6.1
3
+ Version: 1.7
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -34,6 +34,7 @@ Currently supported operations (includes gradient calculation):
34
34
  - Transpose
35
35
  - Gather
36
36
  - Scatter (_supports either no reduction or summation, gradients are only available for summation_)
37
+ - Repeat (_supports target sparsity layout_)
37
38
  - Splitting and merging of matrices along the last dimension
38
39
  - Conversion to and from sparse form
39
40
  - Conversion to different sparsity layouts and different sparsity block sizes
@@ -64,7 +65,7 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
64
65
 
65
66
  ### Dependencies
66
67
 
67
- - [PyTorch](https://pytorch.org/) (built with v2.4.0)
68
+ - [PyTorch](https://pytorch.org/) (built with v2.5.0)
68
69
  - _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
69
70
 
70
71
  ## Changelog
@@ -17,6 +17,7 @@ blksprs/ops/distribution.py
17
17
  blksprs/ops/exp.py
18
18
  blksprs/ops/matmul.py
19
19
  blksprs/ops/partitioning.py
20
+ blksprs/ops/repeat.py
20
21
  blksprs/ops/softmax.py
21
22
  blksprs/ops/transpose.py
22
23
  blksprs/utils/benchmarking.py
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "blksprs"
3
- version = "1.6.1"
3
+ version = "1.7"
4
4
  authors = [{ name = "Felix Schön", email = "schoen@kr.tuwien.ac.at" }]
5
5
  description = "A lightweight library for operations on blocksparse matrices in PyTorch."
6
6
  readme = "README.md"
File without changes
File without changes