blksprs 2.1.8__tar.gz → 2.1.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. {blksprs-2.1.8 → blksprs-2.1.9}/PKG-INFO +2 -2
  2. {blksprs-2.1.8 → blksprs-2.1.9}/README.md +1 -1
  3. {blksprs-2.1.8 → blksprs-2.1.9}/blksprs/__init__.py +1 -1
  4. {blksprs-2.1.8 → blksprs-2.1.9}/blksprs/layouting/distribution_layout.py +10 -7
  5. {blksprs-2.1.8 → blksprs-2.1.9}/blksprs/layouting/sparsity_layout.py +14 -9
  6. {blksprs-2.1.8 → blksprs-2.1.9}/blksprs/ops/conversion.py +16 -13
  7. {blksprs-2.1.8 → blksprs-2.1.9}/blksprs/ops/distribution.py +19 -19
  8. {blksprs-2.1.8 → blksprs-2.1.9}/blksprs/ops/flow.py +12 -10
  9. {blksprs-2.1.8 → blksprs-2.1.9}/blksprs/ops/matmul.py +10 -8
  10. {blksprs-2.1.8 → blksprs-2.1.9}/blksprs/ops/misc/broadcast_ops.py +6 -6
  11. {blksprs-2.1.8 → blksprs-2.1.9}/blksprs/ops/misc/row_wise.py +20 -17
  12. {blksprs-2.1.8 → blksprs-2.1.9}/blksprs/ops/softmax.py +29 -22
  13. {blksprs-2.1.8 → blksprs-2.1.9}/blksprs.egg-info/PKG-INFO +2 -2
  14. {blksprs-2.1.8 → blksprs-2.1.9}/pyproject.toml +1 -1
  15. {blksprs-2.1.8 → blksprs-2.1.9}/blksprs/ops/partitioning.py +0 -0
  16. {blksprs-2.1.8 → blksprs-2.1.9}/blksprs/ops/repeat.py +0 -0
  17. {blksprs-2.1.8 → blksprs-2.1.9}/blksprs/ops/transpose.py +0 -0
  18. {blksprs-2.1.8 → blksprs-2.1.9}/blksprs/utils/autotuning.py +0 -0
  19. {blksprs-2.1.8 → blksprs-2.1.9}/blksprs/utils/benchmarking.py +0 -0
  20. {blksprs-2.1.8 → blksprs-2.1.9}/blksprs/utils/blksprs_tensor.py +0 -0
  21. {blksprs-2.1.8 → blksprs-2.1.9}/blksprs/utils/processing.py +0 -0
  22. {blksprs-2.1.8 → blksprs-2.1.9}/blksprs/utils/tools.py +0 -0
  23. {blksprs-2.1.8 → blksprs-2.1.9}/blksprs/utils/validation.py +0 -0
  24. {blksprs-2.1.8 → blksprs-2.1.9}/blksprs.egg-info/SOURCES.txt +0 -0
  25. {blksprs-2.1.8 → blksprs-2.1.9}/blksprs.egg-info/dependency_links.txt +0 -0
  26. {blksprs-2.1.8 → blksprs-2.1.9}/blksprs.egg-info/requires.txt +0 -0
  27. {blksprs-2.1.8 → blksprs-2.1.9}/blksprs.egg-info/top_level.txt +0 -0
  28. {blksprs-2.1.8 → blksprs-2.1.9}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: blksprs
3
- Version: 2.1.8
3
+ Version: 2.1.9
4
4
  Summary: A lightweight library for operations on block-sparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -102,7 +102,7 @@ We will continue to maintain the library and fix any issues that arise.
102
102
  Should you find any bugs please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
103
103
  We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
104
104
 
105
- It might be that this changes with future projects, but as of June 2025, we are content with the current state of the
105
+ It might be that this changes with future projects, but as of August 2025, we are content with the current state of the
106
106
  library.
107
107
 
108
108
  ## Known Limitations and Issues
@@ -83,7 +83,7 @@ We will continue to maintain the library and fix any issues that arise.
83
83
  Should you find any bugs please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
84
84
  We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
85
85
 
86
- It might be that this changes with future projects, but as of June 2025, we are content with the current state of the
86
+ It might be that this changes with future projects, but as of August 2025, we are content with the current state of the
87
87
  library.
88
88
 
89
89
  ## Known Limitations and Issues
@@ -4,7 +4,7 @@ import torch
4
4
  # Capture scalar outputs for JIT compilation
5
5
  torch._dynamo.config.capture_scalar_outputs = True
6
6
  # Set version
7
- __version__ = "2.1.8"
7
+ __version__ = "2.1.9"
8
8
 
9
9
  # Imports
10
10
 
@@ -98,22 +98,25 @@ def build_distribution_layout_kernel(i,
98
98
 
99
99
  # Get position of current sparsity block consisting of its batch, row, and column index
100
100
  spa_bat_i_idx = (pid_blk * s_lut_i_r_s + 0 * s_lut_i_c_s)
101
- spa_bat_i_msk = (spa_bat_i_idx >= 0 and spa_bat_i_idx < s_lut_i_r * s_lut_i_r_s)
101
+ spa_bat_i_msk = ((spa_bat_i_idx >= 0) &
102
+ (spa_bat_i_idx < s_lut_i_r * s_lut_i_r_s))
102
103
  spa_bat_i = tl.load(s_lut_i + spa_bat_i_idx, mask=spa_bat_i_msk)
103
104
 
104
105
  spa_row_i_idx = (pid_blk * s_lut_i_r_s + 1 * s_lut_i_c_s)
105
- spa_row_i_msk = (spa_row_i_idx >= 0 and spa_row_i_idx < s_lut_i_r * s_lut_i_r_s)
106
+ spa_row_i_msk = ((spa_row_i_idx >= 0) &
107
+ (spa_row_i_idx < s_lut_i_r * s_lut_i_r_s))
106
108
  spa_row_i = tl.load(s_lut_i + spa_row_i_idx, mask=spa_row_i_msk)
107
109
 
108
110
  spa_col_i_idx = (pid_blk * s_lut_i_r_s + 2 * s_lut_i_c_s)
109
- spa_col_i_msk = (spa_col_i_idx >= 0 and spa_col_i_idx < s_lut_i_r * s_lut_i_r_s)
111
+ spa_col_i_msk = ((spa_col_i_idx >= 0) &
112
+ (spa_col_i_idx < s_lut_i_r * s_lut_i_r_s))
110
113
  spa_col_i = tl.load(s_lut_i + spa_col_i_idx, mask=spa_col_i_msk)
111
114
 
112
115
  blk_i_idx = (pid_blk * i_b_s +
113
116
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
114
117
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
115
- blk_i_msk = (blk_i_idx >= 0 and
116
- blk_i_idx < i_b * i_b_s)
118
+ blk_i_msk = ((blk_i_idx >= 0) &
119
+ (blk_i_idx < i_b * i_b_s))
117
120
  blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk)
118
121
 
119
122
  dst_bat_idx = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_i, dtype=tl.int32)
@@ -131,6 +134,6 @@ def build_distribution_layout_kernel(i,
131
134
  blk_o_idx = ((dst_bat_idx * o_b_s) +
132
135
  (dst_row_idx * o_r_s) +
133
136
  (dst_col_idx * o_c_s))
134
- blk_o_msk = (blk_o_idx >= 0 and
135
- blk_o_idx < o_b * o_b_s)
137
+ blk_o_msk = ((blk_o_idx >= 0) &
138
+ (blk_o_idx < o_b * o_b_s))
136
139
  tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
@@ -79,8 +79,8 @@ def build_sparsity_layout_kernel(x,
79
79
  blk_x_idx = (pid_bat * x_b_s +
80
80
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
81
81
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
82
- blk_x_msk = (blk_x_idx >= 0 and
83
- blk_x_idx < x_b * x_b_s)
82
+ blk_x_msk = ((blk_x_idx >= 0) &
83
+ (blk_x_idx < x_b * x_b_s))
84
84
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
85
85
 
86
86
  # Store sparsity layout value
@@ -88,7 +88,8 @@ def build_sparsity_layout_kernel(x,
88
88
  blk_o_idx = (pid_bat * o_b_s +
89
89
  (((pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_r_s +
90
90
  ((pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_c_s))
91
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
91
+ blk_o_msk = ((blk_o_idx >= 0) &
92
+ (blk_o_idx < o_b * o_b_s))
92
93
  tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
93
94
 
94
95
 
@@ -178,23 +179,26 @@ def build_sparsity_layout_adaption_kernel(x,
178
179
 
179
180
  # Get sparsity index of current output block consisting of its batch, row, and column index
180
181
  spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
181
- spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
182
+ spa_bat_msk = ((spa_bat_idx >= 0) &
183
+ (spa_bat_idx < s_lut_r * s_lut_r_s))
182
184
  spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
183
185
 
184
186
  spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
185
- spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
187
+ spa_row_msk = ((spa_row_idx >= 0) &
188
+ (spa_row_idx < s_lut_r * s_lut_r_s))
186
189
  spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
187
190
 
188
191
  spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
189
- spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
192
+ spa_col_msk = ((spa_col_idx >= 0) &
193
+ (spa_col_idx < s_lut_r * s_lut_r_s))
190
194
  spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
191
195
 
192
196
  # Load x values
193
197
  blk_x_idx = ((pid_blk * x_b_s) +
194
198
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
195
199
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
196
- blk_x_msk = (blk_x_idx >= 0 and
197
- blk_x_idx < x_b * x_b_s)
200
+ blk_x_msk = ((blk_x_idx >= 0) &
201
+ (blk_x_idx < x_b * x_b_s))
198
202
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
199
203
 
200
204
  # Store sparsity layout value
@@ -204,7 +208,8 @@ def build_sparsity_layout_adaption_kernel(x,
204
208
  // sparsity_block_size_to) * o_r_s) +
205
209
  (((pid_col * TRITON_BLOCK_SIZE + spa_col * sparsity_block_size_from)
206
210
  // sparsity_block_size_to) * o_c_s))
207
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
211
+ blk_o_msk = ((blk_o_idx >= 0) &
212
+ (blk_o_idx < o_b * o_b_s))
208
213
  tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
209
214
 
210
215
 
@@ -120,16 +120,16 @@ def to_sparse_kernel(x,
120
120
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
121
121
  ((pid_col * TRITON_BLOCK_SIZE + spa_col * sparsity_block_size +
122
122
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
123
- blk_d_msk = (blk_d_idx >= 0 and
124
- blk_d_idx < x_b * x_b_s)
123
+ blk_d_msk = ((blk_d_idx >= 0) &
124
+ (blk_d_idx < x_b * x_b_s))
125
125
  blk_d = tl.load(x + blk_d_idx, mask=blk_d_msk)
126
126
 
127
127
  # Store block in sparse tensor
128
128
  blk_o_idx = ((pid_blk * o_b_s) +
129
129
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
130
130
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
131
- blk_o_msk = (blk_o_idx >= 0 and
132
- blk_o_idx < (pid_blk + 1) * o_b_s)
131
+ blk_o_msk = ((blk_o_idx >= 0) &
132
+ (blk_o_idx < (pid_blk + 1) * o_b_s))
133
133
  tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
134
134
 
135
135
 
@@ -269,7 +269,8 @@ def to_dense_kernel(x,
269
269
 
270
270
  # Get reverse sparsity index for current block
271
271
  rev_idx_spa_idx = (pid_blk * s_l_b_s + spa_row * s_l_r_s + spa_col * s_l_c_s)
272
- rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_b * s_l_b_s)
272
+ rev_idx_spa_msk = ((rev_idx_spa_idx >= 0) &
273
+ (rev_idx_spa_idx < s_l_b * s_l_b_s))
273
274
  rev_idx_spa = tl.load(sparsity_reverse_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
274
275
 
275
276
  # If block is present commence operations
@@ -279,14 +280,15 @@ def to_dense_kernel(x,
279
280
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
280
281
  (((pid_col % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
281
282
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
282
- blk_msk = (blk_idx >= 0 and
283
- blk_idx < x_b * x_b_s)
283
+ blk_msk = ((blk_idx >= 0) &
284
+ (blk_idx < x_b * x_b_s))
284
285
  blk = tl.load(x + blk_idx, mask=blk_msk)
285
286
 
286
287
  o_idx = (pid_blk * o_b_s +
287
288
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
288
289
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
289
- o_msk = (o_idx >= 0 and o_idx < o_b * o_b_s)
290
+ o_msk = ((o_idx >= 0) &
291
+ (o_idx < o_b * o_b_s))
290
292
  tl.store(o + o_idx, blk, o_msk)
291
293
 
292
294
 
@@ -458,7 +460,8 @@ def adapt_layout_kernel(x,
458
460
  rev_idx_spa_x_idx = (spa_bat_x * s_l_x_b_s +
459
461
  spa_row_x * s_l_x_r_s +
460
462
  spa_col_x * s_l_x_c_s)
461
- rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
463
+ rev_idx_spa_x_msk = ((rev_idx_spa_x_idx >= 0) &
464
+ (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s))
462
465
  rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
463
466
 
464
467
  # If block is present commence operations
@@ -473,16 +476,16 @@ def adapt_layout_kernel(x,
473
476
  blk_x_idx = ((rev_idx_spa_x * x_b_s) +
474
477
  ((shift_row_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
475
478
  ((shift_col_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
476
- blk_x_msk = (blk_x_idx >= 0 and
477
- blk_x_idx < x_b * x_b_s)
479
+ blk_x_msk = ((blk_x_idx >= 0) &
480
+ (blk_x_idx < x_b * x_b_s))
478
481
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
479
482
 
480
483
  # Store output
481
484
  blk_o_idx = ((pid_blk * o_b_s) +
482
485
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
483
486
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
484
- blk_o_msk = (blk_o_idx >= 0 and
485
- blk_o_idx < o_b * o_b_s)
487
+ blk_o_msk = ((blk_o_idx >= 0) &
488
+ (blk_o_idx < o_b * o_b_s))
486
489
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
487
490
 
488
491
 
@@ -136,8 +136,8 @@ def gather_kernel(x,
136
136
  blk_i_idx = ((pid_blk * i_b_s) +
137
137
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
138
138
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
139
- blk_i_msk = (blk_i_idx >= 0 and
140
- blk_i_idx < i_b * i_b_s)
139
+ blk_i_msk = ((blk_i_idx >= 0) &
140
+ (blk_i_idx < i_b * i_b_s))
141
141
  blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
142
142
 
143
143
  # Get indices of sparsity blocks and positions within the blocks
@@ -164,26 +164,26 @@ def gather_kernel(x,
164
164
  rev_idx_spa_x_idx = ((rev_dst_bat_x * s_l_x_b_s) +
165
165
  (rev_dst_row_x * s_l_x_r_s) +
166
166
  (rev_dst_col_x * s_l_x_c_s))
167
- rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and
168
- rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
167
+ rev_idx_spa_x_msk = ((rev_idx_spa_x_idx >= 0) &
168
+ (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s))
169
169
  rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
170
170
 
171
171
  # Load x values
172
172
  blk_x_idx = ((rev_idx_spa_x * x_b_s) +
173
173
  dst_row_x +
174
174
  dst_col_x)
175
- blk_x_msk = ((blk_x_idx >= 0 and
176
- blk_x_idx < x_b * x_b_s) and
177
- rev_idx_spa_x_msk != -1)
175
+ blk_x_msk = (((blk_x_idx >= 0) &
176
+ (blk_x_idx < x_b * x_b_s)) &
177
+ (rev_idx_spa_x_msk != -1))
178
178
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
179
179
 
180
180
  # Store output
181
181
  blk_o_idx = ((pid_blk * o_b_s) +
182
182
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
183
183
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
184
- blk_o_msk = ((blk_o_idx >= 0 and
185
- blk_o_idx < o_b * o_b_s) and
186
- rev_idx_spa_x_msk != -1)
184
+ blk_o_msk = (((blk_o_idx >= 0) &
185
+ (blk_o_idx < o_b * o_b_s)) &
186
+ (rev_idx_spa_x_msk != -1))
187
187
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
188
188
 
189
189
 
@@ -380,16 +380,16 @@ def scatter_reduce_kernel(x,
380
380
  blk_x_idx = ((pid_blk * x_b_s) +
381
381
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
382
382
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
383
- blk_x_msk = (blk_x_idx >= 0 and
384
- blk_x_idx < x_b * x_b_s)
383
+ blk_x_msk = ((blk_x_idx >= 0) &
384
+ (blk_x_idx < x_b * x_b_s))
385
385
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
386
386
 
387
387
  # Load index values
388
388
  blk_i_idx = ((pid_blk * i_b_s) +
389
389
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
390
390
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
391
- blk_i_msk = (blk_i_idx >= 0 and
392
- blk_i_idx < i_b * i_b_s)
391
+ blk_i_msk = ((blk_i_idx >= 0) &
392
+ (blk_i_idx < i_b * i_b_s))
393
393
  blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
394
394
 
395
395
  # Get indices of sparsity blocks and positions within the blocks
@@ -416,17 +416,17 @@ def scatter_reduce_kernel(x,
416
416
  rev_idx_spa_o_idx = ((rev_dst_bat_o * s_l_o_b_s) +
417
417
  (rev_dst_row_o * s_l_o_r_s) +
418
418
  (rev_dst_col_o * s_l_o_c_s))
419
- rev_idx_spa_o_msk = (rev_idx_spa_o_idx >= 0 and
420
- rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
419
+ rev_idx_spa_o_msk = ((rev_idx_spa_o_idx >= 0) &
420
+ (rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s))
421
421
  rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
422
422
 
423
423
  # Store output
424
424
  blk_o_idx = ((rev_idx_spa_o * o_b_s) +
425
425
  dst_row_o +
426
426
  dst_col_o)
427
- blk_o_msk = ((blk_o_idx >= 0 and
428
- blk_o_idx < o_b * o_b_s) and
429
- rev_idx_spa_o_msk != -1)
427
+ blk_o_msk = (((blk_o_idx >= 0) &
428
+ (blk_o_idx < o_b * o_b_s)) &
429
+ (rev_idx_spa_o_msk != -1))
430
430
 
431
431
  if reduce_op_ind == 0:
432
432
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
@@ -78,22 +78,23 @@ def flow_pull_kernel(x,
78
78
  rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
79
79
  spa_row * s_l_o_r_s +
80
80
  spa_col * s_l_o_c_s)
81
- rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
81
+ rev_idx_spa_msk = ((rev_idx_spa_idx >= 0) &
82
+ (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s))
82
83
  rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
83
84
 
84
85
  if rev_idx_spa >= 0:
85
86
  blk_x_idx = (rev_idx_spa * x_b_s +
86
87
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
87
88
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
88
- blk_x_msk = (blk_x_idx >= 0 and
89
- blk_x_idx < x_b * x_b_s)
89
+ blk_x_msk = ((blk_x_idx >= 0) &
90
+ (blk_x_idx < x_b * x_b_s))
90
91
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
91
92
 
92
93
  blk_o_idx = (pid_blk * o_b_s +
93
94
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
94
95
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
95
- blk_o_msk = (blk_o_idx >= 0 and
96
- blk_o_idx < o_b * o_b_s)
96
+ blk_o_msk = ((blk_o_idx >= 0) &
97
+ (blk_o_idx < o_b * o_b_s))
97
98
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
98
99
 
99
100
 
@@ -165,20 +166,21 @@ def flow_push_kernel(x,
165
166
  rev_idx_spa_idx = (spa_bat * s_l_x_b_s +
166
167
  spa_row * s_l_x_r_s +
167
168
  spa_col * s_l_x_c_s)
168
- rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_x_b * s_l_x_b_s)
169
+ rev_idx_spa_msk = ((rev_idx_spa_idx >= 0) &
170
+ (rev_idx_spa_idx < s_l_x_b * s_l_x_b_s))
169
171
  rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
170
172
 
171
173
  if rev_idx_spa >= 0:
172
174
  blk_x_idx = (pid_blk * x_b_s +
173
175
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
174
176
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
175
- blk_x_msk = (blk_x_idx >= 0 and
176
- blk_x_idx < x_b * x_b_s)
177
+ blk_x_msk = ((blk_x_idx >= 0) &
178
+ (blk_x_idx < x_b * x_b_s))
177
179
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
178
180
 
179
181
  blk_o_idx = (rev_idx_spa * o_b_s +
180
182
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
181
183
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
182
- blk_o_msk = (blk_o_idx >= 0 and
183
- blk_o_idx < o_b * o_b_s)
184
+ blk_o_msk = ((blk_o_idx >= 0) &
185
+ (blk_o_idx < o_b * o_b_s))
184
186
  tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
@@ -169,12 +169,14 @@ def matmul_kernel(x,
169
169
  rev_idx_spa_x_idx = (spa_bat_o * s_l_x_b_s +
170
170
  spa_row_o * s_l_x_r_s +
171
171
  i_seg_spa * s_l_x_c_s)
172
- rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
172
+ rev_idx_spa_x_msk = ((rev_idx_spa_x_idx >= 0) &
173
+ (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s))
173
174
  rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
174
175
 
175
176
  # Get reverse sparsity indices for y
176
177
  rev_idx_spa_y_idx = (spa_bat_o * s_l_y_b_s + i_seg_spa * s_l_y_r_s + spa_col_o * s_l_y_c_s)
177
- rev_idx_spa_y_msk = (rev_idx_spa_y_idx >= 0 and rev_idx_spa_y_idx < s_l_y_b * s_l_y_b_s)
178
+ rev_idx_spa_y_msk = ((rev_idx_spa_y_idx >= 0) &
179
+ (rev_idx_spa_y_idx < s_l_y_b * s_l_y_b_s))
178
180
  rev_idx_spa_y = tl.load(r_lut_y + rev_idx_spa_y_idx, mask=rev_idx_spa_y_msk).to(tl.int32)
179
181
 
180
182
  # If both blocks are present commence calculation
@@ -183,16 +185,16 @@ def matmul_kernel(x,
183
185
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
184
186
  ((i_seg_tri_mod * TRITON_BLOCK_SIZE +
185
187
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
186
- blk_x_msk = (blk_x_idx >= 0 and
187
- blk_x_idx < x_b * x_b_s)
188
+ blk_x_msk = ((blk_x_idx >= 0) &
189
+ (blk_x_idx < x_b * x_b_s))
188
190
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
189
191
 
190
192
  blk_y_idx = ((rev_idx_spa_y * y_b_s) +
191
193
  ((i_seg_tri_mod * TRITON_BLOCK_SIZE +
192
194
  tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
193
195
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
194
- blk_y_msk = (blk_y_idx >= 0 and
195
- blk_y_idx < y_b * y_b_s)
196
+ blk_y_msk = ((blk_y_idx >= 0) &
197
+ (blk_y_idx < y_b * y_b_s))
196
198
  blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
197
199
 
198
200
  # Perform matrix multiplication
@@ -205,8 +207,8 @@ def matmul_kernel(x,
205
207
  blk_o_idx = ((pid_blk * o_b_s) +
206
208
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
207
209
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
208
- blk_o_msk = (blk_o_idx >= 0 and
209
- blk_o_idx < o_b * o_b_s)
210
+ blk_o_msk = ((blk_o_idx >= 0) &
211
+ (blk_o_idx < o_b * o_b_s))
210
212
  tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
211
213
 
212
214
 
@@ -121,16 +121,16 @@ def broadcast_add_kernel(x,
121
121
  blk_x_idx = (spa_bat_o * x_b_s +
122
122
  ((pid_row * TRITON_BLOCK_SIZE + spa_row_o * sparsity_block_size +
123
123
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
124
- blk_x_msk = (blk_x_idx >= 0 and
125
- blk_x_idx < x_b * x_b_s)
124
+ blk_x_msk = ((blk_x_idx >= 0) &
125
+ (blk_x_idx < x_b * x_b_s))
126
126
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
127
127
 
128
128
  # Load y block
129
129
  blk_y_idx = (spa_bat_o * y_b_s +
130
130
  ((pid_col * TRITON_BLOCK_SIZE + spa_col_o * sparsity_block_size +
131
131
  tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
132
- blk_y_msk = (blk_y_idx >= 0 and
133
- blk_y_idx < y_b * y_b_s)
132
+ blk_y_msk = ((blk_y_idx >= 0) &
133
+ (blk_y_idx < y_b * y_b_s))
134
134
  blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
135
135
 
136
136
  # Compute sum
@@ -141,6 +141,6 @@ def broadcast_add_kernel(x,
141
141
  blk_o_idx = ((pid_blk * o_b_s) +
142
142
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
143
143
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
144
- blk_o_msk = (blk_o_idx >= 0 and
145
- blk_o_idx < o_b * o_b_s)
144
+ blk_o_msk = ((blk_o_idx >= 0) &
145
+ (blk_o_idx < o_b * o_b_s))
146
146
  tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
@@ -130,15 +130,16 @@ def row_wise_sum_kernel(x,
130
130
  # Load reverse sparsity index for current block
131
131
  rev_idx_spa_idx = (spa_bat_x * s_l_o_b_s +
132
132
  spa_row_x * s_l_o_r_s)
133
- rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
133
+ rev_idx_spa_msk = ((rev_idx_spa_idx >= 0) &
134
+ (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s))
134
135
  rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
135
136
 
136
137
  if rev_idx_spa >= 0:
137
138
  blk_idx = ((pid_blk * x_b_s) +
138
139
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
139
140
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
140
- blk_msk = (blk_idx >= 0 and
141
- blk_idx < x_b * x_b_s)
141
+ blk_msk = ((blk_idx >= 0) &
142
+ (blk_idx < x_b * x_b_s))
142
143
  blk = tl.load(x + blk_idx, mask=blk_msk)
143
144
 
144
145
  buf = tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
@@ -146,8 +147,8 @@ def row_wise_sum_kernel(x,
146
147
  o_idx = (rev_idx_spa * o_b_s +
147
148
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
148
149
  (tl.arange(0, 1))[None, :])
149
- o_msk = (o_idx >= 0 and
150
- o_idx < o_b * o_b_s)
150
+ o_msk = ((o_idx >= 0) &
151
+ (o_idx < o_b * o_b_s))
151
152
  tl.atomic_add(o + o_idx, buf, o_msk)
152
153
 
153
154
 
@@ -272,15 +273,16 @@ def row_wise_max_kernel(x,
272
273
  # Load reverse sparsity index for current block
273
274
  rev_idx_spa_idx = (spa_bat_x * s_l_o_b_s +
274
275
  spa_row_x * s_l_o_r_s)
275
- rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
276
+ rev_idx_spa_msk = ((rev_idx_spa_idx >= 0) &
277
+ (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s))
276
278
  rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
277
279
 
278
280
  if rev_idx_spa >= 0:
279
281
  blk_idx = ((pid_blk * x_b_s) +
280
282
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
281
283
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
282
- blk_msk = (blk_idx >= 0 and
283
- blk_idx < x_b * x_b_s)
284
+ blk_msk = ((blk_idx >= 0) &
285
+ (blk_idx < x_b * x_b_s))
284
286
  blk = tl.load(x + blk_idx, mask=blk_msk)
285
287
 
286
288
  buf = tl.reshape(tl.max(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
@@ -288,8 +290,8 @@ def row_wise_max_kernel(x,
288
290
  o_idx = (rev_idx_spa * o_b_s +
289
291
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
290
292
  (tl.arange(0, 1))[None, :])
291
- o_msk = (o_idx >= 0 and
292
- o_idx < o_b * o_b_s)
293
+ o_msk = ((o_idx >= 0) &
294
+ (o_idx < o_b * o_b_s))
293
295
  tl.atomic_max(o + o_idx, buf, o_msk)
294
296
 
295
297
 
@@ -410,7 +412,8 @@ def row_wise_add_kernel(x,
410
412
  # Get reverse sparsity indices for s
411
413
  rev_idx_spa_s_idx = (spa_bat_x * s_l_y_b_s +
412
414
  spa_row_x * s_l_y_r_s)
413
- rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_y_b * s_l_y_b_s)
415
+ rev_idx_spa_s_msk = ((rev_idx_spa_s_idx >= 0) &
416
+ (rev_idx_spa_s_idx < s_l_y_b * s_l_y_b_s))
414
417
  rev_idx_spa_s = tl.load(r_lut_y + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
415
418
 
416
419
  if rev_idx_spa_s == -1:
@@ -421,16 +424,16 @@ def row_wise_add_kernel(x,
421
424
  blk_x_idx = ((pid_blk * x_b_s) +
422
425
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
423
426
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
424
- blk_x_msk = (blk_x_idx >= 0 and
425
- blk_x_idx < x_b * x_b_s)
427
+ blk_x_msk = ((blk_x_idx >= 0) &
428
+ (blk_x_idx < x_b * x_b_s))
426
429
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
427
430
 
428
431
  # Load sum block
429
432
  blk_s_idx = (rev_idx_spa_s * y_b_s +
430
433
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
431
434
  (tl.arange(0, 1) * y_c_s)[None, :])
432
- blk_s_msk = (blk_s_idx >= 0 and
433
- blk_s_idx < y_b * y_b_s)
435
+ blk_s_msk = ((blk_s_idx >= 0) &
436
+ (blk_s_idx < y_b * y_b_s))
434
437
  blk_s = tl.load(y + blk_s_idx, mask=blk_s_msk)
435
438
 
436
439
  # Compute exp
@@ -440,6 +443,6 @@ def row_wise_add_kernel(x,
440
443
  blk_o_idx = ((pid_blk * o_b_s) +
441
444
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
442
445
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
443
- blk_o_msk = (blk_o_idx >= 0 and
444
- blk_o_idx < o_b * o_b_s)
446
+ blk_o_msk = ((blk_o_idx >= 0) &
447
+ (blk_o_idx < o_b * o_b_s))
445
448
  tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
@@ -184,7 +184,8 @@ def softmax_kernel(x,
184
184
  # Get reverse sparsity indices for s
185
185
  rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
186
186
  spa_row * s_l_s_r_s)
187
- rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
187
+ rev_idx_spa_s_msk = ((rev_idx_spa_s_idx >= 0) &
188
+ (rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s))
188
189
  rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
189
190
 
190
191
  if rev_idx_spa_s >= 0:
@@ -192,16 +193,16 @@ def softmax_kernel(x,
192
193
  blk_x_idx = ((pid_blk * x_b_s) +
193
194
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
194
195
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
195
- blk_x_msk = (blk_x_idx >= 0 and
196
- blk_x_idx < x_b * x_b_s)
196
+ blk_x_msk = ((blk_x_idx >= 0) &
197
+ (blk_x_idx < x_b * x_b_s))
197
198
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
198
199
 
199
200
  # Load sum block
200
201
  blk_s_idx = (rev_idx_spa_s * s_b_s +
201
202
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
202
203
  (tl.arange(0, 1) * s_c_s)[None, :])
203
- blk_s_msk = (blk_s_idx >= 0 and
204
- blk_s_idx < s_b * s_b_s)
204
+ blk_s_msk = ((blk_s_idx >= 0) &
205
+ (blk_s_idx < s_b * s_b_s))
205
206
  blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
206
207
 
207
208
  # Compute softmax
@@ -247,29 +248,30 @@ def softmax_kernel_grad(g,
247
248
 
248
249
  rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
249
250
  spa_row * s_l_s_r_s)
250
- rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
251
+ rev_idx_spa_s_msk = ((rev_idx_spa_s_idx >= 0) &
252
+ (rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s))
251
253
  rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
252
254
 
253
255
  if rev_idx_spa_s >= 0:
254
256
  blk_s_idx = (rev_idx_spa_s * s_b_s +
255
257
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
256
258
  (tl.arange(0, 1) * s_c_s)[None, :])
257
- blk_s_msk = (blk_s_idx >= 0 and
258
- blk_s_idx < s_b * s_b_s)
259
+ blk_s_msk = ((blk_s_idx >= 0) &
260
+ (blk_s_idx < s_b * s_b_s))
259
261
  blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
260
262
 
261
263
  blk_g_idx = ((pid_blk * g_b_s) +
262
264
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_r_s)[:, None] +
263
265
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_c_s)[None, :])
264
- blk_g_msk = (blk_g_idx >= 0 and
265
- blk_g_idx < g_b * g_b_s)
266
+ blk_g_msk = ((blk_g_idx >= 0) &
267
+ (blk_g_idx < g_b * g_b_s))
266
268
  blk_g = tl.load(g + blk_g_idx, mask=blk_g_msk)
267
269
 
268
270
  blk_x_idx = ((pid_blk * x_b_s) +
269
271
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
270
272
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
271
- blk_x_msk = (blk_x_idx >= 0 and
272
- blk_x_idx < x_b * x_b_s)
273
+ blk_x_msk = ((blk_x_idx >= 0) &
274
+ (blk_x_idx < x_b * x_b_s))
273
275
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
274
276
 
275
277
  buf = blk_x * (blk_g - blk_s)
@@ -277,8 +279,8 @@ def softmax_kernel_grad(g,
277
279
  blk_o_idx = ((pid_blk * o_b_s) +
278
280
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
279
281
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
280
- blk_o_msk = (blk_o_idx >= 0 and
281
- blk_o_idx < o_b * o_b_s)
282
+ blk_o_msk = ((blk_o_idx >= 0) &
283
+ (blk_o_idx < o_b * o_b_s))
282
284
  tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
283
285
 
284
286
 
@@ -447,7 +449,8 @@ def softmax_fused_kernel(x,
447
449
  blk_rev_idx = (pid_bat * s_l_b_s +
448
450
  pid_row * s_l_r_s +
449
451
  (tl.arange(0, mbs) * s_l_c_s))
450
- blk_rev_msk = ((blk_rev_idx >= 0 and blk_rev_idx < s_l_b * s_l_b_s) and
452
+ blk_rev_msk = (((blk_rev_idx >= 0) &
453
+ (blk_rev_idx < s_l_b * s_l_b_s)) &
451
454
  (tl.arange(0, mbs) < s_l_c))
452
455
  blk_rev = tl.load(r_lut_s + blk_rev_idx, mask=blk_rev_msk, other=-1).to(tl.int32)
453
456
 
@@ -462,8 +465,9 @@ def softmax_fused_kernel(x,
462
465
  blk_x_idx = (blk_rev_ext * x_b_s +
463
466
  pid_lin * x_r_s +
464
467
  (tl.arange(0, mbs * sparsity_block_size) % sparsity_block_size) * x_c_s)
465
- blk_x_mask = ((blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
466
- and blk_rev_ext != -1)
468
+ blk_x_mask = (((blk_x_idx >= 0) &
469
+ (blk_x_idx < x_b * x_b_s)) &
470
+ (blk_rev_ext != -1))
467
471
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_mask, other=float("-inf"))
468
472
 
469
473
  # Compute softmax
@@ -500,7 +504,8 @@ def softmax_fused_kernel_grad(g,
500
504
  blk_rev_idx = (pid_bat * s_l_b_s +
501
505
  pid_row * s_l_r_s +
502
506
  (tl.arange(0, mbs) * s_l_c_s))
503
- blk_rev_msk = ((blk_rev_idx >= 0 and blk_rev_idx < s_l_b * s_l_b_s) and
507
+ blk_rev_msk = (((blk_rev_idx >= 0) &
508
+ (blk_rev_idx < s_l_b * s_l_b_s)) &
504
509
  (tl.arange(0, mbs) < s_l_c))
505
510
  blk_rev = tl.load(r_lut_s + blk_rev_idx, mask=blk_rev_msk, other=-1).to(tl.int32)
506
511
 
@@ -515,16 +520,18 @@ def softmax_fused_kernel_grad(g,
515
520
  blk_g_idx = (blk_rev_ext * g_b_s +
516
521
  pid_lin * g_r_s +
517
522
  (tl.arange(0, mbs * sparsity_block_size) % sparsity_block_size) * g_c_s)
518
- blk_g_mask = ((blk_g_idx >= 0 and blk_g_idx < g_b * g_b_s)
519
- and blk_rev_ext != -1)
523
+ blk_g_mask = (((blk_g_idx >= 0) &
524
+ (blk_g_idx < g_b * g_b_s)) &
525
+ (blk_rev_ext != -1))
520
526
  blk_g = tl.load(g + blk_g_idx, mask=blk_g_mask)
521
527
 
522
528
  # Load line of x
523
529
  blk_x_idx = (blk_rev_ext * x_b_s +
524
530
  pid_lin * x_r_s +
525
531
  (tl.arange(0, mbs * sparsity_block_size) % sparsity_block_size) * x_c_s)
526
- blk_x_mask = ((blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
527
- and blk_rev_ext != -1)
532
+ blk_x_mask = (((blk_x_idx >= 0) &
533
+ (blk_x_idx < x_b * x_b_s)) &
534
+ (blk_rev_ext != -1))
528
535
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_mask)
529
536
 
530
537
  # Compute gradients
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: blksprs
3
- Version: 2.1.8
3
+ Version: 2.1.9
4
4
  Summary: A lightweight library for operations on block-sparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -102,7 +102,7 @@ We will continue to maintain the library and fix any issues that arise.
102
102
  Should you find any bugs please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
103
103
  We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
104
104
 
105
- It might be that this changes with future projects, but as of June 2025, we are content with the current state of the
105
+ It might be that this changes with future projects, but as of August 2025, we are content with the current state of the
106
106
  library.
107
107
 
108
108
  ## Known Limitations and Issues
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "blksprs"
3
- version = "2.1.8"
3
+ version = "2.1.9"
4
4
  authors = [{ name = "Felix Schön", email = "schoen@kr.tuwien.ac.at" }]
5
5
  description = "A lightweight library for operations on block-sparse matrices in PyTorch."
6
6
  readme = "README.md"
File without changes
File without changes
File without changes