blksprs 1.3__py3-none-any.whl → 1.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +18 -0
- blksprs/layouting/distribution_layout.py +1 -1
- blksprs/layouting/sparsity_layout.py +2 -2
- blksprs/misc/{broadcast_addition.py → broadcast_ops.py} +9 -6
- blksprs/misc/repeat_interleave.py +2 -0
- blksprs/misc/row_wise.py +390 -0
- blksprs/ops/conversion.py +6 -0
- blksprs/ops/distribution.py +6 -0
- blksprs/ops/exp.py +2 -0
- blksprs/ops/matmul.py +6 -2
- blksprs/ops/softmax.py +13 -13
- blksprs/ops/transpose.py +2 -0
- blksprs/utils/tools.py +7 -1
- blksprs/utils/validation.py +15 -10
- {blksprs-1.3.dist-info → blksprs-1.4.1.dist-info}/METADATA +31 -30
- blksprs-1.4.1.dist-info/RECORD +19 -0
- blksprs/ops/row_wise_sum.py +0 -231
- blksprs-1.3.dist-info/RECORD +0 -18
- {blksprs-1.3.dist-info → blksprs-1.4.1.dist-info}/WHEEL +0 -0
- {blksprs-1.3.dist-info → blksprs-1.4.1.dist-info}/top_level.txt +0 -0
blksprs/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from blksprs.ops.conversion import to_dense, to_sparse
|
|
2
|
+
from blksprs.ops.distribution import gather, scatter, scatter_reduce
|
|
3
|
+
from blksprs.ops.exp import exp
|
|
4
|
+
from blksprs.ops.matmul import matmul
|
|
5
|
+
from blksprs.ops.softmax import softmax
|
|
6
|
+
from blksprs.ops.transpose import transpose
|
|
7
|
+
|
|
8
|
+
class layout:
|
|
9
|
+
from blksprs.layouting.distribution_layout import build_distribution_layout
|
|
10
|
+
from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption
|
|
11
|
+
|
|
12
|
+
class misc:
|
|
13
|
+
from blksprs.misc.broadcast_ops import broadcast_add, broadcast_sub
|
|
14
|
+
from blksprs.misc.repeat_interleave import repeat_interleave
|
|
15
|
+
from blksprs.misc.row_wise import row_wise_sum, row_wise_max, row_wise_add, row_wise_sub
|
|
16
|
+
|
|
17
|
+
class util:
|
|
18
|
+
from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse, disable_validation
|
|
@@ -31,7 +31,7 @@ def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
|
|
|
31
31
|
sparsity_lut_i = torch.nonzero(sparsity_layout_indices).contiguous()
|
|
32
32
|
|
|
33
33
|
output = torch.zeros(size_target[0], size_target[1] // sparsity_block_size, size_target[2] // sparsity_block_size,
|
|
34
|
-
|
|
34
|
+
dtype=torch.bool, device=indices.device)
|
|
35
35
|
|
|
36
36
|
i_b, i_r, i_c = indices.size()
|
|
37
37
|
i_b_s, i_r_s, i_c_s = indices.stride()
|
|
@@ -27,7 +27,7 @@ def build_sparsity_layout(x: Tensor, sparsity_block_size: int, triton_block_size
|
|
|
27
27
|
validate_device(x)
|
|
28
28
|
|
|
29
29
|
output = torch.zeros(x.size(0), x.size(1) // sparsity_block_size, x.size(2) // sparsity_block_size,
|
|
30
|
-
|
|
30
|
+
dtype=torch.bool, device=x.device)
|
|
31
31
|
|
|
32
32
|
x_b, x_r, x_c = x.size()
|
|
33
33
|
x_b_s, x_r_s, x_c_s = x.stride()
|
|
@@ -117,7 +117,7 @@ def build_sparsity_layout_adaption(x: Tensor, sparsity_layout_from: Tensor,
|
|
|
117
117
|
o_r = math.ceil(sparsity_layout_from.size(1) * sparsity_block_size_from // sparsity_block_size_to)
|
|
118
118
|
o_c = math.ceil(sparsity_layout_from.size(2) * sparsity_block_size_from // sparsity_block_size_to)
|
|
119
119
|
|
|
120
|
-
output = torch.zeros(o_b, o_r, o_c,
|
|
120
|
+
output = torch.zeros(o_b, o_r, o_c, dtype=torch.bool, device=x.device)
|
|
121
121
|
|
|
122
122
|
x_b, x_r, x_c = x.size()
|
|
123
123
|
x_b_s, x_r_s, x_c_s = x.stride()
|
|
@@ -8,8 +8,8 @@ from blksprs.utils.validation import validate_contiguous, validate_device, \
|
|
|
8
8
|
validate_sparsity_block_size, validate_triton_block_size
|
|
9
9
|
|
|
10
10
|
|
|
11
|
-
def
|
|
12
|
-
|
|
11
|
+
def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
12
|
+
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
13
13
|
"""Performs a broadcast and subsequent addition of two dense tensors x and y. Returns a block-sparse tensor in
|
|
14
14
|
compressed form.
|
|
15
15
|
|
|
@@ -25,6 +25,9 @@ def broadcast_addition(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
|
25
25
|
output tensor corresponds to x(i) + y(j).
|
|
26
26
|
|
|
27
27
|
"""
|
|
28
|
+
x = x.contiguous()
|
|
29
|
+
y = y.contiguous()
|
|
30
|
+
|
|
28
31
|
validate_device(x, y)
|
|
29
32
|
validate_contiguous(x, y)
|
|
30
33
|
if x.size(-1) != y.size(-1):
|
|
@@ -70,12 +73,12 @@ def broadcast_addition(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
|
70
73
|
return output
|
|
71
74
|
|
|
72
75
|
|
|
73
|
-
def
|
|
74
|
-
|
|
75
|
-
"""Wrapper for ``
|
|
76
|
+
def broadcast_sub(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
77
|
+
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
78
|
+
"""Wrapper for ``broadcast_add`` with negated y.
|
|
76
79
|
|
|
77
80
|
"""
|
|
78
|
-
return
|
|
81
|
+
return broadcast_add(x, torch.neg(y), sparsity_layout_output, sparsity_block_size, triton_block_size)
|
|
79
82
|
|
|
80
83
|
|
|
81
84
|
@triton.jit
|
blksprs/misc/row_wise.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
from triton import language as tl
|
|
5
|
+
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size
|
|
7
|
+
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, validate_sparsity, \
|
|
8
|
+
validate_sparsity_block_size, validate_triton_block_size
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def row_wise_sum(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
12
|
+
flag_slice_only: bool = False, triton_block_size: int = None) -> tuple[Tensor, Tensor]:
|
|
13
|
+
"""Computes the row-wise sum of a block-sparse tensor.
|
|
14
|
+
|
|
15
|
+
Returns a block-sparse tensor in compressed form with only one block per row, where the first entry contains the sum
|
|
16
|
+
of the corresponding row.
|
|
17
|
+
|
|
18
|
+
Note:
|
|
19
|
+
If ``flag_slice_only`` is set the output will be of shape ``[x.size(0), x.size(1), 1]``.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
x (Tensor): A block-sparse tensor in compressed form.
|
|
23
|
+
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
24
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
25
|
+
flag_slice_only (bool, optional): If set the output will be of shape ``[x.size(0), x.size(1), 1]``
|
|
26
|
+
(default ``False``).
|
|
27
|
+
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
tuple[Tensor, Tensor]: A tuple containing a block-sparse tensor in compressed form containing the row-wise sum
|
|
31
|
+
of the input and the sparsity layout of the output tensor.
|
|
32
|
+
|
|
33
|
+
"""
|
|
34
|
+
x = x.contiguous()
|
|
35
|
+
|
|
36
|
+
validate_dimensions(x)
|
|
37
|
+
validate_contiguous(x)
|
|
38
|
+
validate_device(x)
|
|
39
|
+
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
40
|
+
validate_sparsity_block_size(sparsity_block_size, x)
|
|
41
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
42
|
+
|
|
43
|
+
sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
|
|
44
|
+
|
|
45
|
+
sparsity_layout_output, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)
|
|
46
|
+
sparsity_layout_output_flat = sparsity_layout_output.reshape(-1)
|
|
47
|
+
sparsity_reverse_lut_output = ((torch.cumsum(sparsity_layout_output_flat, dim=-1) - 1) *
|
|
48
|
+
(sparsity_layout_output_flat == 1) -
|
|
49
|
+
(1 * (sparsity_layout_output_flat == 0)))
|
|
50
|
+
|
|
51
|
+
n_sparse_blocks_output = torch.sum(sparsity_layout_output.to(torch.int)).item()
|
|
52
|
+
|
|
53
|
+
validate_contiguous(sparsity_layout, sparsity_lut,
|
|
54
|
+
sparsity_layout_output, sparsity_reverse_lut_output)
|
|
55
|
+
|
|
56
|
+
output = torch.zeros(size=(n_sparse_blocks_output,
|
|
57
|
+
sparsity_block_size,
|
|
58
|
+
1 if flag_slice_only else sparsity_block_size),
|
|
59
|
+
device=x.device)
|
|
60
|
+
|
|
61
|
+
x_b, x_r, x_c = x.size()
|
|
62
|
+
x_b_s, x_r_s, x_c_s = x.stride()
|
|
63
|
+
s_lut_x_r, s_lut_x_c = sparsity_lut.size()
|
|
64
|
+
s_lut_x_r_s, s_lut_x_c_s = sparsity_lut.stride()
|
|
65
|
+
o_b, o_r, o_c = output.size()
|
|
66
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
67
|
+
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
|
|
68
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_output.stride()
|
|
69
|
+
|
|
70
|
+
if triton_block_size is None:
|
|
71
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
72
|
+
|
|
73
|
+
triton_grid = lambda meta: [x_b,
|
|
74
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
75
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
76
|
+
|
|
77
|
+
(kernel_blocksparse_row_wise_sum[triton_grid]
|
|
78
|
+
(x,
|
|
79
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
80
|
+
sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
81
|
+
output,
|
|
82
|
+
o_b, o_b_s, o_r_s,
|
|
83
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s,
|
|
84
|
+
sparsity_reverse_lut_output,
|
|
85
|
+
triton_block_size))
|
|
86
|
+
|
|
87
|
+
return (output, sparsity_layout_output)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@triton.jit
|
|
91
|
+
def kernel_blocksparse_row_wise_sum(x,
|
|
92
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
93
|
+
s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
94
|
+
o,
|
|
95
|
+
o_b, o_b_s, o_r_s,
|
|
96
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s,
|
|
97
|
+
r_lut_o,
|
|
98
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
99
|
+
pid_blk = tl.program_id(axis=0)
|
|
100
|
+
pid_row = tl.program_id(axis=1)
|
|
101
|
+
pid_col = tl.program_id(axis=2)
|
|
102
|
+
|
|
103
|
+
# Get position of current sparsity block consisting of its batch and row index
|
|
104
|
+
spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
|
|
105
|
+
spa_bat_msk = (spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
|
|
106
|
+
spa_bat = tl.load(s_lut_x + spa_bat_idx, mask=spa_bat_msk)
|
|
107
|
+
|
|
108
|
+
spa_row_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
|
|
109
|
+
spa_row_msk = (spa_row_idx < s_lut_x_r * s_lut_x_r_s)
|
|
110
|
+
spa_row = tl.load(s_lut_x + spa_row_idx, mask=spa_row_msk)
|
|
111
|
+
|
|
112
|
+
# Load reverse sparsity index for current block
|
|
113
|
+
rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
|
|
114
|
+
spa_row * s_l_o_r_s)
|
|
115
|
+
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
|
|
116
|
+
rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
117
|
+
|
|
118
|
+
blk_idx = ((pid_blk * x_b_s) +
|
|
119
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
120
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
121
|
+
blk_msk = (blk_idx < x_b * x_b_s)
|
|
122
|
+
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
123
|
+
|
|
124
|
+
buf = tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
|
|
125
|
+
|
|
126
|
+
o_idx = (rev_idx_spa * o_b_s +
|
|
127
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
128
|
+
(tl.arange(0, 1))[None, :])
|
|
129
|
+
o_msk = (o_idx < o_b * o_b_s)
|
|
130
|
+
tl.atomic_add(o + o_idx, buf, o_msk)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def row_wise_max(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
134
|
+
flag_slice_only: bool = False, triton_block_size: int = None) -> tuple[Tensor, Tensor]:
|
|
135
|
+
"""Computes the row-wise max of a block-sparse tensor.
|
|
136
|
+
|
|
137
|
+
Returns a block-sparse tensor in compressed form with only one block per row, where the first entry contains the
|
|
138
|
+
maximum of the corresponding row.
|
|
139
|
+
|
|
140
|
+
Note:
|
|
141
|
+
If ``flag_slice_only`` is set the output will be of shape ``[x.size(0), x.size(1), 1]``.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
x (Tensor): A block-sparse tensor in compressed form.
|
|
145
|
+
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
146
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
147
|
+
flag_slice_only (bool, optional): If set the output will be of shape ``[x.size(0), x.size(1), 1]``
|
|
148
|
+
(default ``False``).
|
|
149
|
+
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
tuple[Tensor, Tensor]: A tuple containing a block-sparse tensor in compressed form containing the row-wise max
|
|
153
|
+
of the input and the sparsity layout of the output tensor.
|
|
154
|
+
|
|
155
|
+
"""
|
|
156
|
+
x = x.contiguous()
|
|
157
|
+
|
|
158
|
+
validate_dimensions(x)
|
|
159
|
+
validate_contiguous(x)
|
|
160
|
+
validate_device(x)
|
|
161
|
+
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
162
|
+
validate_sparsity_block_size(sparsity_block_size, x)
|
|
163
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
164
|
+
|
|
165
|
+
sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
|
|
166
|
+
|
|
167
|
+
sparsity_layout_output, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)
|
|
168
|
+
sparsity_layout_output_flat = sparsity_layout_output.reshape(-1)
|
|
169
|
+
sparsity_reverse_lut_output = ((torch.cumsum(sparsity_layout_output_flat, dim=-1) - 1) *
|
|
170
|
+
(sparsity_layout_output_flat == 1) -
|
|
171
|
+
(1 * (sparsity_layout_output_flat == 0)))
|
|
172
|
+
|
|
173
|
+
n_sparse_blocks_output = torch.sum(sparsity_layout_output.to(torch.int)).item()
|
|
174
|
+
|
|
175
|
+
validate_contiguous(sparsity_layout, sparsity_lut,
|
|
176
|
+
sparsity_layout_output, sparsity_reverse_lut_output)
|
|
177
|
+
|
|
178
|
+
output = torch.full(size=(n_sparse_blocks_output,
|
|
179
|
+
sparsity_block_size,
|
|
180
|
+
1 if flag_slice_only else sparsity_block_size),
|
|
181
|
+
fill_value=float("-inf"),
|
|
182
|
+
device=x.device)
|
|
183
|
+
|
|
184
|
+
x_b, x_r, x_c = x.size()
|
|
185
|
+
x_b_s, x_r_s, x_c_s = x.stride()
|
|
186
|
+
s_lut_x_r, s_lut_x_c = sparsity_lut.size()
|
|
187
|
+
s_lut_x_r_s, s_lut_x_c_s = sparsity_lut.stride()
|
|
188
|
+
o_b, o_r, o_c = output.size()
|
|
189
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
190
|
+
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
|
|
191
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_output.stride()
|
|
192
|
+
|
|
193
|
+
if triton_block_size is None:
|
|
194
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
195
|
+
|
|
196
|
+
triton_grid = lambda meta: [x_b,
|
|
197
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
198
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
199
|
+
|
|
200
|
+
(kernel_blocksparse_row_wise_max[triton_grid]
|
|
201
|
+
(x,
|
|
202
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
203
|
+
sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
204
|
+
output,
|
|
205
|
+
o_b, o_b_s, o_r_s,
|
|
206
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s,
|
|
207
|
+
sparsity_reverse_lut_output,
|
|
208
|
+
triton_block_size))
|
|
209
|
+
|
|
210
|
+
return output, sparsity_layout_output
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
@triton.jit
|
|
214
|
+
def kernel_blocksparse_row_wise_max(x,
|
|
215
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
216
|
+
s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
217
|
+
o,
|
|
218
|
+
o_b, o_b_s, o_r_s,
|
|
219
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s,
|
|
220
|
+
r_lut_o,
|
|
221
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
222
|
+
pid_blk = tl.program_id(axis=0)
|
|
223
|
+
pid_row = tl.program_id(axis=1)
|
|
224
|
+
pid_col = tl.program_id(axis=2)
|
|
225
|
+
|
|
226
|
+
# Get position of current sparsity block consisting of its batch and row index
|
|
227
|
+
spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
|
|
228
|
+
spa_bat_msk = (spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
|
|
229
|
+
spa_bat = tl.load(s_lut_x + spa_bat_idx, mask=spa_bat_msk)
|
|
230
|
+
|
|
231
|
+
spa_row_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
|
|
232
|
+
spa_row_msk = (spa_row_idx < s_lut_x_r * s_lut_x_r_s)
|
|
233
|
+
spa_row = tl.load(s_lut_x + spa_row_idx, mask=spa_row_msk)
|
|
234
|
+
|
|
235
|
+
# Load reverse sparsity index for current block
|
|
236
|
+
rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
|
|
237
|
+
spa_row * s_l_o_r_s)
|
|
238
|
+
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
|
|
239
|
+
rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
240
|
+
|
|
241
|
+
blk_idx = ((pid_blk * x_b_s) +
|
|
242
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
243
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
244
|
+
blk_msk = (blk_idx < x_b * x_b_s)
|
|
245
|
+
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
246
|
+
|
|
247
|
+
buf = tl.reshape(tl.max(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
|
|
248
|
+
|
|
249
|
+
o_idx = (rev_idx_spa * o_b_s +
|
|
250
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
251
|
+
(tl.arange(0, 1))[None, :])
|
|
252
|
+
o_msk = (o_idx < o_b * o_b_s)
|
|
253
|
+
tl.atomic_max(o + o_idx, buf, o_msk)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def row_wise_add(x: Tensor, sparsity_layout_x: Tensor, y: Tensor,
|
|
257
|
+
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
258
|
+
"""For each row in ``y`` adds the value to each value in the corresponding row of the block-sparse tensor ``x``.
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
x (Tensor): A block-sparse tensor in compressed form.
|
|
262
|
+
sparsity_layout_x (Tensor): The sparsity layout of the block-sparse tensor.
|
|
263
|
+
y (Tensor): A block-sparse tensor in compressed form with only one value per row and a single column of sparse blocks.
|
|
264
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
265
|
+
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
Tensor: The values of ``x`` with the first value of ``y`` in each row added to them as a block-sparse tensor in
|
|
269
|
+
compressed form.
|
|
270
|
+
|
|
271
|
+
"""
|
|
272
|
+
validate_dimensions(x)
|
|
273
|
+
validate_contiguous(x)
|
|
274
|
+
validate_device(x)
|
|
275
|
+
validate_sparsity(sparsity_block_size, (x, sparsity_layout_x))
|
|
276
|
+
validate_sparsity_block_size(sparsity_block_size, x)
|
|
277
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
278
|
+
|
|
279
|
+
sparsity_lut = torch.nonzero(sparsity_layout_x).contiguous()
|
|
280
|
+
|
|
281
|
+
sparsity_layout_rwm, _ = torch.max(sparsity_layout_x, dim=-1, keepdim=True)
|
|
282
|
+
sparsity_layout_rwm_flat = sparsity_layout_rwm.reshape(-1)
|
|
283
|
+
sparsity_reverse_lut_rwm = ((torch.cumsum(sparsity_layout_rwm_flat, dim=-1) - 1) *
|
|
284
|
+
(sparsity_layout_rwm_flat == 1) -
|
|
285
|
+
(1 * (sparsity_layout_rwm_flat == 0)))
|
|
286
|
+
|
|
287
|
+
validate_contiguous(sparsity_layout_x, sparsity_lut, sparsity_reverse_lut_rwm)
|
|
288
|
+
|
|
289
|
+
output = torch.empty_like(x)
|
|
290
|
+
|
|
291
|
+
x_b, x_r, x_c = x.size()
|
|
292
|
+
x_b_s, x_r_s, x_c_s = x.stride()
|
|
293
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
294
|
+
s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
|
|
295
|
+
y_b, y_r, y_c = y.size()
|
|
296
|
+
y_b_s, y_r_s, y_c_s = y.stride()
|
|
297
|
+
s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_rwm.size()
|
|
298
|
+
s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = sparsity_layout_rwm.stride()
|
|
299
|
+
o_b, o_r, o_c = output.size()
|
|
300
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
301
|
+
|
|
302
|
+
if triton_block_size is None:
|
|
303
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
304
|
+
|
|
305
|
+
triton_grid = lambda meta: [o_b,
|
|
306
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
307
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
308
|
+
|
|
309
|
+
(kernel_blocksparse_row_wise_add[triton_grid]
|
|
310
|
+
(x,
|
|
311
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
312
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
313
|
+
y, y_b, y_b_s, y_r_s, y_c_s,
|
|
314
|
+
s_l_y_b, s_l_y_b_s, s_l_y_r_s,
|
|
315
|
+
sparsity_reverse_lut_rwm,
|
|
316
|
+
output,
|
|
317
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
318
|
+
triton_block_size
|
|
319
|
+
))
|
|
320
|
+
|
|
321
|
+
return output
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def row_wise_sub(x: Tensor, sparsity_layout_x: Tensor, y: Tensor,
|
|
325
|
+
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
326
|
+
"""Wrapper for ``row_wise_add`` with negated y.
|
|
327
|
+
|
|
328
|
+
"""
|
|
329
|
+
return row_wise_add(x, sparsity_layout_x, torch.neg(y), sparsity_block_size, triton_block_size)
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
@triton.jit
|
|
333
|
+
def kernel_blocksparse_row_wise_add(x,
|
|
334
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
335
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
336
|
+
y, y_b, y_b_s, y_r_s, y_c_s,
|
|
337
|
+
s_l_y_b, s_l_y_b_s, s_l_y_r_s,
|
|
338
|
+
r_lut_y,
|
|
339
|
+
o,
|
|
340
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
341
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
342
|
+
# Get triton block indices
|
|
343
|
+
pid_blk = tl.program_id(axis=0)
|
|
344
|
+
pid_row = tl.program_id(axis=1)
|
|
345
|
+
pid_col = tl.program_id(axis=2)
|
|
346
|
+
|
|
347
|
+
# Get position of current sparsity block consisting of its batch and row index
|
|
348
|
+
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
349
|
+
spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
350
|
+
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
351
|
+
|
|
352
|
+
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
353
|
+
spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
|
|
354
|
+
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
355
|
+
|
|
356
|
+
# Get reverse sparsity indices for s
|
|
357
|
+
rev_idx_spa_s_idx = (spa_bat * s_l_y_b_s +
|
|
358
|
+
spa_row * s_l_y_r_s)
|
|
359
|
+
rev_idx_spa_s_msk = (rev_idx_spa_s_idx < s_l_y_b * s_l_y_b_s)
|
|
360
|
+
rev_idx_spa_s = tl.load(r_lut_y + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
|
|
361
|
+
|
|
362
|
+
if rev_idx_spa_s == -1:
|
|
363
|
+
assert False, "Invalid sparsity block"
|
|
364
|
+
|
|
365
|
+
# Load x block
|
|
366
|
+
blk_x_idx = ((pid_blk * x_b_s) +
|
|
367
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
368
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
369
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
370
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
371
|
+
|
|
372
|
+
# Load sum block
|
|
373
|
+
blk_s_idx = (rev_idx_spa_s * y_b_s +
|
|
374
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
|
|
375
|
+
(tl.arange(0, 1) * y_c_s)[None, :])
|
|
376
|
+
blk_s_msk = (blk_s_idx < y_b * y_b_s)
|
|
377
|
+
blk_s = tl.load(y + blk_s_idx, mask=blk_s_msk)
|
|
378
|
+
|
|
379
|
+
# Compute exp
|
|
380
|
+
buf = blk_x + tl.broadcast_to(blk_s, (TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE))
|
|
381
|
+
|
|
382
|
+
# debug
|
|
383
|
+
asdf = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), 1.0, dtype=tl.float32)
|
|
384
|
+
|
|
385
|
+
# Store block
|
|
386
|
+
blk_o_idx = ((pid_blk * o_b_s) +
|
|
387
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
388
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
389
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
390
|
+
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
blksprs/ops/conversion.py
CHANGED
|
@@ -28,6 +28,8 @@ def to_dense(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_
|
|
|
28
28
|
Tensor: The block-sparse tensor converted to regular form.
|
|
29
29
|
|
|
30
30
|
"""
|
|
31
|
+
x = x.contiguous()
|
|
32
|
+
|
|
31
33
|
validate_dimensions(x)
|
|
32
34
|
validate_contiguous(x, sparsity_layout)
|
|
33
35
|
validate_device(x)
|
|
@@ -156,6 +158,8 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, trit
|
|
|
156
158
|
Tensor: The block-sparse tensor converted to compressed form.
|
|
157
159
|
|
|
158
160
|
"""
|
|
161
|
+
x = x.contiguous()
|
|
162
|
+
|
|
159
163
|
validate_dimensions(x)
|
|
160
164
|
validate_contiguous(x)
|
|
161
165
|
validate_device(x)
|
|
@@ -282,6 +286,8 @@ def adapt_layout(x: Tensor, sparsity_layout_from: Tensor, sparsity_block_size_fr
|
|
|
282
286
|
Tensor: The block-sparse tensor in compressed form with the adapted sparsity layout and sparsity block size.
|
|
283
287
|
|
|
284
288
|
"""
|
|
289
|
+
x = x.contiguous()
|
|
290
|
+
|
|
285
291
|
validate_dimensions(x)
|
|
286
292
|
validate_contiguous(x, sparsity_layout_from)
|
|
287
293
|
validate_device(x)
|
blksprs/ops/distribution.py
CHANGED
|
@@ -24,6 +24,9 @@ def gather(src: Tensor, sparsity_layout_src: Tensor, idx: Tensor, sparsity_layou
|
|
|
24
24
|
Tensor: The result of the gather operation as a block-sparse tensor in compressed form.
|
|
25
25
|
|
|
26
26
|
"""
|
|
27
|
+
src = src.contiguous()
|
|
28
|
+
idx = idx.contiguous()
|
|
29
|
+
|
|
27
30
|
validate_dimensions(src, idx)
|
|
28
31
|
validate_contiguous(src, idx)
|
|
29
32
|
validate_dtype_int(idx)
|
|
@@ -200,6 +203,9 @@ def scatter_reduce(src: Tensor, sparsity_layout_src: Tensor,
|
|
|
200
203
|
Tensor: The result of the scatter operation as a block-sparse tensor in compressed form.
|
|
201
204
|
|
|
202
205
|
"""
|
|
206
|
+
src = src.contiguous()
|
|
207
|
+
idx = idx.contiguous()
|
|
208
|
+
|
|
203
209
|
validate_dimensions(src, idx)
|
|
204
210
|
validate_contiguous(src, idx)
|
|
205
211
|
validate_dtype_int(idx)
|
blksprs/ops/exp.py
CHANGED
blksprs/ops/matmul.py
CHANGED
|
@@ -6,7 +6,7 @@ from triton import language as tl
|
|
|
6
6
|
from blksprs.ops.transpose import transpose
|
|
7
7
|
from blksprs.utils.tools import get_triton_block_size
|
|
8
8
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
9
|
-
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
9
|
+
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size, validate_dtype_float
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
def matmul(x: Tensor, sparsity_layout_x: Tensor,
|
|
@@ -30,8 +30,12 @@ def matmul(x: Tensor, sparsity_layout_x: Tensor,
|
|
|
30
30
|
Tensor: The result of the matrix multiplication as a block-sparse tensor in compressed form.
|
|
31
31
|
|
|
32
32
|
"""
|
|
33
|
+
x = x.contiguous()
|
|
34
|
+
y = y.contiguous()
|
|
35
|
+
|
|
33
36
|
validate_dimensions(x, y)
|
|
34
37
|
validate_contiguous(x, y)
|
|
38
|
+
validate_dtype_float(x, y)
|
|
35
39
|
validate_device(x, y)
|
|
36
40
|
validate_sparsity(sparsity_block_size, (x, sparsity_layout_x), (y, sparsity_layout_y))
|
|
37
41
|
if sparsity_layout_x.size(-1) != sparsity_layout_y.size(-2):
|
|
@@ -211,7 +215,7 @@ class _BlocksparseMatmulSSS(torch.autograd.Function):
|
|
|
211
215
|
blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
|
|
212
216
|
|
|
213
217
|
# Perform matrix multiplication
|
|
214
|
-
buf += tl.dot(blk_x, blk_y)
|
|
218
|
+
buf += tl.dot(blk_x, blk_y, input_precision="tf32")
|
|
215
219
|
|
|
216
220
|
# Store output
|
|
217
221
|
blk_o_idx = ((pid_blk * o_b_s) +
|
blksprs/ops/softmax.py
CHANGED
|
@@ -4,7 +4,7 @@ from torch import Tensor
|
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
6
|
from blksprs.ops.exp import exp
|
|
7
|
-
from blksprs.
|
|
7
|
+
from blksprs.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
|
|
8
8
|
from blksprs.utils.tools import get_triton_block_size
|
|
9
9
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
10
10
|
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
@@ -26,6 +26,8 @@ def softmax(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton
|
|
|
26
26
|
Tensor: The result of the softmax operation as a block-sparse tensor in compressed form.
|
|
27
27
|
|
|
28
28
|
"""
|
|
29
|
+
x = x.contiguous()
|
|
30
|
+
|
|
29
31
|
validate_dimensions(x)
|
|
30
32
|
validate_contiguous(x)
|
|
31
33
|
validate_device(x)
|
|
@@ -33,12 +35,6 @@ def softmax(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton
|
|
|
33
35
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
34
36
|
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
35
37
|
|
|
36
|
-
if x.size(0) != 0:
|
|
37
|
-
max_val = torch.max(x).item()
|
|
38
|
-
else:
|
|
39
|
-
max_val = 0
|
|
40
|
-
x_scaled = x - max_val
|
|
41
|
-
|
|
42
38
|
sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
|
|
43
39
|
|
|
44
40
|
sparsity_layout_rws, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)
|
|
@@ -49,7 +45,7 @@ def softmax(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton
|
|
|
49
45
|
|
|
50
46
|
validate_contiguous(sparsity_layout, sparsity_lut, sparsity_reverse_lut_rws)
|
|
51
47
|
|
|
52
|
-
return _BlocksparseSoftmax.apply(
|
|
48
|
+
return _BlocksparseSoftmax.apply(x, sparsity_layout,
|
|
53
49
|
sparsity_lut,
|
|
54
50
|
sparsity_reverse_lut_rws,
|
|
55
51
|
sparsity_block_size, triton_block_size)
|
|
@@ -64,13 +60,17 @@ class _BlocksparseSoftmax(torch.autograd.Function):
|
|
|
64
60
|
sparsity_block_size: int, triton_block_size: int) -> Tensor:
|
|
65
61
|
output = torch.empty_like(x)
|
|
66
62
|
|
|
67
|
-
x_b, x_r, x_c = x.
|
|
63
|
+
x_b, x_r, x_c = x.size()
|
|
68
64
|
x_b_s, x_r_s, x_c_s = x.stride()
|
|
69
|
-
s_lut_r, s_lut_c = sparsity_lut.
|
|
65
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
70
66
|
s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
|
|
71
|
-
o_b, o_r, o_c = output.
|
|
67
|
+
o_b, o_r, o_c = output.size()
|
|
72
68
|
|
|
73
|
-
|
|
69
|
+
x_row_wise_max, sparsity_layout_rwm = row_wise_max(x, sparsity_layout, sparsity_block_size,
|
|
70
|
+
flag_slice_only=True,
|
|
71
|
+
triton_block_size=triton_block_size)
|
|
72
|
+
x_scaled = row_wise_sub(x, sparsity_layout, x_row_wise_max, sparsity_block_size, triton_block_size)
|
|
73
|
+
x_exp = exp(x_scaled, sparsity_block_size, triton_block_size=triton_block_size)
|
|
74
74
|
x_exp_row_wise_sum, sparsity_layout_rws = row_wise_sum(x_exp, sparsity_layout, sparsity_block_size,
|
|
75
75
|
flag_slice_only=True,
|
|
76
76
|
triton_block_size=triton_block_size)
|
|
@@ -174,7 +174,7 @@ class _BlocksparseSoftmax(torch.autograd.Function):
|
|
|
174
174
|
spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
|
|
175
175
|
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
176
176
|
|
|
177
|
-
# Get reverse sparsity indices for
|
|
177
|
+
# Get reverse sparsity indices for s
|
|
178
178
|
rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
|
|
179
179
|
spa_row * s_l_s_r_s)
|
|
180
180
|
rev_idx_spa_s_msk = (rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
|
blksprs/ops/transpose.py
CHANGED
blksprs/utils/tools.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
from torch import Tensor, Size
|
|
3
3
|
|
|
4
|
+
from blksprs.utils.validation import _set_skip_validation
|
|
5
|
+
|
|
4
6
|
|
|
5
7
|
def do_shape_blocksparse(x: Tensor):
|
|
6
8
|
if x.dim() == 3:
|
|
7
|
-
return x, x.size()
|
|
9
|
+
return x.contiguous(), x.size()
|
|
8
10
|
|
|
9
11
|
return x.reshape(-1, x.size(-2), x.size(-1)), x.size()
|
|
10
12
|
|
|
@@ -18,3 +20,7 @@ def undo_shape_blocksparse(x: Tensor, shape: Size):
|
|
|
18
20
|
|
|
19
21
|
def get_triton_block_size(sparsity_block_size: int, limit: int = 128):
|
|
20
22
|
return min(sparsity_block_size, limit)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def disable_validation():
|
|
26
|
+
_set_skip_validation(True)
|
blksprs/utils/validation.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
from torch import Tensor
|
|
3
3
|
|
|
4
|
+
VALIDATION = True
|
|
4
5
|
|
|
5
6
|
def validate_dimensions(*tensors: Tensor) -> None:
|
|
6
|
-
if
|
|
7
|
+
if _check_skip_validation():
|
|
7
8
|
return
|
|
8
9
|
|
|
9
10
|
for tensor in tensors:
|
|
@@ -12,7 +13,7 @@ def validate_dimensions(*tensors: Tensor) -> None:
|
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
def validate_contiguous(*tensors: Tensor) -> None:
|
|
15
|
-
if
|
|
16
|
+
if _check_skip_validation():
|
|
16
17
|
return
|
|
17
18
|
|
|
18
19
|
for tensor in tensors:
|
|
@@ -21,7 +22,7 @@ def validate_contiguous(*tensors: Tensor) -> None:
|
|
|
21
22
|
|
|
22
23
|
|
|
23
24
|
def validate_dtype_float(*tensors: Tensor) -> None:
|
|
24
|
-
if
|
|
25
|
+
if _check_skip_validation():
|
|
25
26
|
return
|
|
26
27
|
|
|
27
28
|
for tensor in tensors:
|
|
@@ -30,7 +31,7 @@ def validate_dtype_float(*tensors: Tensor) -> None:
|
|
|
30
31
|
|
|
31
32
|
|
|
32
33
|
def validate_dtype_int(*tensors: Tensor) -> None:
|
|
33
|
-
if
|
|
34
|
+
if _check_skip_validation():
|
|
34
35
|
return
|
|
35
36
|
|
|
36
37
|
for tensor in tensors:
|
|
@@ -39,7 +40,7 @@ def validate_dtype_int(*tensors: Tensor) -> None:
|
|
|
39
40
|
|
|
40
41
|
|
|
41
42
|
def validate_device(*tensors: Tensor) -> None:
|
|
42
|
-
if
|
|
43
|
+
if _check_skip_validation():
|
|
43
44
|
return
|
|
44
45
|
|
|
45
46
|
device = None
|
|
@@ -56,7 +57,7 @@ def validate_device(*tensors: Tensor) -> None:
|
|
|
56
57
|
|
|
57
58
|
|
|
58
59
|
def validate_sparsity(sparsity_block_size: int, *tensor_sparsity_layout_tuples: tuple[Tensor, Tensor]) -> None:
|
|
59
|
-
if
|
|
60
|
+
if _check_skip_validation():
|
|
60
61
|
return
|
|
61
62
|
|
|
62
63
|
for (tensor, sparsity_layout) in tensor_sparsity_layout_tuples:
|
|
@@ -73,7 +74,7 @@ def _validate_sparsity_layout_values(sparsity_layout: Tensor):
|
|
|
73
74
|
raise ValueError("Sparsity layout values must be either 0 or 1")
|
|
74
75
|
|
|
75
76
|
def validate_sparsity_block_size(sparsity_block_size: int, *tensors):
|
|
76
|
-
if
|
|
77
|
+
if _check_skip_validation():
|
|
77
78
|
return
|
|
78
79
|
|
|
79
80
|
if not (sparsity_block_size & (sparsity_block_size - 1)) == 0:
|
|
@@ -84,7 +85,7 @@ def validate_sparsity_block_size(sparsity_block_size: int, *tensors):
|
|
|
84
85
|
raise ValueError("Tensor sizes must be divisible by sparsity block size")
|
|
85
86
|
|
|
86
87
|
def validate_triton_block_size(triton_block_size: int, sparsity_block_size: int):
|
|
87
|
-
if
|
|
88
|
+
if _check_skip_validation():
|
|
88
89
|
return
|
|
89
90
|
|
|
90
91
|
if triton_block_size is None:
|
|
@@ -93,5 +94,9 @@ def validate_triton_block_size(triton_block_size: int, sparsity_block_size: int)
|
|
|
93
94
|
if triton_block_size > sparsity_block_size:
|
|
94
95
|
raise ValueError("Triton block size cannot be larger than sparsity block size")
|
|
95
96
|
|
|
96
|
-
def
|
|
97
|
-
return
|
|
97
|
+
def _check_skip_validation():
|
|
98
|
+
return not VALIDATION
|
|
99
|
+
|
|
100
|
+
def _set_skip_validation(skip_validation: bool):
|
|
101
|
+
global VALIDATION
|
|
102
|
+
VALIDATION = not skip_validation
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.4.1
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -8,10 +8,8 @@ Project-URL: Bugtracker, https://github.com/FelixSchoen/blksprs/issues
|
|
|
8
8
|
Requires-Python: >=3.11
|
|
9
9
|
Description-Content-Type: text/markdown
|
|
10
10
|
Requires-Dist: torch
|
|
11
|
-
Provides-Extra:
|
|
12
|
-
Requires-Dist: build; extra == "
|
|
13
|
-
Requires-Dist: twine; extra == "deploy"
|
|
14
|
-
Requires-Dist: pdoc3; extra == "deploy"
|
|
11
|
+
Provides-Extra: build
|
|
12
|
+
Requires-Dist: build; extra == "build"
|
|
15
13
|
Provides-Extra: test
|
|
16
14
|
Requires-Dist: pytest; extra == "test"
|
|
17
15
|
Requires-Dist: pytest-xdist; extra == "test"
|
|
@@ -83,14 +81,7 @@ the [test cases](https://github.com/FelixSchoen/blksprs/blob/main/test/cases/tes
|
|
|
83
81
|
|
|
84
82
|
```python
|
|
85
83
|
import torch
|
|
86
|
-
|
|
87
|
-
from blksprs.layouting.sparsity_layout import build_sparsity_layout
|
|
88
|
-
from blksprs.ops.conversion import to_sparse, to_dense
|
|
89
|
-
from blksprs.ops.matmul import matmul
|
|
90
|
-
from blksprs.ops.row_wise_sum import row_wise_sum
|
|
91
|
-
from blksprs.ops.softmax import softmax
|
|
92
|
-
from blksprs.ops.transpose import transpose
|
|
93
|
-
from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse
|
|
84
|
+
import blksprs as bs
|
|
94
85
|
|
|
95
86
|
|
|
96
87
|
def test_readme():
|
|
@@ -112,47 +103,57 @@ def test_readme():
|
|
|
112
103
|
y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()
|
|
113
104
|
|
|
114
105
|
# Convert tensors to three-dimensional (dense) tensors since Triton can only handle tensors of exactly three dimensions
|
|
115
|
-
x_dense, x_shape_original = do_shape_blocksparse(x)
|
|
116
|
-
y_dense, y_shape_original = do_shape_blocksparse(y)
|
|
106
|
+
x_dense, x_shape_original = bs.util.do_shape_blocksparse(x)
|
|
107
|
+
y_dense, y_shape_original = bs.util.do_shape_blocksparse(y)
|
|
117
108
|
|
|
118
109
|
# Create sparsity layouts from existing tensors
|
|
119
|
-
sparsity_layout_x = build_sparsity_layout(x_dense, sparsity_block_size,
|
|
120
|
-
|
|
110
|
+
sparsity_layout_x = bs.layout.build_sparsity_layout(x_dense, sparsity_block_size,
|
|
111
|
+
triton_block_size=triton_block_size)
|
|
112
|
+
sparsity_layout_y = bs.layout.build_sparsity_layout(y_dense, sparsity_block_size,
|
|
113
|
+
triton_block_size=triton_block_size)
|
|
121
114
|
|
|
122
115
|
# Create random sparsity layout for output tensor
|
|
123
116
|
sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)
|
|
124
117
|
|
|
125
118
|
# Convert tensors to sparse tensors for matrix multiplication
|
|
126
|
-
x_sparse = to_sparse(x_dense, sparsity_layout_x, sparsity_block_size, triton_block_size=triton_block_size)
|
|
127
|
-
y_sparse = to_sparse(y_dense, sparsity_layout_y, sparsity_block_size, triton_block_size=triton_block_size)
|
|
119
|
+
x_sparse = bs.to_sparse(x_dense, sparsity_layout_x, sparsity_block_size, triton_block_size=triton_block_size)
|
|
120
|
+
y_sparse = bs.to_sparse(y_dense, sparsity_layout_y, sparsity_block_size, triton_block_size=triton_block_size)
|
|
128
121
|
|
|
129
122
|
# Perform matrix multiplication
|
|
130
|
-
o_sparse = matmul(x_sparse, sparsity_layout_x, y_sparse, sparsity_layout_y, sparsity_layout_o,
|
|
131
|
-
|
|
132
|
-
|
|
123
|
+
o_sparse = bs.matmul(x_sparse, sparsity_layout_x, y_sparse, sparsity_layout_y, sparsity_layout_o,
|
|
124
|
+
sparsity_block_size,
|
|
125
|
+
triton_block_size=triton_block_size)
|
|
126
|
+
|
|
127
|
+
# Apply element-wise operation
|
|
128
|
+
o_sparse = torch.add(o_sparse, 1)
|
|
129
|
+
|
|
130
|
+
o_dense = bs.to_dense(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
|
|
133
131
|
|
|
134
132
|
# Sanity check
|
|
135
133
|
o_torch = torch.matmul(x_dense, y_dense)
|
|
134
|
+
o_torch = torch.add(o_torch, 1)
|
|
136
135
|
|
|
137
136
|
# Perform round trip to set sparse blocks to 0
|
|
138
|
-
o_torch_round_trip = to_dense(
|
|
139
|
-
to_sparse(o_torch, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size),
|
|
137
|
+
o_torch_round_trip = bs.to_dense(
|
|
138
|
+
bs.to_sparse(o_torch, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size),
|
|
140
139
|
sparsity_layout_o, sparsity_block_size, fill_value=0, triton_block_size=triton_block_size)
|
|
141
140
|
|
|
142
141
|
# Assert that the output is correct
|
|
143
142
|
assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2) # Note that small numerical differences are expected
|
|
144
143
|
|
|
145
144
|
# Assert that the output has the correct sparsity layout
|
|
146
|
-
actual_sparsity_layout_o = build_sparsity_layout(o_dense, sparsity_block_size,
|
|
147
|
-
|
|
145
|
+
actual_sparsity_layout_o = bs.layout.build_sparsity_layout(o_dense, sparsity_block_size,
|
|
146
|
+
triton_block_size=triton_block_size)
|
|
147
|
+
assert torch.allclose(actual_sparsity_layout_o.to(torch.int), sparsity_layout_o)
|
|
148
148
|
|
|
149
149
|
# Convert output tensor back to original shape
|
|
150
|
-
o = undo_shape_blocksparse(o_dense, x_shape_original)
|
|
150
|
+
o = bs.util.undo_shape_blocksparse(o_dense, x_shape_original)
|
|
151
151
|
|
|
152
152
|
# Other available functions
|
|
153
|
-
transpose(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
|
|
154
|
-
softmax(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
|
|
155
|
-
row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
|
|
153
|
+
bs.transpose(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
|
|
154
|
+
bs.softmax(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
|
|
155
|
+
bs.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
|
|
156
|
+
bs.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
|
|
156
157
|
|
|
157
158
|
|
|
158
159
|
def _get_random_sparsity_layout(b, m, n, sparsity_block_size, sparsity_percentage):
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
blksprs/__init__.py,sha256=ORAVhGR91G1wyIOs9Wecv-xfmjju3bJ4Jynq_SGOVY4,833
|
|
2
|
+
blksprs/layouting/distribution_layout.py,sha256=Xd8KjZwI87L9EL1Bw5SGUW9YztFD5q0Ygr99sffvdak,4939
|
|
3
|
+
blksprs/layouting/sparsity_layout.py,sha256=vZL8r5LkMwILYYqTYPZcN_NYFJuVFIB6mmBkdtRyXmI,7893
|
|
4
|
+
blksprs/misc/broadcast_ops.py,sha256=RTcqvx6X_THRBb55jipeEe63YSLIAh27jdpuze0aSek,5308
|
|
5
|
+
blksprs/misc/repeat_interleave.py,sha256=KJeapmxbpA7zGFfa5hUhCGrk4aFmhOhlMw-hbTh9PLI,5668
|
|
6
|
+
blksprs/misc/row_wise.py,sha256=KCDO5ry5TkjI88LLD_QINZwBkzfmjoQpOOvYLfpUn5I,16853
|
|
7
|
+
blksprs/ops/conversion.py,sha256=h1c5T74rQjqYgY9dwWXfPTXRpgzy0dtAhCmtUp8-6uo,21332
|
|
8
|
+
blksprs/ops/distribution.py,sha256=KhtHRVcv4_woyNlldAjIWF-7021-KX-xyIcN6rE-UgE,16879
|
|
9
|
+
blksprs/ops/exp.py,sha256=CVWVq_emO2CnS_xk6Unx67P7EI7IL26dwtsmBJZOLzQ,3698
|
|
10
|
+
blksprs/ops/matmul.py,sha256=6DaYxecJgwiW8L-UISkgyNyzQ31AAkmDL-Oq1EjHt98,11210
|
|
11
|
+
blksprs/ops/softmax.py,sha256=cSTxDnNmMRlJGOlCSpdg1U5KUIFpVtHulz8fteJFeh0,11972
|
|
12
|
+
blksprs/ops/transpose.py,sha256=et8R124L29TUqihci18ms_hBoYXTtPu5LXgEA8sxk_w,6744
|
|
13
|
+
blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
|
|
14
|
+
blksprs/utils/tools.py,sha256=RKGWCGd5h1qFOIoShsdJObx4-QsS0RxCyzFie0geNxo,596
|
|
15
|
+
blksprs/utils/validation.py,sha256=Gsx3aah6355bWXRPpbFuZ1p0fOrYduIqaM3ON9d5NiI,3197
|
|
16
|
+
blksprs-1.4.1.dist-info/METADATA,sha256=3xRmBFHv2U2KnrW3_QX3003SHLkQ1JCaSqh4AUBsJD4,7609
|
|
17
|
+
blksprs-1.4.1.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
|
18
|
+
blksprs-1.4.1.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
|
|
19
|
+
blksprs-1.4.1.dist-info/RECORD,,
|
blksprs/ops/row_wise_sum.py
DELETED
|
@@ -1,231 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
import triton
|
|
3
|
-
from torch import Tensor
|
|
4
|
-
from triton import language as tl
|
|
5
|
-
|
|
6
|
-
from blksprs.utils.tools import get_triton_block_size
|
|
7
|
-
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
8
|
-
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def row_wise_sum(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
12
|
-
flag_slice_only: bool = False, triton_block_size: int = None) -> tuple[Tensor, Tensor]:
|
|
13
|
-
"""Computes the row-wise sum of a block-sparse tensor.
|
|
14
|
-
|
|
15
|
-
Returns a block-sparse tensor in compressed form with only one block per row, where the first entry contains the sum
|
|
16
|
-
of the corresponding row.
|
|
17
|
-
|
|
18
|
-
Note:
|
|
19
|
-
If ``flag_slice_only`` is set the output will be of shape ``[x.size(0), x.size(1), 1]``.
|
|
20
|
-
|
|
21
|
-
Args:
|
|
22
|
-
x (Tensor): A block-sparse tensor in compressed form.
|
|
23
|
-
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
24
|
-
sparsity_block_size (int): The size of the sparsity blocks.
|
|
25
|
-
flag_slice_only (bool, optional): If set the output will be of shape ``[x.size(0), x.size(1), 1]``
|
|
26
|
-
(default ``False``).
|
|
27
|
-
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
28
|
-
|
|
29
|
-
Returns:
|
|
30
|
-
tuple[Tensor, Tensor]: A tuple containing a block-sparse tensor in compressed form containing the row-wise sum
|
|
31
|
-
of the input and the sparsity layout of the output tensor.
|
|
32
|
-
|
|
33
|
-
"""
|
|
34
|
-
validate_dimensions(x)
|
|
35
|
-
validate_contiguous(x)
|
|
36
|
-
validate_device(x)
|
|
37
|
-
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
38
|
-
validate_sparsity_block_size(sparsity_block_size, x)
|
|
39
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
40
|
-
|
|
41
|
-
sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
|
|
42
|
-
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
|
43
|
-
sparsity_reverse_lut = ((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
44
|
-
(sparsity_layout_flat == 1) -
|
|
45
|
-
(1 * (sparsity_layout_flat == 0)))
|
|
46
|
-
|
|
47
|
-
sparsity_layout_output, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)
|
|
48
|
-
sparsity_lut_output = torch.nonzero(sparsity_layout_output).contiguous()
|
|
49
|
-
sparsity_layout_output_flat = sparsity_layout_output.reshape(-1)
|
|
50
|
-
sparsity_reverse_lut_output = ((torch.cumsum(sparsity_layout_output_flat, dim=-1) - 1) *
|
|
51
|
-
(sparsity_layout_output_flat == 1) -
|
|
52
|
-
(1 * (sparsity_layout_output_flat == 0)))
|
|
53
|
-
|
|
54
|
-
n_sparse_blocks_output = torch.sum(sparsity_layout_output.to(torch.int)).item()
|
|
55
|
-
|
|
56
|
-
validate_contiguous(sparsity_layout, sparsity_lut, sparsity_reverse_lut,
|
|
57
|
-
sparsity_layout_output, sparsity_lut_output, sparsity_reverse_lut_output)
|
|
58
|
-
|
|
59
|
-
return (_BlocksparseRowWiseSum.apply(x,
|
|
60
|
-
sparsity_layout, sparsity_lut, sparsity_reverse_lut,
|
|
61
|
-
sparsity_layout_output, sparsity_lut_output, sparsity_reverse_lut_output,
|
|
62
|
-
n_sparse_blocks_output,
|
|
63
|
-
flag_slice_only,
|
|
64
|
-
sparsity_block_size, triton_block_size),
|
|
65
|
-
sparsity_layout_output)
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
class _BlocksparseRowWiseSum(torch.autograd.Function):
|
|
69
|
-
IMPLEMENTATION = "atomic_add"
|
|
70
|
-
|
|
71
|
-
@staticmethod
|
|
72
|
-
def forward(ctx, x: Tensor,
|
|
73
|
-
sparsity_layout: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
74
|
-
sparsity_layout_output: Tensor, sparsity_lut_output: Tensor, sparsity_reverse_lut_output: Tensor,
|
|
75
|
-
n_sparse_blocks_output: int,
|
|
76
|
-
flag_slice_only: bool,
|
|
77
|
-
sparsity_block_size: int, triton_block_size: int) -> Tensor:
|
|
78
|
-
output = torch.zeros(size=(n_sparse_blocks_output,
|
|
79
|
-
sparsity_block_size,
|
|
80
|
-
1 if flag_slice_only else sparsity_block_size),
|
|
81
|
-
device=x.device)
|
|
82
|
-
|
|
83
|
-
x_b, x_r, x_c = x.size()
|
|
84
|
-
x_b_s, x_r_s, x_c_s = x.stride()
|
|
85
|
-
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout.size()
|
|
86
|
-
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout.stride()
|
|
87
|
-
s_lut_x_r, s_lut_x_c = sparsity_lut.size()
|
|
88
|
-
s_lut_x_r_s, s_lut_x_c_s = sparsity_lut.stride()
|
|
89
|
-
o_b, o_r, o_c = output.size()
|
|
90
|
-
o_b_s, o_r_s, o_c_s = output.stride()
|
|
91
|
-
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
|
|
92
|
-
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_output.stride()
|
|
93
|
-
s_lut_o_r, s_lut_o_c = sparsity_lut_output.size()
|
|
94
|
-
s_lut_o_r_s, s_lut_o_c_s = sparsity_lut_output.stride()
|
|
95
|
-
|
|
96
|
-
if triton_block_size is None:
|
|
97
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
98
|
-
|
|
99
|
-
if _BlocksparseRowWiseSum.IMPLEMENTATION == "basic":
|
|
100
|
-
triton_grid = lambda meta: [o_b,
|
|
101
|
-
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"])]
|
|
102
|
-
|
|
103
|
-
(_BlocksparseRowWiseSum.kernel_blocksparse_row_wise_sum[triton_grid]
|
|
104
|
-
(x,
|
|
105
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
106
|
-
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c, s_l_x_c_s,
|
|
107
|
-
sparsity_reverse_lut,
|
|
108
|
-
output,
|
|
109
|
-
o_b, o_b_s, o_r_s,
|
|
110
|
-
sparsity_lut_output, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
111
|
-
sparsity_block_size,
|
|
112
|
-
triton_block_size))
|
|
113
|
-
elif _BlocksparseRowWiseSum.IMPLEMENTATION == "atomic_add":
|
|
114
|
-
triton_grid = lambda meta: [x_b,
|
|
115
|
-
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
116
|
-
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
117
|
-
|
|
118
|
-
(_BlocksparseRowWiseSum.kernel_blocksparse_row_wise_sum_atomic_add[triton_grid]
|
|
119
|
-
(x,
|
|
120
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
121
|
-
sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
122
|
-
output,
|
|
123
|
-
o_b, o_b_s, o_r_s,
|
|
124
|
-
s_l_o_b, s_l_o_b_s, s_l_o_r_s,
|
|
125
|
-
sparsity_reverse_lut_output,
|
|
126
|
-
triton_block_size))
|
|
127
|
-
|
|
128
|
-
return output
|
|
129
|
-
|
|
130
|
-
@staticmethod
|
|
131
|
-
def backward(ctx, grad_output):
|
|
132
|
-
raise NotImplementedError
|
|
133
|
-
|
|
134
|
-
@staticmethod
|
|
135
|
-
@triton.jit
|
|
136
|
-
def kernel_blocksparse_row_wise_sum(x,
|
|
137
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
138
|
-
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c, s_l_x_c_s,
|
|
139
|
-
r_lut_x,
|
|
140
|
-
o,
|
|
141
|
-
o_b, o_b_s, o_r_s,
|
|
142
|
-
s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
143
|
-
sparsity_block_size,
|
|
144
|
-
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
145
|
-
pid_blk = tl.program_id(axis=0)
|
|
146
|
-
pid_row = tl.program_id(axis=1)
|
|
147
|
-
|
|
148
|
-
# Get position of current sparsity block consisting of its batch and row index
|
|
149
|
-
spa_bat_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
|
|
150
|
-
spa_bat_msk = (spa_bat_idx < s_lut_o_r * s_lut_o_r_s)
|
|
151
|
-
spa_bat = tl.load(s_lut_o + spa_bat_idx, mask=spa_bat_msk)
|
|
152
|
-
|
|
153
|
-
spa_row_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
|
|
154
|
-
spa_row_msk = (spa_row_idx < s_lut_o_r * s_lut_o_r_s)
|
|
155
|
-
spa_row = tl.load(s_lut_o + spa_row_idx, mask=spa_row_msk)
|
|
156
|
-
|
|
157
|
-
buf = tl.zeros(shape=(TRITON_BLOCK_SIZE, 1), dtype=tl.float32)
|
|
158
|
-
|
|
159
|
-
# Slide over triton block sized segments of input tensor
|
|
160
|
-
for i_seg_tri in range(0, tl.cdiv(s_l_x_c * sparsity_block_size, TRITON_BLOCK_SIZE)):
|
|
161
|
-
# Convert to segment index of sparsity layout
|
|
162
|
-
i_seg_spa = (i_seg_tri * TRITON_BLOCK_SIZE) // sparsity_block_size
|
|
163
|
-
# Calculate the triton segment index within a block
|
|
164
|
-
i_seg_tri_mod = i_seg_tri % (sparsity_block_size // TRITON_BLOCK_SIZE)
|
|
165
|
-
|
|
166
|
-
# Load reverse sparsity index for current block
|
|
167
|
-
rev_idx_spa_idx = (spa_bat * s_l_x_b_s +
|
|
168
|
-
spa_row * s_l_x_r_s +
|
|
169
|
-
i_seg_spa * s_l_x_c_s)
|
|
170
|
-
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_x_b * s_l_x_b_s)
|
|
171
|
-
rev_idx_spa = tl.load(r_lut_x + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
172
|
-
|
|
173
|
-
# If block is present commence operations
|
|
174
|
-
if rev_idx_spa >= 0:
|
|
175
|
-
blk_idx = ((rev_idx_spa * x_b_s) +
|
|
176
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
177
|
-
((i_seg_tri_mod * TRITON_BLOCK_SIZE +
|
|
178
|
-
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
179
|
-
blk_msk = (blk_idx < x_b * x_b_s)
|
|
180
|
-
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
181
|
-
|
|
182
|
-
buf = buf + tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
|
|
183
|
-
|
|
184
|
-
o_idx = (pid_blk * o_b_s +
|
|
185
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
186
|
-
(tl.arange(0, 1))[None, :])
|
|
187
|
-
o_msk = (o_idx < o_b * o_b_s)
|
|
188
|
-
tl.store(o + o_idx, buf, o_msk)
|
|
189
|
-
|
|
190
|
-
@staticmethod
|
|
191
|
-
@triton.jit
|
|
192
|
-
def kernel_blocksparse_row_wise_sum_atomic_add(x,
|
|
193
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
194
|
-
s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
195
|
-
o,
|
|
196
|
-
o_b, o_b_s, o_r_s,
|
|
197
|
-
s_l_o_b, s_l_o_b_s, s_l_o_r_s,
|
|
198
|
-
r_lut_o,
|
|
199
|
-
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
200
|
-
pid_blk = tl.program_id(axis=0)
|
|
201
|
-
pid_row = tl.program_id(axis=1)
|
|
202
|
-
pid_col = tl.program_id(axis=2)
|
|
203
|
-
|
|
204
|
-
# Get position of current sparsity block consisting of its batch and row index
|
|
205
|
-
spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
|
|
206
|
-
spa_bat_msk = (spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
|
|
207
|
-
spa_bat = tl.load(s_lut_x + spa_bat_idx, mask=spa_bat_msk)
|
|
208
|
-
|
|
209
|
-
spa_row_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
|
|
210
|
-
spa_row_msk = (spa_row_idx < s_lut_x_r * s_lut_x_r_s)
|
|
211
|
-
spa_row = tl.load(s_lut_x + spa_row_idx, mask=spa_row_msk)
|
|
212
|
-
|
|
213
|
-
# Load reverse sparsity index for current block
|
|
214
|
-
rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
|
|
215
|
-
spa_row * s_l_o_r_s)
|
|
216
|
-
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
|
|
217
|
-
rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
218
|
-
|
|
219
|
-
blk_idx = ((pid_blk * x_b_s) +
|
|
220
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
221
|
-
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
222
|
-
blk_msk = (blk_idx < x_b * x_b_s)
|
|
223
|
-
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
224
|
-
|
|
225
|
-
buf = tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
|
|
226
|
-
|
|
227
|
-
o_idx = (rev_idx_spa * o_b_s +
|
|
228
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
229
|
-
(tl.arange(0, 1))[None, :])
|
|
230
|
-
o_msk = (o_idx < o_b * o_b_s)
|
|
231
|
-
tl.atomic_add(o + o_idx, buf, o_msk)
|
blksprs-1.3.dist-info/RECORD
DELETED
|
@@ -1,18 +0,0 @@
|
|
|
1
|
-
blksprs/layouting/distribution_layout.py,sha256=GQ-ZRXbeImiLcbaqnL2FuUZ6DoFwmB0naT_YrOpD84Q,4940
|
|
2
|
-
blksprs/layouting/sparsity_layout.py,sha256=TtADT_WWcZpW3zyGy6KAgkAo44gDryXZqdJLZGEX2V8,7895
|
|
3
|
-
blksprs/misc/broadcast_addition.py,sha256=vf1Hdqz9Uyqykto3DCjmdyepMzpMXL238SpANQqRAwI,5297
|
|
4
|
-
blksprs/misc/repeat_interleave.py,sha256=WrIp7uJsnvjIhFeLYPfkL2j5vXyKmDQGrJ69b3Y0lQ8,5644
|
|
5
|
-
blksprs/ops/conversion.py,sha256=-AOzj_j3WrBLGIgd2oVPvYS8XKfzlvGtSIWzW_qP1lk,21260
|
|
6
|
-
blksprs/ops/distribution.py,sha256=_fQb6fWpLxocAh86D74ATahChi0EK0eBb4eUOUEBVps,16769
|
|
7
|
-
blksprs/ops/exp.py,sha256=qs8fVtCzxl4CKT4GepaqurjEL62jyi8VjMY12JFrFAU,3674
|
|
8
|
-
blksprs/ops/matmul.py,sha256=x3lrYg4g8fIf5PeMtZY_SEpi11kP9RFcRoemCIxcSDE,11086
|
|
9
|
-
blksprs/ops/row_wise_sum.py,sha256=ojuSejV37cLtRNS3lBfknA5KY3TEg8EHxOqVT6JZzoM,11387
|
|
10
|
-
blksprs/ops/softmax.py,sha256=ZyeAVqmG_VzJ72FArGrpUSFfoSM4GPxyubrmNKERVIA,11654
|
|
11
|
-
blksprs/ops/transpose.py,sha256=cX_E3b-QMhsUDNn9D8HVkYesc2JBc-EcVBUZfCWExM8,6720
|
|
12
|
-
blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
|
|
13
|
-
blksprs/utils/tools.py,sha256=DwophH01AeNTZAo0B1uWbKFSGBQjI5z0WmFnYKh-BBk,465
|
|
14
|
-
blksprs/utils/validation.py,sha256=gJYZO5C48YUrXV3Fy_Z_lCaOpiFj951FT-Od7sKfprg,3007
|
|
15
|
-
blksprs-1.3.dist-info/METADATA,sha256=bs4_e4DjSYyAQ354tLVNIKcGLkww_-C2AfHnJIMdjA8,7515
|
|
16
|
-
blksprs-1.3.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
|
17
|
-
blksprs-1.3.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
|
|
18
|
-
blksprs-1.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|