blksprs 1.6.1__py3-none-any.whl → 1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +11 -6
- blksprs/experimental/distribution_mdi.py +14 -14
- blksprs/layouting/distribution_layout.py +4 -4
- blksprs/layouting/sparsity_layout.py +6 -6
- blksprs/misc/broadcast_ops.py +5 -5
- blksprs/{ops → misc}/exp.py +3 -3
- blksprs/{ops → misc}/partitioning.py +9 -98
- blksprs/misc/row_wise.py +16 -15
- blksprs/ops/conversion.py +23 -12
- blksprs/ops/distribution.py +11 -11
- blksprs/ops/matmul.py +7 -7
- blksprs/ops/repeat.py +322 -0
- blksprs/ops/softmax.py +12 -11
- blksprs/ops/transpose.py +7 -6
- blksprs/utils/tools.py +3 -0
- blksprs/utils/validation.py +20 -1
- {blksprs-1.6.1.dist-info → blksprs-1.8.dist-info}/METADATA +12 -5
- blksprs-1.8.dist-info/RECORD +21 -0
- blksprs/misc/repeat_interleave.py +0 -132
- blksprs-1.6.1.dist-info/RECORD +0 -21
- {blksprs-1.6.1.dist-info → blksprs-1.8.dist-info}/WHEEL +0 -0
- {blksprs-1.6.1.dist-info → blksprs-1.8.dist-info}/top_level.txt +0 -0
blksprs/__init__.py
CHANGED
|
@@ -1,22 +1,27 @@
|
|
|
1
|
-
from blksprs.ops.conversion import to_dense, to_sparse
|
|
1
|
+
from blksprs.ops.conversion import to_dense, to_sparse, from_blksprs, to_blksprs
|
|
2
2
|
from blksprs.ops.distribution import gather, scatter, scatter_reduce
|
|
3
|
-
from blksprs.ops.exp import exp
|
|
4
3
|
from blksprs.ops.matmul import matmul
|
|
5
4
|
from blksprs.ops.softmax import softmax
|
|
6
5
|
from blksprs.ops.transpose import transpose
|
|
7
|
-
from blksprs.ops.
|
|
6
|
+
from blksprs.ops.repeat import repeat, repeat_interleave
|
|
7
|
+
from blksprs.misc.partitioning import split, merge
|
|
8
|
+
|
|
8
9
|
|
|
9
10
|
class layout:
|
|
10
11
|
from blksprs.layouting.distribution_layout import build_distribution_layout
|
|
11
|
-
from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption,
|
|
12
|
+
from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption, \
|
|
13
|
+
build_sparsity_layout_matmul
|
|
14
|
+
|
|
12
15
|
|
|
13
16
|
class misc:
|
|
14
17
|
from blksprs.misc.broadcast_ops import broadcast_add, broadcast_sub
|
|
15
|
-
from blksprs.misc.
|
|
18
|
+
from blksprs.misc.exp import exp
|
|
16
19
|
from blksprs.misc.row_wise import row_wise_sum, row_wise_max, row_wise_add, row_wise_sub
|
|
17
20
|
|
|
21
|
+
|
|
18
22
|
class util:
|
|
19
23
|
from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse, disable_validation
|
|
20
24
|
|
|
25
|
+
|
|
21
26
|
class experimental:
|
|
22
|
-
from blksprs.experimental.distribution_mdi import gather_mdi, scatter_reduce_mdi
|
|
27
|
+
from blksprs.experimental.distribution_mdi import gather_mdi, scatter_reduce_mdi
|
|
@@ -3,7 +3,7 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
-
from blksprs.utils.tools import get_triton_block_size
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
7
7
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
8
8
|
validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
|
|
9
9
|
|
|
@@ -51,15 +51,15 @@ class _BlocksparseGatherMDI(torch.autograd.Function):
|
|
|
51
51
|
output = torch.empty_like(idx_col, dtype=x.dtype)
|
|
52
52
|
|
|
53
53
|
x_b, x_r, x_c = x.size()
|
|
54
|
-
x_b_s, x_r_s, x_c_s =
|
|
54
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
55
55
|
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
|
|
56
|
-
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s =
|
|
56
|
+
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
|
|
57
57
|
i_b, i_r, i_c = idx_col.size()
|
|
58
|
-
i_b_s, i_r_s, i_c_s =
|
|
58
|
+
i_b_s, i_r_s, i_c_s = stride(idx_col)
|
|
59
59
|
s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
|
|
60
|
-
s_lut_i_r_s, s_lut_i_c_s =
|
|
60
|
+
s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
|
|
61
61
|
o_b, o_r, o_c = output.size()
|
|
62
|
-
o_b_s, o_r_s, o_c_s =
|
|
62
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
63
63
|
|
|
64
64
|
if triton_block_size is None:
|
|
65
65
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -224,15 +224,15 @@ class _BlocksparseScatterReduceMDI(torch.autograd.Function):
|
|
|
224
224
|
dtype=x.dtype, device=x.device)
|
|
225
225
|
|
|
226
226
|
x_b, x_r, x_c = x.size()
|
|
227
|
-
x_b_s, x_r_s, x_c_s =
|
|
227
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
228
228
|
s_lut_x_r, s_lut_x_c = sparsity_lut_x.size()
|
|
229
|
-
s_lut_x_r_s, s_lut_x_c_s =
|
|
229
|
+
s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut_x)
|
|
230
230
|
i_b, i_r, i_c = idx_col.size()
|
|
231
|
-
i_b_s, i_r_s, i_c_s =
|
|
231
|
+
i_b_s, i_r_s, i_c_s = stride(idx_col)
|
|
232
232
|
o_b, o_r, o_c = output.size()
|
|
233
|
-
o_b_s, o_r_s, o_c_s =
|
|
233
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
234
234
|
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
|
|
235
|
-
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s =
|
|
235
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
|
|
236
236
|
|
|
237
237
|
if triton_block_size is None:
|
|
238
238
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -366,11 +366,11 @@ def build_distribution_layout_mdi(idx_bat: Tensor, idx_row: Tensor, idx_col: Ten
|
|
|
366
366
|
dtype=torch.bool, device=idx_col.device)
|
|
367
367
|
|
|
368
368
|
i_b, i_r, i_c = idx_col.size()
|
|
369
|
-
i_b_s, i_r_s, i_c_s =
|
|
369
|
+
i_b_s, i_r_s, i_c_s = stride(idx_col)
|
|
370
370
|
s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
|
|
371
|
-
s_lut_i_r_s, s_lut_i_c_s =
|
|
371
|
+
s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
|
|
372
372
|
o_b, o_r, o_c = output.size()
|
|
373
|
-
o_b_s, o_r_s, o_c_s =
|
|
373
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
374
374
|
|
|
375
375
|
if triton_block_size is None:
|
|
376
376
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -3,7 +3,7 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
-
from blksprs.utils.tools import get_triton_block_size
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
7
7
|
from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
|
|
8
8
|
validate_contiguous
|
|
9
9
|
|
|
@@ -34,11 +34,11 @@ def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
|
|
|
34
34
|
dtype=torch.bool, device=indices.device)
|
|
35
35
|
|
|
36
36
|
i_b, i_r, i_c = indices.size()
|
|
37
|
-
i_b_s, i_r_s, i_c_s =
|
|
37
|
+
i_b_s, i_r_s, i_c_s = stride(indices)
|
|
38
38
|
s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
|
|
39
|
-
s_lut_i_r_s, s_lut_i_c_s =
|
|
39
|
+
s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
|
|
40
40
|
o_b, o_r, o_c = output.size()
|
|
41
|
-
o_b_s, o_r_s, o_c_s =
|
|
41
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
42
42
|
|
|
43
43
|
if triton_block_size is None:
|
|
44
44
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -5,7 +5,7 @@ import triton
|
|
|
5
5
|
from torch import Tensor
|
|
6
6
|
from triton import language as tl
|
|
7
7
|
|
|
8
|
-
from blksprs.utils.tools import get_triton_block_size
|
|
8
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
9
9
|
from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
|
|
10
10
|
validate_contiguous, validate_sparsity, validate_sparsity_block_size
|
|
11
11
|
|
|
@@ -30,9 +30,9 @@ def build_sparsity_layout(x: Tensor, sparsity_block_size: int, triton_block_size
|
|
|
30
30
|
dtype=torch.bool, device=x.device)
|
|
31
31
|
|
|
32
32
|
x_b, x_r, x_c = x.size()
|
|
33
|
-
x_b_s, x_r_s, x_c_s =
|
|
33
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
34
34
|
o_b, o_r, o_c = output.size()
|
|
35
|
-
o_b_s, o_r_s, o_c_s =
|
|
35
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
36
36
|
|
|
37
37
|
if triton_block_size is None:
|
|
38
38
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -120,10 +120,10 @@ def build_sparsity_layout_adaption(x: Tensor, sparsity_layout_from: Tensor,
|
|
|
120
120
|
output = torch.zeros(o_b, o_r, o_c, dtype=torch.bool, device=x.device)
|
|
121
121
|
|
|
122
122
|
x_b, x_r, x_c = x.size()
|
|
123
|
-
x_b_s, x_r_s, x_c_s =
|
|
123
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
124
124
|
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
125
|
-
s_lut_r_s, s_lut_c_s =
|
|
126
|
-
o_b_s, o_r_s, o_c_s =
|
|
125
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
126
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
127
127
|
|
|
128
128
|
if triton_block_size is None:
|
|
129
129
|
triton_block_size = get_triton_block_size(sparsity_block_size_from)
|
blksprs/misc/broadcast_ops.py
CHANGED
|
@@ -3,7 +3,7 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
-
from blksprs.utils.tools import get_triton_block_size
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
7
7
|
from blksprs.utils.validation import validate_contiguous, validate_device, \
|
|
8
8
|
validate_sparsity_block_size, validate_triton_block_size
|
|
9
9
|
|
|
@@ -44,13 +44,13 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
|
44
44
|
output = torch.zeros(n_sparse_blocks, sparsity_block_size, sparsity_block_size, dtype=x.dtype, device=x.device)
|
|
45
45
|
|
|
46
46
|
x_b, x_c = x.size()
|
|
47
|
-
x_b_s, x_c_s =
|
|
47
|
+
x_b_s, x_c_s = stride(x)
|
|
48
48
|
y_b, y_c = y.size()
|
|
49
|
-
y_b_s, y_c_s =
|
|
49
|
+
y_b_s, y_c_s = stride(y)
|
|
50
50
|
o_b, o_r, o_c = output.size()
|
|
51
|
-
o_b_s, o_r_s, o_c_s =
|
|
51
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
52
52
|
s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
|
|
53
|
-
s_lut_o_r_s, s_lut_o_c_s =
|
|
53
|
+
s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_o)
|
|
54
54
|
|
|
55
55
|
if triton_block_size is None:
|
|
56
56
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
blksprs/{ops → misc}/exp.py
RENAMED
|
@@ -3,7 +3,7 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
-
from blksprs.utils.tools import get_triton_block_size
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
7
7
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
8
8
|
validate_sparsity_block_size, validate_triton_block_size
|
|
9
9
|
|
|
@@ -43,9 +43,9 @@ class _BlocksparseExp(torch.autograd.Function):
|
|
|
43
43
|
output = torch.empty_like(x)
|
|
44
44
|
|
|
45
45
|
x_b, x_r, x_c = x.shape
|
|
46
|
-
x_b_s, x_r_s, x_c_s =
|
|
46
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
47
47
|
o_b, o_r, o_c = output.shape
|
|
48
|
-
o_b_s, o_r_s, o_c_s =
|
|
48
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
49
49
|
|
|
50
50
|
if triton_block_size is None:
|
|
51
51
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -1,10 +1,7 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
import triton
|
|
3
|
-
from sympy.utilities.iterables import partitions
|
|
4
2
|
from torch import Tensor
|
|
5
|
-
from triton import language as tl
|
|
6
3
|
|
|
7
|
-
from blksprs.
|
|
4
|
+
from blksprs.ops.repeat import forward_flow
|
|
8
5
|
|
|
9
6
|
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
|
|
10
7
|
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
@@ -48,12 +45,11 @@ def split(x: Tensor, sparsity_layout: Tensor, partitions: int,
|
|
|
48
45
|
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
49
46
|
(sparsity_layout_flat == 1) -
|
|
50
47
|
(1 * (sparsity_layout_flat == 0)))
|
|
51
|
-
.reshape(sparsity_layout.size())
|
|
52
48
|
.reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
|
|
53
49
|
sparsity_layout.size(2) // partitions)
|
|
54
50
|
.permute(0, 2, 1, 3).reshape(-1).contiguous())
|
|
55
51
|
|
|
56
|
-
n_sparse_blocks = torch.sum(
|
|
52
|
+
n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
|
|
57
53
|
|
|
58
54
|
validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
|
|
59
55
|
|
|
@@ -66,10 +62,11 @@ class _BlocksparseSplit(torch.autograd.Function):
|
|
|
66
62
|
@staticmethod
|
|
67
63
|
def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
68
64
|
num_partitions: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
65
|
+
ctx.save_for_backward(sparsity_layout_o)
|
|
69
66
|
ctx.num_partitions = num_partitions
|
|
70
67
|
|
|
71
|
-
return
|
|
72
|
-
|
|
68
|
+
return forward_flow(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
69
|
+
n_sparse_blocks, triton_block_size)
|
|
73
70
|
|
|
74
71
|
@staticmethod
|
|
75
72
|
def backward(ctx, grad_output):
|
|
@@ -126,7 +123,7 @@ def merge(x: Tensor, sparsity_layout: Tensor, partitions: int,
|
|
|
126
123
|
sparsity_layout.size(1), sparsity_layout.size(2) * partitions)
|
|
127
124
|
.reshape(-1).contiguous())
|
|
128
125
|
|
|
129
|
-
n_sparse_blocks = torch.sum(
|
|
126
|
+
n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
|
|
130
127
|
|
|
131
128
|
validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
|
|
132
129
|
|
|
@@ -139,10 +136,11 @@ class _BlocksparseMerge(torch.autograd.Function):
|
|
|
139
136
|
@staticmethod
|
|
140
137
|
def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
141
138
|
num_partitions: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
139
|
+
ctx.save_for_backward(sparsity_layout_o)
|
|
142
140
|
ctx.num_partitions = num_partitions
|
|
143
141
|
|
|
144
|
-
return
|
|
145
|
-
|
|
142
|
+
return forward_flow(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
143
|
+
n_sparse_blocks, triton_block_size)
|
|
146
144
|
|
|
147
145
|
@staticmethod
|
|
148
146
|
def backward(ctx, grad_output):
|
|
@@ -155,90 +153,3 @@ class _BlocksparseMerge(torch.autograd.Function):
|
|
|
155
153
|
sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None
|
|
156
154
|
|
|
157
155
|
|
|
158
|
-
def forward_reorder(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
159
|
-
sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
160
|
-
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
161
|
-
dtype=x.dtype, device=x.device)
|
|
162
|
-
|
|
163
|
-
x_b, x_r, x_c = x.size()
|
|
164
|
-
x_b_s, x_r_s, x_c_s = x.stride()
|
|
165
|
-
s_l_b, s_l_r, s_l_c = sparsity_layout_o.size()
|
|
166
|
-
s_l_b_s, s_l_r_s, s_l_c_s = sparsity_layout_o.stride()
|
|
167
|
-
s_lut_r, s_lut_c = sparsity_lut.shape
|
|
168
|
-
s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
|
|
169
|
-
o_b, o_r, o_c = output.size()
|
|
170
|
-
o_b_s, o_r_s, o_c_s = output.stride()
|
|
171
|
-
|
|
172
|
-
if triton_block_size is None:
|
|
173
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
174
|
-
|
|
175
|
-
triton_grid = lambda meta: [o_b,
|
|
176
|
-
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
177
|
-
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
178
|
-
|
|
179
|
-
(kernel_blocksparse_reorder[triton_grid]
|
|
180
|
-
(x,
|
|
181
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
182
|
-
s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
|
|
183
|
-
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
184
|
-
sparsity_reverse_lut,
|
|
185
|
-
output,
|
|
186
|
-
o_b, o_b_s,
|
|
187
|
-
triton_block_size))
|
|
188
|
-
|
|
189
|
-
# Save for backward pass
|
|
190
|
-
ctx.save_for_backward(sparsity_layout_o)
|
|
191
|
-
ctx.sparsity_block_size = sparsity_block_size
|
|
192
|
-
ctx.triton_block_size = triton_block_size
|
|
193
|
-
|
|
194
|
-
return output
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
@triton.jit
|
|
198
|
-
def kernel_blocksparse_reorder(x,
|
|
199
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
200
|
-
s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
|
|
201
|
-
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
202
|
-
r_lut,
|
|
203
|
-
o,
|
|
204
|
-
o_b, o_b_s,
|
|
205
|
-
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
206
|
-
# Get triton block indices
|
|
207
|
-
pid_blk = tl.program_id(axis=0)
|
|
208
|
-
pid_row = tl.program_id(axis=1)
|
|
209
|
-
pid_col = tl.program_id(axis=2)
|
|
210
|
-
|
|
211
|
-
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
212
|
-
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
213
|
-
spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
214
|
-
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
215
|
-
|
|
216
|
-
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
217
|
-
spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
|
|
218
|
-
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
219
|
-
|
|
220
|
-
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
221
|
-
spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
|
|
222
|
-
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
223
|
-
|
|
224
|
-
# Get reverse sparsity index
|
|
225
|
-
rev_idx_spa_idx = (spa_bat * s_l_b_s +
|
|
226
|
-
spa_row * s_l_r_s +
|
|
227
|
-
spa_col * s_l_c_s)
|
|
228
|
-
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_b * s_l_b_s)
|
|
229
|
-
rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
230
|
-
|
|
231
|
-
if rev_idx_spa == -1:
|
|
232
|
-
assert False, "Invalid sparsity block"
|
|
233
|
-
|
|
234
|
-
blk_x_idx = (rev_idx_spa * x_b_s +
|
|
235
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
236
|
-
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
237
|
-
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
238
|
-
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
239
|
-
|
|
240
|
-
blk_o_idx = (pid_blk * o_b_s +
|
|
241
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
242
|
-
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
243
|
-
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
244
|
-
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
blksprs/misc/row_wise.py
CHANGED
|
@@ -3,7 +3,7 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
-
from blksprs.utils.tools import get_triton_block_size
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
7
7
|
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, validate_sparsity, \
|
|
8
8
|
validate_sparsity_block_size, validate_triton_block_size
|
|
9
9
|
|
|
@@ -60,13 +60,13 @@ def row_wise_sum(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
|
60
60
|
device=x.device)
|
|
61
61
|
|
|
62
62
|
x_b, x_r, x_c = x.size()
|
|
63
|
-
x_b_s, x_r_s, x_c_s =
|
|
63
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
64
64
|
s_lut_x_r, s_lut_x_c = sparsity_lut.size()
|
|
65
|
-
s_lut_x_r_s, s_lut_x_c_s =
|
|
65
|
+
s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut)
|
|
66
66
|
o_b, o_r, o_c = output.size()
|
|
67
|
-
o_b_s, o_r_s, o_c_s =
|
|
67
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
68
68
|
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
|
|
69
|
-
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s =
|
|
69
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
|
|
70
70
|
|
|
71
71
|
if triton_block_size is None:
|
|
72
72
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -183,13 +183,13 @@ def row_wise_max(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
|
183
183
|
device=x.device)
|
|
184
184
|
|
|
185
185
|
x_b, x_r, x_c = x.size()
|
|
186
|
-
x_b_s, x_r_s, x_c_s =
|
|
186
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
187
187
|
s_lut_x_r, s_lut_x_c = sparsity_lut.size()
|
|
188
|
-
s_lut_x_r_s, s_lut_x_c_s =
|
|
188
|
+
s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut)
|
|
189
189
|
o_b, o_r, o_c = output.size()
|
|
190
|
-
o_b_s, o_r_s, o_c_s =
|
|
190
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
191
191
|
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
|
|
192
|
-
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s =
|
|
192
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
|
|
193
193
|
|
|
194
194
|
if triton_block_size is None:
|
|
195
195
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -290,15 +290,15 @@ def row_wise_add(x: Tensor, sparsity_layout_x: Tensor, y: Tensor,
|
|
|
290
290
|
output = torch.empty_like(x)
|
|
291
291
|
|
|
292
292
|
x_b, x_r, x_c = x.size()
|
|
293
|
-
x_b_s, x_r_s, x_c_s =
|
|
293
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
294
294
|
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
295
|
-
s_lut_r_s, s_lut_c_s =
|
|
295
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
296
296
|
y_b, y_r, y_c = y.size()
|
|
297
|
-
y_b_s, y_r_s, y_c_s =
|
|
297
|
+
y_b_s, y_r_s, y_c_s = stride(y)
|
|
298
298
|
s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_rwm.size()
|
|
299
|
-
s_l_y_b_s, s_l_y_r_s, s_l_y_c_s =
|
|
299
|
+
s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = stride(sparsity_layout_rwm)
|
|
300
300
|
o_b, o_r, o_c = output.size()
|
|
301
|
-
o_b_s, o_r_s, o_c_s =
|
|
301
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
302
302
|
|
|
303
303
|
if triton_block_size is None:
|
|
304
304
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -361,7 +361,8 @@ def kernel_blocksparse_row_wise_add(x,
|
|
|
361
361
|
rev_idx_spa_s = tl.load(r_lut_y + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
|
|
362
362
|
|
|
363
363
|
if rev_idx_spa_s == -1:
|
|
364
|
-
|
|
364
|
+
tl.device_assert(False)
|
|
365
|
+
return
|
|
365
366
|
|
|
366
367
|
# Load x block
|
|
367
368
|
blk_x_idx = ((pid_blk * x_b_s) +
|
blksprs/ops/conversion.py
CHANGED
|
@@ -6,9 +6,14 @@ from torch import Tensor
|
|
|
6
6
|
from triton import language as tl
|
|
7
7
|
|
|
8
8
|
from blksprs.layouting.sparsity_layout import build_sparsity_layout_adaption
|
|
9
|
-
from blksprs.utils.tools import get_triton_block_size
|
|
9
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
10
10
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
11
|
-
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
11
|
+
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size, validate_sparsity_dense
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def from_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
|
|
15
|
+
triton_block_size: int = None) -> Tensor:
|
|
16
|
+
return to_dense(x, sparsity_layout, sparsity_block_size, fill_value, triton_block_size)
|
|
12
17
|
|
|
13
18
|
|
|
14
19
|
def to_dense(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
|
|
@@ -65,11 +70,11 @@ class _BlocksparseToDense(torch.autograd.Function):
|
|
|
65
70
|
dtype=x.dtype, device=x.device)
|
|
66
71
|
|
|
67
72
|
x_b, x_r, x_c = x.shape
|
|
68
|
-
x_b_s, x_r_s, x_c_s =
|
|
73
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
69
74
|
s_l_b, s_l_r, s_l_c = sparsity_layout.size()
|
|
70
|
-
s_l_b_s, s_l_r_s, s_l_c_s =
|
|
75
|
+
s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout)
|
|
71
76
|
o_b, o_r, o_c = output.size()
|
|
72
|
-
o_b_s, o_r_s, o_c_s =
|
|
77
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
73
78
|
|
|
74
79
|
if triton_block_size is None:
|
|
75
80
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -144,6 +149,11 @@ class _BlocksparseToDense(torch.autograd.Function):
|
|
|
144
149
|
tl.store(o + o_idx, blk, o_msk)
|
|
145
150
|
|
|
146
151
|
|
|
152
|
+
def to_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
153
|
+
triton_block_size: int = None) -> Tensor:
|
|
154
|
+
return to_sparse(x, sparsity_layout, sparsity_block_size, triton_block_size)
|
|
155
|
+
|
|
156
|
+
|
|
147
157
|
def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
148
158
|
"""Converts a block-sparse tensor in regular form to a block-sparse tensor in compressed form based on the given
|
|
149
159
|
sparsity layout.
|
|
@@ -163,6 +173,7 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, trit
|
|
|
163
173
|
validate_dimensions(x)
|
|
164
174
|
validate_contiguous(x)
|
|
165
175
|
validate_device(x)
|
|
176
|
+
validate_sparsity_dense(sparsity_block_size, (x, sparsity_layout))
|
|
166
177
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
167
178
|
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
168
179
|
|
|
@@ -190,11 +201,11 @@ class _BlocksparseToSparse(torch.autograd.Function):
|
|
|
190
201
|
dtype=x.dtype, device=x.device)
|
|
191
202
|
|
|
192
203
|
x_b, x_r, x_c = x.size()
|
|
193
|
-
x_b_s, x_r_s, x_c_s =
|
|
204
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
194
205
|
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
195
|
-
s_lut_r_s, s_lut_c_s =
|
|
206
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
196
207
|
o_b, o_r, o_c = output.size()
|
|
197
|
-
o_b_s, o_r_s, o_c_s =
|
|
208
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
198
209
|
|
|
199
210
|
if triton_block_size is None:
|
|
200
211
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -347,13 +358,13 @@ class _BlocksparseAdaptLayout(torch.autograd.Function):
|
|
|
347
358
|
dtype=x.dtype, device=x.device)
|
|
348
359
|
|
|
349
360
|
x_b, x_r, x_c = x.size()
|
|
350
|
-
x_b_s, x_r_s, x_c_s =
|
|
361
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
351
362
|
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_from.size()
|
|
352
|
-
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s =
|
|
363
|
+
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_from)
|
|
353
364
|
o_b, o_r, o_c = output.size()
|
|
354
|
-
o_b_s, o_r_s, o_c_s =
|
|
365
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
355
366
|
s_lut_o_r, s_lut_o_c = sparsity_lut_to.size()
|
|
356
|
-
s_lut_o_r_s, s_lut_o_c_s =
|
|
367
|
+
s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_to)
|
|
357
368
|
|
|
358
369
|
if triton_block_size is None:
|
|
359
370
|
triton_block_size = get_triton_block_size(min_sparsity_block_size)
|
blksprs/ops/distribution.py
CHANGED
|
@@ -3,7 +3,7 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
-
from blksprs.utils.tools import get_triton_block_size
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
7
7
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
8
8
|
validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
|
|
9
9
|
|
|
@@ -59,15 +59,15 @@ class _BlocksparseGather(torch.autograd.Function):
|
|
|
59
59
|
output = torch.empty_like(i, dtype=x.dtype)
|
|
60
60
|
|
|
61
61
|
x_b, x_r, x_c = x.size()
|
|
62
|
-
x_b_s, x_r_s, x_c_s =
|
|
62
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
63
63
|
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
|
|
64
|
-
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s =
|
|
64
|
+
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
|
|
65
65
|
i_b, i_r, i_c = i.size()
|
|
66
|
-
i_b_s, i_r_s, i_c_s =
|
|
66
|
+
i_b_s, i_r_s, i_c_s = stride(i)
|
|
67
67
|
s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
|
|
68
|
-
s_lut_i_r_s, s_lut_i_c_s =
|
|
68
|
+
s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
|
|
69
69
|
o_b, o_r, o_c = output.size()
|
|
70
|
-
o_b_s, o_r_s, o_c_s =
|
|
70
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
71
71
|
|
|
72
72
|
if triton_block_size is None:
|
|
73
73
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -248,15 +248,15 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
|
|
|
248
248
|
dtype=x.dtype, device=x.device)
|
|
249
249
|
|
|
250
250
|
x_b, x_r, x_c = x.size()
|
|
251
|
-
x_b_s, x_r_s, x_c_s =
|
|
251
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
252
252
|
s_lut_x_r, s_lut_x_c = sparsity_lut_x.size()
|
|
253
|
-
s_lut_x_r_s, s_lut_x_c_s =
|
|
253
|
+
s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut_x)
|
|
254
254
|
i_b, i_r, i_c = i.size()
|
|
255
|
-
i_b_s, i_r_s, i_c_s =
|
|
255
|
+
i_b_s, i_r_s, i_c_s = stride(i)
|
|
256
256
|
o_b, o_r, o_c = output.size()
|
|
257
|
-
o_b_s, o_r_s, o_c_s =
|
|
257
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
258
258
|
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
|
|
259
|
-
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s =
|
|
259
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
|
|
260
260
|
|
|
261
261
|
if triton_block_size is None:
|
|
262
262
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
blksprs/ops/matmul.py
CHANGED
|
@@ -4,7 +4,7 @@ from torch import Tensor
|
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
6
|
from blksprs.ops.transpose import transpose
|
|
7
|
-
from blksprs.utils.tools import get_triton_block_size
|
|
7
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
8
8
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
9
9
|
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size, validate_dtype_float
|
|
10
10
|
|
|
@@ -82,17 +82,17 @@ class _BlocksparseMatmulSSS(torch.autograd.Function):
|
|
|
82
82
|
dtype=x.dtype, device=x.device)
|
|
83
83
|
|
|
84
84
|
x_b, x_r, x_c = x.size()
|
|
85
|
-
x_b_s, x_r_s, x_c_s =
|
|
85
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
86
86
|
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
|
|
87
|
-
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s =
|
|
87
|
+
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
|
|
88
88
|
y_b, y_r, y_c = y.size()
|
|
89
|
-
y_b_s, y_r_s, y_c_s =
|
|
89
|
+
y_b_s, y_r_s, y_c_s = stride(y)
|
|
90
90
|
s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_y.size()
|
|
91
|
-
s_l_y_b_s, s_l_y_r_s, s_l_y_c_s =
|
|
91
|
+
s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = stride(sparsity_layout_y)
|
|
92
92
|
o_b, o_r, o_c = output.size()
|
|
93
|
-
o_b_s, o_r_s, o_c_s =
|
|
93
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
94
94
|
s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
|
|
95
|
-
s_lut_o_r_s, s_lut_o_c_s =
|
|
95
|
+
s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_o)
|
|
96
96
|
|
|
97
97
|
if triton_block_size is None:
|
|
98
98
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|