blksprs 1.10.1__py3-none-any.whl → 1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +0 -1
- blksprs/ops/conversion.py +42 -15
- blksprs/ops/distribution.py +60 -30
- blksprs/ops/flow.py +63 -31
- blksprs/ops/matmul.py +40 -22
- blksprs/ops/partitioning.py +102 -59
- blksprs/ops/repeat.py +88 -76
- blksprs/ops/softmax.py +71 -63
- blksprs/ops/transpose.py +38 -101
- blksprs/utils/tools.py +7 -1
- {blksprs-1.10.1.dist-info → blksprs-1.11.dist-info}/METADATA +2 -2
- blksprs-1.11.dist-info/RECORD +23 -0
- {blksprs-1.10.1.dist-info → blksprs-1.11.dist-info}/WHEEL +1 -1
- blksprs/ops/misc/exp.py +0 -104
- blksprs-1.10.1.dist-info/RECORD +0 -24
- {blksprs-1.10.1.dist-info → blksprs-1.11.dist-info}/top_level.txt +0 -0
blksprs/ops/softmax.py
CHANGED
|
@@ -3,7 +3,6 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
-
from blksprs.ops.misc.exp import exp
|
|
7
6
|
from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
|
|
8
7
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
9
8
|
from blksprs.utils.tools import get_triton_block_size, stride
|
|
@@ -12,7 +11,7 @@ from blksprs.utils.validation import validate_contiguous, validate_dimensions, v
|
|
|
12
11
|
|
|
13
12
|
|
|
14
13
|
def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
15
|
-
triton_block_size: int = None) -> BlksprsTensor:
|
|
14
|
+
triton_block_size: int = None, lut: dict = None) -> BlksprsTensor:
|
|
16
15
|
"""Computes the softmax of a block-sparse tensor in compressed form.
|
|
17
16
|
|
|
18
17
|
Note:
|
|
@@ -23,6 +22,7 @@ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
|
23
22
|
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
24
23
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
25
24
|
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
25
|
+
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
26
26
|
|
|
27
27
|
Returns:
|
|
28
28
|
BlksprsTensor: The result of the softmax operation as a block-sparse tensor in compressed form.
|
|
@@ -37,24 +37,38 @@ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
|
37
37
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
38
38
|
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
39
39
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
sparsity_layout_rws, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)
|
|
43
|
-
sparsity_layout_rws_flat = sparsity_layout_rws.reshape(-1)
|
|
44
|
-
sparsity_reverse_lut_rws = ((torch.cumsum(sparsity_layout_rws_flat, dim=-1) - 1) *
|
|
45
|
-
(sparsity_layout_rws_flat == 1) -
|
|
46
|
-
(1 * (sparsity_layout_rws_flat == 0)))
|
|
47
|
-
|
|
48
|
-
validate_contiguous(sparsity_layout, sparsity_lut, sparsity_reverse_lut_rws)
|
|
40
|
+
lut = _BlocksparseSoftmax.build_lut(lut, sparsity_layout)
|
|
49
41
|
|
|
50
42
|
return BlksprsTensor(_BlocksparseSoftmax.apply(x, sparsity_layout,
|
|
51
|
-
sparsity_lut,
|
|
52
|
-
sparsity_reverse_lut_rws,
|
|
43
|
+
lut["sparsity_lut"],
|
|
44
|
+
lut["sparsity_reverse_lut_rws"],
|
|
53
45
|
sparsity_block_size, triton_block_size))
|
|
54
46
|
|
|
55
47
|
|
|
56
48
|
class _BlocksparseSoftmax(torch.autograd.Function):
|
|
57
49
|
|
|
50
|
+
@staticmethod
|
|
51
|
+
def build_lut(lut: dict, sparsity_layout: Tensor):
|
|
52
|
+
if lut is None:
|
|
53
|
+
lut = dict()
|
|
54
|
+
|
|
55
|
+
if "sparsity_lut" not in lut:
|
|
56
|
+
sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
|
|
57
|
+
lut["sparsity_lut"] = sparsity_lut
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
if "sparsity_reverse_lut_rws" not in lut:
|
|
61
|
+
sparsity_layout_rws, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)
|
|
62
|
+
sparsity_layout_rws_flat = sparsity_layout_rws.reshape(-1)
|
|
63
|
+
sparsity_reverse_lut_rws = ((torch.cumsum(sparsity_layout_rws_flat, dim=-1) - 1) *
|
|
64
|
+
(sparsity_layout_rws_flat == 1) -
|
|
65
|
+
(1 * (sparsity_layout_rws_flat == 0)))
|
|
66
|
+
lut["sparsity_reverse_lut_rws"] = sparsity_reverse_lut_rws
|
|
67
|
+
|
|
68
|
+
validate_contiguous(sparsity_layout, lut["sparsity_lut"], lut["sparsity_reverse_lut_rws"])
|
|
69
|
+
|
|
70
|
+
return lut
|
|
71
|
+
|
|
58
72
|
@staticmethod
|
|
59
73
|
def forward(ctx, x: Tensor, sparsity_layout: Tensor,
|
|
60
74
|
sparsity_lut: Tensor,
|
|
@@ -72,7 +86,7 @@ class _BlocksparseSoftmax(torch.autograd.Function):
|
|
|
72
86
|
flag_slice_only=True,
|
|
73
87
|
triton_block_size=triton_block_size)
|
|
74
88
|
x_scaled = row_wise_sub(x, sparsity_layout, x_row_wise_max, sparsity_block_size, triton_block_size)
|
|
75
|
-
x_exp = exp(x_scaled
|
|
89
|
+
x_exp = torch.exp(x_scaled)
|
|
76
90
|
x_exp_row_wise_sum, sparsity_layout_rws = row_wise_sum(x_exp, sparsity_layout, sparsity_block_size,
|
|
77
91
|
flag_slice_only=True,
|
|
78
92
|
triton_block_size=triton_block_size)
|
|
@@ -182,29 +196,26 @@ class _BlocksparseSoftmax(torch.autograd.Function):
|
|
|
182
196
|
rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
|
|
183
197
|
rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
|
|
184
198
|
|
|
185
|
-
if rev_idx_spa_s
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
193
|
-
blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
|
|
194
|
-
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
199
|
+
if rev_idx_spa_s >= 0:
|
|
200
|
+
# Load x block
|
|
201
|
+
blk_x_idx = ((pid_blk * x_b_s) +
|
|
202
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
203
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
204
|
+
blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
|
|
205
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
195
206
|
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
207
|
+
# Load sum block
|
|
208
|
+
blk_s_idx = (rev_idx_spa_s * s_b_s +
|
|
209
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
|
|
210
|
+
(tl.arange(0, 1) * s_c_s)[None, :])
|
|
211
|
+
blk_s_msk = (blk_s_idx >= 0 and blk_s_idx < s_b * s_b_s)
|
|
212
|
+
blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
|
|
202
213
|
|
|
203
|
-
|
|
204
|
-
|
|
214
|
+
# Compute softmax
|
|
215
|
+
buf = tl.div_rn(blk_x, blk_s)
|
|
205
216
|
|
|
206
|
-
|
|
207
|
-
|
|
217
|
+
# Store output
|
|
218
|
+
tl.store(o + blk_x_idx, buf, mask=blk_x_msk)
|
|
208
219
|
|
|
209
220
|
@staticmethod
|
|
210
221
|
@triton.jit
|
|
@@ -239,32 +250,29 @@ class _BlocksparseSoftmax(torch.autograd.Function):
|
|
|
239
250
|
rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
|
|
240
251
|
rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
|
|
241
252
|
|
|
242
|
-
if rev_idx_spa_s
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
269
|
-
blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
|
|
270
|
-
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
|
253
|
+
if rev_idx_spa_s >= 0:
|
|
254
|
+
blk_s_idx = (rev_idx_spa_s * s_b_s +
|
|
255
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
|
|
256
|
+
(tl.arange(0, 1) * s_c_s)[None, :])
|
|
257
|
+
blk_s_msk = (blk_s_idx >= 0 and blk_s_idx < s_b * s_b_s)
|
|
258
|
+
blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
|
|
259
|
+
|
|
260
|
+
blk_g_idx = ((pid_blk * g_b_s) +
|
|
261
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_r_s)[:, None] +
|
|
262
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_c_s)[None, :])
|
|
263
|
+
blk_g_msk = (blk_g_idx >= 0 and blk_g_idx < g_b * g_b_s)
|
|
264
|
+
blk_g = tl.load(g + blk_g_idx, mask=blk_g_msk)
|
|
265
|
+
|
|
266
|
+
blk_x_idx = ((pid_blk * x_b_s) +
|
|
267
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
268
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
269
|
+
blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
|
|
270
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
271
|
+
|
|
272
|
+
buf = blk_x * (blk_g - blk_s)
|
|
273
|
+
|
|
274
|
+
blk_o_idx = ((pid_blk * o_b_s) +
|
|
275
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
276
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
277
|
+
blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
|
|
278
|
+
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
blksprs/ops/transpose.py
CHANGED
|
@@ -3,14 +3,15 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
+
from blksprs.ops.flow import flow_forward_pull
|
|
6
7
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
7
8
|
from blksprs.utils.tools import get_triton_block_size, stride
|
|
8
9
|
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
|
|
9
10
|
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
10
11
|
|
|
11
12
|
|
|
12
|
-
def transpose(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, triton_block_size: int = None
|
|
13
|
-
|
|
13
|
+
def transpose(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, triton_block_size: int = None,
|
|
14
|
+
lut: dict = None) -> (BlksprsTensor, Tensor):
|
|
14
15
|
"""Transposes a block-sparse tensor in compressed form.
|
|
15
16
|
|
|
16
17
|
Note:
|
|
@@ -21,6 +22,7 @@ def transpose(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: in
|
|
|
21
22
|
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
22
23
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
23
24
|
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
25
|
+
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
24
26
|
|
|
25
27
|
Returns:
|
|
26
28
|
BlksprsTensor: The transposed block-sparse tensor in compressed form.
|
|
@@ -28,6 +30,7 @@ def transpose(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: in
|
|
|
28
30
|
|
|
29
31
|
"""
|
|
30
32
|
x = x.contiguous()
|
|
33
|
+
x_t = x.transpose(-1, -2).contiguous()
|
|
31
34
|
|
|
32
35
|
validate_dimensions(x)
|
|
33
36
|
validate_contiguous(x)
|
|
@@ -36,66 +39,53 @@ def transpose(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: in
|
|
|
36
39
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
37
40
|
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
38
41
|
|
|
39
|
-
|
|
42
|
+
lut = _BlocksparseTranspose.build_lut(lut, sparsity_layout)
|
|
40
43
|
|
|
41
|
-
|
|
44
|
+
return BlksprsTensor(
|
|
45
|
+
_BlocksparseTranspose.apply(x_t, lut["sparsity_layout_t"], lut["sparsity_lut"], lut["sparsity_reverse_lut"],
|
|
46
|
+
sparsity_block_size,
|
|
47
|
+
lut["n_sparse_blocks"], triton_block_size)), lut["sparsity_layout_t"]
|
|
42
48
|
|
|
43
|
-
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
|
44
|
-
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
45
|
-
(sparsity_layout_flat == 1) -
|
|
46
|
-
(1 * (sparsity_layout_flat == 0)))
|
|
47
|
-
.reshape(sparsity_layout.size()).transpose(-1, -2).contiguous().reshape(-1))
|
|
48
49
|
|
|
49
|
-
|
|
50
|
+
class _BlocksparseTranspose(torch.autograd.Function):
|
|
50
51
|
|
|
51
|
-
|
|
52
|
+
@staticmethod
|
|
53
|
+
def build_lut(lut: dict, sparsity_layout: Tensor):
|
|
54
|
+
if lut is None:
|
|
55
|
+
lut = dict()
|
|
52
56
|
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
57
|
+
if "sparsity_layout_t" not in lut:
|
|
58
|
+
sparsity_layout_t = sparsity_layout.transpose(-1, -2).contiguous()
|
|
59
|
+
lut["sparsity_layout_t"] = sparsity_layout_t
|
|
56
60
|
|
|
61
|
+
if "sparsity_lut" not in lut:
|
|
62
|
+
sparsity_lut = torch.nonzero(lut["sparsity_layout_t"]).contiguous()
|
|
63
|
+
lut["sparsity_lut"] = sparsity_lut
|
|
57
64
|
|
|
58
|
-
|
|
65
|
+
if "sparsity_reverse_lut" not in lut:
|
|
66
|
+
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
|
67
|
+
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
68
|
+
(sparsity_layout_flat == 1) -
|
|
69
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
70
|
+
.reshape(sparsity_layout.size()).transpose(-1, -2).contiguous().reshape(-1))
|
|
71
|
+
lut["sparsity_reverse_lut"] = sparsity_reverse_lut
|
|
72
|
+
|
|
73
|
+
if "n_sparse_blocks" not in lut:
|
|
74
|
+
n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
|
|
75
|
+
lut["n_sparse_blocks"] = n_sparse_blocks
|
|
76
|
+
|
|
77
|
+
validate_contiguous(lut["sparsity_layout_t"], lut["sparsity_lut"], lut["sparsity_reverse_lut"])
|
|
78
|
+
|
|
79
|
+
return lut
|
|
59
80
|
|
|
60
81
|
@staticmethod
|
|
61
82
|
def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
62
83
|
sparsity_block_size: int,
|
|
63
84
|
n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
64
|
-
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
65
|
-
dtype=x.dtype, device=x.device)
|
|
66
|
-
|
|
67
|
-
x_b, x_r, x_c = x.size()
|
|
68
|
-
x_b_s, x_r_s, x_c_s = stride(x)
|
|
69
|
-
s_l_b, s_l_r, s_l_c = sparsity_layout_o.size()
|
|
70
|
-
s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout_o)
|
|
71
|
-
s_lut_r, s_lut_c = sparsity_lut.shape
|
|
72
|
-
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
73
|
-
o_b, o_r, o_c = output.size()
|
|
74
|
-
o_b_s, o_r_s, o_c_s = stride(output)
|
|
75
|
-
|
|
76
|
-
if triton_block_size is None:
|
|
77
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
78
|
-
|
|
79
|
-
triton_grid = lambda meta: [o_b,
|
|
80
|
-
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
81
|
-
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
82
|
-
|
|
83
|
-
(_BlocksparseTranspose.kernel_blocksparse_transpose[triton_grid]
|
|
84
|
-
(x,
|
|
85
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
86
|
-
s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
|
|
87
|
-
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
88
|
-
sparsity_reverse_lut,
|
|
89
|
-
output,
|
|
90
|
-
o_b, o_b_s,
|
|
91
|
-
triton_block_size))
|
|
92
|
-
|
|
93
|
-
# Save for backward pass
|
|
94
85
|
ctx.save_for_backward(sparsity_layout_o)
|
|
95
|
-
ctx.sparsity_block_size = sparsity_block_size
|
|
96
|
-
ctx.triton_block_size = triton_block_size
|
|
97
86
|
|
|
98
|
-
return
|
|
87
|
+
return flow_forward_pull(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
|
|
88
|
+
sparsity_block_size, n_sparse_blocks, triton_block_size)
|
|
99
89
|
|
|
100
90
|
@staticmethod
|
|
101
91
|
def backward(ctx, grad_output):
|
|
@@ -105,56 +95,3 @@ class _BlocksparseTranspose(torch.autograd.Function):
|
|
|
105
95
|
|
|
106
96
|
return transpose(grad_output, sparsity_layout, sparsity_block_size, triton_block_size)[
|
|
107
97
|
0], None, None, None, None, None, None
|
|
108
|
-
|
|
109
|
-
@staticmethod
|
|
110
|
-
@triton.jit
|
|
111
|
-
def kernel_blocksparse_transpose(x,
|
|
112
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
113
|
-
s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
|
|
114
|
-
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
115
|
-
r_lut,
|
|
116
|
-
o,
|
|
117
|
-
o_b, o_b_s,
|
|
118
|
-
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
119
|
-
# Get triton block indices
|
|
120
|
-
pid_blk = tl.program_id(axis=0)
|
|
121
|
-
pid_row = tl.program_id(axis=1)
|
|
122
|
-
pid_col = tl.program_id(axis=2)
|
|
123
|
-
|
|
124
|
-
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
125
|
-
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
126
|
-
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
127
|
-
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
128
|
-
|
|
129
|
-
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
130
|
-
spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
|
|
131
|
-
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
132
|
-
|
|
133
|
-
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
134
|
-
spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
|
|
135
|
-
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
136
|
-
|
|
137
|
-
# Get reverse sparsity index
|
|
138
|
-
rev_idx_spa_idx = (spa_bat * s_l_b_s +
|
|
139
|
-
spa_row * s_l_r_s +
|
|
140
|
-
spa_col * s_l_c_s)
|
|
141
|
-
rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_b * s_l_b_s)
|
|
142
|
-
rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
143
|
-
|
|
144
|
-
if rev_idx_spa == -1:
|
|
145
|
-
tl.device_assert(False)
|
|
146
|
-
return
|
|
147
|
-
|
|
148
|
-
blk_x_idx = (rev_idx_spa * x_b_s +
|
|
149
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
150
|
-
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
151
|
-
blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
|
|
152
|
-
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
153
|
-
|
|
154
|
-
blk_x_t = tl.trans(blk_x)
|
|
155
|
-
|
|
156
|
-
blk_o_idx = (pid_blk * o_b_s +
|
|
157
|
-
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
158
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
159
|
-
blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
|
|
160
|
-
tl.store(o + blk_o_idx, blk_x_t, mask=blk_o_msk)
|
blksprs/utils/tools.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import torch
|
|
1
2
|
from torch import Tensor, Size
|
|
2
3
|
|
|
3
4
|
|
|
@@ -20,4 +21,9 @@ def get_triton_block_size(sparsity_block_size: int, limit: int = 128):
|
|
|
20
21
|
|
|
21
22
|
|
|
22
23
|
def stride(x: Tensor):
|
|
23
|
-
|
|
24
|
+
if x.dim() == 2:
|
|
25
|
+
return x.size(1), 1
|
|
26
|
+
elif x.dim() == 3:
|
|
27
|
+
return x.size(1) * x.size(2), x.size(2), 1
|
|
28
|
+
else:
|
|
29
|
+
raise NotImplementedError
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.2
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.11
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
blksprs/__init__.py,sha256=AJYVfR40nOfE5F3waHPVSuajwYDcoGkiEQc8HhQbUBU,1721
|
|
2
|
+
blksprs/layouting/distribution_layout.py,sha256=xDGY5-J7uSD8oenlf8bEJ2amMiQG3NBf2klTTydbTJE,5140
|
|
3
|
+
blksprs/layouting/sparsity_layout.py,sha256=IVtHc_nN3ZM2y4GFcys70PqDWmWc7tkHlVGlToErANk,9894
|
|
4
|
+
blksprs/ops/conversion.py,sha256=QFtZ-nmY2JAWutheiO07vatXqz3eSZBP5Ym_U2Q1oWk,23299
|
|
5
|
+
blksprs/ops/distribution.py,sha256=nHTuE7Tq0Q404VN8bWNC2sEwmmdAtgZI6I7auRICdps,21749
|
|
6
|
+
blksprs/ops/flow.py,sha256=7tOXfTBKOAixYmDa_VXg7TwviLV5ZQMHQjtbyOjqA00,7879
|
|
7
|
+
blksprs/ops/matmul.py,sha256=eVj_BGj78bJkXYuvw4KctMfcfveQBt5OdYmeXzdpO88,12631
|
|
8
|
+
blksprs/ops/partitioning.py,sha256=qMv9w3yFWXwXIhIppdcJ_JMsoZ25HCH38vb6GRneoLM,10416
|
|
9
|
+
blksprs/ops/repeat.py,sha256=i824ijprfYpCaEjiSG5FTUZz7wMS5ksVy_-vY7ZX8Fg,9729
|
|
10
|
+
blksprs/ops/softmax.py,sha256=_mGkA2jHN8cXwtWXYswobEPyM7UC0JyzRszoE4ZYs7w,13063
|
|
11
|
+
blksprs/ops/transpose.py,sha256=O1XhGIGiVkhOSKcBD0HrYaeK6HmpvEEzLb7zJl7xsIM,4246
|
|
12
|
+
blksprs/ops/misc/broadcast_ops.py,sha256=pv0nssSDOdDbQFttpqUIs2ZXShqfm2RYCfJH-C5x3H0,5544
|
|
13
|
+
blksprs/ops/misc/row_wise.py,sha256=DnV5-xEJUbqZlK2fETwHiPQDUMwT-lkc0VUhBlnJ5Y0,17458
|
|
14
|
+
blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
|
|
15
|
+
blksprs/utils/blksprs_tensor.py,sha256=pfoz59aJixj_fIoFx76ySiygwRQUemmgjMKepZ2c4j0,244
|
|
16
|
+
blksprs/utils/layout_utils.py,sha256=49ZdPS_gMn_IrWty3FARbi2rda5a8g5DmAEL8LOrC30,670
|
|
17
|
+
blksprs/utils/processing.py,sha256=WLuMJQ8v-YovXwcDjhlDn3N31WMZXrtyeeyKSgq_zn4,3642
|
|
18
|
+
blksprs/utils/tools.py,sha256=k2OfEplbQiAwVjP84zZf7SNB8FzvMtOFBL9sC98OCbI,683
|
|
19
|
+
blksprs/utils/validation.py,sha256=CbxBbeQWJo8wox5eMoVzaTlP9FVBwt3-gxUOmi3EUgw,4213
|
|
20
|
+
blksprs-1.11.dist-info/METADATA,sha256=NUEiHexWiFNbMxQI2TUEzMw9iGBhxqflhWr2xCgOw28,9105
|
|
21
|
+
blksprs-1.11.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
|
|
22
|
+
blksprs-1.11.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
|
|
23
|
+
blksprs-1.11.dist-info/RECORD,,
|
blksprs/ops/misc/exp.py
DELETED
|
@@ -1,104 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
import triton
|
|
3
|
-
from torch import Tensor
|
|
4
|
-
from triton import language as tl
|
|
5
|
-
|
|
6
|
-
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
7
|
-
from blksprs.utils.tools import get_triton_block_size, stride
|
|
8
|
-
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
9
|
-
validate_sparsity_block_size, validate_triton_block_size
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
def exp(x: BlksprsTensor, sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
|
|
13
|
-
"""Applies the element-wise exponential function to a block-sparse tensor.
|
|
14
|
-
|
|
15
|
-
Note:
|
|
16
|
-
This operation does not consider sparse blocks, i.e., these will not be set to ``e^0``.
|
|
17
|
-
Consider this when converting back to tensors in regular form.
|
|
18
|
-
|
|
19
|
-
Args:
|
|
20
|
-
x (BlksprsTensor): A block-sparse tensor in compressed form.
|
|
21
|
-
sparsity_block_size (int): The size of the sparsity blocks.
|
|
22
|
-
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
23
|
-
|
|
24
|
-
Returns:
|
|
25
|
-
BlksprsTensor: The exponential function applied to all elements of the input tensor as a block-sparse tensor in
|
|
26
|
-
compressed form.
|
|
27
|
-
|
|
28
|
-
"""
|
|
29
|
-
x = x.contiguous()
|
|
30
|
-
|
|
31
|
-
validate_dimensions(x)
|
|
32
|
-
validate_contiguous(x)
|
|
33
|
-
validate_device(x)
|
|
34
|
-
validate_sparsity_block_size(sparsity_block_size, x)
|
|
35
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
36
|
-
|
|
37
|
-
return BlksprsTensor(_BlocksparseExp.apply(x, sparsity_block_size, triton_block_size))
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
class _BlocksparseExp(torch.autograd.Function):
|
|
41
|
-
|
|
42
|
-
@staticmethod
|
|
43
|
-
def forward(ctx, x: Tensor, sparsity_block_size: int, triton_block_size: int) -> Tensor:
|
|
44
|
-
output = torch.empty_like(x)
|
|
45
|
-
|
|
46
|
-
x_b, x_r, x_c = x.shape
|
|
47
|
-
x_b_s, x_r_s, x_c_s = stride(x)
|
|
48
|
-
o_b, o_r, o_c = output.shape
|
|
49
|
-
o_b_s, o_r_s, o_c_s = stride(output)
|
|
50
|
-
|
|
51
|
-
if triton_block_size is None:
|
|
52
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
53
|
-
|
|
54
|
-
triton_grid = lambda meta: [o_b,
|
|
55
|
-
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
56
|
-
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
57
|
-
|
|
58
|
-
(_BlocksparseExp.kernel_blocksparse_exp[triton_grid]
|
|
59
|
-
(x,
|
|
60
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
61
|
-
output,
|
|
62
|
-
o_b, o_b_s, o_r_s, o_c_s,
|
|
63
|
-
triton_block_size))
|
|
64
|
-
|
|
65
|
-
ctx.save_for_backward(output)
|
|
66
|
-
|
|
67
|
-
return output
|
|
68
|
-
|
|
69
|
-
@staticmethod
|
|
70
|
-
def backward(ctx, grad_output):
|
|
71
|
-
o = ctx.saved_tensors[0]
|
|
72
|
-
|
|
73
|
-
grad_x = torch.mul(grad_output, o)
|
|
74
|
-
|
|
75
|
-
return grad_x, None, None
|
|
76
|
-
|
|
77
|
-
@staticmethod
|
|
78
|
-
@triton.jit
|
|
79
|
-
def kernel_blocksparse_exp(x,
|
|
80
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
81
|
-
o,
|
|
82
|
-
o_b, o_b_s, o_r_s, o_c_s,
|
|
83
|
-
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
84
|
-
# Get triton block indices
|
|
85
|
-
pid_blk = tl.program_id(axis=0)
|
|
86
|
-
pid_row = tl.program_id(axis=1)
|
|
87
|
-
pid_col = tl.program_id(axis=2)
|
|
88
|
-
|
|
89
|
-
# Load block
|
|
90
|
-
blk_x_idx = ((pid_blk * x_b_s) +
|
|
91
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
92
|
-
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
93
|
-
blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
|
|
94
|
-
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
95
|
-
|
|
96
|
-
# Compute exp
|
|
97
|
-
buf = tl.exp(blk_x)
|
|
98
|
-
|
|
99
|
-
# Store block
|
|
100
|
-
blk_o_idx = ((pid_blk * o_b_s) +
|
|
101
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
102
|
-
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
103
|
-
blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
|
|
104
|
-
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
blksprs-1.10.1.dist-info/RECORD
DELETED
|
@@ -1,24 +0,0 @@
|
|
|
1
|
-
blksprs/__init__.py,sha256=wnpk-20jXq7xV0xa-WpHfPQuauI2gEZz9sH-0blKxP0,1766
|
|
2
|
-
blksprs/layouting/distribution_layout.py,sha256=xDGY5-J7uSD8oenlf8bEJ2amMiQG3NBf2klTTydbTJE,5140
|
|
3
|
-
blksprs/layouting/sparsity_layout.py,sha256=IVtHc_nN3ZM2y4GFcys70PqDWmWc7tkHlVGlToErANk,9894
|
|
4
|
-
blksprs/ops/conversion.py,sha256=NK5uXMepPJ9yYh0vnxKwx5_Ffj_bAvhqPVogf_7PY0g,22248
|
|
5
|
-
blksprs/ops/distribution.py,sha256=qK5t5XgQSJxXPced8RohprqCtUMMTaEP2pFm3KU1c8o,20267
|
|
6
|
-
blksprs/ops/flow.py,sha256=SWHDQ5zx0cjnPR0CcAcRNZdSusSAHSU840SwDNUr24g,6437
|
|
7
|
-
blksprs/ops/matmul.py,sha256=LAQyPNwWVmBMRnAex3msLSPD_aG5SblLCMiutJWqvus,11632
|
|
8
|
-
blksprs/ops/partitioning.py,sha256=ugKnpvH36ND7qeJQp56M74qqfACkzcTVuXebzw__28Y,8286
|
|
9
|
-
blksprs/ops/repeat.py,sha256=RCa-dITomA5v12K5Oxa5_ReA361zS7WHPNNHxSp9PGw,8578
|
|
10
|
-
blksprs/ops/softmax.py,sha256=i8NJhvPRYya94AzpN6qiki6_G9KfDrtPifhWd7wbYzk,12496
|
|
11
|
-
blksprs/ops/transpose.py,sha256=oAtUu7QzQnNAH3lvRs_MIvIKpBu9h74f9Sk07AxKnDM,6991
|
|
12
|
-
blksprs/ops/misc/broadcast_ops.py,sha256=pv0nssSDOdDbQFttpqUIs2ZXShqfm2RYCfJH-C5x3H0,5544
|
|
13
|
-
blksprs/ops/misc/exp.py,sha256=ygfw7oD6ALdPwNQX_HelKgO8I3-LCgIXH_x0gWzkUN8,3840
|
|
14
|
-
blksprs/ops/misc/row_wise.py,sha256=DnV5-xEJUbqZlK2fETwHiPQDUMwT-lkc0VUhBlnJ5Y0,17458
|
|
15
|
-
blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
|
|
16
|
-
blksprs/utils/blksprs_tensor.py,sha256=pfoz59aJixj_fIoFx76ySiygwRQUemmgjMKepZ2c4j0,244
|
|
17
|
-
blksprs/utils/layout_utils.py,sha256=49ZdPS_gMn_IrWty3FARbi2rda5a8g5DmAEL8LOrC30,670
|
|
18
|
-
blksprs/utils/processing.py,sha256=WLuMJQ8v-YovXwcDjhlDn3N31WMZXrtyeeyKSgq_zn4,3642
|
|
19
|
-
blksprs/utils/tools.py,sha256=r7Y4C37vfSWUyQTGwa8NyRqgovmsq9hMufkenqYHOxo,539
|
|
20
|
-
blksprs/utils/validation.py,sha256=CbxBbeQWJo8wox5eMoVzaTlP9FVBwt3-gxUOmi3EUgw,4213
|
|
21
|
-
blksprs-1.10.1.dist-info/METADATA,sha256=5in6lYCZo1bd8urYR0wkTxIiTTAIAANukLpKeZfGasY,9107
|
|
22
|
-
blksprs-1.10.1.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
|
23
|
-
blksprs-1.10.1.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
|
|
24
|
-
blksprs-1.10.1.dist-info/RECORD,,
|
|
File without changes
|