blksprs 0.2b4__py3-none-any.whl → 1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/layouting/distribution_layout.py +114 -0
- blksprs/layouting/sparsity_layout.py +78 -0
- blksprs/misc/broadcast_addition.py +132 -0
- blksprs/ops/conversion.py +256 -0
- blksprs/ops/distribution.py +362 -0
- blksprs/ops/exp.py +101 -0
- blksprs/ops/matmul.py +221 -0
- blksprs/ops/row_wise_sum.py +231 -0
- blksprs/ops/softmax.py +263 -0
- blksprs/ops/transpose.py +154 -0
- blksprs/utils/tools.py +20 -0
- blksprs/utils/validation.py +97 -0
- blksprs-1.1.dist-info/METADATA +164 -0
- blksprs-1.1.dist-info/RECORD +17 -0
- {blksprs-0.2b4.dist-info → blksprs-1.1.dist-info}/WHEEL +1 -1
- blksprs/ops/blocksparse.py +0 -589
- blksprs-0.2b4.dist-info/METADATA +0 -26
- blksprs-0.2b4.dist-info/RECORD +0 -6
- {blksprs-0.2b4.dist-info → blksprs-1.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
from triton import language as tl
|
|
5
|
+
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size
|
|
7
|
+
from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
|
|
8
|
+
validate_contiguous
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
|
|
12
|
+
size_target: torch.Size,
|
|
13
|
+
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
14
|
+
"""Builds the sparsity layout of either the source of a gather or the target of a scatter operation.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
indices (Tensor): The block-sparse indices tensor in compressed form used for the gather or scatter operation.
|
|
18
|
+
sparsity_layout_indices (Tensor): The sparsity layout of the indices block-sparse tensor.
|
|
19
|
+
size_target (torch.Size): The size of the block-sparse target tensor in regular form.
|
|
20
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
21
|
+
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
Tensor: The sparsity layout of the source or target tensor.
|
|
25
|
+
|
|
26
|
+
"""
|
|
27
|
+
validate_dimensions(indices)
|
|
28
|
+
validate_contiguous(indices)
|
|
29
|
+
validate_device(indices)
|
|
30
|
+
|
|
31
|
+
sparsity_lut_i = torch.nonzero(sparsity_layout_indices).contiguous()
|
|
32
|
+
|
|
33
|
+
output = torch.zeros(size_target[0], size_target[1] // sparsity_block_size, size_target[2] // sparsity_block_size,
|
|
34
|
+
device=indices.device, dtype=torch.int32)
|
|
35
|
+
|
|
36
|
+
i_b, i_r, i_c = indices.size()
|
|
37
|
+
i_b_s, i_r_s, i_c_s = indices.stride()
|
|
38
|
+
s_l_i_b, s_l_i_r, s_l_i_c = sparsity_layout_indices.size()
|
|
39
|
+
s_l_i_b_s, s_l_i_r_s, s_l_i_c_s = sparsity_layout_indices.stride()
|
|
40
|
+
s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
|
|
41
|
+
s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
|
|
42
|
+
o_b, o_r, o_c = output.size()
|
|
43
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
44
|
+
|
|
45
|
+
if triton_block_size is None:
|
|
46
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
47
|
+
|
|
48
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
49
|
+
|
|
50
|
+
triton_grid = lambda meta: [i_b,
|
|
51
|
+
triton.cdiv(i_r, meta["TRITON_BLOCK_SIZE"]),
|
|
52
|
+
triton.cdiv(i_c, meta["TRITON_BLOCK_SIZE"])]
|
|
53
|
+
|
|
54
|
+
(kernel_distribution_layout[triton_grid]
|
|
55
|
+
(indices,
|
|
56
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
57
|
+
sparsity_layout_indices,
|
|
58
|
+
s_l_i_b, s_l_i_b_s, s_l_i_r, s_l_i_r_s, s_l_i_c, s_l_i_c_s,
|
|
59
|
+
sparsity_lut_i,
|
|
60
|
+
s_lut_i_r, s_lut_i_r_s, s_lut_i_c, s_lut_i_c_s,
|
|
61
|
+
output,
|
|
62
|
+
o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,
|
|
63
|
+
sparsity_block_size,
|
|
64
|
+
triton_block_size))
|
|
65
|
+
|
|
66
|
+
return output
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@triton.jit
|
|
70
|
+
def kernel_distribution_layout(i,
|
|
71
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
72
|
+
s_l_i,
|
|
73
|
+
s_l_i_b, s_l_i_b_s, s_l_i_r, s_l_i_r_s, s_l_i_c, s_l_i_c_s,
|
|
74
|
+
s_lut_i,
|
|
75
|
+
s_lut_i_r, s_lut_i_r_s, s_lut_i_c, s_lut_i_c_s,
|
|
76
|
+
o,
|
|
77
|
+
o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,
|
|
78
|
+
sparsity_block_size,
|
|
79
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
80
|
+
# Get triton block indices
|
|
81
|
+
pid_blk = tl.program_id(axis=0)
|
|
82
|
+
pid_row = tl.program_id(axis=1)
|
|
83
|
+
pid_col = tl.program_id(axis=2)
|
|
84
|
+
|
|
85
|
+
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
86
|
+
spa_bat_i_idx = (pid_blk * s_lut_i_r_s + 0 * s_lut_i_c_s)
|
|
87
|
+
spa_bat_i_msk = (spa_bat_i_idx < s_lut_i_r * s_lut_i_r_s)
|
|
88
|
+
spa_bat_i = tl.load(s_lut_i + spa_bat_i_idx, mask=spa_bat_i_msk)
|
|
89
|
+
|
|
90
|
+
spa_row_i_idx = (pid_blk * s_lut_i_r_s + 1 * s_lut_i_c_s)
|
|
91
|
+
spa_row_i_msk = (spa_row_i_idx < s_lut_i_r * s_lut_i_r_s)
|
|
92
|
+
spa_row_i = tl.load(s_lut_i + spa_row_i_idx, mask=spa_row_i_msk)
|
|
93
|
+
|
|
94
|
+
blk_i_idx = (pid_blk * i_b_s +
|
|
95
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
96
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
97
|
+
blk_i_msk = (blk_i_idx < i_b * i_b_s)
|
|
98
|
+
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk)
|
|
99
|
+
|
|
100
|
+
blk_i = blk_i // sparsity_block_size
|
|
101
|
+
blk_v = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), 1, dtype=tl.int32)
|
|
102
|
+
|
|
103
|
+
blk_o_idx = ((spa_bat_i * o_b_s) +
|
|
104
|
+
(spa_row_i * o_r_s) +
|
|
105
|
+
(blk_i * o_c_s))
|
|
106
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
107
|
+
tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
|
|
108
|
+
|
|
109
|
+
# if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
|
|
110
|
+
# blk_o_idx = (pid_bat * o_b_s +
|
|
111
|
+
# (((pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_r_s +
|
|
112
|
+
# ((pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_c_s))
|
|
113
|
+
# blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
114
|
+
# tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
from triton import language as tl
|
|
5
|
+
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size
|
|
7
|
+
from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
|
|
8
|
+
validate_contiguous
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def build_sparsity_layout(x: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
12
|
+
"""Builds the sparsity layout of a dense tensor covering its sparse blocks.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
x (Tensor): A block-sparse (or dense) tensor in regular form.
|
|
16
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
17
|
+
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
Tensor: The sparsity layout of the input block-sparse (or dense) tensor.
|
|
21
|
+
|
|
22
|
+
"""
|
|
23
|
+
validate_dimensions(x)
|
|
24
|
+
validate_contiguous(x)
|
|
25
|
+
validate_device(x)
|
|
26
|
+
|
|
27
|
+
output = torch.zeros(x.size(0), x.size(1) // sparsity_block_size, x.size(2) // sparsity_block_size,
|
|
28
|
+
device=x.device, dtype=torch.int32)
|
|
29
|
+
|
|
30
|
+
x_b, x_r, x_c = x.size()
|
|
31
|
+
x_b_s, x_r_s, x_c_s = x.stride()
|
|
32
|
+
o_b, o_r, o_c = output.size()
|
|
33
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
34
|
+
|
|
35
|
+
if triton_block_size is None:
|
|
36
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
37
|
+
|
|
38
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
39
|
+
|
|
40
|
+
triton_grid = lambda meta: [x_b,
|
|
41
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
42
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
43
|
+
|
|
44
|
+
(kernel_sparsity_layout[triton_grid]
|
|
45
|
+
(x,
|
|
46
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
47
|
+
output,
|
|
48
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
49
|
+
sparsity_block_size,
|
|
50
|
+
triton_block_size))
|
|
51
|
+
|
|
52
|
+
return output
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@triton.jit
|
|
56
|
+
def kernel_sparsity_layout(x,
|
|
57
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
58
|
+
o,
|
|
59
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
60
|
+
sparsity_block_size,
|
|
61
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
62
|
+
# Get triton block indices
|
|
63
|
+
pid_bat = tl.program_id(axis=0)
|
|
64
|
+
pid_row = tl.program_id(axis=1)
|
|
65
|
+
pid_col = tl.program_id(axis=2)
|
|
66
|
+
|
|
67
|
+
blk_x_idx = (pid_bat * x_b_s +
|
|
68
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
69
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
70
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
71
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
72
|
+
|
|
73
|
+
if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
|
|
74
|
+
blk_o_idx = (pid_bat * o_b_s +
|
|
75
|
+
(((pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_r_s +
|
|
76
|
+
((pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_c_s))
|
|
77
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
78
|
+
tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
from triton import language as tl
|
|
5
|
+
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size
|
|
7
|
+
from blksprs.utils.validation import validate_contiguous, validate_device, \
|
|
8
|
+
validate_sparsity_block_size, validate_triton_block_size
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def broadcast_addition(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
12
|
+
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
13
|
+
"""Performs a broadcast and subsequent addition of two dense tensors x and y. Returns a block-sparse tensor in
|
|
14
|
+
compressed form.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
x (Tensor): A dense input tensor.
|
|
18
|
+
y (Tensor): A dense input tensor.
|
|
19
|
+
sparsity_layout_output (Tensor): The sparsity layout of the output tensor.
|
|
20
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
21
|
+
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
Tensor: The result of the operation as a block-sparse tensor in compressed form. Each element o(i, j) of the
|
|
25
|
+
output tensor corresponds to x(i) + y(j).
|
|
26
|
+
|
|
27
|
+
"""
|
|
28
|
+
validate_device(x, y)
|
|
29
|
+
validate_contiguous(x, y)
|
|
30
|
+
if x.size(-1) != y.size(-1):
|
|
31
|
+
raise ValueError("Dimensions of tensors must match")
|
|
32
|
+
validate_sparsity_block_size(sparsity_block_size)
|
|
33
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
34
|
+
|
|
35
|
+
sparsity_lut_o = torch.nonzero(sparsity_layout_output).contiguous()
|
|
36
|
+
|
|
37
|
+
n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
|
|
38
|
+
|
|
39
|
+
validate_contiguous(sparsity_layout_output, sparsity_lut_o)
|
|
40
|
+
|
|
41
|
+
output = torch.zeros(n_sparse_blocks, sparsity_block_size, sparsity_block_size, device=x.device)
|
|
42
|
+
|
|
43
|
+
x_b, x_c = x.size()
|
|
44
|
+
x_b_s, x_c_s = x.stride()
|
|
45
|
+
y_b, y_c = y.size()
|
|
46
|
+
y_b_s, y_c_s = y.stride()
|
|
47
|
+
o_b, o_r, o_c = output.size()
|
|
48
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
49
|
+
s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
|
|
50
|
+
s_lut_o_r_s, s_lut_o_c_s = sparsity_lut_o.stride()
|
|
51
|
+
|
|
52
|
+
if triton_block_size is None:
|
|
53
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
54
|
+
|
|
55
|
+
triton_grid = lambda meta: [o_b,
|
|
56
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
57
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
58
|
+
|
|
59
|
+
(kernel_broadcast_addition[triton_grid]
|
|
60
|
+
(x,
|
|
61
|
+
x_b, x_b_s, x_c_s,
|
|
62
|
+
y,
|
|
63
|
+
y_b, y_b_s, y_c_s,
|
|
64
|
+
output,
|
|
65
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
66
|
+
sparsity_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
67
|
+
sparsity_block_size,
|
|
68
|
+
triton_block_size))
|
|
69
|
+
|
|
70
|
+
return output
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def broadcast_subtraction(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
74
|
+
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
75
|
+
"""Wrapper for ``broadcast_addition`` with negated y.
|
|
76
|
+
|
|
77
|
+
"""
|
|
78
|
+
return broadcast_addition(x, torch.neg(y), sparsity_layout_output, sparsity_block_size, triton_block_size)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@triton.jit
|
|
82
|
+
def kernel_broadcast_addition(x,
|
|
83
|
+
x_b, x_b_s, x_c_s,
|
|
84
|
+
y,
|
|
85
|
+
y_b, y_b_s, y_c_s,
|
|
86
|
+
o,
|
|
87
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
88
|
+
s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
89
|
+
sparsity_block_size,
|
|
90
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
91
|
+
# Get triton block indices
|
|
92
|
+
pid_blk = tl.program_id(axis=0)
|
|
93
|
+
pid_row = tl.program_id(axis=1)
|
|
94
|
+
pid_col = tl.program_id(axis=2)
|
|
95
|
+
|
|
96
|
+
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
97
|
+
spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
|
|
98
|
+
spa_bat_o_msk = (spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
99
|
+
spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
|
|
100
|
+
|
|
101
|
+
spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
|
|
102
|
+
spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
103
|
+
spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
|
|
104
|
+
|
|
105
|
+
spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
|
|
106
|
+
spa_col_o_msk = (spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
107
|
+
spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
|
|
108
|
+
|
|
109
|
+
# Load x block
|
|
110
|
+
blk_x_idx = (spa_bat_o * x_b_s +
|
|
111
|
+
((spa_row_o * sparsity_block_size + pid_row * TRITON_BLOCK_SIZE +
|
|
112
|
+
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
113
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
114
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
115
|
+
|
|
116
|
+
# Load y block
|
|
117
|
+
blk_y_idx = (spa_bat_o * y_b_s +
|
|
118
|
+
((spa_col_o * sparsity_block_size + pid_col * TRITON_BLOCK_SIZE +
|
|
119
|
+
tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
|
|
120
|
+
blk_y_msk = (blk_y_idx < y_b * y_b_s)
|
|
121
|
+
blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
|
|
122
|
+
|
|
123
|
+
# Compute sum
|
|
124
|
+
blk_x, blk_y = tl.broadcast(tl.trans(blk_x), blk_y)
|
|
125
|
+
buf = blk_x + blk_y
|
|
126
|
+
|
|
127
|
+
# Store result
|
|
128
|
+
blk_o_idx = ((pid_blk * o_b_s) +
|
|
129
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
130
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
131
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
132
|
+
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
from triton import language as tl
|
|
5
|
+
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size
|
|
7
|
+
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
8
|
+
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def to_dense(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
|
|
12
|
+
triton_block_size: int = None) -> Tensor:
|
|
13
|
+
"""Converts a block-sparse tensor in compressed form to a block-sparse tensor in regular form based on the given
|
|
14
|
+
sparsity layout.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
x (Tensor): A block-sparse tensor in compressed form.
|
|
18
|
+
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
19
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
20
|
+
fill_value (float): The value to fill the resulting dense tensor with where the block-sparse tensor is not
|
|
21
|
+
present (default ``0``).
|
|
22
|
+
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
Tensor: The block-sparse tensor converted to regular form.
|
|
26
|
+
|
|
27
|
+
"""
|
|
28
|
+
validate_dimensions(x)
|
|
29
|
+
validate_contiguous(x, sparsity_layout)
|
|
30
|
+
validate_device(x)
|
|
31
|
+
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
32
|
+
validate_sparsity_block_size(sparsity_block_size, x)
|
|
33
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
34
|
+
|
|
35
|
+
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
|
36
|
+
sparsity_reverse_lut = ((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
37
|
+
(sparsity_layout_flat == 1) -
|
|
38
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
39
|
+
|
|
40
|
+
validate_contiguous(sparsity_reverse_lut)
|
|
41
|
+
|
|
42
|
+
return _BlocksparseToDense.apply(x,
|
|
43
|
+
sparsity_layout, sparsity_reverse_lut,
|
|
44
|
+
sparsity_block_size, fill_value,
|
|
45
|
+
triton_block_size)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class _BlocksparseToDense(torch.autograd.Function):
|
|
49
|
+
|
|
50
|
+
@staticmethod
|
|
51
|
+
def forward(ctx, x: Tensor,
|
|
52
|
+
sparsity_layout: Tensor, sparsity_reverse_lut: Tensor,
|
|
53
|
+
sparsity_block_size: int, fill_value: float,
|
|
54
|
+
triton_block_size: int) -> Tensor:
|
|
55
|
+
output = torch.full(size=(sparsity_layout.size(0), sparsity_layout.size(1) * sparsity_block_size,
|
|
56
|
+
sparsity_layout.size(2) * sparsity_block_size), fill_value=fill_value,
|
|
57
|
+
dtype=x.dtype, device=x.device)
|
|
58
|
+
|
|
59
|
+
x_b, x_r, x_c = x.shape
|
|
60
|
+
x_b_s, x_r_s, x_c_s = x.stride()
|
|
61
|
+
s_l_b, s_l_r, s_l_c = sparsity_layout.size()
|
|
62
|
+
s_l_b_s, s_l_r_s, s_l_c_s = sparsity_layout.stride()
|
|
63
|
+
o_b, o_r, o_c = output.size()
|
|
64
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
65
|
+
|
|
66
|
+
if triton_block_size is None:
|
|
67
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
68
|
+
|
|
69
|
+
triton_grid = lambda meta: [o_b,
|
|
70
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
71
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
72
|
+
|
|
73
|
+
(_BlocksparseToDense.kernel_blocksparse_to_dense[triton_grid]
|
|
74
|
+
(x,
|
|
75
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
76
|
+
s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
|
|
77
|
+
sparsity_reverse_lut,
|
|
78
|
+
output,
|
|
79
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
80
|
+
sparsity_block_size,
|
|
81
|
+
triton_block_size))
|
|
82
|
+
|
|
83
|
+
ctx.save_for_backward(sparsity_layout)
|
|
84
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
85
|
+
ctx.triton_block_size = triton_block_size
|
|
86
|
+
|
|
87
|
+
return output
|
|
88
|
+
|
|
89
|
+
@staticmethod
|
|
90
|
+
def backward(ctx, grad_output):
|
|
91
|
+
sparsity_layout = ctx.saved_tensors[0]
|
|
92
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
93
|
+
triton_block_size = ctx.triton_block_size
|
|
94
|
+
|
|
95
|
+
return to_sparse(grad_output, sparsity_layout, sparsity_block_size,
|
|
96
|
+
triton_block_size), None, None, None, None, None
|
|
97
|
+
|
|
98
|
+
@staticmethod
|
|
99
|
+
@triton.jit
|
|
100
|
+
def kernel_blocksparse_to_dense(x,
|
|
101
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
102
|
+
s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
|
|
103
|
+
sparsity_reverse_lut,
|
|
104
|
+
o,
|
|
105
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
106
|
+
sparsity_block_size,
|
|
107
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
108
|
+
# Get triton block indices
|
|
109
|
+
pid_blk = tl.program_id(axis=0)
|
|
110
|
+
pid_row = tl.program_id(axis=1)
|
|
111
|
+
pid_col = tl.program_id(axis=2)
|
|
112
|
+
|
|
113
|
+
# Get sparsity index of current block
|
|
114
|
+
spa_row = (pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size
|
|
115
|
+
spa_col = (pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size
|
|
116
|
+
|
|
117
|
+
# Get reverse sparsity index for current block
|
|
118
|
+
rev_idx_spa_idx = (pid_blk * s_l_b_s + spa_row * s_l_r_s + spa_col * s_l_c_s)
|
|
119
|
+
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_b * s_l_b_s)
|
|
120
|
+
rev_idx_spa = tl.load(sparsity_reverse_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
121
|
+
|
|
122
|
+
# If block is present commence operations
|
|
123
|
+
if rev_idx_spa >= 0:
|
|
124
|
+
blk_idx = (rev_idx_spa * x_b_s +
|
|
125
|
+
(((pid_row % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
|
|
126
|
+
tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
127
|
+
(((pid_col % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
|
|
128
|
+
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
129
|
+
blk_msk = (blk_idx < x_b * x_b_s)
|
|
130
|
+
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
131
|
+
|
|
132
|
+
o_idx = (pid_blk * o_b_s +
|
|
133
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
134
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
135
|
+
o_msk = (o_idx < o_b * o_b_s)
|
|
136
|
+
tl.store(o + o_idx, blk, o_msk)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
140
|
+
"""Converts a block-sparse tensor in regular form to a block-sparse tensor in compressed form based on the given
|
|
141
|
+
sparsity layout.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
x (Tensor): A block-sparse tensor in regular form.
|
|
145
|
+
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
146
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
147
|
+
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
Tensor: The block-sparse tensor converted to compressed form.
|
|
151
|
+
|
|
152
|
+
"""
|
|
153
|
+
validate_dimensions(x)
|
|
154
|
+
validate_contiguous(x)
|
|
155
|
+
validate_device(x)
|
|
156
|
+
validate_sparsity_block_size(sparsity_block_size, x)
|
|
157
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
158
|
+
|
|
159
|
+
sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
|
|
160
|
+
n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
|
|
161
|
+
|
|
162
|
+
validate_contiguous(sparsity_layout, sparsity_lut)
|
|
163
|
+
|
|
164
|
+
return _BlocksparseToSparse.apply(x,
|
|
165
|
+
sparsity_layout, sparsity_lut,
|
|
166
|
+
sparsity_block_size, n_sparse_blocks,
|
|
167
|
+
triton_block_size)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class _BlocksparseToSparse(torch.autograd.Function):
|
|
171
|
+
|
|
172
|
+
@staticmethod
|
|
173
|
+
def forward(ctx, x: Tensor,
|
|
174
|
+
sparsity_layout: Tensor, sparsity_lut: Tensor,
|
|
175
|
+
sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
176
|
+
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size), dtype=x.dtype,
|
|
177
|
+
device=x.device)
|
|
178
|
+
|
|
179
|
+
x_b, x_r, x_c = x.size()
|
|
180
|
+
x_b_s, x_r_s, x_c_s = x.stride()
|
|
181
|
+
o_b, o_r, o_c = output.size()
|
|
182
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
183
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
184
|
+
s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
|
|
185
|
+
|
|
186
|
+
if triton_block_size is None:
|
|
187
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
188
|
+
|
|
189
|
+
triton_grid = lambda meta: [o_b,
|
|
190
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
191
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
192
|
+
|
|
193
|
+
(_BlocksparseToSparse.kernel_blocksparse_to_sparse[triton_grid]
|
|
194
|
+
(x, x_b, x_b_s, x_r_s, x_c_s,
|
|
195
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
196
|
+
output, o_b_s, o_r_s, o_c_s,
|
|
197
|
+
sparsity_block_size,
|
|
198
|
+
triton_block_size))
|
|
199
|
+
|
|
200
|
+
ctx.save_for_backward(sparsity_layout)
|
|
201
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
202
|
+
ctx.triton_block_size = triton_block_size
|
|
203
|
+
|
|
204
|
+
return output
|
|
205
|
+
|
|
206
|
+
@staticmethod
|
|
207
|
+
def backward(ctx, grad_output):
|
|
208
|
+
sparsity_layout = ctx.saved_tensors[0]
|
|
209
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
210
|
+
triton_block_size = ctx.triton_block_size
|
|
211
|
+
|
|
212
|
+
return to_dense(grad_output, sparsity_layout, sparsity_block_size,
|
|
213
|
+
triton_block_size=triton_block_size), None, None, None, None, None
|
|
214
|
+
|
|
215
|
+
@staticmethod
|
|
216
|
+
@triton.jit
|
|
217
|
+
def kernel_blocksparse_to_sparse(x,
|
|
218
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
219
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
220
|
+
o,
|
|
221
|
+
o_b_s, o_r_s, o_c_s,
|
|
222
|
+
sparsity_block_size,
|
|
223
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
224
|
+
# Get triton block indices
|
|
225
|
+
pid_blk = tl.program_id(axis=0)
|
|
226
|
+
pid_row = tl.program_id(axis=1)
|
|
227
|
+
pid_col = tl.program_id(axis=2)
|
|
228
|
+
|
|
229
|
+
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
230
|
+
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
231
|
+
spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
232
|
+
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
233
|
+
|
|
234
|
+
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
235
|
+
spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
|
|
236
|
+
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
237
|
+
|
|
238
|
+
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
239
|
+
spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
|
|
240
|
+
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
241
|
+
|
|
242
|
+
# Load block from dense tensor
|
|
243
|
+
blk_d_idx = (spa_bat * x_b_s +
|
|
244
|
+
((spa_row * sparsity_block_size + pid_row * TRITON_BLOCK_SIZE +
|
|
245
|
+
tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
246
|
+
((spa_col * sparsity_block_size + pid_col * TRITON_BLOCK_SIZE +
|
|
247
|
+
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
248
|
+
blk_d_msk = (blk_d_idx < x_b * x_b_s)
|
|
249
|
+
blk_d = tl.load(x + blk_d_idx, mask=blk_d_msk)
|
|
250
|
+
|
|
251
|
+
# Store block in sparse tensor
|
|
252
|
+
blk_o_idx = ((pid_blk * o_b_s) +
|
|
253
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
254
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
|
|
255
|
+
blk_o_msk = (blk_o_idx < (pid_blk + 1) * o_b_s)
|
|
256
|
+
tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
|