blksprs 2.1.7__tar.gz → 2.1.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. {blksprs-2.1.7 → blksprs-2.1.9}/PKG-INFO +3 -3
  2. {blksprs-2.1.7 → blksprs-2.1.9}/README.md +1 -1
  3. {blksprs-2.1.7 → blksprs-2.1.9}/blksprs/__init__.py +10 -2
  4. {blksprs-2.1.7 → blksprs-2.1.9}/blksprs/layouting/distribution_layout.py +11 -8
  5. {blksprs-2.1.7 → blksprs-2.1.9}/blksprs/layouting/sparsity_layout.py +15 -10
  6. {blksprs-2.1.7 → blksprs-2.1.9}/blksprs/ops/conversion.py +28 -25
  7. {blksprs-2.1.7 → blksprs-2.1.9}/blksprs/ops/distribution.py +28 -28
  8. {blksprs-2.1.7 → blksprs-2.1.9}/blksprs/ops/flow.py +13 -11
  9. {blksprs-2.1.7 → blksprs-2.1.9}/blksprs/ops/matmul.py +16 -14
  10. {blksprs-2.1.7 → blksprs-2.1.9}/blksprs/ops/misc/broadcast_ops.py +8 -8
  11. {blksprs-2.1.7 → blksprs-2.1.9}/blksprs/ops/misc/row_wise.py +24 -23
  12. {blksprs-2.1.7 → blksprs-2.1.9}/blksprs/ops/partitioning.py +2 -2
  13. {blksprs-2.1.7 → blksprs-2.1.9}/blksprs/ops/repeat.py +2 -2
  14. {blksprs-2.1.7 → blksprs-2.1.9}/blksprs/ops/softmax.py +38 -33
  15. {blksprs-2.1.7 → blksprs-2.1.9}/blksprs/ops/transpose.py +3 -3
  16. {blksprs-2.1.7 → blksprs-2.1.9}/blksprs/utils/autotuning.py +1 -1
  17. {blksprs-2.1.7 → blksprs-2.1.9}/blksprs/utils/blksprs_tensor.py +10 -1
  18. {blksprs-2.1.7 → blksprs-2.1.9}/blksprs/utils/processing.py +2 -1
  19. {blksprs-2.1.7 → blksprs-2.1.9}/blksprs/utils/tools.py +2 -5
  20. {blksprs-2.1.7 → blksprs-2.1.9}/blksprs.egg-info/PKG-INFO +3 -3
  21. {blksprs-2.1.7 → blksprs-2.1.9}/blksprs.egg-info/requires.txt +1 -1
  22. {blksprs-2.1.7 → blksprs-2.1.9}/pyproject.toml +2 -2
  23. {blksprs-2.1.7 → blksprs-2.1.9}/blksprs/utils/benchmarking.py +0 -0
  24. {blksprs-2.1.7 → blksprs-2.1.9}/blksprs/utils/validation.py +0 -0
  25. {blksprs-2.1.7 → blksprs-2.1.9}/blksprs.egg-info/SOURCES.txt +0 -0
  26. {blksprs-2.1.7 → blksprs-2.1.9}/blksprs.egg-info/dependency_links.txt +0 -0
  27. {blksprs-2.1.7 → blksprs-2.1.9}/blksprs.egg-info/top_level.txt +0 -0
  28. {blksprs-2.1.7 → blksprs-2.1.9}/setup.cfg +0 -0
@@ -1,13 +1,13 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: blksprs
3
- Version: 2.1.7
3
+ Version: 2.1.9
4
4
  Summary: A lightweight library for operations on block-sparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
7
7
  Project-URL: Bugtracker, https://github.com/FelixSchoen/blksprs/issues
8
8
  Requires-Python: >=3.11
9
9
  Description-Content-Type: text/markdown
10
- Requires-Dist: torch
10
+ Requires-Dist: torch>=2.8.0
11
11
  Requires-Dist: numpy
12
12
  Provides-Extra: test
13
13
  Requires-Dist: pytest; extra == "test"
@@ -102,7 +102,7 @@ We will continue to maintain the library and fix any issues that arise.
102
102
  Should you find any bugs please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
103
103
  We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
104
104
 
105
- It might be that this changes with future projects, but as of June 2025, we are content with the current state of the
105
+ It might be that this changes with future projects, but as of August 2025, we are content with the current state of the
106
106
  library.
107
107
 
108
108
  ## Known Limitations and Issues
@@ -83,7 +83,7 @@ We will continue to maintain the library and fix any issues that arise.
83
83
  Should you find any bugs please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
84
84
  We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
85
85
 
86
- It might be that this changes with future projects, but as of June 2025, we are content with the current state of the
86
+ It might be that this changes with future projects, but as of August 2025, we are content with the current state of the
87
87
  library.
88
88
 
89
89
  ## Known Limitations and Issues
@@ -1,6 +1,14 @@
1
- from blksprs.utils.blksprs_tensor import BlksprsTensor
1
+ # Settings
2
+ import torch
3
+
4
+ # Capture scalar outputs for JIT compilation
5
+ torch._dynamo.config.capture_scalar_outputs = True
6
+ # Set version
7
+ __version__ = "2.1.9"
2
8
 
3
- __version__ = "2.1.7"
9
+ # Imports
10
+
11
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
4
12
 
5
13
 
6
14
  class ops:
@@ -7,9 +7,9 @@ from torch._library import triton_op
7
7
  from torch._library.triton import wrap_triton
8
8
  from triton import language as tl
9
9
 
10
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
10
11
  from blksprs.utils.blksprs_tensor import BlksprsTensor
11
12
  from blksprs.utils.tools import stride
12
- from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
13
13
  from blksprs.utils.validation import validate_dimensions, validate_device, \
14
14
  validate_contiguous
15
15
 
@@ -98,22 +98,25 @@ def build_distribution_layout_kernel(i,
98
98
 
99
99
  # Get position of current sparsity block consisting of its batch, row, and column index
100
100
  spa_bat_i_idx = (pid_blk * s_lut_i_r_s + 0 * s_lut_i_c_s)
101
- spa_bat_i_msk = (spa_bat_i_idx >= 0 and spa_bat_i_idx < s_lut_i_r * s_lut_i_r_s)
101
+ spa_bat_i_msk = ((spa_bat_i_idx >= 0) &
102
+ (spa_bat_i_idx < s_lut_i_r * s_lut_i_r_s))
102
103
  spa_bat_i = tl.load(s_lut_i + spa_bat_i_idx, mask=spa_bat_i_msk)
103
104
 
104
105
  spa_row_i_idx = (pid_blk * s_lut_i_r_s + 1 * s_lut_i_c_s)
105
- spa_row_i_msk = (spa_row_i_idx >= 0 and spa_row_i_idx < s_lut_i_r * s_lut_i_r_s)
106
+ spa_row_i_msk = ((spa_row_i_idx >= 0) &
107
+ (spa_row_i_idx < s_lut_i_r * s_lut_i_r_s))
106
108
  spa_row_i = tl.load(s_lut_i + spa_row_i_idx, mask=spa_row_i_msk)
107
109
 
108
110
  spa_col_i_idx = (pid_blk * s_lut_i_r_s + 2 * s_lut_i_c_s)
109
- spa_col_i_msk = (spa_col_i_idx >= 0 and spa_col_i_idx < s_lut_i_r * s_lut_i_r_s)
111
+ spa_col_i_msk = ((spa_col_i_idx >= 0) &
112
+ (spa_col_i_idx < s_lut_i_r * s_lut_i_r_s))
110
113
  spa_col_i = tl.load(s_lut_i + spa_col_i_idx, mask=spa_col_i_msk)
111
114
 
112
115
  blk_i_idx = (pid_blk * i_b_s +
113
116
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
114
117
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
115
- blk_i_msk = (blk_i_idx >= 0 and
116
- blk_i_idx < i_b * i_b_s)
118
+ blk_i_msk = ((blk_i_idx >= 0) &
119
+ (blk_i_idx < i_b * i_b_s))
117
120
  blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk)
118
121
 
119
122
  dst_bat_idx = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_i, dtype=tl.int32)
@@ -131,6 +134,6 @@ def build_distribution_layout_kernel(i,
131
134
  blk_o_idx = ((dst_bat_idx * o_b_s) +
132
135
  (dst_row_idx * o_r_s) +
133
136
  (dst_col_idx * o_c_s))
134
- blk_o_msk = (blk_o_idx >= 0 and
135
- blk_o_idx < o_b * o_b_s)
137
+ blk_o_msk = ((blk_o_idx >= 0) &
138
+ (blk_o_idx < o_b * o_b_s))
136
139
  tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
@@ -6,9 +6,9 @@ from torch import Tensor
6
6
  from torch._library.triton import wrap_triton, triton_op
7
7
  from triton import language as tl
8
8
 
9
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs, prune_autotune_configs_conversion
9
10
  from blksprs.utils.blksprs_tensor import BlksprsTensor
10
11
  from blksprs.utils.tools import stride
11
- from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs, prune_autotune_configs_conversion
12
12
  from blksprs.utils.validation import validate_dimensions, validate_device, \
13
13
  validate_contiguous, validate_sparsity, validate_sparsity_block_size
14
14
 
@@ -79,8 +79,8 @@ def build_sparsity_layout_kernel(x,
79
79
  blk_x_idx = (pid_bat * x_b_s +
80
80
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
81
81
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
82
- blk_x_msk = (blk_x_idx >= 0 and
83
- blk_x_idx < x_b * x_b_s)
82
+ blk_x_msk = ((blk_x_idx >= 0) &
83
+ (blk_x_idx < x_b * x_b_s))
84
84
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
85
85
 
86
86
  # Store sparsity layout value
@@ -88,7 +88,8 @@ def build_sparsity_layout_kernel(x,
88
88
  blk_o_idx = (pid_bat * o_b_s +
89
89
  (((pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_r_s +
90
90
  ((pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_c_s))
91
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
91
+ blk_o_msk = ((blk_o_idx >= 0) &
92
+ (blk_o_idx < o_b * o_b_s))
92
93
  tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
93
94
 
94
95
 
@@ -178,23 +179,26 @@ def build_sparsity_layout_adaption_kernel(x,
178
179
 
179
180
  # Get sparsity index of current output block consisting of its batch, row, and column index
180
181
  spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
181
- spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
182
+ spa_bat_msk = ((spa_bat_idx >= 0) &
183
+ (spa_bat_idx < s_lut_r * s_lut_r_s))
182
184
  spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
183
185
 
184
186
  spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
185
- spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
187
+ spa_row_msk = ((spa_row_idx >= 0) &
188
+ (spa_row_idx < s_lut_r * s_lut_r_s))
186
189
  spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
187
190
 
188
191
  spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
189
- spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
192
+ spa_col_msk = ((spa_col_idx >= 0) &
193
+ (spa_col_idx < s_lut_r * s_lut_r_s))
190
194
  spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
191
195
 
192
196
  # Load x values
193
197
  blk_x_idx = ((pid_blk * x_b_s) +
194
198
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
195
199
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
196
- blk_x_msk = (blk_x_idx >= 0 and
197
- blk_x_idx < x_b * x_b_s)
200
+ blk_x_msk = ((blk_x_idx >= 0) &
201
+ (blk_x_idx < x_b * x_b_s))
198
202
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
199
203
 
200
204
  # Store sparsity layout value
@@ -204,7 +208,8 @@ def build_sparsity_layout_adaption_kernel(x,
204
208
  // sparsity_block_size_to) * o_r_s) +
205
209
  (((pid_col * TRITON_BLOCK_SIZE + spa_col * sparsity_block_size_from)
206
210
  // sparsity_block_size_to) * o_c_s))
207
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
211
+ blk_o_msk = ((blk_o_idx >= 0) &
212
+ (blk_o_idx < o_b * o_b_s))
208
213
  tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
209
214
 
210
215
 
@@ -5,9 +5,9 @@ from torch._library.triton import wrap_triton, triton_op
5
5
  from triton import language as tl
6
6
 
7
7
  from blksprs.layouting.sparsity_layout import build_sparsity_layout_adaption
8
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs, prune_autotune_configs_conversion
8
9
  from blksprs.utils.blksprs_tensor import BlksprsTensor
9
10
  from blksprs.utils.tools import stride
10
- from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs, prune_autotune_configs_conversion
11
11
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
12
12
  validate_sparsity, validate_sparsity_block_size, validate_sparsity_dense, ensure_contiguous
13
13
 
@@ -46,10 +46,10 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor,
46
46
  lut = to_sparse_build_lut(lut, sparsity_layout)
47
47
 
48
48
  if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
49
- return BlksprsTensor(x)
49
+ return BlksprsTensor.wrap(x)
50
50
 
51
- return BlksprsTensor(to_sparse_forward(x, sparsity_layout,
52
- lut["sparsity_lut"], sparsity_block_size, lut["n_sparse_blocks"]))
51
+ return BlksprsTensor.wrap(to_sparse_forward(x, sparsity_layout,
52
+ lut["sparsity_lut"], sparsity_block_size, lut["n_sparse_blocks"]))
53
53
 
54
54
 
55
55
  @triton_op("blksprs::to_sparse_forward", mutates_args={})
@@ -120,16 +120,16 @@ def to_sparse_kernel(x,
120
120
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
121
121
  ((pid_col * TRITON_BLOCK_SIZE + spa_col * sparsity_block_size +
122
122
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
123
- blk_d_msk = (blk_d_idx >= 0 and
124
- blk_d_idx < x_b * x_b_s)
123
+ blk_d_msk = ((blk_d_idx >= 0) &
124
+ (blk_d_idx < x_b * x_b_s))
125
125
  blk_d = tl.load(x + blk_d_idx, mask=blk_d_msk)
126
126
 
127
127
  # Store block in sparse tensor
128
128
  blk_o_idx = ((pid_blk * o_b_s) +
129
129
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
130
130
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
131
- blk_o_msk = (blk_o_idx >= 0 and
132
- blk_o_idx < (pid_blk + 1) * o_b_s)
131
+ blk_o_msk = ((blk_o_idx >= 0) &
132
+ (blk_o_idx < (pid_blk + 1) * o_b_s))
133
133
  tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
134
134
 
135
135
 
@@ -201,7 +201,7 @@ def to_dense(x: BlksprsTensor, sparsity_layout: Tensor,
201
201
  return x
202
202
 
203
203
  return Tensor(to_dense_forward(x, sparsity_layout,
204
- lut["sparsity_reverse_lut"], sparsity_block_size, fill_value))
204
+ lut["sparsity_reverse_lut"], sparsity_block_size, fill_value))
205
205
 
206
206
 
207
207
  @triton_op("blksprs::to_dense_forward", mutates_args={})
@@ -269,7 +269,8 @@ def to_dense_kernel(x,
269
269
 
270
270
  # Get reverse sparsity index for current block
271
271
  rev_idx_spa_idx = (pid_blk * s_l_b_s + spa_row * s_l_r_s + spa_col * s_l_c_s)
272
- rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_b * s_l_b_s)
272
+ rev_idx_spa_msk = ((rev_idx_spa_idx >= 0) &
273
+ (rev_idx_spa_idx < s_l_b * s_l_b_s))
273
274
  rev_idx_spa = tl.load(sparsity_reverse_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
274
275
 
275
276
  # If block is present commence operations
@@ -279,14 +280,15 @@ def to_dense_kernel(x,
279
280
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
280
281
  (((pid_col % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
281
282
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
282
- blk_msk = (blk_idx >= 0 and
283
- blk_idx < x_b * x_b_s)
283
+ blk_msk = ((blk_idx >= 0) &
284
+ (blk_idx < x_b * x_b_s))
284
285
  blk = tl.load(x + blk_idx, mask=blk_msk)
285
286
 
286
287
  o_idx = (pid_blk * o_b_s +
287
288
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
288
289
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
289
- o_msk = (o_idx >= 0 and o_idx < o_b * o_b_s)
290
+ o_msk = ((o_idx >= 0) &
291
+ (o_idx < o_b * o_b_s))
290
292
  tl.store(o + o_idx, blk, o_msk)
291
293
 
292
294
 
@@ -360,14 +362,14 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
360
362
  validate_contiguous(sparsity_reverse_lut_from, sparsity_layout_to, sparsity_lut_to)
361
363
 
362
364
  if (sparsity_block_size_from == sparsity_block_size_to) and torch.equal(sparsity_layout_from, sparsity_layout_to):
363
- return BlksprsTensor(x), sparsity_layout_to
365
+ return BlksprsTensor.wrap(x), sparsity_layout_to
364
366
 
365
- return BlksprsTensor(adapt_layout_forward(x,
366
- sparsity_layout_from, sparsity_reverse_lut_from,
367
- sparsity_block_size_from,
368
- sparsity_layout_to, sparsity_lut_to,
369
- sparsity_block_size_to,
370
- n_sparse_blocks_to)), sparsity_layout_to
367
+ return BlksprsTensor.wrap(adapt_layout_forward(x,
368
+ sparsity_layout_from, sparsity_reverse_lut_from,
369
+ sparsity_block_size_from,
370
+ sparsity_layout_to, sparsity_lut_to,
371
+ sparsity_block_size_to,
372
+ n_sparse_blocks_to)), sparsity_layout_to
371
373
 
372
374
 
373
375
  @triton_op("blksprs::adapt_layout_forward", mutates_args={})
@@ -458,7 +460,8 @@ def adapt_layout_kernel(x,
458
460
  rev_idx_spa_x_idx = (spa_bat_x * s_l_x_b_s +
459
461
  spa_row_x * s_l_x_r_s +
460
462
  spa_col_x * s_l_x_c_s)
461
- rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
463
+ rev_idx_spa_x_msk = ((rev_idx_spa_x_idx >= 0) &
464
+ (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s))
462
465
  rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
463
466
 
464
467
  # If block is present commence operations
@@ -473,16 +476,16 @@ def adapt_layout_kernel(x,
473
476
  blk_x_idx = ((rev_idx_spa_x * x_b_s) +
474
477
  ((shift_row_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
475
478
  ((shift_col_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
476
- blk_x_msk = (blk_x_idx >= 0 and
477
- blk_x_idx < x_b * x_b_s)
479
+ blk_x_msk = ((blk_x_idx >= 0) &
480
+ (blk_x_idx < x_b * x_b_s))
478
481
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
479
482
 
480
483
  # Store output
481
484
  blk_o_idx = ((pid_blk * o_b_s) +
482
485
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
483
486
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
484
- blk_o_msk = (blk_o_idx >= 0 and
485
- blk_o_idx < o_b * o_b_s)
487
+ blk_o_msk = ((blk_o_idx >= 0) &
488
+ (blk_o_idx < o_b * o_b_s))
486
489
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
487
490
 
488
491
 
@@ -5,9 +5,9 @@ from torch._library import triton_op
5
5
  from torch._library.triton import wrap_triton
6
6
  from triton import language as tl
7
7
 
8
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
8
9
  from blksprs.utils.blksprs_tensor import BlksprsTensor
9
10
  from blksprs.utils.tools import stride
10
- from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
11
11
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
12
12
  validate_sparsity, validate_dtype_int, validate_sparsity_block_size, ensure_contiguous
13
13
 
@@ -45,9 +45,9 @@ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
45
45
 
46
46
  lut = gather_build_lut(lut, sparsity_layout_src, sparsity_layout_idx)
47
47
 
48
- return BlksprsTensor(gather_forward(src, sparsity_layout_src, lut["sparsity_reverse_lut_x"],
49
- adjusted_dim, idx, sparsity_layout_idx, lut["sparsity_lut_i"],
50
- sparsity_block_size))
48
+ return BlksprsTensor.wrap(gather_forward(src, sparsity_layout_src, lut["sparsity_reverse_lut_x"],
49
+ adjusted_dim, idx, sparsity_layout_idx, lut["sparsity_lut_i"],
50
+ sparsity_block_size))
51
51
 
52
52
 
53
53
  @triton_op("blksprs::gather_forward", mutates_args={})
@@ -136,8 +136,8 @@ def gather_kernel(x,
136
136
  blk_i_idx = ((pid_blk * i_b_s) +
137
137
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
138
138
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
139
- blk_i_msk = (blk_i_idx >= 0 and
140
- blk_i_idx < i_b * i_b_s)
139
+ blk_i_msk = ((blk_i_idx >= 0) &
140
+ (blk_i_idx < i_b * i_b_s))
141
141
  blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
142
142
 
143
143
  # Get indices of sparsity blocks and positions within the blocks
@@ -164,26 +164,26 @@ def gather_kernel(x,
164
164
  rev_idx_spa_x_idx = ((rev_dst_bat_x * s_l_x_b_s) +
165
165
  (rev_dst_row_x * s_l_x_r_s) +
166
166
  (rev_dst_col_x * s_l_x_c_s))
167
- rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and
168
- rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
167
+ rev_idx_spa_x_msk = ((rev_idx_spa_x_idx >= 0) &
168
+ (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s))
169
169
  rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
170
170
 
171
171
  # Load x values
172
172
  blk_x_idx = ((rev_idx_spa_x * x_b_s) +
173
173
  dst_row_x +
174
174
  dst_col_x)
175
- blk_x_msk = ((blk_x_idx >= 0 and
176
- blk_x_idx < x_b * x_b_s) and
177
- rev_idx_spa_x_msk != -1)
175
+ blk_x_msk = (((blk_x_idx >= 0) &
176
+ (blk_x_idx < x_b * x_b_s)) &
177
+ (rev_idx_spa_x_msk != -1))
178
178
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
179
179
 
180
180
  # Store output
181
181
  blk_o_idx = ((pid_blk * o_b_s) +
182
182
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
183
183
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
184
- blk_o_msk = ((blk_o_idx >= 0 and
185
- blk_o_idx < o_b * o_b_s) and
186
- rev_idx_spa_x_msk != -1)
184
+ blk_o_msk = (((blk_o_idx >= 0) &
185
+ (blk_o_idx < o_b * o_b_s)) &
186
+ (rev_idx_spa_x_msk != -1))
187
187
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
188
188
 
189
189
 
@@ -276,11 +276,11 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
276
276
 
277
277
  lut = scatter_reduce_build_lut(lut, sparsity_layout_src, sparsity_layout_tgt)
278
278
 
279
- return BlksprsTensor(scatter_reduce_forward(src, sparsity_layout_src, lut["sparsity_lut_x"],
280
- adjusted_dim, idx,
281
- sparsity_layout_tgt, lut["sparsity_reverse_lut_o"],
282
- sparsity_block_size, lut["n_sparse_blocks"],
283
- reduce_op))
279
+ return BlksprsTensor.wrap(scatter_reduce_forward(src, sparsity_layout_src, lut["sparsity_lut_x"],
280
+ adjusted_dim, idx,
281
+ sparsity_layout_tgt, lut["sparsity_reverse_lut_o"],
282
+ sparsity_block_size, lut["n_sparse_blocks"],
283
+ reduce_op))
284
284
 
285
285
 
286
286
  @triton_op("blksprs::scatter_reduce_forward", mutates_args={})
@@ -380,16 +380,16 @@ def scatter_reduce_kernel(x,
380
380
  blk_x_idx = ((pid_blk * x_b_s) +
381
381
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
382
382
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
383
- blk_x_msk = (blk_x_idx >= 0 and
384
- blk_x_idx < x_b * x_b_s)
383
+ blk_x_msk = ((blk_x_idx >= 0) &
384
+ (blk_x_idx < x_b * x_b_s))
385
385
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
386
386
 
387
387
  # Load index values
388
388
  blk_i_idx = ((pid_blk * i_b_s) +
389
389
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
390
390
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
391
- blk_i_msk = (blk_i_idx >= 0 and
392
- blk_i_idx < i_b * i_b_s)
391
+ blk_i_msk = ((blk_i_idx >= 0) &
392
+ (blk_i_idx < i_b * i_b_s))
393
393
  blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
394
394
 
395
395
  # Get indices of sparsity blocks and positions within the blocks
@@ -416,17 +416,17 @@ def scatter_reduce_kernel(x,
416
416
  rev_idx_spa_o_idx = ((rev_dst_bat_o * s_l_o_b_s) +
417
417
  (rev_dst_row_o * s_l_o_r_s) +
418
418
  (rev_dst_col_o * s_l_o_c_s))
419
- rev_idx_spa_o_msk = (rev_idx_spa_o_idx >= 0 and
420
- rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
419
+ rev_idx_spa_o_msk = ((rev_idx_spa_o_idx >= 0) &
420
+ (rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s))
421
421
  rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
422
422
 
423
423
  # Store output
424
424
  blk_o_idx = ((rev_idx_spa_o * o_b_s) +
425
425
  dst_row_o +
426
426
  dst_col_o)
427
- blk_o_msk = ((blk_o_idx >= 0 and
428
- blk_o_idx < o_b * o_b_s) and
429
- rev_idx_spa_o_msk != -1)
427
+ blk_o_msk = (((blk_o_idx >= 0) &
428
+ (blk_o_idx < o_b * o_b_s)) &
429
+ (rev_idx_spa_o_msk != -1))
430
430
 
431
431
  if reduce_op_ind == 0:
432
432
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
@@ -5,8 +5,8 @@ from torch._library import triton_op
5
5
  from torch._library.triton import wrap_triton
6
6
  from triton import language as tl
7
7
 
8
- from blksprs.utils.tools import stride
9
8
  from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
9
+ from blksprs.utils.tools import stride
10
10
 
11
11
 
12
12
  @triton_op("blksprs::flow_pull_forward", mutates_args={})
@@ -78,22 +78,23 @@ def flow_pull_kernel(x,
78
78
  rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
79
79
  spa_row * s_l_o_r_s +
80
80
  spa_col * s_l_o_c_s)
81
- rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
81
+ rev_idx_spa_msk = ((rev_idx_spa_idx >= 0) &
82
+ (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s))
82
83
  rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
83
84
 
84
85
  if rev_idx_spa >= 0:
85
86
  blk_x_idx = (rev_idx_spa * x_b_s +
86
87
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
87
88
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
88
- blk_x_msk = (blk_x_idx >= 0 and
89
- blk_x_idx < x_b * x_b_s)
89
+ blk_x_msk = ((blk_x_idx >= 0) &
90
+ (blk_x_idx < x_b * x_b_s))
90
91
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
91
92
 
92
93
  blk_o_idx = (pid_blk * o_b_s +
93
94
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
94
95
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
95
- blk_o_msk = (blk_o_idx >= 0 and
96
- blk_o_idx < o_b * o_b_s)
96
+ blk_o_msk = ((blk_o_idx >= 0) &
97
+ (blk_o_idx < o_b * o_b_s))
97
98
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
98
99
 
99
100
 
@@ -165,20 +166,21 @@ def flow_push_kernel(x,
165
166
  rev_idx_spa_idx = (spa_bat * s_l_x_b_s +
166
167
  spa_row * s_l_x_r_s +
167
168
  spa_col * s_l_x_c_s)
168
- rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_x_b * s_l_x_b_s)
169
+ rev_idx_spa_msk = ((rev_idx_spa_idx >= 0) &
170
+ (rev_idx_spa_idx < s_l_x_b * s_l_x_b_s))
169
171
  rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
170
172
 
171
173
  if rev_idx_spa >= 0:
172
174
  blk_x_idx = (pid_blk * x_b_s +
173
175
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
174
176
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
175
- blk_x_msk = (blk_x_idx >= 0 and
176
- blk_x_idx < x_b * x_b_s)
177
+ blk_x_msk = ((blk_x_idx >= 0) &
178
+ (blk_x_idx < x_b * x_b_s))
177
179
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
178
180
 
179
181
  blk_o_idx = (rev_idx_spa * o_b_s +
180
182
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
181
183
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
182
- blk_o_msk = (blk_o_idx >= 0 and
183
- blk_o_idx < o_b * o_b_s)
184
+ blk_o_msk = ((blk_o_idx >= 0) &
185
+ (blk_o_idx < o_b * o_b_s))
184
186
  tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
@@ -5,9 +5,9 @@ from torch.library import triton_op, wrap_triton
5
5
  from triton import language as tl
6
6
 
7
7
  from blksprs.ops.transpose import transpose
8
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
8
9
  from blksprs.utils.blksprs_tensor import BlksprsTensor
9
10
  from blksprs.utils.tools import stride
10
- from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
11
11
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
12
12
  validate_sparsity, validate_sparsity_block_size, validate_dtype_float, ensure_contiguous
13
13
 
@@ -47,11 +47,11 @@ def matmul(x: BlksprsTensor, sparsity_layout_x: Tensor,
47
47
 
48
48
  lut = matmul_build_lut(lut, sparsity_layout_x, sparsity_layout_y, sparsity_layout_output)
49
49
 
50
- return BlksprsTensor(matmul_forward(x, y,
51
- sparsity_layout_x, lut["sparsity_reverse_lut_x"],
52
- sparsity_layout_y, lut["sparsity_reverse_lut_y"],
53
- sparsity_layout_output, lut["sparsity_lut_o"],
54
- sparsity_block_size, lut["n_sparse_blocks"]))
50
+ return BlksprsTensor.wrap(matmul_forward(x, y,
51
+ sparsity_layout_x, lut["sparsity_reverse_lut_x"],
52
+ sparsity_layout_y, lut["sparsity_reverse_lut_y"],
53
+ sparsity_layout_output, lut["sparsity_lut_o"],
54
+ sparsity_block_size, lut["n_sparse_blocks"]))
55
55
 
56
56
 
57
57
  @triton_op("blksprs::matmul_forward", mutates_args={})
@@ -169,12 +169,14 @@ def matmul_kernel(x,
169
169
  rev_idx_spa_x_idx = (spa_bat_o * s_l_x_b_s +
170
170
  spa_row_o * s_l_x_r_s +
171
171
  i_seg_spa * s_l_x_c_s)
172
- rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
172
+ rev_idx_spa_x_msk = ((rev_idx_spa_x_idx >= 0) &
173
+ (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s))
173
174
  rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
174
175
 
175
176
  # Get reverse sparsity indices for y
176
177
  rev_idx_spa_y_idx = (spa_bat_o * s_l_y_b_s + i_seg_spa * s_l_y_r_s + spa_col_o * s_l_y_c_s)
177
- rev_idx_spa_y_msk = (rev_idx_spa_y_idx >= 0 and rev_idx_spa_y_idx < s_l_y_b * s_l_y_b_s)
178
+ rev_idx_spa_y_msk = ((rev_idx_spa_y_idx >= 0) &
179
+ (rev_idx_spa_y_idx < s_l_y_b * s_l_y_b_s))
178
180
  rev_idx_spa_y = tl.load(r_lut_y + rev_idx_spa_y_idx, mask=rev_idx_spa_y_msk).to(tl.int32)
179
181
 
180
182
  # If both blocks are present commence calculation
@@ -183,16 +185,16 @@ def matmul_kernel(x,
183
185
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
184
186
  ((i_seg_tri_mod * TRITON_BLOCK_SIZE +
185
187
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
186
- blk_x_msk = (blk_x_idx >= 0 and
187
- blk_x_idx < x_b * x_b_s)
188
+ blk_x_msk = ((blk_x_idx >= 0) &
189
+ (blk_x_idx < x_b * x_b_s))
188
190
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
189
191
 
190
192
  blk_y_idx = ((rev_idx_spa_y * y_b_s) +
191
193
  ((i_seg_tri_mod * TRITON_BLOCK_SIZE +
192
194
  tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
193
195
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
194
- blk_y_msk = (blk_y_idx >= 0 and
195
- blk_y_idx < y_b * y_b_s)
196
+ blk_y_msk = ((blk_y_idx >= 0) &
197
+ (blk_y_idx < y_b * y_b_s))
196
198
  blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
197
199
 
198
200
  # Perform matrix multiplication
@@ -205,8 +207,8 @@ def matmul_kernel(x,
205
207
  blk_o_idx = ((pid_blk * o_b_s) +
206
208
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
207
209
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
208
- blk_o_msk = (blk_o_idx >= 0 and
209
- blk_o_idx < o_b * o_b_s)
210
+ blk_o_msk = ((blk_o_idx >= 0) &
211
+ (blk_o_idx < o_b * o_b_s))
210
212
  tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
211
213
 
212
214
 
@@ -5,9 +5,9 @@ from torch._library import triton_op
5
5
  from torch._library.triton import wrap_triton
6
6
  from triton import language as tl
7
7
 
8
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
8
9
  from blksprs.utils.blksprs_tensor import BlksprsTensor
9
10
  from blksprs.utils.tools import stride
10
- from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
11
11
  from blksprs.utils.validation import validate_contiguous, validate_device, \
12
12
  validate_sparsity_block_size, ensure_contiguous
13
13
 
@@ -43,7 +43,7 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
43
43
 
44
44
  validate_contiguous(sparsity_layout_output, sparsity_lut_o)
45
45
 
46
- return BlksprsTensor(broadcast_add_forward(x, y, sparsity_lut_o, sparsity_block_size, n_sparse_blocks))
46
+ return BlksprsTensor.wrap(broadcast_add_forward(x, y, sparsity_lut_o, sparsity_block_size, n_sparse_blocks))
47
47
 
48
48
 
49
49
  def broadcast_sub(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
@@ -121,16 +121,16 @@ def broadcast_add_kernel(x,
121
121
  blk_x_idx = (spa_bat_o * x_b_s +
122
122
  ((pid_row * TRITON_BLOCK_SIZE + spa_row_o * sparsity_block_size +
123
123
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
124
- blk_x_msk = (blk_x_idx >= 0 and
125
- blk_x_idx < x_b * x_b_s)
124
+ blk_x_msk = ((blk_x_idx >= 0) &
125
+ (blk_x_idx < x_b * x_b_s))
126
126
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
127
127
 
128
128
  # Load y block
129
129
  blk_y_idx = (spa_bat_o * y_b_s +
130
130
  ((pid_col * TRITON_BLOCK_SIZE + spa_col_o * sparsity_block_size +
131
131
  tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
132
- blk_y_msk = (blk_y_idx >= 0 and
133
- blk_y_idx < y_b * y_b_s)
132
+ blk_y_msk = ((blk_y_idx >= 0) &
133
+ (blk_y_idx < y_b * y_b_s))
134
134
  blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
135
135
 
136
136
  # Compute sum
@@ -141,6 +141,6 @@ def broadcast_add_kernel(x,
141
141
  blk_o_idx = ((pid_blk * o_b_s) +
142
142
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
143
143
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
144
- blk_o_msk = (blk_o_idx >= 0 and
145
- blk_o_idx < o_b * o_b_s)
144
+ blk_o_msk = ((blk_o_idx >= 0) &
145
+ (blk_o_idx < o_b * o_b_s))
146
146
  tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
@@ -55,7 +55,7 @@ def row_wise_sum(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
55
55
  validate_contiguous(sparsity_layout, sparsity_lut,
56
56
  sparsity_layout_output, sparsity_reverse_lut_output)
57
57
 
58
- return BlksprsTensor(row_wise_sum_forward(
58
+ return BlksprsTensor.wrap(row_wise_sum_forward(
59
59
  x, sparsity_lut, sparsity_layout_output, sparsity_reverse_lut_output,
60
60
  sparsity_block_size, n_sparse_blocks_output, flag_slice_only)), sparsity_layout_output
61
61
 
@@ -130,15 +130,16 @@ def row_wise_sum_kernel(x,
130
130
  # Load reverse sparsity index for current block
131
131
  rev_idx_spa_idx = (spa_bat_x * s_l_o_b_s +
132
132
  spa_row_x * s_l_o_r_s)
133
- rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
133
+ rev_idx_spa_msk = ((rev_idx_spa_idx >= 0) &
134
+ (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s))
134
135
  rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
135
136
 
136
137
  if rev_idx_spa >= 0:
137
138
  blk_idx = ((pid_blk * x_b_s) +
138
139
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
139
140
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
140
- blk_msk = (blk_idx >= 0 and
141
- blk_idx < x_b * x_b_s)
141
+ blk_msk = ((blk_idx >= 0) &
142
+ (blk_idx < x_b * x_b_s))
142
143
  blk = tl.load(x + blk_idx, mask=blk_msk)
143
144
 
144
145
  buf = tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
@@ -146,8 +147,8 @@ def row_wise_sum_kernel(x,
146
147
  o_idx = (rev_idx_spa * o_b_s +
147
148
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
148
149
  (tl.arange(0, 1))[None, :])
149
- o_msk = (o_idx >= 0 and
150
- o_idx < o_b * o_b_s)
150
+ o_msk = ((o_idx >= 0) &
151
+ (o_idx < o_b * o_b_s))
151
152
  tl.atomic_add(o + o_idx, buf, o_msk)
152
153
 
153
154
 
@@ -174,8 +175,6 @@ def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
174
175
  of the input and the sparsity layout of the output tensor.
175
176
 
176
177
  """
177
- # TODO Fix for triton bug, see https://github.com/triton-lang/triton/issues/6376, should be fixed with the upcoming 3.4.0 release
178
- x = torch.where(x == -0.0, torch.tensor(0.0), x)
179
178
  x = ensure_contiguous(x)
180
179
 
181
180
  validate_dimensions(x)
@@ -197,7 +196,7 @@ def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
197
196
  validate_contiguous(sparsity_layout, sparsity_lut,
198
197
  sparsity_layout_output, sparsity_reverse_lut_output)
199
198
 
200
- return BlksprsTensor(
199
+ return BlksprsTensor.wrap(
201
200
  row_wise_max_forward(x, sparsity_lut, sparsity_layout_output, sparsity_reverse_lut_output, sparsity_block_size,
202
201
  n_sparse_blocks_output, flag_slice_only)), sparsity_layout_output
203
202
 
@@ -274,15 +273,16 @@ def row_wise_max_kernel(x,
274
273
  # Load reverse sparsity index for current block
275
274
  rev_idx_spa_idx = (spa_bat_x * s_l_o_b_s +
276
275
  spa_row_x * s_l_o_r_s)
277
- rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
276
+ rev_idx_spa_msk = ((rev_idx_spa_idx >= 0) &
277
+ (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s))
278
278
  rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
279
279
 
280
280
  if rev_idx_spa >= 0:
281
281
  blk_idx = ((pid_blk * x_b_s) +
282
282
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
283
283
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
284
- blk_msk = (blk_idx >= 0 and
285
- blk_idx < x_b * x_b_s)
284
+ blk_msk = ((blk_idx >= 0) &
285
+ (blk_idx < x_b * x_b_s))
286
286
  blk = tl.load(x + blk_idx, mask=blk_msk)
287
287
 
288
288
  buf = tl.reshape(tl.max(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
@@ -290,8 +290,8 @@ def row_wise_max_kernel(x,
290
290
  o_idx = (rev_idx_spa * o_b_s +
291
291
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
292
292
  (tl.arange(0, 1))[None, :])
293
- o_msk = (o_idx >= 0 and
294
- o_idx < o_b * o_b_s)
293
+ o_msk = ((o_idx >= 0) &
294
+ (o_idx < o_b * o_b_s))
295
295
  tl.atomic_max(o + o_idx, buf, o_msk)
296
296
 
297
297
 
@@ -329,8 +329,8 @@ def row_wise_add(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
329
329
 
330
330
  validate_contiguous(sparsity_layout_x, sparsity_lut_x, sparsity_reverse_lut_rwm)
331
331
 
332
- return BlksprsTensor(row_wise_add_forward(x, sparsity_lut_x, sparsity_layout_rwm,
333
- sparsity_reverse_lut_rwm, y, sparsity_block_size))
332
+ return BlksprsTensor.wrap(row_wise_add_forward(x, sparsity_lut_x, sparsity_layout_rwm,
333
+ sparsity_reverse_lut_rwm, y, sparsity_block_size))
334
334
 
335
335
 
336
336
  def row_wise_sub(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
@@ -412,7 +412,8 @@ def row_wise_add_kernel(x,
412
412
  # Get reverse sparsity indices for s
413
413
  rev_idx_spa_s_idx = (spa_bat_x * s_l_y_b_s +
414
414
  spa_row_x * s_l_y_r_s)
415
- rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_y_b * s_l_y_b_s)
415
+ rev_idx_spa_s_msk = ((rev_idx_spa_s_idx >= 0) &
416
+ (rev_idx_spa_s_idx < s_l_y_b * s_l_y_b_s))
416
417
  rev_idx_spa_s = tl.load(r_lut_y + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
417
418
 
418
419
  if rev_idx_spa_s == -1:
@@ -423,16 +424,16 @@ def row_wise_add_kernel(x,
423
424
  blk_x_idx = ((pid_blk * x_b_s) +
424
425
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
425
426
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
426
- blk_x_msk = (blk_x_idx >= 0 and
427
- blk_x_idx < x_b * x_b_s)
427
+ blk_x_msk = ((blk_x_idx >= 0) &
428
+ (blk_x_idx < x_b * x_b_s))
428
429
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
429
430
 
430
431
  # Load sum block
431
432
  blk_s_idx = (rev_idx_spa_s * y_b_s +
432
433
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
433
434
  (tl.arange(0, 1) * y_c_s)[None, :])
434
- blk_s_msk = (blk_s_idx >= 0 and
435
- blk_s_idx < y_b * y_b_s)
435
+ blk_s_msk = ((blk_s_idx >= 0) &
436
+ (blk_s_idx < y_b * y_b_s))
436
437
  blk_s = tl.load(y + blk_s_idx, mask=blk_s_msk)
437
438
 
438
439
  # Compute exp
@@ -442,6 +443,6 @@ def row_wise_add_kernel(x,
442
443
  blk_o_idx = ((pid_blk * o_b_s) +
443
444
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
444
445
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
445
- blk_o_msk = (blk_o_idx >= 0 and
446
- blk_o_idx < o_b * o_b_s)
446
+ blk_o_msk = ((blk_o_idx >= 0) &
447
+ (blk_o_idx < o_b * o_b_s))
447
448
  tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
@@ -41,7 +41,7 @@ def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
41
41
 
42
42
  lut = split_build_lut(lut, sparsity_layout, partitions)
43
43
 
44
- return BlksprsTensor(split_forward(
44
+ return BlksprsTensor.wrap(split_forward(
45
45
  x, lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"],
46
46
  partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_output"]
47
47
 
@@ -146,7 +146,7 @@ def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
146
146
 
147
147
  lut = merge_build_lut(lut, sparsity_layout, partitions)
148
148
 
149
- return BlksprsTensor(merge_forward(
149
+ return BlksprsTensor.wrap(merge_forward(
150
150
  x, lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"],
151
151
  partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_output"]
152
152
 
@@ -46,7 +46,7 @@ def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int,
46
46
 
47
47
  lut = repeat_build_lut(lut, sparsity_layout_x, repeats, sparsity_layout_output)
48
48
 
49
- return BlksprsTensor(repeat_forward(
49
+ return BlksprsTensor.wrap(repeat_forward(
50
50
  x, sparsity_layout_x, lut["sparsity_layout_o"], lut["sparsity_lut"],
51
51
  lut["sparsity_reverse_lut"], sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_o"]
52
52
 
@@ -87,7 +87,7 @@ def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
87
87
 
88
88
  lut = repeat_interleave_build_lut(lut, sparsity_layout_x, repeats, sparsity_layout_output)
89
89
 
90
- return BlksprsTensor(repeat_forward(
90
+ return BlksprsTensor.wrap(repeat_forward(
91
91
  x, sparsity_layout_x, lut["sparsity_layout_o"], lut["sparsity_lut"],
92
92
  lut["sparsity_reverse_lut"], sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_o"]
93
93
 
@@ -1,5 +1,3 @@
1
- import pdb
2
-
3
1
  import torch
4
2
  import triton
5
3
  from torch import Tensor
@@ -8,9 +6,9 @@ from torch._library.triton import wrap_triton
8
6
  from triton import language as tl
9
7
 
10
8
  from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
9
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
11
10
  from blksprs.utils.blksprs_tensor import BlksprsTensor
12
11
  from blksprs.utils.tools import stride, ceil_pow2
13
- from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
14
12
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
15
13
  validate_sparsity, validate_sparsity_block_size, validate_dtype_float_32, ensure_contiguous
16
14
 
@@ -55,10 +53,10 @@ def softmax_regular(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_si
55
53
 
56
54
  lut = softmax_build_lut(lut, sparsity_layout)
57
55
 
58
- return BlksprsTensor(softmax_forward(x, sparsity_layout,
59
- lut["sparsity_lut"],
60
- lut["sparsity_reverse_lut_rws"],
61
- sparsity_block_size))
56
+ return BlksprsTensor.wrap(softmax_forward(x, sparsity_layout,
57
+ lut["sparsity_lut"],
58
+ lut["sparsity_reverse_lut_rws"],
59
+ sparsity_block_size))
62
60
 
63
61
 
64
62
  @triton_op("blksprs::softmax_forward", mutates_args={})
@@ -186,7 +184,8 @@ def softmax_kernel(x,
186
184
  # Get reverse sparsity indices for s
187
185
  rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
188
186
  spa_row * s_l_s_r_s)
189
- rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
187
+ rev_idx_spa_s_msk = ((rev_idx_spa_s_idx >= 0) &
188
+ (rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s))
190
189
  rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
191
190
 
192
191
  if rev_idx_spa_s >= 0:
@@ -194,16 +193,16 @@ def softmax_kernel(x,
194
193
  blk_x_idx = ((pid_blk * x_b_s) +
195
194
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
196
195
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
197
- blk_x_msk = (blk_x_idx >= 0 and
198
- blk_x_idx < x_b * x_b_s)
196
+ blk_x_msk = ((blk_x_idx >= 0) &
197
+ (blk_x_idx < x_b * x_b_s))
199
198
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
200
199
 
201
200
  # Load sum block
202
201
  blk_s_idx = (rev_idx_spa_s * s_b_s +
203
202
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
204
203
  (tl.arange(0, 1) * s_c_s)[None, :])
205
- blk_s_msk = (blk_s_idx >= 0 and
206
- blk_s_idx < s_b * s_b_s)
204
+ blk_s_msk = ((blk_s_idx >= 0) &
205
+ (blk_s_idx < s_b * s_b_s))
207
206
  blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
208
207
 
209
208
  # Compute softmax
@@ -249,29 +248,30 @@ def softmax_kernel_grad(g,
249
248
 
250
249
  rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
251
250
  spa_row * s_l_s_r_s)
252
- rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
251
+ rev_idx_spa_s_msk = ((rev_idx_spa_s_idx >= 0) &
252
+ (rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s))
253
253
  rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
254
254
 
255
255
  if rev_idx_spa_s >= 0:
256
256
  blk_s_idx = (rev_idx_spa_s * s_b_s +
257
257
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
258
258
  (tl.arange(0, 1) * s_c_s)[None, :])
259
- blk_s_msk = (blk_s_idx >= 0 and
260
- blk_s_idx < s_b * s_b_s)
259
+ blk_s_msk = ((blk_s_idx >= 0) &
260
+ (blk_s_idx < s_b * s_b_s))
261
261
  blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
262
262
 
263
263
  blk_g_idx = ((pid_blk * g_b_s) +
264
264
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_r_s)[:, None] +
265
265
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_c_s)[None, :])
266
- blk_g_msk = (blk_g_idx >= 0 and
267
- blk_g_idx < g_b * g_b_s)
266
+ blk_g_msk = ((blk_g_idx >= 0) &
267
+ (blk_g_idx < g_b * g_b_s))
268
268
  blk_g = tl.load(g + blk_g_idx, mask=blk_g_msk)
269
269
 
270
270
  blk_x_idx = ((pid_blk * x_b_s) +
271
271
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
272
272
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
273
- blk_x_msk = (blk_x_idx >= 0 and
274
- blk_x_idx < x_b * x_b_s)
273
+ blk_x_msk = ((blk_x_idx >= 0) &
274
+ (blk_x_idx < x_b * x_b_s))
275
275
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
276
276
 
277
277
  buf = blk_x * (blk_g - blk_s)
@@ -279,8 +279,8 @@ def softmax_kernel_grad(g,
279
279
  blk_o_idx = ((pid_blk * o_b_s) +
280
280
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
281
281
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
282
- blk_o_msk = (blk_o_idx >= 0 and
283
- blk_o_idx < o_b * o_b_s)
282
+ blk_o_msk = ((blk_o_idx >= 0) &
283
+ (blk_o_idx < o_b * o_b_s))
284
284
  tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
285
285
 
286
286
 
@@ -346,10 +346,10 @@ def softmax_fused(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size
346
346
 
347
347
  lut = softmax_fused_build_lut(lut, sparsity_layout)
348
348
 
349
- return BlksprsTensor(softmax_fused_forward(x, sparsity_layout,
350
- lut["sparsity_reverse_lut_sorted"],
351
- lut["max_blocks_line"],
352
- sparsity_block_size))
349
+ return BlksprsTensor.wrap(softmax_fused_forward(x, sparsity_layout,
350
+ lut["sparsity_reverse_lut_sorted"],
351
+ lut["max_blocks_line"],
352
+ sparsity_block_size))
353
353
 
354
354
 
355
355
  @triton_op("blksprs::softmax_fused_forward", mutates_args={})
@@ -449,7 +449,8 @@ def softmax_fused_kernel(x,
449
449
  blk_rev_idx = (pid_bat * s_l_b_s +
450
450
  pid_row * s_l_r_s +
451
451
  (tl.arange(0, mbs) * s_l_c_s))
452
- blk_rev_msk = ((blk_rev_idx >= 0 and blk_rev_idx < s_l_b * s_l_b_s) and
452
+ blk_rev_msk = (((blk_rev_idx >= 0) &
453
+ (blk_rev_idx < s_l_b * s_l_b_s)) &
453
454
  (tl.arange(0, mbs) < s_l_c))
454
455
  blk_rev = tl.load(r_lut_s + blk_rev_idx, mask=blk_rev_msk, other=-1).to(tl.int32)
455
456
 
@@ -464,8 +465,9 @@ def softmax_fused_kernel(x,
464
465
  blk_x_idx = (blk_rev_ext * x_b_s +
465
466
  pid_lin * x_r_s +
466
467
  (tl.arange(0, mbs * sparsity_block_size) % sparsity_block_size) * x_c_s)
467
- blk_x_mask = ((blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
468
- and blk_rev_ext != -1)
468
+ blk_x_mask = (((blk_x_idx >= 0) &
469
+ (blk_x_idx < x_b * x_b_s)) &
470
+ (blk_rev_ext != -1))
469
471
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_mask, other=float("-inf"))
470
472
 
471
473
  # Compute softmax
@@ -502,7 +504,8 @@ def softmax_fused_kernel_grad(g,
502
504
  blk_rev_idx = (pid_bat * s_l_b_s +
503
505
  pid_row * s_l_r_s +
504
506
  (tl.arange(0, mbs) * s_l_c_s))
505
- blk_rev_msk = ((blk_rev_idx >= 0 and blk_rev_idx < s_l_b * s_l_b_s) and
507
+ blk_rev_msk = (((blk_rev_idx >= 0) &
508
+ (blk_rev_idx < s_l_b * s_l_b_s)) &
506
509
  (tl.arange(0, mbs) < s_l_c))
507
510
  blk_rev = tl.load(r_lut_s + blk_rev_idx, mask=blk_rev_msk, other=-1).to(tl.int32)
508
511
 
@@ -517,16 +520,18 @@ def softmax_fused_kernel_grad(g,
517
520
  blk_g_idx = (blk_rev_ext * g_b_s +
518
521
  pid_lin * g_r_s +
519
522
  (tl.arange(0, mbs * sparsity_block_size) % sparsity_block_size) * g_c_s)
520
- blk_g_mask = ((blk_g_idx >= 0 and blk_g_idx < g_b * g_b_s)
521
- and blk_rev_ext != -1)
523
+ blk_g_mask = (((blk_g_idx >= 0) &
524
+ (blk_g_idx < g_b * g_b_s)) &
525
+ (blk_rev_ext != -1))
522
526
  blk_g = tl.load(g + blk_g_idx, mask=blk_g_mask)
523
527
 
524
528
  # Load line of x
525
529
  blk_x_idx = (blk_rev_ext * x_b_s +
526
530
  pid_lin * x_r_s +
527
531
  (tl.arange(0, mbs * sparsity_block_size) % sparsity_block_size) * x_c_s)
528
- blk_x_mask = ((blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
529
- and blk_rev_ext != -1)
532
+ blk_x_mask = (((blk_x_idx >= 0) &
533
+ (blk_x_idx < x_b * x_b_s)) &
534
+ (blk_rev_ext != -1))
530
535
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_mask)
531
536
 
532
537
  # Compute gradients
@@ -37,9 +37,9 @@ def transpose(x: BlksprsTensor, sparsity_layout: Tensor,
37
37
 
38
38
  lut = transpose_build_lut(lut, sparsity_layout)
39
39
 
40
- return BlksprsTensor(transpose_forward(x, lut["sparsity_layout_t"],
41
- lut["sparsity_lut"], lut["sparsity_reverse_lut"],
42
- sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_t"]
40
+ return BlksprsTensor.wrap(transpose_forward(x, lut["sparsity_layout_t"],
41
+ lut["sparsity_lut"], lut["sparsity_reverse_lut"],
42
+ sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_t"]
43
43
 
44
44
 
45
45
  @triton_op("blksprs::transpose_forward", mutates_args={})
@@ -75,4 +75,4 @@ def get_autotune_configs():
75
75
  autotune_configs.append(
76
76
  triton.Config({"TRITON_BLOCK_SIZE": block_size}, num_stages=num_stages, num_warps=num_warps))
77
77
 
78
- return autotune_configs
78
+ return autotune_configs
@@ -1,3 +1,5 @@
1
+ from typing import Union
2
+
1
3
  import torch
2
4
  from torch import Tensor
3
5
 
@@ -7,4 +9,11 @@ class BlksprsTensor(Tensor):
7
9
  """
8
10
 
9
11
  def __repr__(self):
10
- return f"BlksprsTensor({torch.Tensor(self).__repr__()})"
12
+ return f"BlksprsTensor({torch.Tensor(self).__repr__()})"
13
+
14
+ @staticmethod
15
+ def wrap(tensor: Tensor) -> Union[Tensor, "BlksprsTensor"]:
16
+ if torch._dynamo.is_compiling():
17
+ return tensor
18
+ else:
19
+ return BlksprsTensor(tensor)
@@ -26,7 +26,8 @@ def apply_torch_linear(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block
26
26
 
27
27
  # Apply weights
28
28
  sparsity_layout_xw = build_sparsity_layout_matmul_fast(sparsity_layout, sparsity_layout_w_t)
29
- xw = matmul(x, sparsity_layout, BlksprsTensor(w_t_bs.to(x.dtype)), sparsity_layout_w_t, sparsity_layout_xw, sparsity_block_size)
29
+ xw = matmul(x, sparsity_layout, BlksprsTensor.wrap(w_t_bs.to(x.dtype)), sparsity_layout_w_t, sparsity_layout_xw,
30
+ sparsity_block_size)
30
31
  interim = xw
31
32
 
32
33
  # Apply bias
@@ -1,9 +1,5 @@
1
- import torch
2
1
  from torch import Tensor, Size
3
2
 
4
- # Capture scalar outputs for JIT compilation
5
- torch._dynamo.config.capture_scalar_outputs = True
6
-
7
3
 
8
4
  def do_shape_blocksparse(x: Tensor) -> tuple[Tensor, Size]:
9
5
  if x.dim() == 3:
@@ -27,7 +23,8 @@ def stride(x: Tensor):
27
23
  else:
28
24
  raise NotImplementedError
29
25
 
26
+
30
27
  def ceil_pow2(x: int) -> int:
31
28
  if x <= 0:
32
29
  raise ValueError("Input must be a positive integer.")
33
- return 1 << (x - 1).bit_length()
30
+ return 1 << (x - 1).bit_length()
@@ -1,13 +1,13 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: blksprs
3
- Version: 2.1.7
3
+ Version: 2.1.9
4
4
  Summary: A lightweight library for operations on block-sparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
7
7
  Project-URL: Bugtracker, https://github.com/FelixSchoen/blksprs/issues
8
8
  Requires-Python: >=3.11
9
9
  Description-Content-Type: text/markdown
10
- Requires-Dist: torch
10
+ Requires-Dist: torch>=2.8.0
11
11
  Requires-Dist: numpy
12
12
  Provides-Extra: test
13
13
  Requires-Dist: pytest; extra == "test"
@@ -102,7 +102,7 @@ We will continue to maintain the library and fix any issues that arise.
102
102
  Should you find any bugs please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
103
103
  We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
104
104
 
105
- It might be that this changes with future projects, but as of June 2025, we are content with the current state of the
105
+ It might be that this changes with future projects, but as of August 2025, we are content with the current state of the
106
106
  library.
107
107
 
108
108
  ## Known Limitations and Issues
@@ -1,4 +1,4 @@
1
- torch
1
+ torch>=2.8.0
2
2
  numpy
3
3
 
4
4
  [test]
@@ -1,13 +1,13 @@
1
1
  [project]
2
2
  name = "blksprs"
3
- version = "2.1.7"
3
+ version = "2.1.9"
4
4
  authors = [{ name = "Felix Schön", email = "schoen@kr.tuwien.ac.at" }]
5
5
  description = "A lightweight library for operations on block-sparse matrices in PyTorch."
6
6
  readme = "README.md"
7
7
  requires-python = ">=3.11"
8
8
  license = { file = "LICENSE.md" }
9
9
  dependencies = [
10
- "torch",
10
+ "torch >= 2.8.0",
11
11
  "numpy"
12
12
  ]
13
13
 
File without changes