blksprs 2.1.8__tar.gz → 2.1.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {blksprs-2.1.8 → blksprs-2.1.9}/PKG-INFO +2 -2
- {blksprs-2.1.8 → blksprs-2.1.9}/README.md +1 -1
- {blksprs-2.1.8 → blksprs-2.1.9}/blksprs/__init__.py +1 -1
- {blksprs-2.1.8 → blksprs-2.1.9}/blksprs/layouting/distribution_layout.py +10 -7
- {blksprs-2.1.8 → blksprs-2.1.9}/blksprs/layouting/sparsity_layout.py +14 -9
- {blksprs-2.1.8 → blksprs-2.1.9}/blksprs/ops/conversion.py +16 -13
- {blksprs-2.1.8 → blksprs-2.1.9}/blksprs/ops/distribution.py +19 -19
- {blksprs-2.1.8 → blksprs-2.1.9}/blksprs/ops/flow.py +12 -10
- {blksprs-2.1.8 → blksprs-2.1.9}/blksprs/ops/matmul.py +10 -8
- {blksprs-2.1.8 → blksprs-2.1.9}/blksprs/ops/misc/broadcast_ops.py +6 -6
- {blksprs-2.1.8 → blksprs-2.1.9}/blksprs/ops/misc/row_wise.py +20 -17
- {blksprs-2.1.8 → blksprs-2.1.9}/blksprs/ops/softmax.py +29 -22
- {blksprs-2.1.8 → blksprs-2.1.9}/blksprs.egg-info/PKG-INFO +2 -2
- {blksprs-2.1.8 → blksprs-2.1.9}/pyproject.toml +1 -1
- {blksprs-2.1.8 → blksprs-2.1.9}/blksprs/ops/partitioning.py +0 -0
- {blksprs-2.1.8 → blksprs-2.1.9}/blksprs/ops/repeat.py +0 -0
- {blksprs-2.1.8 → blksprs-2.1.9}/blksprs/ops/transpose.py +0 -0
- {blksprs-2.1.8 → blksprs-2.1.9}/blksprs/utils/autotuning.py +0 -0
- {blksprs-2.1.8 → blksprs-2.1.9}/blksprs/utils/benchmarking.py +0 -0
- {blksprs-2.1.8 → blksprs-2.1.9}/blksprs/utils/blksprs_tensor.py +0 -0
- {blksprs-2.1.8 → blksprs-2.1.9}/blksprs/utils/processing.py +0 -0
- {blksprs-2.1.8 → blksprs-2.1.9}/blksprs/utils/tools.py +0 -0
- {blksprs-2.1.8 → blksprs-2.1.9}/blksprs/utils/validation.py +0 -0
- {blksprs-2.1.8 → blksprs-2.1.9}/blksprs.egg-info/SOURCES.txt +0 -0
- {blksprs-2.1.8 → blksprs-2.1.9}/blksprs.egg-info/dependency_links.txt +0 -0
- {blksprs-2.1.8 → blksprs-2.1.9}/blksprs.egg-info/requires.txt +0 -0
- {blksprs-2.1.8 → blksprs-2.1.9}/blksprs.egg-info/top_level.txt +0 -0
- {blksprs-2.1.8 → blksprs-2.1.9}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 2.1.
|
|
3
|
+
Version: 2.1.9
|
|
4
4
|
Summary: A lightweight library for operations on block-sparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -102,7 +102,7 @@ We will continue to maintain the library and fix any issues that arise.
|
|
|
102
102
|
Should you find any bugs please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
|
|
103
103
|
We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
|
|
104
104
|
|
|
105
|
-
It might be that this changes with future projects, but as of
|
|
105
|
+
It might be that this changes with future projects, but as of August 2025, we are content with the current state of the
|
|
106
106
|
library.
|
|
107
107
|
|
|
108
108
|
## Known Limitations and Issues
|
|
@@ -83,7 +83,7 @@ We will continue to maintain the library and fix any issues that arise.
|
|
|
83
83
|
Should you find any bugs please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
|
|
84
84
|
We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
|
|
85
85
|
|
|
86
|
-
It might be that this changes with future projects, but as of
|
|
86
|
+
It might be that this changes with future projects, but as of August 2025, we are content with the current state of the
|
|
87
87
|
library.
|
|
88
88
|
|
|
89
89
|
## Known Limitations and Issues
|
|
@@ -98,22 +98,25 @@ def build_distribution_layout_kernel(i,
|
|
|
98
98
|
|
|
99
99
|
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
100
100
|
spa_bat_i_idx = (pid_blk * s_lut_i_r_s + 0 * s_lut_i_c_s)
|
|
101
|
-
spa_bat_i_msk = (spa_bat_i_idx >= 0
|
|
101
|
+
spa_bat_i_msk = ((spa_bat_i_idx >= 0) &
|
|
102
|
+
(spa_bat_i_idx < s_lut_i_r * s_lut_i_r_s))
|
|
102
103
|
spa_bat_i = tl.load(s_lut_i + spa_bat_i_idx, mask=spa_bat_i_msk)
|
|
103
104
|
|
|
104
105
|
spa_row_i_idx = (pid_blk * s_lut_i_r_s + 1 * s_lut_i_c_s)
|
|
105
|
-
spa_row_i_msk = (spa_row_i_idx >= 0
|
|
106
|
+
spa_row_i_msk = ((spa_row_i_idx >= 0) &
|
|
107
|
+
(spa_row_i_idx < s_lut_i_r * s_lut_i_r_s))
|
|
106
108
|
spa_row_i = tl.load(s_lut_i + spa_row_i_idx, mask=spa_row_i_msk)
|
|
107
109
|
|
|
108
110
|
spa_col_i_idx = (pid_blk * s_lut_i_r_s + 2 * s_lut_i_c_s)
|
|
109
|
-
spa_col_i_msk = (spa_col_i_idx >= 0
|
|
111
|
+
spa_col_i_msk = ((spa_col_i_idx >= 0) &
|
|
112
|
+
(spa_col_i_idx < s_lut_i_r * s_lut_i_r_s))
|
|
110
113
|
spa_col_i = tl.load(s_lut_i + spa_col_i_idx, mask=spa_col_i_msk)
|
|
111
114
|
|
|
112
115
|
blk_i_idx = (pid_blk * i_b_s +
|
|
113
116
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
114
117
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
115
|
-
blk_i_msk = (blk_i_idx >= 0
|
|
116
|
-
blk_i_idx < i_b * i_b_s)
|
|
118
|
+
blk_i_msk = ((blk_i_idx >= 0) &
|
|
119
|
+
(blk_i_idx < i_b * i_b_s))
|
|
117
120
|
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk)
|
|
118
121
|
|
|
119
122
|
dst_bat_idx = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_i, dtype=tl.int32)
|
|
@@ -131,6 +134,6 @@ def build_distribution_layout_kernel(i,
|
|
|
131
134
|
blk_o_idx = ((dst_bat_idx * o_b_s) +
|
|
132
135
|
(dst_row_idx * o_r_s) +
|
|
133
136
|
(dst_col_idx * o_c_s))
|
|
134
|
-
blk_o_msk = (blk_o_idx >= 0
|
|
135
|
-
blk_o_idx < o_b * o_b_s)
|
|
137
|
+
blk_o_msk = ((blk_o_idx >= 0) &
|
|
138
|
+
(blk_o_idx < o_b * o_b_s))
|
|
136
139
|
tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
|
|
@@ -79,8 +79,8 @@ def build_sparsity_layout_kernel(x,
|
|
|
79
79
|
blk_x_idx = (pid_bat * x_b_s +
|
|
80
80
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
81
81
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
82
|
-
blk_x_msk = (blk_x_idx >= 0
|
|
83
|
-
blk_x_idx < x_b * x_b_s)
|
|
82
|
+
blk_x_msk = ((blk_x_idx >= 0) &
|
|
83
|
+
(blk_x_idx < x_b * x_b_s))
|
|
84
84
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
85
85
|
|
|
86
86
|
# Store sparsity layout value
|
|
@@ -88,7 +88,8 @@ def build_sparsity_layout_kernel(x,
|
|
|
88
88
|
blk_o_idx = (pid_bat * o_b_s +
|
|
89
89
|
(((pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_r_s +
|
|
90
90
|
((pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_c_s))
|
|
91
|
-
blk_o_msk = (blk_o_idx >= 0
|
|
91
|
+
blk_o_msk = ((blk_o_idx >= 0) &
|
|
92
|
+
(blk_o_idx < o_b * o_b_s))
|
|
92
93
|
tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
|
|
93
94
|
|
|
94
95
|
|
|
@@ -178,23 +179,26 @@ def build_sparsity_layout_adaption_kernel(x,
|
|
|
178
179
|
|
|
179
180
|
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
180
181
|
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
181
|
-
spa_bat_msk = (spa_bat_idx >= 0
|
|
182
|
+
spa_bat_msk = ((spa_bat_idx >= 0) &
|
|
183
|
+
(spa_bat_idx < s_lut_r * s_lut_r_s))
|
|
182
184
|
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
183
185
|
|
|
184
186
|
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
185
|
-
spa_row_msk = (spa_row_idx >= 0
|
|
187
|
+
spa_row_msk = ((spa_row_idx >= 0) &
|
|
188
|
+
(spa_row_idx < s_lut_r * s_lut_r_s))
|
|
186
189
|
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
187
190
|
|
|
188
191
|
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
189
|
-
spa_col_msk = (spa_col_idx >= 0
|
|
192
|
+
spa_col_msk = ((spa_col_idx >= 0) &
|
|
193
|
+
(spa_col_idx < s_lut_r * s_lut_r_s))
|
|
190
194
|
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
191
195
|
|
|
192
196
|
# Load x values
|
|
193
197
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
194
198
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
195
199
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
196
|
-
blk_x_msk = (blk_x_idx >= 0
|
|
197
|
-
blk_x_idx < x_b * x_b_s)
|
|
200
|
+
blk_x_msk = ((blk_x_idx >= 0) &
|
|
201
|
+
(blk_x_idx < x_b * x_b_s))
|
|
198
202
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
199
203
|
|
|
200
204
|
# Store sparsity layout value
|
|
@@ -204,7 +208,8 @@ def build_sparsity_layout_adaption_kernel(x,
|
|
|
204
208
|
// sparsity_block_size_to) * o_r_s) +
|
|
205
209
|
(((pid_col * TRITON_BLOCK_SIZE + spa_col * sparsity_block_size_from)
|
|
206
210
|
// sparsity_block_size_to) * o_c_s))
|
|
207
|
-
blk_o_msk = (blk_o_idx >= 0
|
|
211
|
+
blk_o_msk = ((blk_o_idx >= 0) &
|
|
212
|
+
(blk_o_idx < o_b * o_b_s))
|
|
208
213
|
tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
|
|
209
214
|
|
|
210
215
|
|
|
@@ -120,16 +120,16 @@ def to_sparse_kernel(x,
|
|
|
120
120
|
tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
121
121
|
((pid_col * TRITON_BLOCK_SIZE + spa_col * sparsity_block_size +
|
|
122
122
|
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
123
|
-
blk_d_msk = (blk_d_idx >= 0
|
|
124
|
-
blk_d_idx < x_b * x_b_s)
|
|
123
|
+
blk_d_msk = ((blk_d_idx >= 0) &
|
|
124
|
+
(blk_d_idx < x_b * x_b_s))
|
|
125
125
|
blk_d = tl.load(x + blk_d_idx, mask=blk_d_msk)
|
|
126
126
|
|
|
127
127
|
# Store block in sparse tensor
|
|
128
128
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
129
129
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
130
130
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
|
|
131
|
-
blk_o_msk = (blk_o_idx >= 0
|
|
132
|
-
blk_o_idx < (pid_blk + 1) * o_b_s)
|
|
131
|
+
blk_o_msk = ((blk_o_idx >= 0) &
|
|
132
|
+
(blk_o_idx < (pid_blk + 1) * o_b_s))
|
|
133
133
|
tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
|
|
134
134
|
|
|
135
135
|
|
|
@@ -269,7 +269,8 @@ def to_dense_kernel(x,
|
|
|
269
269
|
|
|
270
270
|
# Get reverse sparsity index for current block
|
|
271
271
|
rev_idx_spa_idx = (pid_blk * s_l_b_s + spa_row * s_l_r_s + spa_col * s_l_c_s)
|
|
272
|
-
rev_idx_spa_msk = (rev_idx_spa_idx >= 0
|
|
272
|
+
rev_idx_spa_msk = ((rev_idx_spa_idx >= 0) &
|
|
273
|
+
(rev_idx_spa_idx < s_l_b * s_l_b_s))
|
|
273
274
|
rev_idx_spa = tl.load(sparsity_reverse_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
274
275
|
|
|
275
276
|
# If block is present commence operations
|
|
@@ -279,14 +280,15 @@ def to_dense_kernel(x,
|
|
|
279
280
|
tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
280
281
|
(((pid_col % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
|
|
281
282
|
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
282
|
-
blk_msk = (blk_idx >= 0
|
|
283
|
-
blk_idx < x_b * x_b_s)
|
|
283
|
+
blk_msk = ((blk_idx >= 0) &
|
|
284
|
+
(blk_idx < x_b * x_b_s))
|
|
284
285
|
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
285
286
|
|
|
286
287
|
o_idx = (pid_blk * o_b_s +
|
|
287
288
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
288
289
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
289
|
-
o_msk = (o_idx >= 0
|
|
290
|
+
o_msk = ((o_idx >= 0) &
|
|
291
|
+
(o_idx < o_b * o_b_s))
|
|
290
292
|
tl.store(o + o_idx, blk, o_msk)
|
|
291
293
|
|
|
292
294
|
|
|
@@ -458,7 +460,8 @@ def adapt_layout_kernel(x,
|
|
|
458
460
|
rev_idx_spa_x_idx = (spa_bat_x * s_l_x_b_s +
|
|
459
461
|
spa_row_x * s_l_x_r_s +
|
|
460
462
|
spa_col_x * s_l_x_c_s)
|
|
461
|
-
rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0
|
|
463
|
+
rev_idx_spa_x_msk = ((rev_idx_spa_x_idx >= 0) &
|
|
464
|
+
(rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s))
|
|
462
465
|
rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
|
|
463
466
|
|
|
464
467
|
# If block is present commence operations
|
|
@@ -473,16 +476,16 @@ def adapt_layout_kernel(x,
|
|
|
473
476
|
blk_x_idx = ((rev_idx_spa_x * x_b_s) +
|
|
474
477
|
((shift_row_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
475
478
|
((shift_col_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
476
|
-
blk_x_msk = (blk_x_idx >= 0
|
|
477
|
-
blk_x_idx < x_b * x_b_s)
|
|
479
|
+
blk_x_msk = ((blk_x_idx >= 0) &
|
|
480
|
+
(blk_x_idx < x_b * x_b_s))
|
|
478
481
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
479
482
|
|
|
480
483
|
# Store output
|
|
481
484
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
482
485
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
483
486
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
484
|
-
blk_o_msk = (blk_o_idx >= 0
|
|
485
|
-
blk_o_idx < o_b * o_b_s)
|
|
487
|
+
blk_o_msk = ((blk_o_idx >= 0) &
|
|
488
|
+
(blk_o_idx < o_b * o_b_s))
|
|
486
489
|
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
487
490
|
|
|
488
491
|
|
|
@@ -136,8 +136,8 @@ def gather_kernel(x,
|
|
|
136
136
|
blk_i_idx = ((pid_blk * i_b_s) +
|
|
137
137
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
138
138
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
139
|
-
blk_i_msk = (blk_i_idx >= 0
|
|
140
|
-
blk_i_idx < i_b * i_b_s)
|
|
139
|
+
blk_i_msk = ((blk_i_idx >= 0) &
|
|
140
|
+
(blk_i_idx < i_b * i_b_s))
|
|
141
141
|
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
|
|
142
142
|
|
|
143
143
|
# Get indices of sparsity blocks and positions within the blocks
|
|
@@ -164,26 +164,26 @@ def gather_kernel(x,
|
|
|
164
164
|
rev_idx_spa_x_idx = ((rev_dst_bat_x * s_l_x_b_s) +
|
|
165
165
|
(rev_dst_row_x * s_l_x_r_s) +
|
|
166
166
|
(rev_dst_col_x * s_l_x_c_s))
|
|
167
|
-
rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0
|
|
168
|
-
rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
|
|
167
|
+
rev_idx_spa_x_msk = ((rev_idx_spa_x_idx >= 0) &
|
|
168
|
+
(rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s))
|
|
169
169
|
rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
|
|
170
170
|
|
|
171
171
|
# Load x values
|
|
172
172
|
blk_x_idx = ((rev_idx_spa_x * x_b_s) +
|
|
173
173
|
dst_row_x +
|
|
174
174
|
dst_col_x)
|
|
175
|
-
blk_x_msk = ((blk_x_idx >= 0
|
|
176
|
-
blk_x_idx < x_b * x_b_s)
|
|
177
|
-
rev_idx_spa_x_msk != -1)
|
|
175
|
+
blk_x_msk = (((blk_x_idx >= 0) &
|
|
176
|
+
(blk_x_idx < x_b * x_b_s)) &
|
|
177
|
+
(rev_idx_spa_x_msk != -1))
|
|
178
178
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
179
179
|
|
|
180
180
|
# Store output
|
|
181
181
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
182
182
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
183
183
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
184
|
-
blk_o_msk = ((blk_o_idx >= 0
|
|
185
|
-
blk_o_idx < o_b * o_b_s)
|
|
186
|
-
rev_idx_spa_x_msk != -1)
|
|
184
|
+
blk_o_msk = (((blk_o_idx >= 0) &
|
|
185
|
+
(blk_o_idx < o_b * o_b_s)) &
|
|
186
|
+
(rev_idx_spa_x_msk != -1))
|
|
187
187
|
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
188
188
|
|
|
189
189
|
|
|
@@ -380,16 +380,16 @@ def scatter_reduce_kernel(x,
|
|
|
380
380
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
381
381
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
382
382
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
383
|
-
blk_x_msk = (blk_x_idx >= 0
|
|
384
|
-
blk_x_idx < x_b * x_b_s)
|
|
383
|
+
blk_x_msk = ((blk_x_idx >= 0) &
|
|
384
|
+
(blk_x_idx < x_b * x_b_s))
|
|
385
385
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
386
386
|
|
|
387
387
|
# Load index values
|
|
388
388
|
blk_i_idx = ((pid_blk * i_b_s) +
|
|
389
389
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
390
390
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
391
|
-
blk_i_msk = (blk_i_idx >= 0
|
|
392
|
-
blk_i_idx < i_b * i_b_s)
|
|
391
|
+
blk_i_msk = ((blk_i_idx >= 0) &
|
|
392
|
+
(blk_i_idx < i_b * i_b_s))
|
|
393
393
|
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
|
|
394
394
|
|
|
395
395
|
# Get indices of sparsity blocks and positions within the blocks
|
|
@@ -416,17 +416,17 @@ def scatter_reduce_kernel(x,
|
|
|
416
416
|
rev_idx_spa_o_idx = ((rev_dst_bat_o * s_l_o_b_s) +
|
|
417
417
|
(rev_dst_row_o * s_l_o_r_s) +
|
|
418
418
|
(rev_dst_col_o * s_l_o_c_s))
|
|
419
|
-
rev_idx_spa_o_msk = (rev_idx_spa_o_idx >= 0
|
|
420
|
-
rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
|
|
419
|
+
rev_idx_spa_o_msk = ((rev_idx_spa_o_idx >= 0) &
|
|
420
|
+
(rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s))
|
|
421
421
|
rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
|
|
422
422
|
|
|
423
423
|
# Store output
|
|
424
424
|
blk_o_idx = ((rev_idx_spa_o * o_b_s) +
|
|
425
425
|
dst_row_o +
|
|
426
426
|
dst_col_o)
|
|
427
|
-
blk_o_msk = ((blk_o_idx >= 0
|
|
428
|
-
blk_o_idx < o_b * o_b_s)
|
|
429
|
-
rev_idx_spa_o_msk != -1)
|
|
427
|
+
blk_o_msk = (((blk_o_idx >= 0) &
|
|
428
|
+
(blk_o_idx < o_b * o_b_s)) &
|
|
429
|
+
(rev_idx_spa_o_msk != -1))
|
|
430
430
|
|
|
431
431
|
if reduce_op_ind == 0:
|
|
432
432
|
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
@@ -78,22 +78,23 @@ def flow_pull_kernel(x,
|
|
|
78
78
|
rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
|
|
79
79
|
spa_row * s_l_o_r_s +
|
|
80
80
|
spa_col * s_l_o_c_s)
|
|
81
|
-
rev_idx_spa_msk = (rev_idx_spa_idx >= 0
|
|
81
|
+
rev_idx_spa_msk = ((rev_idx_spa_idx >= 0) &
|
|
82
|
+
(rev_idx_spa_idx < s_l_o_b * s_l_o_b_s))
|
|
82
83
|
rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
83
84
|
|
|
84
85
|
if rev_idx_spa >= 0:
|
|
85
86
|
blk_x_idx = (rev_idx_spa * x_b_s +
|
|
86
87
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
87
88
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
88
|
-
blk_x_msk = (blk_x_idx >= 0
|
|
89
|
-
blk_x_idx < x_b * x_b_s)
|
|
89
|
+
blk_x_msk = ((blk_x_idx >= 0) &
|
|
90
|
+
(blk_x_idx < x_b * x_b_s))
|
|
90
91
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
91
92
|
|
|
92
93
|
blk_o_idx = (pid_blk * o_b_s +
|
|
93
94
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
94
95
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
95
|
-
blk_o_msk = (blk_o_idx >= 0
|
|
96
|
-
blk_o_idx < o_b * o_b_s)
|
|
96
|
+
blk_o_msk = ((blk_o_idx >= 0) &
|
|
97
|
+
(blk_o_idx < o_b * o_b_s))
|
|
97
98
|
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
98
99
|
|
|
99
100
|
|
|
@@ -165,20 +166,21 @@ def flow_push_kernel(x,
|
|
|
165
166
|
rev_idx_spa_idx = (spa_bat * s_l_x_b_s +
|
|
166
167
|
spa_row * s_l_x_r_s +
|
|
167
168
|
spa_col * s_l_x_c_s)
|
|
168
|
-
rev_idx_spa_msk = (rev_idx_spa_idx >= 0
|
|
169
|
+
rev_idx_spa_msk = ((rev_idx_spa_idx >= 0) &
|
|
170
|
+
(rev_idx_spa_idx < s_l_x_b * s_l_x_b_s))
|
|
169
171
|
rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
170
172
|
|
|
171
173
|
if rev_idx_spa >= 0:
|
|
172
174
|
blk_x_idx = (pid_blk * x_b_s +
|
|
173
175
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
174
176
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
175
|
-
blk_x_msk = (blk_x_idx >= 0
|
|
176
|
-
blk_x_idx < x_b * x_b_s)
|
|
177
|
+
blk_x_msk = ((blk_x_idx >= 0) &
|
|
178
|
+
(blk_x_idx < x_b * x_b_s))
|
|
177
179
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
178
180
|
|
|
179
181
|
blk_o_idx = (rev_idx_spa * o_b_s +
|
|
180
182
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
181
183
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
182
|
-
blk_o_msk = (blk_o_idx >= 0
|
|
183
|
-
blk_o_idx < o_b * o_b_s)
|
|
184
|
+
blk_o_msk = ((blk_o_idx >= 0) &
|
|
185
|
+
(blk_o_idx < o_b * o_b_s))
|
|
184
186
|
tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
@@ -169,12 +169,14 @@ def matmul_kernel(x,
|
|
|
169
169
|
rev_idx_spa_x_idx = (spa_bat_o * s_l_x_b_s +
|
|
170
170
|
spa_row_o * s_l_x_r_s +
|
|
171
171
|
i_seg_spa * s_l_x_c_s)
|
|
172
|
-
rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0
|
|
172
|
+
rev_idx_spa_x_msk = ((rev_idx_spa_x_idx >= 0) &
|
|
173
|
+
(rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s))
|
|
173
174
|
rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
|
|
174
175
|
|
|
175
176
|
# Get reverse sparsity indices for y
|
|
176
177
|
rev_idx_spa_y_idx = (spa_bat_o * s_l_y_b_s + i_seg_spa * s_l_y_r_s + spa_col_o * s_l_y_c_s)
|
|
177
|
-
rev_idx_spa_y_msk = (rev_idx_spa_y_idx >= 0
|
|
178
|
+
rev_idx_spa_y_msk = ((rev_idx_spa_y_idx >= 0) &
|
|
179
|
+
(rev_idx_spa_y_idx < s_l_y_b * s_l_y_b_s))
|
|
178
180
|
rev_idx_spa_y = tl.load(r_lut_y + rev_idx_spa_y_idx, mask=rev_idx_spa_y_msk).to(tl.int32)
|
|
179
181
|
|
|
180
182
|
# If both blocks are present commence calculation
|
|
@@ -183,16 +185,16 @@ def matmul_kernel(x,
|
|
|
183
185
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
184
186
|
((i_seg_tri_mod * TRITON_BLOCK_SIZE +
|
|
185
187
|
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
186
|
-
blk_x_msk = (blk_x_idx >= 0
|
|
187
|
-
blk_x_idx < x_b * x_b_s)
|
|
188
|
+
blk_x_msk = ((blk_x_idx >= 0) &
|
|
189
|
+
(blk_x_idx < x_b * x_b_s))
|
|
188
190
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
189
191
|
|
|
190
192
|
blk_y_idx = ((rev_idx_spa_y * y_b_s) +
|
|
191
193
|
((i_seg_tri_mod * TRITON_BLOCK_SIZE +
|
|
192
194
|
tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
|
|
193
195
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
|
|
194
|
-
blk_y_msk = (blk_y_idx >= 0
|
|
195
|
-
blk_y_idx < y_b * y_b_s)
|
|
196
|
+
blk_y_msk = ((blk_y_idx >= 0) &
|
|
197
|
+
(blk_y_idx < y_b * y_b_s))
|
|
196
198
|
blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
|
|
197
199
|
|
|
198
200
|
# Perform matrix multiplication
|
|
@@ -205,8 +207,8 @@ def matmul_kernel(x,
|
|
|
205
207
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
206
208
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
207
209
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
208
|
-
blk_o_msk = (blk_o_idx >= 0
|
|
209
|
-
blk_o_idx < o_b * o_b_s)
|
|
210
|
+
blk_o_msk = ((blk_o_idx >= 0) &
|
|
211
|
+
(blk_o_idx < o_b * o_b_s))
|
|
210
212
|
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
|
211
213
|
|
|
212
214
|
|
|
@@ -121,16 +121,16 @@ def broadcast_add_kernel(x,
|
|
|
121
121
|
blk_x_idx = (spa_bat_o * x_b_s +
|
|
122
122
|
((pid_row * TRITON_BLOCK_SIZE + spa_row_o * sparsity_block_size +
|
|
123
123
|
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
124
|
-
blk_x_msk = (blk_x_idx >= 0
|
|
125
|
-
blk_x_idx < x_b * x_b_s)
|
|
124
|
+
blk_x_msk = ((blk_x_idx >= 0) &
|
|
125
|
+
(blk_x_idx < x_b * x_b_s))
|
|
126
126
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
127
127
|
|
|
128
128
|
# Load y block
|
|
129
129
|
blk_y_idx = (spa_bat_o * y_b_s +
|
|
130
130
|
((pid_col * TRITON_BLOCK_SIZE + spa_col_o * sparsity_block_size +
|
|
131
131
|
tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
|
|
132
|
-
blk_y_msk = (blk_y_idx >= 0
|
|
133
|
-
blk_y_idx < y_b * y_b_s)
|
|
132
|
+
blk_y_msk = ((blk_y_idx >= 0) &
|
|
133
|
+
(blk_y_idx < y_b * y_b_s))
|
|
134
134
|
blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
|
|
135
135
|
|
|
136
136
|
# Compute sum
|
|
@@ -141,6 +141,6 @@ def broadcast_add_kernel(x,
|
|
|
141
141
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
142
142
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
143
143
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
144
|
-
blk_o_msk = (blk_o_idx >= 0
|
|
145
|
-
blk_o_idx < o_b * o_b_s)
|
|
144
|
+
blk_o_msk = ((blk_o_idx >= 0) &
|
|
145
|
+
(blk_o_idx < o_b * o_b_s))
|
|
146
146
|
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
|
@@ -130,15 +130,16 @@ def row_wise_sum_kernel(x,
|
|
|
130
130
|
# Load reverse sparsity index for current block
|
|
131
131
|
rev_idx_spa_idx = (spa_bat_x * s_l_o_b_s +
|
|
132
132
|
spa_row_x * s_l_o_r_s)
|
|
133
|
-
rev_idx_spa_msk = (rev_idx_spa_idx >= 0
|
|
133
|
+
rev_idx_spa_msk = ((rev_idx_spa_idx >= 0) &
|
|
134
|
+
(rev_idx_spa_idx < s_l_o_b * s_l_o_b_s))
|
|
134
135
|
rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
135
136
|
|
|
136
137
|
if rev_idx_spa >= 0:
|
|
137
138
|
blk_idx = ((pid_blk * x_b_s) +
|
|
138
139
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
139
140
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
140
|
-
blk_msk = (blk_idx >= 0
|
|
141
|
-
blk_idx < x_b * x_b_s)
|
|
141
|
+
blk_msk = ((blk_idx >= 0) &
|
|
142
|
+
(blk_idx < x_b * x_b_s))
|
|
142
143
|
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
143
144
|
|
|
144
145
|
buf = tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
|
|
@@ -146,8 +147,8 @@ def row_wise_sum_kernel(x,
|
|
|
146
147
|
o_idx = (rev_idx_spa * o_b_s +
|
|
147
148
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
148
149
|
(tl.arange(0, 1))[None, :])
|
|
149
|
-
o_msk = (o_idx >= 0
|
|
150
|
-
o_idx < o_b * o_b_s)
|
|
150
|
+
o_msk = ((o_idx >= 0) &
|
|
151
|
+
(o_idx < o_b * o_b_s))
|
|
151
152
|
tl.atomic_add(o + o_idx, buf, o_msk)
|
|
152
153
|
|
|
153
154
|
|
|
@@ -272,15 +273,16 @@ def row_wise_max_kernel(x,
|
|
|
272
273
|
# Load reverse sparsity index for current block
|
|
273
274
|
rev_idx_spa_idx = (spa_bat_x * s_l_o_b_s +
|
|
274
275
|
spa_row_x * s_l_o_r_s)
|
|
275
|
-
rev_idx_spa_msk = (rev_idx_spa_idx >= 0
|
|
276
|
+
rev_idx_spa_msk = ((rev_idx_spa_idx >= 0) &
|
|
277
|
+
(rev_idx_spa_idx < s_l_o_b * s_l_o_b_s))
|
|
276
278
|
rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
277
279
|
|
|
278
280
|
if rev_idx_spa >= 0:
|
|
279
281
|
blk_idx = ((pid_blk * x_b_s) +
|
|
280
282
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
281
283
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
282
|
-
blk_msk = (blk_idx >= 0
|
|
283
|
-
blk_idx < x_b * x_b_s)
|
|
284
|
+
blk_msk = ((blk_idx >= 0) &
|
|
285
|
+
(blk_idx < x_b * x_b_s))
|
|
284
286
|
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
285
287
|
|
|
286
288
|
buf = tl.reshape(tl.max(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
|
|
@@ -288,8 +290,8 @@ def row_wise_max_kernel(x,
|
|
|
288
290
|
o_idx = (rev_idx_spa * o_b_s +
|
|
289
291
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
290
292
|
(tl.arange(0, 1))[None, :])
|
|
291
|
-
o_msk = (o_idx >= 0
|
|
292
|
-
o_idx < o_b * o_b_s)
|
|
293
|
+
o_msk = ((o_idx >= 0) &
|
|
294
|
+
(o_idx < o_b * o_b_s))
|
|
293
295
|
tl.atomic_max(o + o_idx, buf, o_msk)
|
|
294
296
|
|
|
295
297
|
|
|
@@ -410,7 +412,8 @@ def row_wise_add_kernel(x,
|
|
|
410
412
|
# Get reverse sparsity indices for s
|
|
411
413
|
rev_idx_spa_s_idx = (spa_bat_x * s_l_y_b_s +
|
|
412
414
|
spa_row_x * s_l_y_r_s)
|
|
413
|
-
rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0
|
|
415
|
+
rev_idx_spa_s_msk = ((rev_idx_spa_s_idx >= 0) &
|
|
416
|
+
(rev_idx_spa_s_idx < s_l_y_b * s_l_y_b_s))
|
|
414
417
|
rev_idx_spa_s = tl.load(r_lut_y + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
|
|
415
418
|
|
|
416
419
|
if rev_idx_spa_s == -1:
|
|
@@ -421,16 +424,16 @@ def row_wise_add_kernel(x,
|
|
|
421
424
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
422
425
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
423
426
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
424
|
-
blk_x_msk = (blk_x_idx >= 0
|
|
425
|
-
blk_x_idx < x_b * x_b_s)
|
|
427
|
+
blk_x_msk = ((blk_x_idx >= 0) &
|
|
428
|
+
(blk_x_idx < x_b * x_b_s))
|
|
426
429
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
427
430
|
|
|
428
431
|
# Load sum block
|
|
429
432
|
blk_s_idx = (rev_idx_spa_s * y_b_s +
|
|
430
433
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
|
|
431
434
|
(tl.arange(0, 1) * y_c_s)[None, :])
|
|
432
|
-
blk_s_msk = (blk_s_idx >= 0
|
|
433
|
-
blk_s_idx < y_b * y_b_s)
|
|
435
|
+
blk_s_msk = ((blk_s_idx >= 0) &
|
|
436
|
+
(blk_s_idx < y_b * y_b_s))
|
|
434
437
|
blk_s = tl.load(y + blk_s_idx, mask=blk_s_msk)
|
|
435
438
|
|
|
436
439
|
# Compute exp
|
|
@@ -440,6 +443,6 @@ def row_wise_add_kernel(x,
|
|
|
440
443
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
441
444
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
442
445
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
443
|
-
blk_o_msk = (blk_o_idx >= 0
|
|
444
|
-
blk_o_idx < o_b * o_b_s)
|
|
446
|
+
blk_o_msk = ((blk_o_idx >= 0) &
|
|
447
|
+
(blk_o_idx < o_b * o_b_s))
|
|
445
448
|
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
|
@@ -184,7 +184,8 @@ def softmax_kernel(x,
|
|
|
184
184
|
# Get reverse sparsity indices for s
|
|
185
185
|
rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
|
|
186
186
|
spa_row * s_l_s_r_s)
|
|
187
|
-
rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0
|
|
187
|
+
rev_idx_spa_s_msk = ((rev_idx_spa_s_idx >= 0) &
|
|
188
|
+
(rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s))
|
|
188
189
|
rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
|
|
189
190
|
|
|
190
191
|
if rev_idx_spa_s >= 0:
|
|
@@ -192,16 +193,16 @@ def softmax_kernel(x,
|
|
|
192
193
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
193
194
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
194
195
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
195
|
-
blk_x_msk = (blk_x_idx >= 0
|
|
196
|
-
blk_x_idx < x_b * x_b_s)
|
|
196
|
+
blk_x_msk = ((blk_x_idx >= 0) &
|
|
197
|
+
(blk_x_idx < x_b * x_b_s))
|
|
197
198
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
198
199
|
|
|
199
200
|
# Load sum block
|
|
200
201
|
blk_s_idx = (rev_idx_spa_s * s_b_s +
|
|
201
202
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
|
|
202
203
|
(tl.arange(0, 1) * s_c_s)[None, :])
|
|
203
|
-
blk_s_msk = (blk_s_idx >= 0
|
|
204
|
-
blk_s_idx < s_b * s_b_s)
|
|
204
|
+
blk_s_msk = ((blk_s_idx >= 0) &
|
|
205
|
+
(blk_s_idx < s_b * s_b_s))
|
|
205
206
|
blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
|
|
206
207
|
|
|
207
208
|
# Compute softmax
|
|
@@ -247,29 +248,30 @@ def softmax_kernel_grad(g,
|
|
|
247
248
|
|
|
248
249
|
rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
|
|
249
250
|
spa_row * s_l_s_r_s)
|
|
250
|
-
rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0
|
|
251
|
+
rev_idx_spa_s_msk = ((rev_idx_spa_s_idx >= 0) &
|
|
252
|
+
(rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s))
|
|
251
253
|
rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
|
|
252
254
|
|
|
253
255
|
if rev_idx_spa_s >= 0:
|
|
254
256
|
blk_s_idx = (rev_idx_spa_s * s_b_s +
|
|
255
257
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
|
|
256
258
|
(tl.arange(0, 1) * s_c_s)[None, :])
|
|
257
|
-
blk_s_msk = (blk_s_idx >= 0
|
|
258
|
-
blk_s_idx < s_b * s_b_s)
|
|
259
|
+
blk_s_msk = ((blk_s_idx >= 0) &
|
|
260
|
+
(blk_s_idx < s_b * s_b_s))
|
|
259
261
|
blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
|
|
260
262
|
|
|
261
263
|
blk_g_idx = ((pid_blk * g_b_s) +
|
|
262
264
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_r_s)[:, None] +
|
|
263
265
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_c_s)[None, :])
|
|
264
|
-
blk_g_msk = (blk_g_idx >= 0
|
|
265
|
-
blk_g_idx < g_b * g_b_s)
|
|
266
|
+
blk_g_msk = ((blk_g_idx >= 0) &
|
|
267
|
+
(blk_g_idx < g_b * g_b_s))
|
|
266
268
|
blk_g = tl.load(g + blk_g_idx, mask=blk_g_msk)
|
|
267
269
|
|
|
268
270
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
269
271
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
270
272
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
271
|
-
blk_x_msk = (blk_x_idx >= 0
|
|
272
|
-
blk_x_idx < x_b * x_b_s)
|
|
273
|
+
blk_x_msk = ((blk_x_idx >= 0) &
|
|
274
|
+
(blk_x_idx < x_b * x_b_s))
|
|
273
275
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
274
276
|
|
|
275
277
|
buf = blk_x * (blk_g - blk_s)
|
|
@@ -277,8 +279,8 @@ def softmax_kernel_grad(g,
|
|
|
277
279
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
278
280
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
279
281
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
280
|
-
blk_o_msk = (blk_o_idx >= 0
|
|
281
|
-
blk_o_idx < o_b * o_b_s)
|
|
282
|
+
blk_o_msk = ((blk_o_idx >= 0) &
|
|
283
|
+
(blk_o_idx < o_b * o_b_s))
|
|
282
284
|
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
|
283
285
|
|
|
284
286
|
|
|
@@ -447,7 +449,8 @@ def softmax_fused_kernel(x,
|
|
|
447
449
|
blk_rev_idx = (pid_bat * s_l_b_s +
|
|
448
450
|
pid_row * s_l_r_s +
|
|
449
451
|
(tl.arange(0, mbs) * s_l_c_s))
|
|
450
|
-
blk_rev_msk = ((blk_rev_idx >= 0
|
|
452
|
+
blk_rev_msk = (((blk_rev_idx >= 0) &
|
|
453
|
+
(blk_rev_idx < s_l_b * s_l_b_s)) &
|
|
451
454
|
(tl.arange(0, mbs) < s_l_c))
|
|
452
455
|
blk_rev = tl.load(r_lut_s + blk_rev_idx, mask=blk_rev_msk, other=-1).to(tl.int32)
|
|
453
456
|
|
|
@@ -462,8 +465,9 @@ def softmax_fused_kernel(x,
|
|
|
462
465
|
blk_x_idx = (blk_rev_ext * x_b_s +
|
|
463
466
|
pid_lin * x_r_s +
|
|
464
467
|
(tl.arange(0, mbs * sparsity_block_size) % sparsity_block_size) * x_c_s)
|
|
465
|
-
blk_x_mask = ((blk_x_idx >= 0
|
|
466
|
-
|
|
468
|
+
blk_x_mask = (((blk_x_idx >= 0) &
|
|
469
|
+
(blk_x_idx < x_b * x_b_s)) &
|
|
470
|
+
(blk_rev_ext != -1))
|
|
467
471
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_mask, other=float("-inf"))
|
|
468
472
|
|
|
469
473
|
# Compute softmax
|
|
@@ -500,7 +504,8 @@ def softmax_fused_kernel_grad(g,
|
|
|
500
504
|
blk_rev_idx = (pid_bat * s_l_b_s +
|
|
501
505
|
pid_row * s_l_r_s +
|
|
502
506
|
(tl.arange(0, mbs) * s_l_c_s))
|
|
503
|
-
blk_rev_msk = ((blk_rev_idx >= 0
|
|
507
|
+
blk_rev_msk = (((blk_rev_idx >= 0) &
|
|
508
|
+
(blk_rev_idx < s_l_b * s_l_b_s)) &
|
|
504
509
|
(tl.arange(0, mbs) < s_l_c))
|
|
505
510
|
blk_rev = tl.load(r_lut_s + blk_rev_idx, mask=blk_rev_msk, other=-1).to(tl.int32)
|
|
506
511
|
|
|
@@ -515,16 +520,18 @@ def softmax_fused_kernel_grad(g,
|
|
|
515
520
|
blk_g_idx = (blk_rev_ext * g_b_s +
|
|
516
521
|
pid_lin * g_r_s +
|
|
517
522
|
(tl.arange(0, mbs * sparsity_block_size) % sparsity_block_size) * g_c_s)
|
|
518
|
-
blk_g_mask = ((blk_g_idx >= 0
|
|
519
|
-
|
|
523
|
+
blk_g_mask = (((blk_g_idx >= 0) &
|
|
524
|
+
(blk_g_idx < g_b * g_b_s)) &
|
|
525
|
+
(blk_rev_ext != -1))
|
|
520
526
|
blk_g = tl.load(g + blk_g_idx, mask=blk_g_mask)
|
|
521
527
|
|
|
522
528
|
# Load line of x
|
|
523
529
|
blk_x_idx = (blk_rev_ext * x_b_s +
|
|
524
530
|
pid_lin * x_r_s +
|
|
525
531
|
(tl.arange(0, mbs * sparsity_block_size) % sparsity_block_size) * x_c_s)
|
|
526
|
-
blk_x_mask = ((blk_x_idx >= 0
|
|
527
|
-
|
|
532
|
+
blk_x_mask = (((blk_x_idx >= 0) &
|
|
533
|
+
(blk_x_idx < x_b * x_b_s)) &
|
|
534
|
+
(blk_rev_ext != -1))
|
|
528
535
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_mask)
|
|
529
536
|
|
|
530
537
|
# Compute gradients
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 2.1.
|
|
3
|
+
Version: 2.1.9
|
|
4
4
|
Summary: A lightweight library for operations on block-sparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -102,7 +102,7 @@ We will continue to maintain the library and fix any issues that arise.
|
|
|
102
102
|
Should you find any bugs please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
|
|
103
103
|
We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
|
|
104
104
|
|
|
105
|
-
It might be that this changes with future projects, but as of
|
|
105
|
+
It might be that this changes with future projects, but as of August 2025, we are content with the current state of the
|
|
106
106
|
library.
|
|
107
107
|
|
|
108
108
|
## Known Limitations and Issues
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|