blksprs 2.1.4__py3-none-any.whl → 2.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +2 -2
- blksprs/ops/conversion.py +12 -20
- blksprs/ops/distribution.py +12 -20
- blksprs/ops/flow.py +12 -20
- blksprs/ops/matmul.py +6 -10
- blksprs/ops/misc/broadcast_ops.py +6 -10
- blksprs/ops/misc/row_wise.py +35 -35
- blksprs/ops/repeat.py +2 -2
- blksprs/ops/softmax.py +10 -12
- blksprs/utils/autotuning.py +2 -2
- blksprs/utils/validation.py +21 -0
- {blksprs-2.1.4.dist-info → blksprs-2.1.5.dist-info}/METADATA +1 -1
- blksprs-2.1.5.dist-info/RECORD +23 -0
- blksprs-2.1.4.dist-info/RECORD +0 -23
- {blksprs-2.1.4.dist-info → blksprs-2.1.5.dist-info}/WHEEL +0 -0
- {blksprs-2.1.4.dist-info → blksprs-2.1.5.dist-info}/top_level.txt +0 -0
blksprs/__init__.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
2
2
|
|
|
3
|
-
__version__ = "2.1.
|
|
3
|
+
__version__ = "2.1.5"
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
class ops:
|
|
@@ -27,9 +27,9 @@ class utils:
|
|
|
27
27
|
from blksprs.utils.processing import apply_torch_linear, apply_torch_normalisation, apply_torch_dropout, \
|
|
28
28
|
apply_function_applicable_row_wise
|
|
29
29
|
from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse
|
|
30
|
+
from blksprs.utils.validation import disable_contiguous, disable_validation
|
|
30
31
|
|
|
31
32
|
class validation:
|
|
32
|
-
from blksprs.utils.validation import disable_validation
|
|
33
33
|
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_dtype_float, \
|
|
34
34
|
validate_dtype_int, validate_device, validate_sparsity, validate_sparsity_dense, \
|
|
35
35
|
validate_sparsity_block_size
|
blksprs/ops/conversion.py
CHANGED
|
@@ -106,17 +106,13 @@ def to_sparse_kernel(x,
|
|
|
106
106
|
pid_col = tl.program_id(axis=2)
|
|
107
107
|
|
|
108
108
|
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
109
|
+
spa_val_idx = pid_blk * s_lut_r_s + tl.arange(0, 4) * s_lut_c_s
|
|
110
|
+
spa_val_msk = (tl.arange(0, 4) < 3)
|
|
111
|
+
spa_val = tl.load(s_lut + spa_val_idx, mask=spa_val_msk)
|
|
112
112
|
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
118
|
-
spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
|
|
119
|
-
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
113
|
+
spa_bat = tl.sum(spa_val * (tl.arange(0, 4) == 0))
|
|
114
|
+
spa_row = tl.sum(spa_val * (tl.arange(0, 4) == 1))
|
|
115
|
+
spa_col = tl.sum(spa_val * (tl.arange(0, 4) == 2))
|
|
120
116
|
|
|
121
117
|
# Load block from dense tensor
|
|
122
118
|
blk_d_idx = (spa_bat * x_b_s +
|
|
@@ -445,17 +441,13 @@ def adapt_layout_kernel(x,
|
|
|
445
441
|
pid_col = tl.program_id(axis=2)
|
|
446
442
|
|
|
447
443
|
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
|
|
453
|
-
spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
454
|
-
spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
|
|
444
|
+
spa_val_idx = pid_blk * s_lut_o_r_s + tl.arange(0, 4) * s_lut_o_c_s
|
|
445
|
+
spa_val_msk = (tl.arange(0, 4) < 3)
|
|
446
|
+
spa_val = tl.load(s_lut_o + spa_val_idx, mask=spa_val_msk)
|
|
455
447
|
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
spa_col_o = tl.
|
|
448
|
+
spa_bat_o = tl.sum(spa_val * (tl.arange(0, 4) == 0))
|
|
449
|
+
spa_row_o = tl.sum(spa_val * (tl.arange(0, 4) == 1))
|
|
450
|
+
spa_col_o = tl.sum(spa_val * (tl.arange(0, 4) == 2))
|
|
459
451
|
|
|
460
452
|
# Get equivalent sparsity block in from layout
|
|
461
453
|
spa_bat_x = spa_bat_o
|
blksprs/ops/distribution.py
CHANGED
|
@@ -125,17 +125,13 @@ def gather_kernel(x,
|
|
|
125
125
|
pid_col = tl.program_id(axis=2)
|
|
126
126
|
|
|
127
127
|
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
128
|
+
spa_val_idx = pid_blk * s_lut_o_r_s + tl.arange(0, 4) * s_lut_o_c_s
|
|
129
|
+
spa_val_msk = (tl.arange(0, 4) < 3)
|
|
130
|
+
spa_val = tl.load(s_lut_o + spa_val_idx, mask=spa_val_msk)
|
|
131
131
|
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
|
|
137
|
-
spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
138
|
-
spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
|
|
132
|
+
spa_bat_o = tl.sum(spa_val * (tl.arange(0, 4) == 0))
|
|
133
|
+
spa_row_o = tl.sum(spa_val * (tl.arange(0, 4) == 1))
|
|
134
|
+
spa_col_o = tl.sum(spa_val * (tl.arange(0, 4) == 2))
|
|
139
135
|
|
|
140
136
|
# Load index values
|
|
141
137
|
blk_i_idx = ((pid_blk * i_b_s) +
|
|
@@ -374,17 +370,13 @@ def scatter_reduce_kernel(x,
|
|
|
374
370
|
pid_col = tl.program_id(axis=2)
|
|
375
371
|
|
|
376
372
|
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
spa_row_x_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
|
|
382
|
-
spa_row_x_msk = (spa_row_x_idx >= 0 and spa_row_x_idx < s_lut_x_r * s_lut_x_r_s)
|
|
383
|
-
spa_row_x = tl.load(s_lut_x + spa_row_x_idx, mask=spa_row_x_msk)
|
|
373
|
+
spa_val_idx = pid_blk * s_lut_x_r_s + tl.arange(0, 4) * s_lut_x_c_s
|
|
374
|
+
spa_val_msk = (tl.arange(0, 4) < 3)
|
|
375
|
+
spa_val = tl.load(s_lut_x + spa_val_idx, mask=spa_val_msk)
|
|
384
376
|
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
spa_col_x = tl.
|
|
377
|
+
spa_bat_x = tl.sum(spa_val * (tl.arange(0, 4) == 0))
|
|
378
|
+
spa_row_x = tl.sum(spa_val * (tl.arange(0, 4) == 1))
|
|
379
|
+
spa_col_x = tl.sum(spa_val * (tl.arange(0, 4) == 2))
|
|
388
380
|
|
|
389
381
|
# Load x values
|
|
390
382
|
blk_x_idx = ((pid_blk * x_b_s) +
|
blksprs/ops/flow.py
CHANGED
|
@@ -66,17 +66,13 @@ def flow_pull_kernel(x,
|
|
|
66
66
|
pid_col = tl.program_id(axis=2)
|
|
67
67
|
|
|
68
68
|
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
69
|
+
spa_val_idx = pid_blk * s_lut_r_s + tl.arange(0, 4) * s_lut_c_s
|
|
70
|
+
spa_val_msk = (tl.arange(0, 4) < 3)
|
|
71
|
+
spa_val = tl.load(s_lut + spa_val_idx, mask=spa_val_msk)
|
|
72
72
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
78
|
-
spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
|
|
79
|
-
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
73
|
+
spa_bat = tl.sum(spa_val * (tl.arange(0, 4) == 0))
|
|
74
|
+
spa_row = tl.sum(spa_val * (tl.arange(0, 4) == 1))
|
|
75
|
+
spa_col = tl.sum(spa_val * (tl.arange(0, 4) == 2))
|
|
80
76
|
|
|
81
77
|
# Load reverse sparsity index
|
|
82
78
|
rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
|
|
@@ -157,17 +153,13 @@ def flow_push_kernel(x,
|
|
|
157
153
|
pid_col = tl.program_id(axis=2)
|
|
158
154
|
|
|
159
155
|
# Get sparsity index of current input block consisting of its batch, row, and column index
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
165
|
-
spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
|
|
166
|
-
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
156
|
+
spa_val_idx = pid_blk * s_lut_r_s + tl.arange(0, 4) * s_lut_c_s
|
|
157
|
+
spa_val_msk = (tl.arange(0, 4) < 3)
|
|
158
|
+
spa_val = tl.load(s_lut + spa_val_idx, mask=spa_val_msk)
|
|
167
159
|
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
spa_col = tl.
|
|
160
|
+
spa_bat = tl.sum(spa_val * (tl.arange(0, 4) == 0))
|
|
161
|
+
spa_row = tl.sum(spa_val * (tl.arange(0, 4) == 1))
|
|
162
|
+
spa_col = tl.sum(spa_val * (tl.arange(0, 4) == 2))
|
|
171
163
|
|
|
172
164
|
# Get reverse sparsity index
|
|
173
165
|
rev_idx_spa_idx = (spa_bat * s_l_x_b_s +
|
blksprs/ops/matmul.py
CHANGED
|
@@ -145,17 +145,13 @@ def matmul_kernel(x,
|
|
|
145
145
|
pid_col = tl.program_id(axis=2)
|
|
146
146
|
|
|
147
147
|
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
148
|
+
spa_val_idx = pid_blk * s_lut_o_r_s + tl.arange(0, 4) * s_lut_o_c_s
|
|
149
|
+
spa_val_msk = (tl.arange(0, 4) < 3)
|
|
150
|
+
spa_val = tl.load(s_lut_o + spa_val_idx, mask=spa_val_msk)
|
|
151
151
|
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
|
|
157
|
-
spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
158
|
-
spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
|
|
152
|
+
spa_bat_o = tl.sum(spa_val * (tl.arange(0, 4) == 0))
|
|
153
|
+
spa_row_o = tl.sum(spa_val * (tl.arange(0, 4) == 1))
|
|
154
|
+
spa_col_o = tl.sum(spa_val * (tl.arange(0, 4) == 2))
|
|
159
155
|
|
|
160
156
|
# Setup buffer
|
|
161
157
|
buf = tl.zeros(shape=(TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), dtype=tl.float32)
|
|
@@ -110,17 +110,13 @@ def broadcast_add_kernel(x,
|
|
|
110
110
|
pid_col = tl.program_id(axis=2)
|
|
111
111
|
|
|
112
112
|
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
113
|
+
spa_val_idx = pid_blk * s_lut_o_r_s + tl.arange(0, 4) * s_lut_o_c_s
|
|
114
|
+
spa_val_msk = (tl.arange(0, 4) < 3)
|
|
115
|
+
spa_val = tl.load(s_lut_o + spa_val_idx, mask=spa_val_msk)
|
|
116
116
|
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
|
|
122
|
-
spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
123
|
-
spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
|
|
117
|
+
spa_bat_o = tl.sum(spa_val * (tl.arange(0, 4) == 0))
|
|
118
|
+
spa_row_o = tl.sum(spa_val * (tl.arange(0, 4) == 1))
|
|
119
|
+
spa_col_o = tl.sum(spa_val * (tl.arange(0, 4) == 2))
|
|
124
120
|
|
|
125
121
|
# Load x block
|
|
126
122
|
blk_x_idx = (spa_bat_o * x_b_s +
|
blksprs/ops/misc/row_wise.py
CHANGED
|
@@ -119,17 +119,17 @@ def row_wise_sum_kernel(x,
|
|
|
119
119
|
pid_col = tl.program_id(axis=2)
|
|
120
120
|
|
|
121
121
|
# Get position of current sparsity block consisting of its batch and row index
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
122
|
+
spa_val_idx = pid_blk * s_lut_x_r_s + tl.arange(0, 4) * s_lut_x_c_s
|
|
123
|
+
spa_val_msk = (tl.arange(0, 4) < 3)
|
|
124
|
+
spa_val = tl.load(s_lut_x + spa_val_idx, mask=spa_val_msk)
|
|
125
125
|
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
126
|
+
spa_bat_x = tl.sum(spa_val * (tl.arange(0, 4) == 0))
|
|
127
|
+
spa_row_x = tl.sum(spa_val * (tl.arange(0, 4) == 1))
|
|
128
|
+
spa_col_x = tl.sum(spa_val * (tl.arange(0, 4) == 2))
|
|
129
129
|
|
|
130
130
|
# Load reverse sparsity index for current block
|
|
131
|
-
rev_idx_spa_idx = (
|
|
132
|
-
|
|
131
|
+
rev_idx_spa_idx = (spa_bat_x * s_l_o_b_s +
|
|
132
|
+
spa_row_x * s_l_o_r_s)
|
|
133
133
|
rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
|
|
134
134
|
rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
135
135
|
|
|
@@ -263,17 +263,17 @@ def row_wise_max_kernel(x,
|
|
|
263
263
|
pid_col = tl.program_id(axis=2)
|
|
264
264
|
|
|
265
265
|
# Get position of current sparsity block consisting of its batch and row index
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
266
|
+
spa_val_idx = pid_blk * s_lut_x_r_s + tl.arange(0, 4) * s_lut_x_c_s
|
|
267
|
+
spa_val_msk = (tl.arange(0, 4) < 3)
|
|
268
|
+
spa_val = tl.load(s_lut_x + spa_val_idx, mask=spa_val_msk)
|
|
269
269
|
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
270
|
+
spa_bat_x = tl.sum(spa_val * (tl.arange(0, 4) == 0))
|
|
271
|
+
spa_row_x = tl.sum(spa_val * (tl.arange(0, 4) == 1))
|
|
272
|
+
spa_col_x = tl.sum(spa_val * (tl.arange(0, 4) == 2))
|
|
273
273
|
|
|
274
274
|
# Load reverse sparsity index for current block
|
|
275
|
-
rev_idx_spa_idx = (
|
|
276
|
-
|
|
275
|
+
rev_idx_spa_idx = (spa_bat_x * s_l_o_b_s +
|
|
276
|
+
spa_row_x * s_l_o_r_s)
|
|
277
277
|
rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
|
|
278
278
|
rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
279
279
|
|
|
@@ -361,7 +361,7 @@ def row_wise_add_forward(x: Tensor, sparsity_lut_x: Tensor,
|
|
|
361
361
|
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
362
362
|
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
363
363
|
|
|
364
|
-
(wrap_triton(
|
|
364
|
+
(wrap_triton(row_wise_add_kernel)[triton_grid]
|
|
365
365
|
(x,
|
|
366
366
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
367
367
|
sparsity_lut_x, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
@@ -383,33 +383,33 @@ def row_wise_add_forward(x: Tensor, sparsity_lut_x: Tensor,
|
|
|
383
383
|
reset_to_zero=["o"]
|
|
384
384
|
)
|
|
385
385
|
@triton.jit
|
|
386
|
-
def
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
386
|
+
def row_wise_add_kernel(x,
|
|
387
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
388
|
+
s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
389
|
+
y, y_b, y_b_s, y_r_s, y_c_s,
|
|
390
|
+
s_l_y_b, s_l_y_b_s, s_l_y_r_s,
|
|
391
|
+
r_lut_y,
|
|
392
|
+
o,
|
|
393
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
394
|
+
sparsity_block_size,
|
|
395
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
396
396
|
# Get triton block indices
|
|
397
397
|
pid_blk = tl.program_id(axis=0)
|
|
398
398
|
pid_row = tl.program_id(axis=1)
|
|
399
399
|
pid_col = tl.program_id(axis=2)
|
|
400
400
|
|
|
401
401
|
# Get position of current sparsity block consisting of its batch and row index
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
402
|
+
spa_val_idx = pid_blk * s_lut_x_r_s + tl.arange(0, 4) * s_lut_x_c_s
|
|
403
|
+
spa_val_msk = (tl.arange(0, 4) < 3)
|
|
404
|
+
spa_val = tl.load(s_lut_x + spa_val_idx, mask=spa_val_msk)
|
|
405
405
|
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
406
|
+
spa_bat_x = tl.sum(spa_val * (tl.arange(0, 4) == 0))
|
|
407
|
+
spa_row_x = tl.sum(spa_val * (tl.arange(0, 4) == 1))
|
|
408
|
+
spa_col_x = tl.sum(spa_val * (tl.arange(0, 4) == 2))
|
|
409
409
|
|
|
410
410
|
# Get reverse sparsity indices for s
|
|
411
|
-
rev_idx_spa_s_idx = (
|
|
412
|
-
|
|
411
|
+
rev_idx_spa_s_idx = (spa_bat_x * s_l_y_b_s +
|
|
412
|
+
spa_row_x * s_l_y_r_s)
|
|
413
413
|
rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_y_b * s_l_y_b_s)
|
|
414
414
|
rev_idx_spa_s = tl.load(r_lut_y + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
|
|
415
415
|
|
blksprs/ops/repeat.py
CHANGED
|
@@ -142,7 +142,7 @@ def repeat_build_lut(lut: dict, sparsity_layout_x: Tensor, repeats: tuple[int, i
|
|
|
142
142
|
n_sparse_blocks = torch.sum(lut["sparsity_layout_o"].to(torch.int)).item()
|
|
143
143
|
lut["n_sparse_blocks"] = n_sparse_blocks
|
|
144
144
|
|
|
145
|
-
validate_contiguous(sparsity_layout_o, lut["sparsity_lut"], lut["sparsity_reverse_lut"])
|
|
145
|
+
validate_contiguous(lut["sparsity_layout_o"], lut["sparsity_lut"], lut["sparsity_reverse_lut"])
|
|
146
146
|
|
|
147
147
|
return lut
|
|
148
148
|
|
|
@@ -178,7 +178,7 @@ def repeat_interleave_build_lut(lut: dict, sparsity_layout_x: Tensor, repeats: i
|
|
|
178
178
|
n_sparse_blocks = torch.sum(lut["sparsity_layout_o"].to(torch.int)).item()
|
|
179
179
|
lut["n_sparse_blocks"] = n_sparse_blocks
|
|
180
180
|
|
|
181
|
-
validate_contiguous(sparsity_layout_o, lut["sparsity_lut"], lut["sparsity_reverse_lut"])
|
|
181
|
+
validate_contiguous(lut["sparsity_layout_o"], lut["sparsity_lut"], lut["sparsity_reverse_lut"])
|
|
182
182
|
|
|
183
183
|
return lut
|
|
184
184
|
|
blksprs/ops/softmax.py
CHANGED
|
@@ -176,13 +176,12 @@ def softmax_kernel(x,
|
|
|
176
176
|
pid_col = tl.program_id(axis=2)
|
|
177
177
|
|
|
178
178
|
# Get position of current sparsity block consisting of its batch and row index
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
179
|
+
spa_val_idx = pid_blk * s_lut_r_s + tl.arange(0, 4) * s_lut_c_s
|
|
180
|
+
spa_val_msk = (tl.arange(0, 4) < 3)
|
|
181
|
+
spa_val = tl.load(s_lut + spa_val_idx, mask=spa_val_msk)
|
|
182
182
|
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
183
|
+
spa_bat = tl.sum(spa_val * (tl.arange(0, 4) == 0))
|
|
184
|
+
spa_row = tl.sum(spa_val * (tl.arange(0, 4) == 1))
|
|
186
185
|
|
|
187
186
|
# Get reverse sparsity indices for s
|
|
188
187
|
rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
|
|
@@ -241,13 +240,12 @@ def softmax_kernel_grad(g,
|
|
|
241
240
|
pid_col = tl.program_id(axis=2)
|
|
242
241
|
|
|
243
242
|
# Get position of current sparsity block consisting of its batch and row index
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
243
|
+
spa_val_idx = pid_blk * s_lut_r_s + tl.arange(0, 4) * s_lut_c_s
|
|
244
|
+
spa_val_msk = (tl.arange(0, 4) < 3)
|
|
245
|
+
spa_val = tl.load(s_lut + spa_val_idx, mask=spa_val_msk)
|
|
247
246
|
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
247
|
+
spa_bat = tl.sum(spa_val * (tl.arange(0, 4) == 0))
|
|
248
|
+
spa_row = tl.sum(spa_val * (tl.arange(0, 4) == 1))
|
|
251
249
|
|
|
252
250
|
rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
|
|
253
251
|
spa_row * s_l_s_r_s)
|
blksprs/utils/autotuning.py
CHANGED
blksprs/utils/validation.py
CHANGED
|
@@ -1,9 +1,17 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
from torch import Tensor
|
|
3
3
|
|
|
4
|
+
CONTIGUOUS = True
|
|
4
5
|
VALIDATION = True
|
|
5
6
|
|
|
6
7
|
|
|
8
|
+
def ensure_contiguous(*tensors: Tensor) -> tuple[Tensor, ...]:
|
|
9
|
+
if _check_skip_contiguous():
|
|
10
|
+
return tensors
|
|
11
|
+
|
|
12
|
+
return tuple(tensor.contiguous() for tensor in tensors)
|
|
13
|
+
|
|
14
|
+
|
|
7
15
|
def validate_dimensions(*tensors: Tensor, dims=3) -> None:
|
|
8
16
|
if _check_skip_validation():
|
|
9
17
|
return
|
|
@@ -124,6 +132,19 @@ def validate_sparsity_block_size(sparsity_block_size: int, *tensors):
|
|
|
124
132
|
raise ValueError("Tensor sizes must be divisible by sparsity block size")
|
|
125
133
|
|
|
126
134
|
|
|
135
|
+
def _check_skip_contiguous():
|
|
136
|
+
return not CONTIGUOUS
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def _set_skip_contiguous(skip_contiguous: bool):
|
|
140
|
+
global CONTIGUOUS
|
|
141
|
+
CONTIGUOUS = not skip_contiguous
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def disable_contiguous():
|
|
145
|
+
_set_skip_contiguous(True)
|
|
146
|
+
|
|
147
|
+
|
|
127
148
|
def _check_skip_validation():
|
|
128
149
|
return not VALIDATION
|
|
129
150
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 2.1.
|
|
3
|
+
Version: 2.1.5
|
|
4
4
|
Summary: A lightweight library for operations on block-sparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
blksprs/__init__.py,sha256=xlrLL9EgaiEnGQsdzFScy4SVZMN9g_5nvX-LkWxVKCw,1631
|
|
2
|
+
blksprs/layouting/distribution_layout.py,sha256=ur1ty_2U-Hfj78hMWsLZvu7ZuGhzW3qGLKMc72DfTZM,5861
|
|
3
|
+
blksprs/layouting/sparsity_layout.py,sha256=eXHmu2h7K5Q-YUpfOxocJoeP_5ZoQFZf_eHLxRZQbYU,11207
|
|
4
|
+
blksprs/ops/conversion.py,sha256=_LgkT-6aSLPO2FXeMA2lE26g9qAEzxhMWcagenMedFU,21368
|
|
5
|
+
blksprs/ops/distribution.py,sha256=HcFKcB1x59cP8Im_LuxKeJXTknZNM2Kx8hz3nu1GpvE,20183
|
|
6
|
+
blksprs/ops/flow.py,sha256=JEGES5ZbMqxR02rwi2Ym4j3VDxkcRxhFO1f-5nNUlM8,7760
|
|
7
|
+
blksprs/ops/matmul.py,sha256=ZYOv8Qeb7pBpbMsMnndk7IR2WO8rEXfL_KtYhbVeFdw,11576
|
|
8
|
+
blksprs/ops/partitioning.py,sha256=cfQmY9BZqGTvvJorIhtb-EyuGRJGPraWR-wTKdb47aI,9954
|
|
9
|
+
blksprs/ops/repeat.py,sha256=-rFC-u2eytmFxKi7vZTXpvyxReHOPZeRz4SvuO07NxE,9049
|
|
10
|
+
blksprs/ops/softmax.py,sha256=iJ8GniyM83iKM3J9BXTpLdqqEVeRjxeU2rAKP553VPM,23439
|
|
11
|
+
blksprs/ops/transpose.py,sha256=U-VAyLRT6_NDv9qYSFzBqfVlDeIpTqAMEXkqto0VF6w,4072
|
|
12
|
+
blksprs/ops/misc/broadcast_ops.py,sha256=TD7wjBJIMn-4SUdYy7e_5bpf0UQ4Sga4QEipQFaaVPM,5684
|
|
13
|
+
blksprs/ops/misc/row_wise.py,sha256=kKKpDfpq92UU5P7HVuK9gh2MNPvHOB2KQ6ijKE1RmHM,19359
|
|
14
|
+
blksprs/utils/autotuning.py,sha256=xalNP3sWdRn8XiVG4jE1-_iy2QhUmIJvTGM83YwgKA0,2052
|
|
15
|
+
blksprs/utils/benchmarking.py,sha256=dLabDscTFn5NkmOI1g7DnKeTneUYW3RIVv9MDF-8BKc,1271
|
|
16
|
+
blksprs/utils/blksprs_tensor.py,sha256=pfoz59aJixj_fIoFx76ySiygwRQUemmgjMKepZ2c4j0,244
|
|
17
|
+
blksprs/utils/processing.py,sha256=RNkEDc0g-sNHRuMPkRzNWU13d3_lIkXMJdoqES4yQTM,3738
|
|
18
|
+
blksprs/utils/tools.py,sha256=TKygEKge4wJtJnXXDg8BTL8vzBpqIJsQ_A3_5FmLpcE,859
|
|
19
|
+
blksprs/utils/validation.py,sha256=hME6hf5t7-IxM1rHypqlzk7IE1kYEQACqCZ9KEtW6N0,4775
|
|
20
|
+
blksprs-2.1.5.dist-info/METADATA,sha256=MnA7fThFWn_mrMk0BFEkBm29rtWtPY-npQecpGF1P7c,9590
|
|
21
|
+
blksprs-2.1.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
22
|
+
blksprs-2.1.5.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
|
|
23
|
+
blksprs-2.1.5.dist-info/RECORD,,
|
blksprs-2.1.4.dist-info/RECORD
DELETED
|
@@ -1,23 +0,0 @@
|
|
|
1
|
-
blksprs/__init__.py,sha256=XERzTtkiElDeBppOO8rNrF6OktUQf_yozDiA4DUXqTY,1615
|
|
2
|
-
blksprs/layouting/distribution_layout.py,sha256=ur1ty_2U-Hfj78hMWsLZvu7ZuGhzW3qGLKMc72DfTZM,5861
|
|
3
|
-
blksprs/layouting/sparsity_layout.py,sha256=eXHmu2h7K5Q-YUpfOxocJoeP_5ZoQFZf_eHLxRZQbYU,11207
|
|
4
|
-
blksprs/ops/conversion.py,sha256=nv5gXiyZkUtk1kCIlPr0Vpaj4G8G6dJdW7StlbV3nDw,21914
|
|
5
|
-
blksprs/ops/distribution.py,sha256=0tPldv0ARzmCV1CU2jvfqpHBgOuHPrDFiCtqsLs7CZc,20789
|
|
6
|
-
blksprs/ops/flow.py,sha256=oUn_xDT74220-EmnBnB8bRNtbS1mjbxWpm76PFsK22o,8246
|
|
7
|
-
blksprs/ops/matmul.py,sha256=ES9bpiCIRBxaynNIL5ftDP0c9LSArbj8YJqkPEzBaIU,11879
|
|
8
|
-
blksprs/ops/partitioning.py,sha256=cfQmY9BZqGTvvJorIhtb-EyuGRJGPraWR-wTKdb47aI,9954
|
|
9
|
-
blksprs/ops/repeat.py,sha256=TLYNxwPuT9y5K9xyM41WK5gnggAJF3lI61Q2K7zWjns,9035
|
|
10
|
-
blksprs/ops/softmax.py,sha256=tfC_jaAKrA956rxGeb57klMuYRKTiyMCd5Zg5DIH3fc,23649
|
|
11
|
-
blksprs/ops/transpose.py,sha256=U-VAyLRT6_NDv9qYSFzBqfVlDeIpTqAMEXkqto0VF6w,4072
|
|
12
|
-
blksprs/ops/misc/broadcast_ops.py,sha256=-PrHiSJikZh8nXUmXxSCtFEP27TTxFr4wcrNxBjnimk,5987
|
|
13
|
-
blksprs/ops/misc/row_wise.py,sha256=n5FJjAuOd8BHBJQx4bsQwr-HmXkR9PYVAqfk77wjOFU,19653
|
|
14
|
-
blksprs/utils/autotuning.py,sha256=a-kmWRjJ3eED2XbjkQeOJSyW8bdIs27HgKMPvAKqWeU,2052
|
|
15
|
-
blksprs/utils/benchmarking.py,sha256=dLabDscTFn5NkmOI1g7DnKeTneUYW3RIVv9MDF-8BKc,1271
|
|
16
|
-
blksprs/utils/blksprs_tensor.py,sha256=pfoz59aJixj_fIoFx76ySiygwRQUemmgjMKepZ2c4j0,244
|
|
17
|
-
blksprs/utils/processing.py,sha256=RNkEDc0g-sNHRuMPkRzNWU13d3_lIkXMJdoqES4yQTM,3738
|
|
18
|
-
blksprs/utils/tools.py,sha256=TKygEKge4wJtJnXXDg8BTL8vzBpqIJsQ_A3_5FmLpcE,859
|
|
19
|
-
blksprs/utils/validation.py,sha256=G8eQlvJVMKfEX3k2AwBD0A6Ck-gFoRLpLNY6HXsB3fA,4348
|
|
20
|
-
blksprs-2.1.4.dist-info/METADATA,sha256=qGLQunHEIoHlmRvFnM0TVDjOSApwGzBglpZezmfhHLU,9590
|
|
21
|
-
blksprs-2.1.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
22
|
-
blksprs-2.1.4.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
|
|
23
|
-
blksprs-2.1.4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|