blksprs 1.2.1__py3-none-any.whl → 1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +18 -0
- blksprs/misc/{broadcast_addition.py → broadcast_ops.py} +6 -6
- blksprs/misc/repeat_interleave.py +130 -0
- blksprs/misc/row_wise.py +386 -0
- blksprs/ops/softmax.py +11 -13
- blksprs/ops/transpose.py +1 -1
- blksprs/utils/tools.py +1 -1
- {blksprs-1.2.1.dist-info → blksprs-1.4.dist-info}/METADATA +28 -25
- blksprs-1.4.dist-info/RECORD +19 -0
- blksprs/ops/row_wise_sum.py +0 -231
- blksprs-1.2.1.dist-info/RECORD +0 -17
- {blksprs-1.2.1.dist-info → blksprs-1.4.dist-info}/WHEEL +0 -0
- {blksprs-1.2.1.dist-info → blksprs-1.4.dist-info}/top_level.txt +0 -0
blksprs/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from blksprs.ops.conversion import to_dense, to_sparse
|
|
2
|
+
from blksprs.ops.distribution import gather, scatter, scatter_reduce
|
|
3
|
+
from blksprs.ops.exp import exp
|
|
4
|
+
from blksprs.ops.matmul import matmul
|
|
5
|
+
from blksprs.ops.softmax import softmax
|
|
6
|
+
from blksprs.ops.transpose import transpose
|
|
7
|
+
|
|
8
|
+
class layout:
|
|
9
|
+
from blksprs.layouting.distribution_layout import build_distribution_layout
|
|
10
|
+
from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption
|
|
11
|
+
|
|
12
|
+
class misc:
|
|
13
|
+
from blksprs.misc.broadcast_ops import broadcast_add, broadcast_sub
|
|
14
|
+
from blksprs.misc.repeat_interleave import repeat_interleave
|
|
15
|
+
from blksprs.misc.row_wise import row_wise_sum, row_wise_max, row_wise_add, row_wise_sub
|
|
16
|
+
|
|
17
|
+
class util:
|
|
18
|
+
from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse
|
|
@@ -8,8 +8,8 @@ from blksprs.utils.validation import validate_contiguous, validate_device, \
|
|
|
8
8
|
validate_sparsity_block_size, validate_triton_block_size
|
|
9
9
|
|
|
10
10
|
|
|
11
|
-
def
|
|
12
|
-
|
|
11
|
+
def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
12
|
+
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
13
13
|
"""Performs a broadcast and subsequent addition of two dense tensors x and y. Returns a block-sparse tensor in
|
|
14
14
|
compressed form.
|
|
15
15
|
|
|
@@ -70,12 +70,12 @@ def broadcast_addition(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
|
70
70
|
return output
|
|
71
71
|
|
|
72
72
|
|
|
73
|
-
def
|
|
74
|
-
|
|
75
|
-
"""Wrapper for ``
|
|
73
|
+
def broadcast_sub(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
74
|
+
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
75
|
+
"""Wrapper for ``broadcast_add`` with negated y.
|
|
76
76
|
|
|
77
77
|
"""
|
|
78
|
-
return
|
|
78
|
+
return broadcast_add(x, torch.neg(y), sparsity_layout_output, sparsity_block_size, triton_block_size)
|
|
79
79
|
|
|
80
80
|
|
|
81
81
|
@triton.jit
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
from triton import language as tl
|
|
5
|
+
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size
|
|
7
|
+
from blksprs.utils.validation import validate_contiguous, validate_device, \
|
|
8
|
+
validate_sparsity_block_size, validate_triton_block_size, validate_dimensions
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def repeat_interleave(x: Tensor, sparsity_layout: Tensor, repeats: int,
|
|
12
|
+
sparsity_block_size: int, triton_block_size: int = None) -> tuple[Tensor, Tensor]:
|
|
13
|
+
"""Repeats and interleaves the block-sparse tensor in compressed form.
|
|
14
|
+
|
|
15
|
+
Repeats each matrix contained in the tensors by ``repeats`` amount and places them consecutively in the output
|
|
16
|
+
tensor.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
x (Tensor): A block-sparse tensor in compressed form.
|
|
20
|
+
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
21
|
+
repeats (int): The number of times to repeat the matrices.
|
|
22
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
23
|
+
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
Tensor: A block-sparse tensor in compressed form containing the repeated and interleaved matrices.
|
|
27
|
+
Tensor: The sparsity layout of the resulting output tensor.
|
|
28
|
+
|
|
29
|
+
"""
|
|
30
|
+
validate_dimensions(x)
|
|
31
|
+
validate_contiguous(x)
|
|
32
|
+
validate_device(x)
|
|
33
|
+
validate_sparsity_block_size(sparsity_block_size, x)
|
|
34
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
35
|
+
|
|
36
|
+
sparsity_layout_output = torch.repeat_interleave(sparsity_layout, 3, dim=0).contiguous()
|
|
37
|
+
|
|
38
|
+
sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
|
|
39
|
+
|
|
40
|
+
sparsity_layout_output_flat = sparsity_layout_output.reshape(-1)
|
|
41
|
+
sparsity_output_reverse_lut = ((torch.cumsum(sparsity_layout_output_flat, dim=-1) - 1) *
|
|
42
|
+
(sparsity_layout_output_flat == 1) -
|
|
43
|
+
(1 * (sparsity_layout_output_flat == 0)))
|
|
44
|
+
|
|
45
|
+
n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
|
|
46
|
+
|
|
47
|
+
validate_contiguous(sparsity_layout, sparsity_lut, sparsity_layout_output, sparsity_output_reverse_lut)
|
|
48
|
+
|
|
49
|
+
output = torch.empty(n_sparse_blocks * repeats, sparsity_block_size, sparsity_block_size,
|
|
50
|
+
dtype=x.dtype, device=x.device)
|
|
51
|
+
|
|
52
|
+
x_b, x_r, x_c = x.size()
|
|
53
|
+
x_b_s, x_r_s, x_c_s = x.stride()
|
|
54
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
55
|
+
s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
|
|
56
|
+
o_b, o_r, o_c = output.size()
|
|
57
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
58
|
+
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
|
|
59
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_output.stride()
|
|
60
|
+
|
|
61
|
+
if triton_block_size is None:
|
|
62
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
63
|
+
|
|
64
|
+
triton_grid = lambda meta: [x_b,
|
|
65
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
66
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
67
|
+
|
|
68
|
+
(kernel_repeat_interleave[triton_grid]
|
|
69
|
+
(x,
|
|
70
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
71
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
72
|
+
output,
|
|
73
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
74
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
75
|
+
sparsity_output_reverse_lut,
|
|
76
|
+
repeats,
|
|
77
|
+
triton_block_size))
|
|
78
|
+
|
|
79
|
+
return output, sparsity_layout_output
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@triton.jit
|
|
83
|
+
def kernel_repeat_interleave(x,
|
|
84
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
85
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
86
|
+
o,
|
|
87
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
88
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
89
|
+
r_lut_o,
|
|
90
|
+
repeats,
|
|
91
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
92
|
+
# Get triton block indices
|
|
93
|
+
pid_blk = tl.program_id(axis=0)
|
|
94
|
+
pid_row = tl.program_id(axis=1)
|
|
95
|
+
pid_col = tl.program_id(axis=2)
|
|
96
|
+
|
|
97
|
+
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
98
|
+
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
99
|
+
spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
100
|
+
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
101
|
+
|
|
102
|
+
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
103
|
+
spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
|
|
104
|
+
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
105
|
+
|
|
106
|
+
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
107
|
+
spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
|
|
108
|
+
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
109
|
+
|
|
110
|
+
# Load block
|
|
111
|
+
blk_x_idx = ((pid_blk * x_b_s) +
|
|
112
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
113
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
114
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
115
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
116
|
+
|
|
117
|
+
for repeat in range(repeats):
|
|
118
|
+
# Get reverse sparsity index
|
|
119
|
+
rev_idx_spa_idx = ((spa_bat * repeats + repeat) * s_l_o_b_s +
|
|
120
|
+
spa_row * s_l_o_r_s +
|
|
121
|
+
spa_col * s_l_o_c_s)
|
|
122
|
+
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
|
|
123
|
+
rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
124
|
+
|
|
125
|
+
# Store block
|
|
126
|
+
blk_o_idx = ((rev_idx_spa * o_b_s) +
|
|
127
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
128
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
129
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
130
|
+
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
blksprs/misc/row_wise.py
ADDED
|
@@ -0,0 +1,386 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
from triton import language as tl
|
|
5
|
+
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size
|
|
7
|
+
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, validate_sparsity, \
|
|
8
|
+
validate_sparsity_block_size, validate_triton_block_size
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def row_wise_sum(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
12
|
+
flag_slice_only: bool = False, triton_block_size: int = None) -> tuple[Tensor, Tensor]:
|
|
13
|
+
"""Computes the row-wise sum of a block-sparse tensor.
|
|
14
|
+
|
|
15
|
+
Returns a block-sparse tensor in compressed form with only one block per row, where the first entry contains the sum
|
|
16
|
+
of the corresponding row.
|
|
17
|
+
|
|
18
|
+
Note:
|
|
19
|
+
If ``flag_slice_only`` is set the output will be of shape ``[x.size(0), x.size(1), 1]``.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
x (Tensor): A block-sparse tensor in compressed form.
|
|
23
|
+
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
24
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
25
|
+
flag_slice_only (bool, optional): If set the output will be of shape ``[x.size(0), x.size(1), 1]``
|
|
26
|
+
(default ``False``).
|
|
27
|
+
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
tuple[Tensor, Tensor]: A tuple containing a block-sparse tensor in compressed form containing the row-wise sum
|
|
31
|
+
of the input and the sparsity layout of the output tensor.
|
|
32
|
+
|
|
33
|
+
"""
|
|
34
|
+
validate_dimensions(x)
|
|
35
|
+
validate_contiguous(x)
|
|
36
|
+
validate_device(x)
|
|
37
|
+
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
38
|
+
validate_sparsity_block_size(sparsity_block_size, x)
|
|
39
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
40
|
+
|
|
41
|
+
sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
|
|
42
|
+
|
|
43
|
+
sparsity_layout_output, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)
|
|
44
|
+
sparsity_layout_output_flat = sparsity_layout_output.reshape(-1)
|
|
45
|
+
sparsity_reverse_lut_output = ((torch.cumsum(sparsity_layout_output_flat, dim=-1) - 1) *
|
|
46
|
+
(sparsity_layout_output_flat == 1) -
|
|
47
|
+
(1 * (sparsity_layout_output_flat == 0)))
|
|
48
|
+
|
|
49
|
+
n_sparse_blocks_output = torch.sum(sparsity_layout_output.to(torch.int)).item()
|
|
50
|
+
|
|
51
|
+
validate_contiguous(sparsity_layout, sparsity_lut,
|
|
52
|
+
sparsity_layout_output, sparsity_reverse_lut_output)
|
|
53
|
+
|
|
54
|
+
output = torch.zeros(size=(n_sparse_blocks_output,
|
|
55
|
+
sparsity_block_size,
|
|
56
|
+
1 if flag_slice_only else sparsity_block_size),
|
|
57
|
+
device=x.device)
|
|
58
|
+
|
|
59
|
+
x_b, x_r, x_c = x.size()
|
|
60
|
+
x_b_s, x_r_s, x_c_s = x.stride()
|
|
61
|
+
s_lut_x_r, s_lut_x_c = sparsity_lut.size()
|
|
62
|
+
s_lut_x_r_s, s_lut_x_c_s = sparsity_lut.stride()
|
|
63
|
+
o_b, o_r, o_c = output.size()
|
|
64
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
65
|
+
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
|
|
66
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_output.stride()
|
|
67
|
+
|
|
68
|
+
if triton_block_size is None:
|
|
69
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
70
|
+
|
|
71
|
+
triton_grid = lambda meta: [x_b,
|
|
72
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
73
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
74
|
+
|
|
75
|
+
(kernel_blocksparse_row_wise_sum[triton_grid]
|
|
76
|
+
(x,
|
|
77
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
78
|
+
sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
79
|
+
output,
|
|
80
|
+
o_b, o_b_s, o_r_s,
|
|
81
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s,
|
|
82
|
+
sparsity_reverse_lut_output,
|
|
83
|
+
triton_block_size))
|
|
84
|
+
|
|
85
|
+
return (output, sparsity_layout_output)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@triton.jit
|
|
89
|
+
def kernel_blocksparse_row_wise_sum(x,
|
|
90
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
91
|
+
s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
92
|
+
o,
|
|
93
|
+
o_b, o_b_s, o_r_s,
|
|
94
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s,
|
|
95
|
+
r_lut_o,
|
|
96
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
97
|
+
pid_blk = tl.program_id(axis=0)
|
|
98
|
+
pid_row = tl.program_id(axis=1)
|
|
99
|
+
pid_col = tl.program_id(axis=2)
|
|
100
|
+
|
|
101
|
+
# Get position of current sparsity block consisting of its batch and row index
|
|
102
|
+
spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
|
|
103
|
+
spa_bat_msk = (spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
|
|
104
|
+
spa_bat = tl.load(s_lut_x + spa_bat_idx, mask=spa_bat_msk)
|
|
105
|
+
|
|
106
|
+
spa_row_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
|
|
107
|
+
spa_row_msk = (spa_row_idx < s_lut_x_r * s_lut_x_r_s)
|
|
108
|
+
spa_row = tl.load(s_lut_x + spa_row_idx, mask=spa_row_msk)
|
|
109
|
+
|
|
110
|
+
# Load reverse sparsity index for current block
|
|
111
|
+
rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
|
|
112
|
+
spa_row * s_l_o_r_s)
|
|
113
|
+
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
|
|
114
|
+
rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
115
|
+
|
|
116
|
+
blk_idx = ((pid_blk * x_b_s) +
|
|
117
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
118
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
119
|
+
blk_msk = (blk_idx < x_b * x_b_s)
|
|
120
|
+
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
121
|
+
|
|
122
|
+
buf = tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
|
|
123
|
+
|
|
124
|
+
o_idx = (rev_idx_spa * o_b_s +
|
|
125
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
126
|
+
(tl.arange(0, 1))[None, :])
|
|
127
|
+
o_msk = (o_idx < o_b * o_b_s)
|
|
128
|
+
tl.atomic_add(o + o_idx, buf, o_msk)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def row_wise_max(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
132
|
+
flag_slice_only: bool = False, triton_block_size: int = None) -> tuple[Tensor, Tensor]:
|
|
133
|
+
"""Computes the row-wise max of a block-sparse tensor.
|
|
134
|
+
|
|
135
|
+
Returns a block-sparse tensor in compressed form with only one block per row, where the first entry contains the
|
|
136
|
+
maximum of the corresponding row.
|
|
137
|
+
|
|
138
|
+
Note:
|
|
139
|
+
If ``flag_slice_only`` is set the output will be of shape ``[x.size(0), x.size(1), 1]``.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
x (Tensor): A block-sparse tensor in compressed form.
|
|
143
|
+
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
144
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
145
|
+
flag_slice_only (bool, optional): If set the output will be of shape ``[x.size(0), x.size(1), 1]``
|
|
146
|
+
(default ``False``).
|
|
147
|
+
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
tuple[Tensor, Tensor]: A tuple containing a block-sparse tensor in compressed form containing the row-wise max
|
|
151
|
+
of the input and the sparsity layout of the output tensor.
|
|
152
|
+
|
|
153
|
+
"""
|
|
154
|
+
validate_dimensions(x)
|
|
155
|
+
validate_contiguous(x)
|
|
156
|
+
validate_device(x)
|
|
157
|
+
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
158
|
+
validate_sparsity_block_size(sparsity_block_size, x)
|
|
159
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
160
|
+
|
|
161
|
+
sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
|
|
162
|
+
|
|
163
|
+
sparsity_layout_output, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)
|
|
164
|
+
sparsity_layout_output_flat = sparsity_layout_output.reshape(-1)
|
|
165
|
+
sparsity_reverse_lut_output = ((torch.cumsum(sparsity_layout_output_flat, dim=-1) - 1) *
|
|
166
|
+
(sparsity_layout_output_flat == 1) -
|
|
167
|
+
(1 * (sparsity_layout_output_flat == 0)))
|
|
168
|
+
|
|
169
|
+
n_sparse_blocks_output = torch.sum(sparsity_layout_output.to(torch.int)).item()
|
|
170
|
+
|
|
171
|
+
validate_contiguous(sparsity_layout, sparsity_lut,
|
|
172
|
+
sparsity_layout_output, sparsity_reverse_lut_output)
|
|
173
|
+
|
|
174
|
+
output = torch.full(size=(n_sparse_blocks_output,
|
|
175
|
+
sparsity_block_size,
|
|
176
|
+
1 if flag_slice_only else sparsity_block_size),
|
|
177
|
+
fill_value=float("-inf"),
|
|
178
|
+
device=x.device)
|
|
179
|
+
|
|
180
|
+
x_b, x_r, x_c = x.size()
|
|
181
|
+
x_b_s, x_r_s, x_c_s = x.stride()
|
|
182
|
+
s_lut_x_r, s_lut_x_c = sparsity_lut.size()
|
|
183
|
+
s_lut_x_r_s, s_lut_x_c_s = sparsity_lut.stride()
|
|
184
|
+
o_b, o_r, o_c = output.size()
|
|
185
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
186
|
+
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
|
|
187
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_output.stride()
|
|
188
|
+
|
|
189
|
+
if triton_block_size is None:
|
|
190
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
191
|
+
|
|
192
|
+
triton_grid = lambda meta: [x_b,
|
|
193
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
194
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
195
|
+
|
|
196
|
+
(kernel_blocksparse_row_wise_max[triton_grid]
|
|
197
|
+
(x,
|
|
198
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
199
|
+
sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
200
|
+
output,
|
|
201
|
+
o_b, o_b_s, o_r_s,
|
|
202
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s,
|
|
203
|
+
sparsity_reverse_lut_output,
|
|
204
|
+
triton_block_size))
|
|
205
|
+
|
|
206
|
+
return output, sparsity_layout_output
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
@triton.jit
|
|
210
|
+
def kernel_blocksparse_row_wise_max(x,
|
|
211
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
212
|
+
s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
213
|
+
o,
|
|
214
|
+
o_b, o_b_s, o_r_s,
|
|
215
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s,
|
|
216
|
+
r_lut_o,
|
|
217
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
218
|
+
pid_blk = tl.program_id(axis=0)
|
|
219
|
+
pid_row = tl.program_id(axis=1)
|
|
220
|
+
pid_col = tl.program_id(axis=2)
|
|
221
|
+
|
|
222
|
+
# Get position of current sparsity block consisting of its batch and row index
|
|
223
|
+
spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
|
|
224
|
+
spa_bat_msk = (spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
|
|
225
|
+
spa_bat = tl.load(s_lut_x + spa_bat_idx, mask=spa_bat_msk)
|
|
226
|
+
|
|
227
|
+
spa_row_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
|
|
228
|
+
spa_row_msk = (spa_row_idx < s_lut_x_r * s_lut_x_r_s)
|
|
229
|
+
spa_row = tl.load(s_lut_x + spa_row_idx, mask=spa_row_msk)
|
|
230
|
+
|
|
231
|
+
# Load reverse sparsity index for current block
|
|
232
|
+
rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
|
|
233
|
+
spa_row * s_l_o_r_s)
|
|
234
|
+
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
|
|
235
|
+
rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
236
|
+
|
|
237
|
+
blk_idx = ((pid_blk * x_b_s) +
|
|
238
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
239
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
240
|
+
blk_msk = (blk_idx < x_b * x_b_s)
|
|
241
|
+
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
242
|
+
|
|
243
|
+
buf = tl.reshape(tl.max(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
|
|
244
|
+
|
|
245
|
+
o_idx = (rev_idx_spa * o_b_s +
|
|
246
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
247
|
+
(tl.arange(0, 1))[None, :])
|
|
248
|
+
o_msk = (o_idx < o_b * o_b_s)
|
|
249
|
+
tl.atomic_max(o + o_idx, buf, o_msk)
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def row_wise_add(x: Tensor, sparsity_layout_x: Tensor, y: Tensor,
|
|
253
|
+
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
254
|
+
"""For each row in ``y`` adds the value to each value in the corresponding row of the block-sparse tensor ``x``.
|
|
255
|
+
|
|
256
|
+
Args:
|
|
257
|
+
x (Tensor): A block-sparse tensor in compressed form.
|
|
258
|
+
sparsity_layout_x (Tensor): The sparsity layout of the block-sparse tensor.
|
|
259
|
+
y (Tensor): A block-sparse tensor in compressed form with only one value per row and a single column of sparse blocks.
|
|
260
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
261
|
+
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
262
|
+
|
|
263
|
+
Returns:
|
|
264
|
+
Tensor: The values of ``x`` with the first value of ``y`` in each row added to them as a block-sparse tensor in
|
|
265
|
+
compressed form.
|
|
266
|
+
|
|
267
|
+
"""
|
|
268
|
+
validate_dimensions(x)
|
|
269
|
+
validate_contiguous(x)
|
|
270
|
+
validate_device(x)
|
|
271
|
+
validate_sparsity(sparsity_block_size, (x, sparsity_layout_x))
|
|
272
|
+
validate_sparsity_block_size(sparsity_block_size, x)
|
|
273
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
274
|
+
|
|
275
|
+
sparsity_lut = torch.nonzero(sparsity_layout_x).contiguous()
|
|
276
|
+
|
|
277
|
+
sparsity_layout_rwm, _ = torch.max(sparsity_layout_x, dim=-1, keepdim=True)
|
|
278
|
+
sparsity_layout_rwm_flat = sparsity_layout_rwm.reshape(-1)
|
|
279
|
+
sparsity_reverse_lut_rwm = ((torch.cumsum(sparsity_layout_rwm_flat, dim=-1) - 1) *
|
|
280
|
+
(sparsity_layout_rwm_flat == 1) -
|
|
281
|
+
(1 * (sparsity_layout_rwm_flat == 0)))
|
|
282
|
+
|
|
283
|
+
validate_contiguous(sparsity_layout_x, sparsity_lut, sparsity_reverse_lut_rwm)
|
|
284
|
+
|
|
285
|
+
output = torch.empty_like(x)
|
|
286
|
+
|
|
287
|
+
x_b, x_r, x_c = x.size()
|
|
288
|
+
x_b_s, x_r_s, x_c_s = x.stride()
|
|
289
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
290
|
+
s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
|
|
291
|
+
y_b, y_r, y_c = y.size()
|
|
292
|
+
y_b_s, y_r_s, y_c_s = y.stride()
|
|
293
|
+
s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_rwm.size()
|
|
294
|
+
s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = sparsity_layout_rwm.stride()
|
|
295
|
+
o_b, o_r, o_c = output.size()
|
|
296
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
297
|
+
|
|
298
|
+
if triton_block_size is None:
|
|
299
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
300
|
+
|
|
301
|
+
triton_grid = lambda meta: [o_b,
|
|
302
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
303
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
304
|
+
|
|
305
|
+
(kernel_blocksparse_row_wise_add[triton_grid]
|
|
306
|
+
(x,
|
|
307
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
308
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
309
|
+
y, y_b, y_b_s, y_r_s, y_c_s,
|
|
310
|
+
s_l_y_b, s_l_y_b_s, s_l_y_r_s,
|
|
311
|
+
sparsity_reverse_lut_rwm,
|
|
312
|
+
output,
|
|
313
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
314
|
+
triton_block_size
|
|
315
|
+
))
|
|
316
|
+
|
|
317
|
+
return output
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def row_wise_sub(x: Tensor, sparsity_layout_x: Tensor, y: Tensor,
|
|
321
|
+
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
322
|
+
"""Wrapper for ``row_wise_add`` with negated y.
|
|
323
|
+
|
|
324
|
+
"""
|
|
325
|
+
return row_wise_add(x, sparsity_layout_x, torch.neg(y), sparsity_block_size, triton_block_size)
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
@triton.jit
|
|
329
|
+
def kernel_blocksparse_row_wise_add(x,
|
|
330
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
331
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
332
|
+
y, y_b, y_b_s, y_r_s, y_c_s,
|
|
333
|
+
s_l_y_b, s_l_y_b_s, s_l_y_r_s,
|
|
334
|
+
r_lut_y,
|
|
335
|
+
o,
|
|
336
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
337
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
338
|
+
# Get triton block indices
|
|
339
|
+
pid_blk = tl.program_id(axis=0)
|
|
340
|
+
pid_row = tl.program_id(axis=1)
|
|
341
|
+
pid_col = tl.program_id(axis=2)
|
|
342
|
+
|
|
343
|
+
# Get position of current sparsity block consisting of its batch and row index
|
|
344
|
+
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
345
|
+
spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
346
|
+
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
347
|
+
|
|
348
|
+
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
349
|
+
spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
|
|
350
|
+
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
351
|
+
|
|
352
|
+
# Get reverse sparsity indices for s
|
|
353
|
+
rev_idx_spa_s_idx = (spa_bat * s_l_y_b_s +
|
|
354
|
+
spa_row * s_l_y_r_s)
|
|
355
|
+
rev_idx_spa_s_msk = (rev_idx_spa_s_idx < s_l_y_b * s_l_y_b_s)
|
|
356
|
+
rev_idx_spa_s = tl.load(r_lut_y + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
|
|
357
|
+
|
|
358
|
+
if rev_idx_spa_s == -1:
|
|
359
|
+
assert False, "Invalid sparsity block"
|
|
360
|
+
|
|
361
|
+
# Load x block
|
|
362
|
+
blk_x_idx = ((pid_blk * x_b_s) +
|
|
363
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
364
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
365
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
366
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
367
|
+
|
|
368
|
+
# Load sum block
|
|
369
|
+
blk_s_idx = (rev_idx_spa_s * y_b_s +
|
|
370
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
|
|
371
|
+
(tl.arange(0, 1) * y_c_s)[None, :])
|
|
372
|
+
blk_s_msk = (blk_s_idx < y_b * y_b_s)
|
|
373
|
+
blk_s = tl.load(y + blk_s_idx, mask=blk_s_msk)
|
|
374
|
+
|
|
375
|
+
# Compute exp
|
|
376
|
+
buf = blk_x + tl.broadcast_to(blk_s, (TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE))
|
|
377
|
+
|
|
378
|
+
# debug
|
|
379
|
+
asdf = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), 1.0, dtype=tl.float32)
|
|
380
|
+
|
|
381
|
+
# Store block
|
|
382
|
+
blk_o_idx = ((pid_blk * o_b_s) +
|
|
383
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
384
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
385
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
386
|
+
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
blksprs/ops/softmax.py
CHANGED
|
@@ -4,7 +4,7 @@ from torch import Tensor
|
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
6
|
from blksprs.ops.exp import exp
|
|
7
|
-
from blksprs.
|
|
7
|
+
from blksprs.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
|
|
8
8
|
from blksprs.utils.tools import get_triton_block_size
|
|
9
9
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
10
10
|
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
@@ -33,12 +33,6 @@ def softmax(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton
|
|
|
33
33
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
34
34
|
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
35
35
|
|
|
36
|
-
if x.size(0) != 0:
|
|
37
|
-
max_val = torch.max(x).item()
|
|
38
|
-
else:
|
|
39
|
-
max_val = 0
|
|
40
|
-
x_scaled = x - max_val
|
|
41
|
-
|
|
42
36
|
sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
|
|
43
37
|
|
|
44
38
|
sparsity_layout_rws, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)
|
|
@@ -49,7 +43,7 @@ def softmax(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton
|
|
|
49
43
|
|
|
50
44
|
validate_contiguous(sparsity_layout, sparsity_lut, sparsity_reverse_lut_rws)
|
|
51
45
|
|
|
52
|
-
return _BlocksparseSoftmax.apply(
|
|
46
|
+
return _BlocksparseSoftmax.apply(x, sparsity_layout,
|
|
53
47
|
sparsity_lut,
|
|
54
48
|
sparsity_reverse_lut_rws,
|
|
55
49
|
sparsity_block_size, triton_block_size)
|
|
@@ -64,13 +58,17 @@ class _BlocksparseSoftmax(torch.autograd.Function):
|
|
|
64
58
|
sparsity_block_size: int, triton_block_size: int) -> Tensor:
|
|
65
59
|
output = torch.empty_like(x)
|
|
66
60
|
|
|
67
|
-
x_b, x_r, x_c = x.
|
|
61
|
+
x_b, x_r, x_c = x.size()
|
|
68
62
|
x_b_s, x_r_s, x_c_s = x.stride()
|
|
69
|
-
s_lut_r, s_lut_c = sparsity_lut.
|
|
63
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
70
64
|
s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
|
|
71
|
-
o_b, o_r, o_c = output.
|
|
65
|
+
o_b, o_r, o_c = output.size()
|
|
72
66
|
|
|
73
|
-
|
|
67
|
+
x_row_wise_max, sparsity_layout_rwm = row_wise_max(x, sparsity_layout, sparsity_block_size,
|
|
68
|
+
flag_slice_only=True,
|
|
69
|
+
triton_block_size=triton_block_size)
|
|
70
|
+
x_scaled = row_wise_sub(x, sparsity_layout, x_row_wise_max, sparsity_block_size, triton_block_size)
|
|
71
|
+
x_exp = exp(x_scaled, sparsity_block_size, triton_block_size=triton_block_size)
|
|
74
72
|
x_exp_row_wise_sum, sparsity_layout_rws = row_wise_sum(x_exp, sparsity_layout, sparsity_block_size,
|
|
75
73
|
flag_slice_only=True,
|
|
76
74
|
triton_block_size=triton_block_size)
|
|
@@ -174,7 +172,7 @@ class _BlocksparseSoftmax(torch.autograd.Function):
|
|
|
174
172
|
spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
|
|
175
173
|
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
176
174
|
|
|
177
|
-
# Get reverse sparsity indices for
|
|
175
|
+
# Get reverse sparsity indices for s
|
|
178
176
|
rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
|
|
179
177
|
spa_row * s_l_s_r_s)
|
|
180
178
|
rev_idx_spa_s_msk = (rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
|
blksprs/ops/transpose.py
CHANGED
|
@@ -129,7 +129,7 @@ class _BlocksparseTranspose(torch.autograd.Function):
|
|
|
129
129
|
spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
|
|
130
130
|
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
131
131
|
|
|
132
|
-
# Get reverse sparsity
|
|
132
|
+
# Get reverse sparsity index
|
|
133
133
|
rev_idx_spa_idx = (spa_bat * s_l_b_s +
|
|
134
134
|
spa_row * s_l_r_s +
|
|
135
135
|
spa_col * s_l_c_s)
|
blksprs/utils/tools.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.4
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -83,14 +83,7 @@ the [test cases](https://github.com/FelixSchoen/blksprs/blob/main/test/cases/tes
|
|
|
83
83
|
|
|
84
84
|
```python
|
|
85
85
|
import torch
|
|
86
|
-
|
|
87
|
-
from blksprs.layouting.sparsity_layout import build_sparsity_layout
|
|
88
|
-
from blksprs.ops.conversion import to_sparse, to_dense
|
|
89
|
-
from blksprs.ops.matmul import matmul
|
|
90
|
-
from blksprs.ops.row_wise_sum import row_wise_sum
|
|
91
|
-
from blksprs.ops.softmax import softmax
|
|
92
|
-
from blksprs.ops.transpose import transpose
|
|
93
|
-
from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse
|
|
86
|
+
import blksprs as bs
|
|
94
87
|
|
|
95
88
|
|
|
96
89
|
def test_readme():
|
|
@@ -112,47 +105,57 @@ def test_readme():
|
|
|
112
105
|
y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()
|
|
113
106
|
|
|
114
107
|
# Convert tensors to three-dimensional (dense) tensors since Triton can only handle tensors of exactly three dimensions
|
|
115
|
-
x_dense, x_shape_original = do_shape_blocksparse(x)
|
|
116
|
-
y_dense, y_shape_original = do_shape_blocksparse(y)
|
|
108
|
+
x_dense, x_shape_original = bs.util.do_shape_blocksparse(x)
|
|
109
|
+
y_dense, y_shape_original = bs.util.do_shape_blocksparse(y)
|
|
117
110
|
|
|
118
111
|
# Create sparsity layouts from existing tensors
|
|
119
|
-
sparsity_layout_x = build_sparsity_layout(x_dense, sparsity_block_size,
|
|
120
|
-
|
|
112
|
+
sparsity_layout_x = bs.layout.build_sparsity_layout(x_dense, sparsity_block_size,
|
|
113
|
+
triton_block_size=triton_block_size)
|
|
114
|
+
sparsity_layout_y = bs.layout.build_sparsity_layout(y_dense, sparsity_block_size,
|
|
115
|
+
triton_block_size=triton_block_size)
|
|
121
116
|
|
|
122
117
|
# Create random sparsity layout for output tensor
|
|
123
118
|
sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)
|
|
124
119
|
|
|
125
120
|
# Convert tensors to sparse tensors for matrix multiplication
|
|
126
|
-
x_sparse = to_sparse(x_dense, sparsity_layout_x, sparsity_block_size, triton_block_size=triton_block_size)
|
|
127
|
-
y_sparse = to_sparse(y_dense, sparsity_layout_y, sparsity_block_size, triton_block_size=triton_block_size)
|
|
121
|
+
x_sparse = bs.to_sparse(x_dense, sparsity_layout_x, sparsity_block_size, triton_block_size=triton_block_size)
|
|
122
|
+
y_sparse = bs.to_sparse(y_dense, sparsity_layout_y, sparsity_block_size, triton_block_size=triton_block_size)
|
|
128
123
|
|
|
129
124
|
# Perform matrix multiplication
|
|
130
|
-
o_sparse = matmul(x_sparse, sparsity_layout_x, y_sparse, sparsity_layout_y, sparsity_layout_o,
|
|
131
|
-
|
|
132
|
-
|
|
125
|
+
o_sparse = bs.matmul(x_sparse, sparsity_layout_x, y_sparse, sparsity_layout_y, sparsity_layout_o,
|
|
126
|
+
sparsity_block_size,
|
|
127
|
+
triton_block_size=triton_block_size)
|
|
128
|
+
|
|
129
|
+
# Apply element-wise operation
|
|
130
|
+
o_sparse = torch.add(o_sparse, 1)
|
|
131
|
+
|
|
132
|
+
o_dense = bs.to_dense(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
|
|
133
133
|
|
|
134
134
|
# Sanity check
|
|
135
135
|
o_torch = torch.matmul(x_dense, y_dense)
|
|
136
|
+
o_torch = torch.add(o_torch, 1)
|
|
136
137
|
|
|
137
138
|
# Perform round trip to set sparse blocks to 0
|
|
138
|
-
o_torch_round_trip = to_dense(
|
|
139
|
-
to_sparse(o_torch, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size),
|
|
139
|
+
o_torch_round_trip = bs.to_dense(
|
|
140
|
+
bs.to_sparse(o_torch, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size),
|
|
140
141
|
sparsity_layout_o, sparsity_block_size, fill_value=0, triton_block_size=triton_block_size)
|
|
141
142
|
|
|
142
143
|
# Assert that the output is correct
|
|
143
144
|
assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2) # Note that small numerical differences are expected
|
|
144
145
|
|
|
145
146
|
# Assert that the output has the correct sparsity layout
|
|
146
|
-
actual_sparsity_layout_o = build_sparsity_layout(o_dense, sparsity_block_size,
|
|
147
|
+
actual_sparsity_layout_o = bs.layout.build_sparsity_layout(o_dense, sparsity_block_size,
|
|
148
|
+
triton_block_size=triton_block_size)
|
|
147
149
|
assert torch.allclose(actual_sparsity_layout_o, sparsity_layout_o)
|
|
148
150
|
|
|
149
151
|
# Convert output tensor back to original shape
|
|
150
|
-
o = undo_shape_blocksparse(o_dense, x_shape_original)
|
|
152
|
+
o = bs.util.undo_shape_blocksparse(o_dense, x_shape_original)
|
|
151
153
|
|
|
152
154
|
# Other available functions
|
|
153
|
-
transpose(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
|
|
154
|
-
softmax(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
|
|
155
|
-
row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
|
|
155
|
+
bs.transpose(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
|
|
156
|
+
bs.softmax(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
|
|
157
|
+
bs.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
|
|
158
|
+
bs.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
|
|
156
159
|
|
|
157
160
|
|
|
158
161
|
def _get_random_sparsity_layout(b, m, n, sparsity_block_size, sparsity_percentage):
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
blksprs/__init__.py,sha256=vUthykoYgmHqo2rNYgfrKTNMq7IDalRpCa1nVdFEOqA,813
|
|
2
|
+
blksprs/layouting/distribution_layout.py,sha256=GQ-ZRXbeImiLcbaqnL2FuUZ6DoFwmB0naT_YrOpD84Q,4940
|
|
3
|
+
blksprs/layouting/sparsity_layout.py,sha256=TtADT_WWcZpW3zyGy6KAgkAo44gDryXZqdJLZGEX2V8,7895
|
|
4
|
+
blksprs/misc/broadcast_ops.py,sha256=xLj7CH5yEBihI5gT8SRFqQta1DXvl3iSskhbHsOX_EM,5261
|
|
5
|
+
blksprs/misc/repeat_interleave.py,sha256=WrIp7uJsnvjIhFeLYPfkL2j5vXyKmDQGrJ69b3Y0lQ8,5644
|
|
6
|
+
blksprs/misc/row_wise.py,sha256=Fa57BVfmneXT_8Ms-Vao8H8fh89sT3Z0b_gtN-7gano,16805
|
|
7
|
+
blksprs/ops/conversion.py,sha256=-AOzj_j3WrBLGIgd2oVPvYS8XKfzlvGtSIWzW_qP1lk,21260
|
|
8
|
+
blksprs/ops/distribution.py,sha256=_fQb6fWpLxocAh86D74ATahChi0EK0eBb4eUOUEBVps,16769
|
|
9
|
+
blksprs/ops/exp.py,sha256=qs8fVtCzxl4CKT4GepaqurjEL62jyi8VjMY12JFrFAU,3674
|
|
10
|
+
blksprs/ops/matmul.py,sha256=x3lrYg4g8fIf5PeMtZY_SEpi11kP9RFcRoemCIxcSDE,11086
|
|
11
|
+
blksprs/ops/softmax.py,sha256=1lxgS12oJ5UcRkDxq13OOjp9AHwhgzSfBosEO1GzKvs,11948
|
|
12
|
+
blksprs/ops/transpose.py,sha256=cX_E3b-QMhsUDNn9D8HVkYesc2JBc-EcVBUZfCWExM8,6720
|
|
13
|
+
blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
|
|
14
|
+
blksprs/utils/tools.py,sha256=DwophH01AeNTZAo0B1uWbKFSGBQjI5z0WmFnYKh-BBk,465
|
|
15
|
+
blksprs/utils/validation.py,sha256=gJYZO5C48YUrXV3Fy_Z_lCaOpiFj951FT-Od7sKfprg,3007
|
|
16
|
+
blksprs-1.4.dist-info/METADATA,sha256=mC9Vql8wtF_gLYwnGXx8_p9aKL7PnxrQSoZNkQegxic,7675
|
|
17
|
+
blksprs-1.4.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
|
18
|
+
blksprs-1.4.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
|
|
19
|
+
blksprs-1.4.dist-info/RECORD,,
|
blksprs/ops/row_wise_sum.py
DELETED
|
@@ -1,231 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
import triton
|
|
3
|
-
from torch import Tensor
|
|
4
|
-
from triton import language as tl
|
|
5
|
-
|
|
6
|
-
from blksprs.utils.tools import get_triton_block_size
|
|
7
|
-
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
8
|
-
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def row_wise_sum(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
12
|
-
flag_slice_only: bool = False, triton_block_size: int = None) -> tuple[Tensor, Tensor]:
|
|
13
|
-
"""Computes the row-wise sum of a block-sparse tensor.
|
|
14
|
-
|
|
15
|
-
Returns a block-sparse tensor in compressed form with only one block per row, where the first entry contains the sum
|
|
16
|
-
of the corresponding row.
|
|
17
|
-
|
|
18
|
-
Note:
|
|
19
|
-
If ``flag_slice_only`` is set the output will be of shape ``[x.size(0), x.size(1), 1]``.
|
|
20
|
-
|
|
21
|
-
Args:
|
|
22
|
-
x (Tensor): A block-sparse tensor in compressed form.
|
|
23
|
-
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
24
|
-
sparsity_block_size (int): The size of the sparsity blocks.
|
|
25
|
-
flag_slice_only (bool, optional): If set the output will be of shape ``[x.size(0), x.size(1), 1]``
|
|
26
|
-
(default ``False``).
|
|
27
|
-
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
28
|
-
|
|
29
|
-
Returns:
|
|
30
|
-
tuple[Tensor, Tensor]: A tuple containing a block-sparse tensor in compressed form containing the row-wise sum
|
|
31
|
-
of the input and the sparsity layout of the output tensor.
|
|
32
|
-
|
|
33
|
-
"""
|
|
34
|
-
validate_dimensions(x)
|
|
35
|
-
validate_contiguous(x)
|
|
36
|
-
validate_device(x)
|
|
37
|
-
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
38
|
-
validate_sparsity_block_size(sparsity_block_size, x)
|
|
39
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
40
|
-
|
|
41
|
-
sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
|
|
42
|
-
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
|
43
|
-
sparsity_reverse_lut = ((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
44
|
-
(sparsity_layout_flat == 1) -
|
|
45
|
-
(1 * (sparsity_layout_flat == 0)))
|
|
46
|
-
|
|
47
|
-
sparsity_layout_output, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)
|
|
48
|
-
sparsity_lut_output = torch.nonzero(sparsity_layout_output).contiguous()
|
|
49
|
-
sparsity_layout_output_flat = sparsity_layout_output.reshape(-1)
|
|
50
|
-
sparsity_reverse_lut_output = ((torch.cumsum(sparsity_layout_output_flat, dim=-1) - 1) *
|
|
51
|
-
(sparsity_layout_output_flat == 1) -
|
|
52
|
-
(1 * (sparsity_layout_output_flat == 0)))
|
|
53
|
-
|
|
54
|
-
n_sparse_blocks_output = torch.sum(sparsity_layout_output.to(torch.int)).item()
|
|
55
|
-
|
|
56
|
-
validate_contiguous(sparsity_layout, sparsity_lut, sparsity_reverse_lut,
|
|
57
|
-
sparsity_layout_output, sparsity_lut_output, sparsity_reverse_lut_output)
|
|
58
|
-
|
|
59
|
-
return (_BlocksparseRowWiseSum.apply(x,
|
|
60
|
-
sparsity_layout, sparsity_lut, sparsity_reverse_lut,
|
|
61
|
-
sparsity_layout_output, sparsity_lut_output, sparsity_reverse_lut_output,
|
|
62
|
-
n_sparse_blocks_output,
|
|
63
|
-
flag_slice_only,
|
|
64
|
-
sparsity_block_size, triton_block_size),
|
|
65
|
-
sparsity_layout_output)
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
class _BlocksparseRowWiseSum(torch.autograd.Function):
|
|
69
|
-
IMPLEMENTATION = "atomic_add"
|
|
70
|
-
|
|
71
|
-
@staticmethod
|
|
72
|
-
def forward(ctx, x: Tensor,
|
|
73
|
-
sparsity_layout: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
74
|
-
sparsity_layout_output: Tensor, sparsity_lut_output: Tensor, sparsity_reverse_lut_output: Tensor,
|
|
75
|
-
n_sparse_blocks_output: int,
|
|
76
|
-
flag_slice_only: bool,
|
|
77
|
-
sparsity_block_size: int, triton_block_size: int) -> Tensor:
|
|
78
|
-
output = torch.zeros(size=(n_sparse_blocks_output,
|
|
79
|
-
sparsity_block_size,
|
|
80
|
-
1 if flag_slice_only else sparsity_block_size),
|
|
81
|
-
device=x.device)
|
|
82
|
-
|
|
83
|
-
x_b, x_r, x_c = x.size()
|
|
84
|
-
x_b_s, x_r_s, x_c_s = x.stride()
|
|
85
|
-
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout.size()
|
|
86
|
-
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout.stride()
|
|
87
|
-
s_lut_x_r, s_lut_x_c = sparsity_lut.size()
|
|
88
|
-
s_lut_x_r_s, s_lut_x_c_s = sparsity_lut.stride()
|
|
89
|
-
o_b, o_r, o_c = output.size()
|
|
90
|
-
o_b_s, o_r_s, o_c_s = output.stride()
|
|
91
|
-
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
|
|
92
|
-
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_output.stride()
|
|
93
|
-
s_lut_o_r, s_lut_o_c = sparsity_lut_output.size()
|
|
94
|
-
s_lut_o_r_s, s_lut_o_c_s = sparsity_lut_output.stride()
|
|
95
|
-
|
|
96
|
-
if triton_block_size is None:
|
|
97
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
98
|
-
|
|
99
|
-
if _BlocksparseRowWiseSum.IMPLEMENTATION == "basic":
|
|
100
|
-
triton_grid = lambda meta: [o_b,
|
|
101
|
-
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"])]
|
|
102
|
-
|
|
103
|
-
(_BlocksparseRowWiseSum.kernel_blocksparse_row_wise_sum[triton_grid]
|
|
104
|
-
(x,
|
|
105
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
106
|
-
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c, s_l_x_c_s,
|
|
107
|
-
sparsity_reverse_lut,
|
|
108
|
-
output,
|
|
109
|
-
o_b, o_b_s, o_r_s,
|
|
110
|
-
sparsity_lut_output, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
111
|
-
sparsity_block_size,
|
|
112
|
-
triton_block_size))
|
|
113
|
-
elif _BlocksparseRowWiseSum.IMPLEMENTATION == "atomic_add":
|
|
114
|
-
triton_grid = lambda meta: [x_b,
|
|
115
|
-
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
116
|
-
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
117
|
-
|
|
118
|
-
(_BlocksparseRowWiseSum.kernel_blocksparse_row_wise_sum_atomic_add[triton_grid]
|
|
119
|
-
(x,
|
|
120
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
121
|
-
sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
122
|
-
output,
|
|
123
|
-
o_b, o_b_s, o_r_s,
|
|
124
|
-
s_l_o_b, s_l_o_b_s, s_l_o_r_s,
|
|
125
|
-
sparsity_reverse_lut_output,
|
|
126
|
-
triton_block_size))
|
|
127
|
-
|
|
128
|
-
return output
|
|
129
|
-
|
|
130
|
-
@staticmethod
|
|
131
|
-
def backward(ctx, grad_output):
|
|
132
|
-
raise NotImplementedError
|
|
133
|
-
|
|
134
|
-
@staticmethod
|
|
135
|
-
@triton.jit
|
|
136
|
-
def kernel_blocksparse_row_wise_sum(x,
|
|
137
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
138
|
-
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c, s_l_x_c_s,
|
|
139
|
-
r_lut_x,
|
|
140
|
-
o,
|
|
141
|
-
o_b, o_b_s, o_r_s,
|
|
142
|
-
s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
143
|
-
sparsity_block_size,
|
|
144
|
-
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
145
|
-
pid_blk = tl.program_id(axis=0)
|
|
146
|
-
pid_row = tl.program_id(axis=1)
|
|
147
|
-
|
|
148
|
-
# Get position of current sparsity block consisting of its batch and row index
|
|
149
|
-
spa_bat_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
|
|
150
|
-
spa_bat_msk = (spa_bat_idx < s_lut_o_r * s_lut_o_r_s)
|
|
151
|
-
spa_bat = tl.load(s_lut_o + spa_bat_idx, mask=spa_bat_msk)
|
|
152
|
-
|
|
153
|
-
spa_row_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
|
|
154
|
-
spa_row_msk = (spa_row_idx < s_lut_o_r * s_lut_o_r_s)
|
|
155
|
-
spa_row = tl.load(s_lut_o + spa_row_idx, mask=spa_row_msk)
|
|
156
|
-
|
|
157
|
-
buf = tl.zeros(shape=(TRITON_BLOCK_SIZE, 1), dtype=tl.float32)
|
|
158
|
-
|
|
159
|
-
# Slide over triton block sized segments of input tensor
|
|
160
|
-
for i_seg_tri in range(0, tl.cdiv(s_l_x_c * sparsity_block_size, TRITON_BLOCK_SIZE)):
|
|
161
|
-
# Convert to segment index of sparsity layout
|
|
162
|
-
i_seg_spa = (i_seg_tri * TRITON_BLOCK_SIZE) // sparsity_block_size
|
|
163
|
-
# Calculate the triton segment index within a block
|
|
164
|
-
i_seg_tri_mod = i_seg_tri % (sparsity_block_size // TRITON_BLOCK_SIZE)
|
|
165
|
-
|
|
166
|
-
# Load reverse sparsity index for current block
|
|
167
|
-
rev_idx_spa_idx = (spa_bat * s_l_x_b_s +
|
|
168
|
-
spa_row * s_l_x_r_s +
|
|
169
|
-
i_seg_spa * s_l_x_c_s)
|
|
170
|
-
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_x_b * s_l_x_b_s)
|
|
171
|
-
rev_idx_spa = tl.load(r_lut_x + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
172
|
-
|
|
173
|
-
# If block is present commence operations
|
|
174
|
-
if rev_idx_spa >= 0:
|
|
175
|
-
blk_idx = ((rev_idx_spa * x_b_s) +
|
|
176
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
177
|
-
((i_seg_tri_mod * TRITON_BLOCK_SIZE +
|
|
178
|
-
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
179
|
-
blk_msk = (blk_idx < x_b * x_b_s)
|
|
180
|
-
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
181
|
-
|
|
182
|
-
buf = buf + tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
|
|
183
|
-
|
|
184
|
-
o_idx = (pid_blk * o_b_s +
|
|
185
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
186
|
-
(tl.arange(0, 1))[None, :])
|
|
187
|
-
o_msk = (o_idx < o_b * o_b_s)
|
|
188
|
-
tl.store(o + o_idx, buf, o_msk)
|
|
189
|
-
|
|
190
|
-
@staticmethod
|
|
191
|
-
@triton.jit
|
|
192
|
-
def kernel_blocksparse_row_wise_sum_atomic_add(x,
|
|
193
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
194
|
-
s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
195
|
-
o,
|
|
196
|
-
o_b, o_b_s, o_r_s,
|
|
197
|
-
s_l_o_b, s_l_o_b_s, s_l_o_r_s,
|
|
198
|
-
r_lut_o,
|
|
199
|
-
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
200
|
-
pid_blk = tl.program_id(axis=0)
|
|
201
|
-
pid_row = tl.program_id(axis=1)
|
|
202
|
-
pid_col = tl.program_id(axis=2)
|
|
203
|
-
|
|
204
|
-
# Get position of current sparsity block consisting of its batch and row index
|
|
205
|
-
spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
|
|
206
|
-
spa_bat_msk = (spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
|
|
207
|
-
spa_bat = tl.load(s_lut_x + spa_bat_idx, mask=spa_bat_msk)
|
|
208
|
-
|
|
209
|
-
spa_row_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
|
|
210
|
-
spa_row_msk = (spa_row_idx < s_lut_x_r * s_lut_x_r_s)
|
|
211
|
-
spa_row = tl.load(s_lut_x + spa_row_idx, mask=spa_row_msk)
|
|
212
|
-
|
|
213
|
-
# Load reverse sparsity index for current block
|
|
214
|
-
rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
|
|
215
|
-
spa_row * s_l_o_r_s)
|
|
216
|
-
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
|
|
217
|
-
rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
218
|
-
|
|
219
|
-
blk_idx = ((pid_blk * x_b_s) +
|
|
220
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
221
|
-
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
222
|
-
blk_msk = (blk_idx < x_b * x_b_s)
|
|
223
|
-
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
224
|
-
|
|
225
|
-
buf = tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
|
|
226
|
-
|
|
227
|
-
o_idx = (rev_idx_spa * o_b_s +
|
|
228
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
229
|
-
(tl.arange(0, 1))[None, :])
|
|
230
|
-
o_msk = (o_idx < o_b * o_b_s)
|
|
231
|
-
tl.atomic_add(o + o_idx, buf, o_msk)
|
blksprs-1.2.1.dist-info/RECORD
DELETED
|
@@ -1,17 +0,0 @@
|
|
|
1
|
-
blksprs/layouting/distribution_layout.py,sha256=GQ-ZRXbeImiLcbaqnL2FuUZ6DoFwmB0naT_YrOpD84Q,4940
|
|
2
|
-
blksprs/layouting/sparsity_layout.py,sha256=TtADT_WWcZpW3zyGy6KAgkAo44gDryXZqdJLZGEX2V8,7895
|
|
3
|
-
blksprs/misc/broadcast_addition.py,sha256=vf1Hdqz9Uyqykto3DCjmdyepMzpMXL238SpANQqRAwI,5297
|
|
4
|
-
blksprs/ops/conversion.py,sha256=-AOzj_j3WrBLGIgd2oVPvYS8XKfzlvGtSIWzW_qP1lk,21260
|
|
5
|
-
blksprs/ops/distribution.py,sha256=_fQb6fWpLxocAh86D74ATahChi0EK0eBb4eUOUEBVps,16769
|
|
6
|
-
blksprs/ops/exp.py,sha256=qs8fVtCzxl4CKT4GepaqurjEL62jyi8VjMY12JFrFAU,3674
|
|
7
|
-
blksprs/ops/matmul.py,sha256=x3lrYg4g8fIf5PeMtZY_SEpi11kP9RFcRoemCIxcSDE,11086
|
|
8
|
-
blksprs/ops/row_wise_sum.py,sha256=ojuSejV37cLtRNS3lBfknA5KY3TEg8EHxOqVT6JZzoM,11387
|
|
9
|
-
blksprs/ops/softmax.py,sha256=ZyeAVqmG_VzJ72FArGrpUSFfoSM4GPxyubrmNKERVIA,11654
|
|
10
|
-
blksprs/ops/transpose.py,sha256=DVEXoxo2MoTNL3NZrjxsukMDrzk2vnEXL1uRnKFWkn0,6722
|
|
11
|
-
blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
|
|
12
|
-
blksprs/utils/tools.py,sha256=P2UALvccRjJJ7w05YGuaxB3qmNObgct4idfM0jlE2wg,465
|
|
13
|
-
blksprs/utils/validation.py,sha256=gJYZO5C48YUrXV3Fy_Z_lCaOpiFj951FT-Od7sKfprg,3007
|
|
14
|
-
blksprs-1.2.1.dist-info/METADATA,sha256=hzuuw0MkMzpNZgM5PFdD_zI7QAvVmeu65tk0dx8TUTI,7517
|
|
15
|
-
blksprs-1.2.1.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
|
16
|
-
blksprs-1.2.1.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
|
|
17
|
-
blksprs-1.2.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|