blksprs 2.1.3__tar.gz → 2.1.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. {blksprs-2.1.3 → blksprs-2.1.5}/PKG-INFO +7 -11
  2. {blksprs-2.1.3 → blksprs-2.1.5}/README.md +6 -10
  3. {blksprs-2.1.3 → blksprs-2.1.5}/blksprs/__init__.py +2 -2
  4. {blksprs-2.1.3 → blksprs-2.1.5}/blksprs/ops/conversion.py +12 -20
  5. {blksprs-2.1.3 → blksprs-2.1.5}/blksprs/ops/distribution.py +12 -20
  6. {blksprs-2.1.3 → blksprs-2.1.5}/blksprs/ops/flow.py +12 -20
  7. {blksprs-2.1.3 → blksprs-2.1.5}/blksprs/ops/matmul.py +6 -10
  8. {blksprs-2.1.3 → blksprs-2.1.5}/blksprs/ops/misc/broadcast_ops.py +6 -10
  9. {blksprs-2.1.3 → blksprs-2.1.5}/blksprs/ops/misc/row_wise.py +35 -35
  10. {blksprs-2.1.3 → blksprs-2.1.5}/blksprs/ops/repeat.py +2 -2
  11. {blksprs-2.1.3 → blksprs-2.1.5}/blksprs/ops/softmax.py +10 -12
  12. {blksprs-2.1.3 → blksprs-2.1.5}/blksprs/utils/autotuning.py +2 -2
  13. {blksprs-2.1.3 → blksprs-2.1.5}/blksprs/utils/validation.py +21 -0
  14. {blksprs-2.1.3 → blksprs-2.1.5}/blksprs.egg-info/PKG-INFO +7 -11
  15. {blksprs-2.1.3 → blksprs-2.1.5}/pyproject.toml +1 -1
  16. {blksprs-2.1.3 → blksprs-2.1.5}/blksprs/layouting/distribution_layout.py +0 -0
  17. {blksprs-2.1.3 → blksprs-2.1.5}/blksprs/layouting/sparsity_layout.py +0 -0
  18. {blksprs-2.1.3 → blksprs-2.1.5}/blksprs/ops/partitioning.py +0 -0
  19. {blksprs-2.1.3 → blksprs-2.1.5}/blksprs/ops/transpose.py +0 -0
  20. {blksprs-2.1.3 → blksprs-2.1.5}/blksprs/utils/benchmarking.py +0 -0
  21. {blksprs-2.1.3 → blksprs-2.1.5}/blksprs/utils/blksprs_tensor.py +0 -0
  22. {blksprs-2.1.3 → blksprs-2.1.5}/blksprs/utils/processing.py +0 -0
  23. {blksprs-2.1.3 → blksprs-2.1.5}/blksprs/utils/tools.py +0 -0
  24. {blksprs-2.1.3 → blksprs-2.1.5}/blksprs.egg-info/SOURCES.txt +0 -0
  25. {blksprs-2.1.3 → blksprs-2.1.5}/blksprs.egg-info/dependency_links.txt +0 -0
  26. {blksprs-2.1.3 → blksprs-2.1.5}/blksprs.egg-info/requires.txt +0 -0
  27. {blksprs-2.1.3 → blksprs-2.1.5}/blksprs.egg-info/top_level.txt +0 -0
  28. {blksprs-2.1.3 → blksprs-2.1.5}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: blksprs
3
- Version: 2.1.3
3
+ Version: 2.1.5
4
4
  Summary: A lightweight library for operations on block-sparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -20,7 +20,8 @@ Requires-Dist: matplotlib; extra == "test"
20
20
  # blksprs
21
21
 
22
22
  [![GitHub Release](https://img.shields.io/github/v/release/FelixSchoen/blksprs?include_prereleases&label=Latest%20Release)](https://github.com/FelixSchoen/blksprs/releases)
23
- [![Python Version](https://img.shields.io/badge/Python%20Version-3.11-blue)](https://www.python.org/downloads/release/python-3119/)
23
+ [![Python 3.11](https://img.shields.io/badge/Python%20Version-3.11-blue)](https://www.python.org/downloads/release/python-3119/)
24
+ [![Python 3.12](https://img.shields.io/badge/Python%20Version-3.12-blue)](https://www.python.org/downloads/release/python-31210/)
24
25
 
25
26
  ## Overview
26
27
 
@@ -75,9 +76,7 @@ _* see the [Roadmap](#roadmap) section for more information_
75
76
 
76
77
  ## Installation
77
78
 
78
- Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible
79
- with
80
- the Linux platform**.
79
+ Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible with the Linux platform**.
81
80
  Keep track of this [issue](https://github.com/triton-lang/triton/issues/1640) for updates.
82
81
 
83
82
  We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) using pip:
@@ -86,8 +85,8 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
86
85
 
87
86
  ### Dependencies
88
87
 
89
- - [PyTorch](https://pytorch.org/) (built with v2.6)
90
- - _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.2.4)_
88
+ - [PyTorch](https://pytorch.org/) (built with v2.7.1)
89
+ - _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.3.1)_
91
90
  - _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
92
91
 
93
92
  ## Changelog
@@ -103,7 +102,7 @@ We will continue to maintain the library and fix any issues that arise.
103
102
  Should you find any bugs please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
104
103
  We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
105
104
 
106
- It might be that this changes with future projects, but as of March 2025, we are content with the current state of the
105
+ It might be that this changes with future projects, but as of June 2025, we are content with the current state of the
107
106
  library.
108
107
 
109
108
  ## Known Limitations and Issues
@@ -112,9 +111,6 @@ library.
112
111
  In order to work around this bug a manual conversion of some values is needed, (slightly) negatively impacting
113
112
  performance.
114
113
  Watch the [issue](https://github.com/triton-lang/triton/issues/6376) on Triton's issue tracker for more information.
115
- - PyTorch's `wrap_triton()` currently does not support config pruning. It thus cannot be used for some of the kernels,
116
- which could impact graph compilation.
117
- - There seem to be some issues with autocasting, forcing some operations to manually cast.
118
114
  - There will be some slight numerical differences between vanilla and blksprs operations.
119
115
  These instabilities are due to Triton and thus cannot be fixed by this library alone.
120
116
  However, for all intents and purposes, these very minor differences should not matter and can safely be ignored.
@@ -1,7 +1,8 @@
1
1
  # blksprs
2
2
 
3
3
  [![GitHub Release](https://img.shields.io/github/v/release/FelixSchoen/blksprs?include_prereleases&label=Latest%20Release)](https://github.com/FelixSchoen/blksprs/releases)
4
- [![Python Version](https://img.shields.io/badge/Python%20Version-3.11-blue)](https://www.python.org/downloads/release/python-3119/)
4
+ [![Python 3.11](https://img.shields.io/badge/Python%20Version-3.11-blue)](https://www.python.org/downloads/release/python-3119/)
5
+ [![Python 3.12](https://img.shields.io/badge/Python%20Version-3.12-blue)](https://www.python.org/downloads/release/python-31210/)
5
6
 
6
7
  ## Overview
7
8
 
@@ -56,9 +57,7 @@ _* see the [Roadmap](#roadmap) section for more information_
56
57
 
57
58
  ## Installation
58
59
 
59
- Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible
60
- with
61
- the Linux platform**.
60
+ Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible with the Linux platform**.
62
61
  Keep track of this [issue](https://github.com/triton-lang/triton/issues/1640) for updates.
63
62
 
64
63
  We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) using pip:
@@ -67,8 +66,8 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
67
66
 
68
67
  ### Dependencies
69
68
 
70
- - [PyTorch](https://pytorch.org/) (built with v2.6)
71
- - _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.2.4)_
69
+ - [PyTorch](https://pytorch.org/) (built with v2.7.1)
70
+ - _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.3.1)_
72
71
  - _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
73
72
 
74
73
  ## Changelog
@@ -84,7 +83,7 @@ We will continue to maintain the library and fix any issues that arise.
84
83
  Should you find any bugs please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
85
84
  We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
86
85
 
87
- It might be that this changes with future projects, but as of March 2025, we are content with the current state of the
86
+ It might be that this changes with future projects, but as of June 2025, we are content with the current state of the
88
87
  library.
89
88
 
90
89
  ## Known Limitations and Issues
@@ -93,9 +92,6 @@ library.
93
92
  In order to work around this bug a manual conversion of some values is needed, (slightly) negatively impacting
94
93
  performance.
95
94
  Watch the [issue](https://github.com/triton-lang/triton/issues/6376) on Triton's issue tracker for more information.
96
- - PyTorch's `wrap_triton()` currently does not support config pruning. It thus cannot be used for some of the kernels,
97
- which could impact graph compilation.
98
- - There seem to be some issues with autocasting, forcing some operations to manually cast.
99
95
  - There will be some slight numerical differences between vanilla and blksprs operations.
100
96
  These instabilities are due to Triton and thus cannot be fixed by this library alone.
101
97
  However, for all intents and purposes, these very minor differences should not matter and can safely be ignored.
@@ -1,6 +1,6 @@
1
1
  from blksprs.utils.blksprs_tensor import BlksprsTensor
2
2
 
3
- __version__ = "2.1.2"
3
+ __version__ = "2.1.5"
4
4
 
5
5
 
6
6
  class ops:
@@ -27,9 +27,9 @@ class utils:
27
27
  from blksprs.utils.processing import apply_torch_linear, apply_torch_normalisation, apply_torch_dropout, \
28
28
  apply_function_applicable_row_wise
29
29
  from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse
30
+ from blksprs.utils.validation import disable_contiguous, disable_validation
30
31
 
31
32
  class validation:
32
- from blksprs.utils.validation import disable_validation
33
33
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_dtype_float, \
34
34
  validate_dtype_int, validate_device, validate_sparsity, validate_sparsity_dense, \
35
35
  validate_sparsity_block_size
@@ -106,17 +106,13 @@ def to_sparse_kernel(x,
106
106
  pid_col = tl.program_id(axis=2)
107
107
 
108
108
  # Get sparsity index of current output block consisting of its batch, row, and column index
109
- spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
110
- spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
111
- spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
109
+ spa_val_idx = pid_blk * s_lut_r_s + tl.arange(0, 4) * s_lut_c_s
110
+ spa_val_msk = (tl.arange(0, 4) < 3)
111
+ spa_val = tl.load(s_lut + spa_val_idx, mask=spa_val_msk)
112
112
 
113
- spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
114
- spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
115
- spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
116
-
117
- spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
118
- spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
119
- spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
113
+ spa_bat = tl.sum(spa_val * (tl.arange(0, 4) == 0))
114
+ spa_row = tl.sum(spa_val * (tl.arange(0, 4) == 1))
115
+ spa_col = tl.sum(spa_val * (tl.arange(0, 4) == 2))
120
116
 
121
117
  # Load block from dense tensor
122
118
  blk_d_idx = (spa_bat * x_b_s +
@@ -445,17 +441,13 @@ def adapt_layout_kernel(x,
445
441
  pid_col = tl.program_id(axis=2)
446
442
 
447
443
  # Get position of current sparsity block consisting of its batch, row, and column index
448
- spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
449
- spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
450
- spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
451
-
452
- spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
453
- spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
454
- spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
444
+ spa_val_idx = pid_blk * s_lut_o_r_s + tl.arange(0, 4) * s_lut_o_c_s
445
+ spa_val_msk = (tl.arange(0, 4) < 3)
446
+ spa_val = tl.load(s_lut_o + spa_val_idx, mask=spa_val_msk)
455
447
 
456
- spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
457
- spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
458
- spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
448
+ spa_bat_o = tl.sum(spa_val * (tl.arange(0, 4) == 0))
449
+ spa_row_o = tl.sum(spa_val * (tl.arange(0, 4) == 1))
450
+ spa_col_o = tl.sum(spa_val * (tl.arange(0, 4) == 2))
459
451
 
460
452
  # Get equivalent sparsity block in from layout
461
453
  spa_bat_x = spa_bat_o
@@ -125,17 +125,13 @@ def gather_kernel(x,
125
125
  pid_col = tl.program_id(axis=2)
126
126
 
127
127
  # Get position of current sparsity block consisting of its batch, row, and column index
128
- spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
129
- spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
130
- spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
128
+ spa_val_idx = pid_blk * s_lut_o_r_s + tl.arange(0, 4) * s_lut_o_c_s
129
+ spa_val_msk = (tl.arange(0, 4) < 3)
130
+ spa_val = tl.load(s_lut_o + spa_val_idx, mask=spa_val_msk)
131
131
 
132
- spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
133
- spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
134
- spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
135
-
136
- spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
137
- spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
138
- spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
132
+ spa_bat_o = tl.sum(spa_val * (tl.arange(0, 4) == 0))
133
+ spa_row_o = tl.sum(spa_val * (tl.arange(0, 4) == 1))
134
+ spa_col_o = tl.sum(spa_val * (tl.arange(0, 4) == 2))
139
135
 
140
136
  # Load index values
141
137
  blk_i_idx = ((pid_blk * i_b_s) +
@@ -374,17 +370,13 @@ def scatter_reduce_kernel(x,
374
370
  pid_col = tl.program_id(axis=2)
375
371
 
376
372
  # Get position of current sparsity block consisting of its batch, row, and column index
377
- spa_bat_x_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
378
- spa_bat_x_msk = (spa_bat_x_idx >= 0 and spa_bat_x_idx < s_lut_x_r * s_lut_x_r_s)
379
- spa_bat_x = tl.load(s_lut_x + spa_bat_x_idx, mask=spa_bat_x_msk)
380
-
381
- spa_row_x_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
382
- spa_row_x_msk = (spa_row_x_idx >= 0 and spa_row_x_idx < s_lut_x_r * s_lut_x_r_s)
383
- spa_row_x = tl.load(s_lut_x + spa_row_x_idx, mask=spa_row_x_msk)
373
+ spa_val_idx = pid_blk * s_lut_x_r_s + tl.arange(0, 4) * s_lut_x_c_s
374
+ spa_val_msk = (tl.arange(0, 4) < 3)
375
+ spa_val = tl.load(s_lut_x + spa_val_idx, mask=spa_val_msk)
384
376
 
385
- spa_col_x_idx = (pid_blk * s_lut_x_r_s + 2 * s_lut_x_c_s)
386
- spa_col_x_msk = (spa_col_x_idx >= 0 and spa_col_x_idx < s_lut_x_r * s_lut_x_r_s)
387
- spa_col_x = tl.load(s_lut_x + spa_col_x_idx, mask=spa_col_x_msk)
377
+ spa_bat_x = tl.sum(spa_val * (tl.arange(0, 4) == 0))
378
+ spa_row_x = tl.sum(spa_val * (tl.arange(0, 4) == 1))
379
+ spa_col_x = tl.sum(spa_val * (tl.arange(0, 4) == 2))
388
380
 
389
381
  # Load x values
390
382
  blk_x_idx = ((pid_blk * x_b_s) +
@@ -66,17 +66,13 @@ def flow_pull_kernel(x,
66
66
  pid_col = tl.program_id(axis=2)
67
67
 
68
68
  # Get sparsity index of current output block consisting of its batch, row, and column index
69
- spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
70
- spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
71
- spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
69
+ spa_val_idx = pid_blk * s_lut_r_s + tl.arange(0, 4) * s_lut_c_s
70
+ spa_val_msk = (tl.arange(0, 4) < 3)
71
+ spa_val = tl.load(s_lut + spa_val_idx, mask=spa_val_msk)
72
72
 
73
- spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
74
- spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
75
- spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
76
-
77
- spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
78
- spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
79
- spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
73
+ spa_bat = tl.sum(spa_val * (tl.arange(0, 4) == 0))
74
+ spa_row = tl.sum(spa_val * (tl.arange(0, 4) == 1))
75
+ spa_col = tl.sum(spa_val * (tl.arange(0, 4) == 2))
80
76
 
81
77
  # Load reverse sparsity index
82
78
  rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
@@ -157,17 +153,13 @@ def flow_push_kernel(x,
157
153
  pid_col = tl.program_id(axis=2)
158
154
 
159
155
  # Get sparsity index of current input block consisting of its batch, row, and column index
160
- spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
161
- spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
162
- spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
163
-
164
- spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
165
- spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
166
- spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
156
+ spa_val_idx = pid_blk * s_lut_r_s + tl.arange(0, 4) * s_lut_c_s
157
+ spa_val_msk = (tl.arange(0, 4) < 3)
158
+ spa_val = tl.load(s_lut + spa_val_idx, mask=spa_val_msk)
167
159
 
168
- spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
169
- spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
170
- spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
160
+ spa_bat = tl.sum(spa_val * (tl.arange(0, 4) == 0))
161
+ spa_row = tl.sum(spa_val * (tl.arange(0, 4) == 1))
162
+ spa_col = tl.sum(spa_val * (tl.arange(0, 4) == 2))
171
163
 
172
164
  # Get reverse sparsity index
173
165
  rev_idx_spa_idx = (spa_bat * s_l_x_b_s +
@@ -145,17 +145,13 @@ def matmul_kernel(x,
145
145
  pid_col = tl.program_id(axis=2)
146
146
 
147
147
  # Get position of current sparsity block consisting of its batch, row, and column index
148
- spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
149
- spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
150
- spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
148
+ spa_val_idx = pid_blk * s_lut_o_r_s + tl.arange(0, 4) * s_lut_o_c_s
149
+ spa_val_msk = (tl.arange(0, 4) < 3)
150
+ spa_val = tl.load(s_lut_o + spa_val_idx, mask=spa_val_msk)
151
151
 
152
- spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
153
- spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
154
- spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
155
-
156
- spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
157
- spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
158
- spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
152
+ spa_bat_o = tl.sum(spa_val * (tl.arange(0, 4) == 0))
153
+ spa_row_o = tl.sum(spa_val * (tl.arange(0, 4) == 1))
154
+ spa_col_o = tl.sum(spa_val * (tl.arange(0, 4) == 2))
159
155
 
160
156
  # Setup buffer
161
157
  buf = tl.zeros(shape=(TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), dtype=tl.float32)
@@ -110,17 +110,13 @@ def broadcast_add_kernel(x,
110
110
  pid_col = tl.program_id(axis=2)
111
111
 
112
112
  # Get position of current sparsity block consisting of its batch, row, and column index
113
- spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
114
- spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
115
- spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
113
+ spa_val_idx = pid_blk * s_lut_o_r_s + tl.arange(0, 4) * s_lut_o_c_s
114
+ spa_val_msk = (tl.arange(0, 4) < 3)
115
+ spa_val = tl.load(s_lut_o + spa_val_idx, mask=spa_val_msk)
116
116
 
117
- spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
118
- spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
119
- spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
120
-
121
- spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
122
- spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
123
- spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
117
+ spa_bat_o = tl.sum(spa_val * (tl.arange(0, 4) == 0))
118
+ spa_row_o = tl.sum(spa_val * (tl.arange(0, 4) == 1))
119
+ spa_col_o = tl.sum(spa_val * (tl.arange(0, 4) == 2))
124
120
 
125
121
  # Load x block
126
122
  blk_x_idx = (spa_bat_o * x_b_s +
@@ -119,17 +119,17 @@ def row_wise_sum_kernel(x,
119
119
  pid_col = tl.program_id(axis=2)
120
120
 
121
121
  # Get position of current sparsity block consisting of its batch and row index
122
- spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
123
- spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
124
- spa_bat = tl.load(s_lut_x + spa_bat_idx, mask=spa_bat_msk)
122
+ spa_val_idx = pid_blk * s_lut_x_r_s + tl.arange(0, 4) * s_lut_x_c_s
123
+ spa_val_msk = (tl.arange(0, 4) < 3)
124
+ spa_val = tl.load(s_lut_x + spa_val_idx, mask=spa_val_msk)
125
125
 
126
- spa_row_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
127
- spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_x_r * s_lut_x_r_s)
128
- spa_row = tl.load(s_lut_x + spa_row_idx, mask=spa_row_msk)
126
+ spa_bat_x = tl.sum(spa_val * (tl.arange(0, 4) == 0))
127
+ spa_row_x = tl.sum(spa_val * (tl.arange(0, 4) == 1))
128
+ spa_col_x = tl.sum(spa_val * (tl.arange(0, 4) == 2))
129
129
 
130
130
  # Load reverse sparsity index for current block
131
- rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
132
- spa_row * s_l_o_r_s)
131
+ rev_idx_spa_idx = (spa_bat_x * s_l_o_b_s +
132
+ spa_row_x * s_l_o_r_s)
133
133
  rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
134
134
  rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
135
135
 
@@ -263,17 +263,17 @@ def row_wise_max_kernel(x,
263
263
  pid_col = tl.program_id(axis=2)
264
264
 
265
265
  # Get position of current sparsity block consisting of its batch and row index
266
- spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
267
- spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
268
- spa_bat = tl.load(s_lut_x + spa_bat_idx, mask=spa_bat_msk)
266
+ spa_val_idx = pid_blk * s_lut_x_r_s + tl.arange(0, 4) * s_lut_x_c_s
267
+ spa_val_msk = (tl.arange(0, 4) < 3)
268
+ spa_val = tl.load(s_lut_x + spa_val_idx, mask=spa_val_msk)
269
269
 
270
- spa_row_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
271
- spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_x_r * s_lut_x_r_s)
272
- spa_row = tl.load(s_lut_x + spa_row_idx, mask=spa_row_msk)
270
+ spa_bat_x = tl.sum(spa_val * (tl.arange(0, 4) == 0))
271
+ spa_row_x = tl.sum(spa_val * (tl.arange(0, 4) == 1))
272
+ spa_col_x = tl.sum(spa_val * (tl.arange(0, 4) == 2))
273
273
 
274
274
  # Load reverse sparsity index for current block
275
- rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
276
- spa_row * s_l_o_r_s)
275
+ rev_idx_spa_idx = (spa_bat_x * s_l_o_b_s +
276
+ spa_row_x * s_l_o_r_s)
277
277
  rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
278
278
  rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
279
279
 
@@ -361,7 +361,7 @@ def row_wise_add_forward(x: Tensor, sparsity_lut_x: Tensor,
361
361
  triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
362
362
  triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
363
363
 
364
- (wrap_triton(kernel_blocksparse_row_wise_add)[triton_grid]
364
+ (wrap_triton(row_wise_add_kernel)[triton_grid]
365
365
  (x,
366
366
  x_b, x_b_s, x_r_s, x_c_s,
367
367
  sparsity_lut_x, s_lut_r, s_lut_r_s, s_lut_c_s,
@@ -383,33 +383,33 @@ def row_wise_add_forward(x: Tensor, sparsity_lut_x: Tensor,
383
383
  reset_to_zero=["o"]
384
384
  )
385
385
  @triton.jit
386
- def kernel_blocksparse_row_wise_add(x,
387
- x_b, x_b_s, x_r_s, x_c_s,
388
- s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
389
- y, y_b, y_b_s, y_r_s, y_c_s,
390
- s_l_y_b, s_l_y_b_s, s_l_y_r_s,
391
- r_lut_y,
392
- o,
393
- o_b, o_b_s, o_r_s, o_c_s,
394
- sparsity_block_size,
395
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
386
+ def row_wise_add_kernel(x,
387
+ x_b, x_b_s, x_r_s, x_c_s,
388
+ s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
389
+ y, y_b, y_b_s, y_r_s, y_c_s,
390
+ s_l_y_b, s_l_y_b_s, s_l_y_r_s,
391
+ r_lut_y,
392
+ o,
393
+ o_b, o_b_s, o_r_s, o_c_s,
394
+ sparsity_block_size,
395
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
396
396
  # Get triton block indices
397
397
  pid_blk = tl.program_id(axis=0)
398
398
  pid_row = tl.program_id(axis=1)
399
399
  pid_col = tl.program_id(axis=2)
400
400
 
401
401
  # Get position of current sparsity block consisting of its batch and row index
402
- spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
403
- spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
404
- spa_bat = tl.load(s_lut_x + spa_bat_idx, mask=spa_bat_msk)
402
+ spa_val_idx = pid_blk * s_lut_x_r_s + tl.arange(0, 4) * s_lut_x_c_s
403
+ spa_val_msk = (tl.arange(0, 4) < 3)
404
+ spa_val = tl.load(s_lut_x + spa_val_idx, mask=spa_val_msk)
405
405
 
406
- spa_row_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
407
- spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_x_r * s_lut_x_r_s)
408
- spa_row = tl.load(s_lut_x + spa_row_idx, mask=spa_row_msk)
406
+ spa_bat_x = tl.sum(spa_val * (tl.arange(0, 4) == 0))
407
+ spa_row_x = tl.sum(spa_val * (tl.arange(0, 4) == 1))
408
+ spa_col_x = tl.sum(spa_val * (tl.arange(0, 4) == 2))
409
409
 
410
410
  # Get reverse sparsity indices for s
411
- rev_idx_spa_s_idx = (spa_bat * s_l_y_b_s +
412
- spa_row * s_l_y_r_s)
411
+ rev_idx_spa_s_idx = (spa_bat_x * s_l_y_b_s +
412
+ spa_row_x * s_l_y_r_s)
413
413
  rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_y_b * s_l_y_b_s)
414
414
  rev_idx_spa_s = tl.load(r_lut_y + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
415
415
 
@@ -142,7 +142,7 @@ def repeat_build_lut(lut: dict, sparsity_layout_x: Tensor, repeats: tuple[int, i
142
142
  n_sparse_blocks = torch.sum(lut["sparsity_layout_o"].to(torch.int)).item()
143
143
  lut["n_sparse_blocks"] = n_sparse_blocks
144
144
 
145
- validate_contiguous(sparsity_layout_o, lut["sparsity_lut"], lut["sparsity_reverse_lut"])
145
+ validate_contiguous(lut["sparsity_layout_o"], lut["sparsity_lut"], lut["sparsity_reverse_lut"])
146
146
 
147
147
  return lut
148
148
 
@@ -178,7 +178,7 @@ def repeat_interleave_build_lut(lut: dict, sparsity_layout_x: Tensor, repeats: i
178
178
  n_sparse_blocks = torch.sum(lut["sparsity_layout_o"].to(torch.int)).item()
179
179
  lut["n_sparse_blocks"] = n_sparse_blocks
180
180
 
181
- validate_contiguous(sparsity_layout_o, lut["sparsity_lut"], lut["sparsity_reverse_lut"])
181
+ validate_contiguous(lut["sparsity_layout_o"], lut["sparsity_lut"], lut["sparsity_reverse_lut"])
182
182
 
183
183
  return lut
184
184
 
@@ -176,13 +176,12 @@ def softmax_kernel(x,
176
176
  pid_col = tl.program_id(axis=2)
177
177
 
178
178
  # Get position of current sparsity block consisting of its batch and row index
179
- spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
180
- spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
181
- spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
179
+ spa_val_idx = pid_blk * s_lut_r_s + tl.arange(0, 4) * s_lut_c_s
180
+ spa_val_msk = (tl.arange(0, 4) < 3)
181
+ spa_val = tl.load(s_lut + spa_val_idx, mask=spa_val_msk)
182
182
 
183
- spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
184
- spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
185
- spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
183
+ spa_bat = tl.sum(spa_val * (tl.arange(0, 4) == 0))
184
+ spa_row = tl.sum(spa_val * (tl.arange(0, 4) == 1))
186
185
 
187
186
  # Get reverse sparsity indices for s
188
187
  rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
@@ -241,13 +240,12 @@ def softmax_kernel_grad(g,
241
240
  pid_col = tl.program_id(axis=2)
242
241
 
243
242
  # Get position of current sparsity block consisting of its batch and row index
244
- spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
245
- spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
246
- spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
243
+ spa_val_idx = pid_blk * s_lut_r_s + tl.arange(0, 4) * s_lut_c_s
244
+ spa_val_msk = (tl.arange(0, 4) < 3)
245
+ spa_val = tl.load(s_lut + spa_val_idx, mask=spa_val_msk)
247
246
 
248
- spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
249
- spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
250
- spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
247
+ spa_bat = tl.sum(spa_val * (tl.arange(0, 4) == 0))
248
+ spa_row = tl.sum(spa_val * (tl.arange(0, 4) == 1))
251
249
 
252
250
  rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
253
251
  spa_row * s_l_s_r_s)
@@ -14,11 +14,11 @@ if blksprs_autotune_mode == "DEFAULT":
14
14
 
15
15
  (64, 3, 8),
16
16
  (64, 4, 4),
17
- (64, 5, 2),
17
+ (64, 4, 8),
18
18
 
19
19
  (128, 3, 8),
20
20
  (128, 4, 4),
21
- (128, 5, 2),
21
+ (128, 4, 8),
22
22
  ]
23
23
  elif blksprs_autotune_mode == "TEST":
24
24
  autotune_parameters = [
@@ -1,9 +1,17 @@
1
1
  import torch
2
2
  from torch import Tensor
3
3
 
4
+ CONTIGUOUS = True
4
5
  VALIDATION = True
5
6
 
6
7
 
8
+ def ensure_contiguous(*tensors: Tensor) -> tuple[Tensor, ...]:
9
+ if _check_skip_contiguous():
10
+ return tensors
11
+
12
+ return tuple(tensor.contiguous() for tensor in tensors)
13
+
14
+
7
15
  def validate_dimensions(*tensors: Tensor, dims=3) -> None:
8
16
  if _check_skip_validation():
9
17
  return
@@ -124,6 +132,19 @@ def validate_sparsity_block_size(sparsity_block_size: int, *tensors):
124
132
  raise ValueError("Tensor sizes must be divisible by sparsity block size")
125
133
 
126
134
 
135
+ def _check_skip_contiguous():
136
+ return not CONTIGUOUS
137
+
138
+
139
+ def _set_skip_contiguous(skip_contiguous: bool):
140
+ global CONTIGUOUS
141
+ CONTIGUOUS = not skip_contiguous
142
+
143
+
144
+ def disable_contiguous():
145
+ _set_skip_contiguous(True)
146
+
147
+
127
148
  def _check_skip_validation():
128
149
  return not VALIDATION
129
150
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: blksprs
3
- Version: 2.1.3
3
+ Version: 2.1.5
4
4
  Summary: A lightweight library for operations on block-sparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -20,7 +20,8 @@ Requires-Dist: matplotlib; extra == "test"
20
20
  # blksprs
21
21
 
22
22
  [![GitHub Release](https://img.shields.io/github/v/release/FelixSchoen/blksprs?include_prereleases&label=Latest%20Release)](https://github.com/FelixSchoen/blksprs/releases)
23
- [![Python Version](https://img.shields.io/badge/Python%20Version-3.11-blue)](https://www.python.org/downloads/release/python-3119/)
23
+ [![Python 3.11](https://img.shields.io/badge/Python%20Version-3.11-blue)](https://www.python.org/downloads/release/python-3119/)
24
+ [![Python 3.12](https://img.shields.io/badge/Python%20Version-3.12-blue)](https://www.python.org/downloads/release/python-31210/)
24
25
 
25
26
  ## Overview
26
27
 
@@ -75,9 +76,7 @@ _* see the [Roadmap](#roadmap) section for more information_
75
76
 
76
77
  ## Installation
77
78
 
78
- Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible
79
- with
80
- the Linux platform**.
79
+ Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible with the Linux platform**.
81
80
  Keep track of this [issue](https://github.com/triton-lang/triton/issues/1640) for updates.
82
81
 
83
82
  We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) using pip:
@@ -86,8 +85,8 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
86
85
 
87
86
  ### Dependencies
88
87
 
89
- - [PyTorch](https://pytorch.org/) (built with v2.6)
90
- - _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.2.4)_
88
+ - [PyTorch](https://pytorch.org/) (built with v2.7.1)
89
+ - _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.3.1)_
91
90
  - _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
92
91
 
93
92
  ## Changelog
@@ -103,7 +102,7 @@ We will continue to maintain the library and fix any issues that arise.
103
102
  Should you find any bugs please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
104
103
  We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
105
104
 
106
- It might be that this changes with future projects, but as of March 2025, we are content with the current state of the
105
+ It might be that this changes with future projects, but as of June 2025, we are content with the current state of the
107
106
  library.
108
107
 
109
108
  ## Known Limitations and Issues
@@ -112,9 +111,6 @@ library.
112
111
  In order to work around this bug a manual conversion of some values is needed, (slightly) negatively impacting
113
112
  performance.
114
113
  Watch the [issue](https://github.com/triton-lang/triton/issues/6376) on Triton's issue tracker for more information.
115
- - PyTorch's `wrap_triton()` currently does not support config pruning. It thus cannot be used for some of the kernels,
116
- which could impact graph compilation.
117
- - There seem to be some issues with autocasting, forcing some operations to manually cast.
118
114
  - There will be some slight numerical differences between vanilla and blksprs operations.
119
115
  These instabilities are due to Triton and thus cannot be fixed by this library alone.
120
116
  However, for all intents and purposes, these very minor differences should not matter and can safely be ignored.
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "blksprs"
3
- version = "2.1.3"
3
+ version = "2.1.5"
4
4
  authors = [{ name = "Felix Schön", email = "schoen@kr.tuwien.ac.at" }]
5
5
  description = "A lightweight library for operations on block-sparse matrices in PyTorch."
6
6
  readme = "README.md"
File without changes
File without changes