blksprs 1.4.2__tar.gz → 1.6.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {blksprs-1.4.2 → blksprs-1.6.1}/PKG-INFO +3 -2
- {blksprs-1.4.2 → blksprs-1.6.1}/README.md +2 -1
- {blksprs-1.4.2 → blksprs-1.6.1}/blksprs/__init__.py +6 -2
- blksprs-1.6.1/blksprs/experimental/distribution_mdi.py +438 -0
- {blksprs-1.4.2 → blksprs-1.6.1}/blksprs/layouting/distribution_layout.py +4 -17
- {blksprs-1.4.2 → blksprs-1.6.1}/blksprs/layouting/sparsity_layout.py +36 -0
- {blksprs-1.4.2 → blksprs-1.6.1}/blksprs/misc/repeat_interleave.py +1 -1
- blksprs-1.6.1/blksprs/ops/partitioning.py +244 -0
- {blksprs-1.4.2 → blksprs-1.6.1}/blksprs/ops/transpose.py +6 -7
- {blksprs-1.4.2 → blksprs-1.6.1}/blksprs/utils/validation.py +2 -0
- {blksprs-1.4.2 → blksprs-1.6.1}/blksprs.egg-info/PKG-INFO +3 -2
- {blksprs-1.4.2 → blksprs-1.6.1}/blksprs.egg-info/SOURCES.txt +2 -0
- {blksprs-1.4.2 → blksprs-1.6.1}/pyproject.toml +1 -1
- {blksprs-1.4.2 → blksprs-1.6.1}/blksprs/misc/broadcast_ops.py +0 -0
- {blksprs-1.4.2 → blksprs-1.6.1}/blksprs/misc/row_wise.py +0 -0
- {blksprs-1.4.2 → blksprs-1.6.1}/blksprs/ops/conversion.py +0 -0
- {blksprs-1.4.2 → blksprs-1.6.1}/blksprs/ops/distribution.py +0 -0
- {blksprs-1.4.2 → blksprs-1.6.1}/blksprs/ops/exp.py +0 -0
- {blksprs-1.4.2 → blksprs-1.6.1}/blksprs/ops/matmul.py +0 -0
- {blksprs-1.4.2 → blksprs-1.6.1}/blksprs/ops/softmax.py +0 -0
- {blksprs-1.4.2 → blksprs-1.6.1}/blksprs/utils/benchmarking.py +0 -0
- {blksprs-1.4.2 → blksprs-1.6.1}/blksprs/utils/tools.py +0 -0
- {blksprs-1.4.2 → blksprs-1.6.1}/blksprs.egg-info/dependency_links.txt +0 -0
- {blksprs-1.4.2 → blksprs-1.6.1}/blksprs.egg-info/requires.txt +0 -0
- {blksprs-1.4.2 → blksprs-1.6.1}/blksprs.egg-info/top_level.txt +0 -0
- {blksprs-1.4.2 → blksprs-1.6.1}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.6.1
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -31,9 +31,10 @@ Currently supported operations (includes gradient calculation):
|
|
|
31
31
|
- Sparse matrix multiplication (_supports any combination of sparse and dense matrices due to support
|
|
32
32
|
for `sparse = sparse @ sparse` matmul_)
|
|
33
33
|
- Softmax
|
|
34
|
-
-
|
|
34
|
+
- Transpose
|
|
35
35
|
- Gather
|
|
36
36
|
- Scatter (_supports either no reduction or summation, gradients are only available for summation_)
|
|
37
|
+
- Splitting and merging of matrices along the last dimension
|
|
37
38
|
- Conversion to and from sparse form
|
|
38
39
|
- Conversion to different sparsity layouts and different sparsity block sizes
|
|
39
40
|
|
|
@@ -12,9 +12,10 @@ Currently supported operations (includes gradient calculation):
|
|
|
12
12
|
- Sparse matrix multiplication (_supports any combination of sparse and dense matrices due to support
|
|
13
13
|
for `sparse = sparse @ sparse` matmul_)
|
|
14
14
|
- Softmax
|
|
15
|
-
-
|
|
15
|
+
- Transpose
|
|
16
16
|
- Gather
|
|
17
17
|
- Scatter (_supports either no reduction or summation, gradients are only available for summation_)
|
|
18
|
+
- Splitting and merging of matrices along the last dimension
|
|
18
19
|
- Conversion to and from sparse form
|
|
19
20
|
- Conversion to different sparsity layouts and different sparsity block sizes
|
|
20
21
|
|
|
@@ -4,10 +4,11 @@ from blksprs.ops.exp import exp
|
|
|
4
4
|
from blksprs.ops.matmul import matmul
|
|
5
5
|
from blksprs.ops.softmax import softmax
|
|
6
6
|
from blksprs.ops.transpose import transpose
|
|
7
|
+
from blksprs.ops.partitioning import split, merge
|
|
7
8
|
|
|
8
9
|
class layout:
|
|
9
10
|
from blksprs.layouting.distribution_layout import build_distribution_layout
|
|
10
|
-
from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption
|
|
11
|
+
from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption, build_sparsity_layout_matmul
|
|
11
12
|
|
|
12
13
|
class misc:
|
|
13
14
|
from blksprs.misc.broadcast_ops import broadcast_add, broadcast_sub
|
|
@@ -15,4 +16,7 @@ class misc:
|
|
|
15
16
|
from blksprs.misc.row_wise import row_wise_sum, row_wise_max, row_wise_add, row_wise_sub
|
|
16
17
|
|
|
17
18
|
class util:
|
|
18
|
-
from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse, disable_validation
|
|
19
|
+
from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse, disable_validation
|
|
20
|
+
|
|
21
|
+
class experimental:
|
|
22
|
+
from blksprs.experimental.distribution_mdi import gather_mdi, scatter_reduce_mdi
|
|
@@ -0,0 +1,438 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
from triton import language as tl
|
|
5
|
+
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size
|
|
7
|
+
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
8
|
+
validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def gather_mdi(src: Tensor, sparsity_layout_src: Tensor,
|
|
12
|
+
idx_bat: Tensor,
|
|
13
|
+
idx_row: Tensor,
|
|
14
|
+
idx_col: Tensor,
|
|
15
|
+
sparsity_layout_idx: Tensor,
|
|
16
|
+
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
17
|
+
src = src.contiguous()
|
|
18
|
+
idx_bat = idx_bat.contiguous()
|
|
19
|
+
idx_col = idx_col.contiguous()
|
|
20
|
+
|
|
21
|
+
validate_dimensions(src, idx_bat, idx_col)
|
|
22
|
+
validate_contiguous(src, idx_bat, idx_col)
|
|
23
|
+
validate_dtype_int(idx_bat, idx_col)
|
|
24
|
+
validate_device(src, idx_bat, idx_col)
|
|
25
|
+
validate_sparsity(sparsity_block_size, (src, sparsity_layout_src),
|
|
26
|
+
(idx_bat, sparsity_layout_idx), (idx_col, sparsity_layout_idx))
|
|
27
|
+
validate_sparsity_block_size(sparsity_block_size, src, idx_bat, idx_col)
|
|
28
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
29
|
+
|
|
30
|
+
sparsity_layout_x_flat = sparsity_layout_src.reshape(-1)
|
|
31
|
+
sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
|
|
32
|
+
(sparsity_layout_x_flat == 1) -
|
|
33
|
+
(1 * (sparsity_layout_x_flat == 0)))
|
|
34
|
+
|
|
35
|
+
sparsity_lut_i = torch.nonzero(sparsity_layout_idx).contiguous()
|
|
36
|
+
|
|
37
|
+
validate_contiguous(sparsity_layout_src, sparsity_reverse_lut_x,
|
|
38
|
+
sparsity_layout_idx, sparsity_lut_i)
|
|
39
|
+
|
|
40
|
+
return _BlocksparseGatherMDI.apply(src, sparsity_layout_src, sparsity_reverse_lut_x,
|
|
41
|
+
idx_bat, idx_col, sparsity_layout_idx, sparsity_lut_i,
|
|
42
|
+
sparsity_block_size, triton_block_size)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class _BlocksparseGatherMDI(torch.autograd.Function):
|
|
46
|
+
|
|
47
|
+
@staticmethod
|
|
48
|
+
def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
|
|
49
|
+
idx_bat: Tensor, idx_col: Tensor, sparsity_layout_i: Tensor, sparsity_lut_i: Tensor,
|
|
50
|
+
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
51
|
+
output = torch.empty_like(idx_col, dtype=x.dtype)
|
|
52
|
+
|
|
53
|
+
x_b, x_r, x_c = x.size()
|
|
54
|
+
x_b_s, x_r_s, x_c_s = x.stride()
|
|
55
|
+
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
|
|
56
|
+
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout_x.stride()
|
|
57
|
+
i_b, i_r, i_c = idx_col.size()
|
|
58
|
+
i_b_s, i_r_s, i_c_s = idx_col.stride()
|
|
59
|
+
s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
|
|
60
|
+
s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
|
|
61
|
+
o_b, o_r, o_c = output.size()
|
|
62
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
63
|
+
|
|
64
|
+
if triton_block_size is None:
|
|
65
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
66
|
+
|
|
67
|
+
triton_grid = lambda meta: [o_b,
|
|
68
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
69
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
70
|
+
|
|
71
|
+
(_BlocksparseGatherMDI.kernel_blocksparse_gather_mdi[triton_grid]
|
|
72
|
+
(x,
|
|
73
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
74
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
75
|
+
sparsity_reverse_lut_x,
|
|
76
|
+
idx_bat,
|
|
77
|
+
idx_col,
|
|
78
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
79
|
+
output,
|
|
80
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
81
|
+
sparsity_lut_i, s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
|
|
82
|
+
sparsity_block_size,
|
|
83
|
+
triton_block_size))
|
|
84
|
+
|
|
85
|
+
ctx.save_for_backward(sparsity_layout_x, idx_bat, idx_col, sparsity_layout_i)
|
|
86
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
87
|
+
ctx.triton_block_size = triton_block_size
|
|
88
|
+
|
|
89
|
+
return output
|
|
90
|
+
|
|
91
|
+
@staticmethod
|
|
92
|
+
def backward(ctx, grad_output):
|
|
93
|
+
sparsity_layout_x, idx_bat, idx_col, sparsity_layout_i = ctx.saved_tensors
|
|
94
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
95
|
+
triton_block_size = ctx.triton_block_size
|
|
96
|
+
|
|
97
|
+
return scatter_reduce_mdi(grad_output, sparsity_layout_i,
|
|
98
|
+
idx_bat,
|
|
99
|
+
None,
|
|
100
|
+
idx_col,
|
|
101
|
+
sparsity_layout_x,
|
|
102
|
+
sparsity_block_size,
|
|
103
|
+
reduce_op="sum",
|
|
104
|
+
triton_block_size=triton_block_size), None, None, None, None, None, None, None, None
|
|
105
|
+
|
|
106
|
+
@staticmethod
|
|
107
|
+
@triton.jit
|
|
108
|
+
def kernel_blocksparse_gather_mdi(x,
|
|
109
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
110
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
111
|
+
r_lut_x,
|
|
112
|
+
idx_bat,
|
|
113
|
+
idx_col,
|
|
114
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
115
|
+
o,
|
|
116
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
117
|
+
s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
118
|
+
sparsity_block_size,
|
|
119
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
120
|
+
# Get triton block indices
|
|
121
|
+
pid_blk = tl.program_id(axis=0)
|
|
122
|
+
pid_row = tl.program_id(axis=1)
|
|
123
|
+
pid_col = tl.program_id(axis=2)
|
|
124
|
+
|
|
125
|
+
# Load batch index values
|
|
126
|
+
blk_idx_bat_idx = ((pid_blk * i_b_s) +
|
|
127
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
128
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
129
|
+
blk_idx_bat_msk = (blk_idx_bat_idx < i_b * i_b_s)
|
|
130
|
+
blk_idx_bat = tl.load(idx_bat + blk_idx_bat_idx, mask=blk_idx_bat_msk).to(tl.int32)
|
|
131
|
+
|
|
132
|
+
# Get position of current sparsity block row
|
|
133
|
+
spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
|
|
134
|
+
spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
135
|
+
spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
|
|
136
|
+
|
|
137
|
+
# Load column index values
|
|
138
|
+
blk_idx_col_idx = ((pid_blk * i_b_s) +
|
|
139
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
140
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
141
|
+
blk_idx_col_msk = (blk_idx_col_idx < i_b * i_b_s)
|
|
142
|
+
blk_idx_col = tl.load(idx_col + blk_idx_col_idx, mask=blk_idx_col_msk).to(tl.int32)
|
|
143
|
+
|
|
144
|
+
# Get positions of sparsity blocks
|
|
145
|
+
pos_spa_blk_x = blk_idx_col // sparsity_block_size
|
|
146
|
+
pos_spa_col_x = blk_idx_col % sparsity_block_size
|
|
147
|
+
|
|
148
|
+
# Load reverse sparsity indices for x
|
|
149
|
+
rev_idx_spa_x_idx = ((blk_idx_bat * s_l_x_b_s) +
|
|
150
|
+
(spa_row_o * s_l_x_r_s) +
|
|
151
|
+
(pos_spa_blk_x * s_l_x_c_s))
|
|
152
|
+
rev_idx_spa_x_msk = (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
|
|
153
|
+
rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
|
|
154
|
+
|
|
155
|
+
# Load x values
|
|
156
|
+
blk_x_idx = ((rev_idx_spa_x * x_b_s) +
|
|
157
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
158
|
+
(pos_spa_col_x * x_c_s))
|
|
159
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
160
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
161
|
+
|
|
162
|
+
# Store output
|
|
163
|
+
blk_o_idx = ((pid_blk * o_b_s) +
|
|
164
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
165
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
166
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
167
|
+
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def scatter_reduce_mdi(src: Tensor, sparsity_layout_src: Tensor,
|
|
171
|
+
idx_bat: Tensor,
|
|
172
|
+
idx_row: Tensor,
|
|
173
|
+
idx_col: Tensor,
|
|
174
|
+
sparsity_layout_tgt: Tensor,
|
|
175
|
+
sparsity_block_size: int,
|
|
176
|
+
reduce_op: str = "sum", triton_block_size: int = None) -> Tensor:
|
|
177
|
+
src = src.contiguous()
|
|
178
|
+
idx_bat = idx_bat.contiguous()
|
|
179
|
+
idx_col = idx_col.contiguous()
|
|
180
|
+
|
|
181
|
+
validate_dimensions(src, idx_bat, idx_col)
|
|
182
|
+
validate_contiguous(src, idx_bat, idx_col)
|
|
183
|
+
validate_dtype_int(idx_bat, idx_col)
|
|
184
|
+
validate_device(src, idx_bat, idx_col)
|
|
185
|
+
validate_sparsity(sparsity_block_size, (src, sparsity_layout_src),
|
|
186
|
+
(idx_bat, sparsity_layout_src),
|
|
187
|
+
(idx_col, sparsity_layout_src))
|
|
188
|
+
validate_sparsity_block_size(sparsity_block_size, src, idx_bat, idx_col)
|
|
189
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
190
|
+
|
|
191
|
+
if reduce_op not in ["none", "sum"]:
|
|
192
|
+
raise ValueError(f"Reduction operation '{reduce_op}' is not supported")
|
|
193
|
+
|
|
194
|
+
sparsity_lut_x = torch.nonzero(sparsity_layout_src).contiguous()
|
|
195
|
+
|
|
196
|
+
sparsity_layout_o_flat = sparsity_layout_tgt.reshape(-1)
|
|
197
|
+
sparsity_reverse_lut_o = ((torch.cumsum(sparsity_layout_o_flat, dim=-1) - 1) *
|
|
198
|
+
(sparsity_layout_o_flat == 1) -
|
|
199
|
+
(1 * (sparsity_layout_o_flat == 0)))
|
|
200
|
+
|
|
201
|
+
n_sparse_blocks = torch.sum(sparsity_layout_tgt.to(torch.int)).item()
|
|
202
|
+
|
|
203
|
+
validate_contiguous(sparsity_layout_src, sparsity_lut_x,
|
|
204
|
+
sparsity_layout_tgt, sparsity_reverse_lut_o)
|
|
205
|
+
|
|
206
|
+
return _BlocksparseScatterReduceMDI.apply(src, sparsity_layout_src, sparsity_lut_x,
|
|
207
|
+
idx_bat,
|
|
208
|
+
idx_col,
|
|
209
|
+
sparsity_layout_tgt, sparsity_reverse_lut_o,
|
|
210
|
+
sparsity_block_size, n_sparse_blocks,
|
|
211
|
+
reduce_op, triton_block_size)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
class _BlocksparseScatterReduceMDI(torch.autograd.Function):
|
|
215
|
+
|
|
216
|
+
@staticmethod
|
|
217
|
+
def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_lut_x: Tensor,
|
|
218
|
+
idx_bat: Tensor,
|
|
219
|
+
idx_col: Tensor,
|
|
220
|
+
sparsity_layout_o: Tensor, sparsity_reverse_lut_o: Tensor,
|
|
221
|
+
sparsity_block_size: int, n_sparse_blocks: int,
|
|
222
|
+
reduce_op: str, triton_block_size: int) -> Tensor:
|
|
223
|
+
output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
224
|
+
dtype=x.dtype, device=x.device)
|
|
225
|
+
|
|
226
|
+
x_b, x_r, x_c = x.size()
|
|
227
|
+
x_b_s, x_r_s, x_c_s = x.stride()
|
|
228
|
+
s_lut_x_r, s_lut_x_c = sparsity_lut_x.size()
|
|
229
|
+
s_lut_x_r_s, s_lut_x_c_s = sparsity_lut_x.stride()
|
|
230
|
+
i_b, i_r, i_c = idx_col.size()
|
|
231
|
+
i_b_s, i_r_s, i_c_s = idx_col.stride()
|
|
232
|
+
o_b, o_r, o_c = output.size()
|
|
233
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
234
|
+
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
|
|
235
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_o.stride()
|
|
236
|
+
|
|
237
|
+
if triton_block_size is None:
|
|
238
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
239
|
+
|
|
240
|
+
triton_grid = lambda meta: [x_b,
|
|
241
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
242
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
243
|
+
|
|
244
|
+
reduce_op_ind = 0
|
|
245
|
+
if reduce_op == "sum":
|
|
246
|
+
reduce_op_ind = 1
|
|
247
|
+
|
|
248
|
+
(_BlocksparseScatterReduceMDI.kernel_blocksparse_scatter_mdi[triton_grid]
|
|
249
|
+
(x,
|
|
250
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
251
|
+
sparsity_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
252
|
+
idx_bat,
|
|
253
|
+
idx_col,
|
|
254
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
255
|
+
output,
|
|
256
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
257
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
258
|
+
sparsity_reverse_lut_o,
|
|
259
|
+
reduce_op_ind,
|
|
260
|
+
sparsity_block_size,
|
|
261
|
+
triton_block_size))
|
|
262
|
+
|
|
263
|
+
ctx.save_for_backward(sparsity_layout_x, idx_bat, idx_col, sparsity_layout_o)
|
|
264
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
265
|
+
ctx.reduce_op = reduce_op
|
|
266
|
+
ctx.triton_block_size = triton_block_size
|
|
267
|
+
|
|
268
|
+
return output
|
|
269
|
+
|
|
270
|
+
@staticmethod
|
|
271
|
+
def backward(ctx, grad_output):
|
|
272
|
+
sparsity_layout_x, idx_bat, idx_col, sparsity_layout_o = ctx.saved_tensors
|
|
273
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
274
|
+
reduce_op = ctx.reduce_op
|
|
275
|
+
triton_block_size = ctx.triton_block_size
|
|
276
|
+
|
|
277
|
+
if reduce_op == "sum":
|
|
278
|
+
return gather_mdi(grad_output, sparsity_layout_o,
|
|
279
|
+
idx_bat,
|
|
280
|
+
None,
|
|
281
|
+
idx_col,
|
|
282
|
+
sparsity_layout_x, sparsity_block_size,
|
|
283
|
+
triton_block_size=triton_block_size), None, None, None, None, None, None, None, None, None, None
|
|
284
|
+
else:
|
|
285
|
+
raise ValueError(f"Reduction operation '{reduce_op}' does not support backward pass")
|
|
286
|
+
|
|
287
|
+
@staticmethod
|
|
288
|
+
@triton.jit
|
|
289
|
+
def kernel_blocksparse_scatter_mdi(x,
|
|
290
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
291
|
+
s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
292
|
+
idx_bat,
|
|
293
|
+
idx_col,
|
|
294
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
295
|
+
o,
|
|
296
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
297
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
298
|
+
r_lut_o,
|
|
299
|
+
reduce_op_ind,
|
|
300
|
+
sparsity_block_size,
|
|
301
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
302
|
+
# Get triton block indices
|
|
303
|
+
pid_blk = tl.program_id(axis=0)
|
|
304
|
+
pid_row = tl.program_id(axis=1)
|
|
305
|
+
pid_col = tl.program_id(axis=2)
|
|
306
|
+
|
|
307
|
+
# Load batch index values
|
|
308
|
+
blk_idx_bat_idx = ((pid_blk * i_b_s) +
|
|
309
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
310
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
311
|
+
blk_idx_bat_msk = (blk_idx_bat_idx < i_b * i_b_s)
|
|
312
|
+
blk_idx_bat = tl.load(idx_bat + blk_idx_bat_idx, mask=blk_idx_bat_msk).to(tl.int32)
|
|
313
|
+
|
|
314
|
+
# Get position of current sparsity block row
|
|
315
|
+
spa_row_x_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
|
|
316
|
+
spa_row_x_msk = (spa_row_x_idx < s_lut_x_r * s_lut_x_r_s)
|
|
317
|
+
spa_row_x = tl.load(s_lut_x + spa_row_x_idx, mask=spa_row_x_msk)
|
|
318
|
+
|
|
319
|
+
# Load x values
|
|
320
|
+
blk_x_idx = ((pid_blk * x_b_s) +
|
|
321
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
322
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
323
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
324
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
325
|
+
|
|
326
|
+
# Load column index values
|
|
327
|
+
blk_idx_col_idx = ((pid_blk * i_b_s) +
|
|
328
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
329
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
330
|
+
blk_idx_col_msk = (blk_idx_col_idx < i_b * i_b_s)
|
|
331
|
+
blk_idx_col = tl.load(idx_col + blk_idx_col_idx, mask=blk_idx_col_msk).to(tl.int32)
|
|
332
|
+
|
|
333
|
+
# Get positions of sparsity blocks
|
|
334
|
+
pos_spa_blk_o = blk_idx_col // sparsity_block_size
|
|
335
|
+
pos_spa_col_o = blk_idx_col % sparsity_block_size
|
|
336
|
+
|
|
337
|
+
# Load reverse sparsity indices for o
|
|
338
|
+
rev_idx_spa_o_idx = ((blk_idx_bat * s_l_o_b_s) +
|
|
339
|
+
(spa_row_x * s_l_o_r_s) +
|
|
340
|
+
(pos_spa_blk_o * s_l_o_c_s))
|
|
341
|
+
rev_idx_spa_o_msk = (rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
|
|
342
|
+
rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
|
|
343
|
+
|
|
344
|
+
# Store output
|
|
345
|
+
blk_o_idx = ((rev_idx_spa_o * o_b_s) +
|
|
346
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
347
|
+
(pos_spa_col_o * o_c_s))
|
|
348
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
349
|
+
|
|
350
|
+
if reduce_op_ind == 0:
|
|
351
|
+
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
352
|
+
elif reduce_op_ind == 1:
|
|
353
|
+
tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def build_distribution_layout_mdi(idx_bat: Tensor, idx_row: Tensor, idx_col: Tensor, sparsity_layout_idx: Tensor,
|
|
357
|
+
size_target: torch.Size,
|
|
358
|
+
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
359
|
+
validate_dimensions(idx_bat, idx_col)
|
|
360
|
+
validate_contiguous(idx_bat, idx_col)
|
|
361
|
+
validate_device(idx_bat, idx_col)
|
|
362
|
+
|
|
363
|
+
sparsity_lut_i = torch.nonzero(sparsity_layout_idx).contiguous()
|
|
364
|
+
|
|
365
|
+
output = torch.zeros(size_target[0], size_target[1] // sparsity_block_size, size_target[2] // sparsity_block_size,
|
|
366
|
+
dtype=torch.bool, device=idx_col.device)
|
|
367
|
+
|
|
368
|
+
i_b, i_r, i_c = idx_col.size()
|
|
369
|
+
i_b_s, i_r_s, i_c_s = idx_col.stride()
|
|
370
|
+
s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
|
|
371
|
+
s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
|
|
372
|
+
o_b, o_r, o_c = output.size()
|
|
373
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
374
|
+
|
|
375
|
+
if triton_block_size is None:
|
|
376
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
377
|
+
|
|
378
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
379
|
+
|
|
380
|
+
triton_grid = lambda meta: [i_b,
|
|
381
|
+
triton.cdiv(i_r, meta["TRITON_BLOCK_SIZE"]),
|
|
382
|
+
triton.cdiv(i_c, meta["TRITON_BLOCK_SIZE"])]
|
|
383
|
+
|
|
384
|
+
(kernel_distribution_layout_mdi[triton_grid]
|
|
385
|
+
(idx_bat,
|
|
386
|
+
idx_col,
|
|
387
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
388
|
+
sparsity_lut_i,
|
|
389
|
+
s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
|
|
390
|
+
output,
|
|
391
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
392
|
+
sparsity_block_size,
|
|
393
|
+
triton_block_size))
|
|
394
|
+
|
|
395
|
+
return output
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
@triton.jit
|
|
399
|
+
def kernel_distribution_layout_mdi(idx_bat,
|
|
400
|
+
idx_col,
|
|
401
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
402
|
+
s_lut_i,
|
|
403
|
+
s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
|
|
404
|
+
o,
|
|
405
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
406
|
+
sparsity_block_size,
|
|
407
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
408
|
+
# Get triton block indices
|
|
409
|
+
pid_blk = tl.program_id(axis=0)
|
|
410
|
+
pid_row = tl.program_id(axis=1)
|
|
411
|
+
pid_col = tl.program_id(axis=2)
|
|
412
|
+
|
|
413
|
+
# Load batch index values
|
|
414
|
+
blk_idx_bat_idx = ((pid_blk * i_b_s) +
|
|
415
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
416
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
417
|
+
blk_idx_bat_msk = (blk_idx_bat_idx < i_b * i_b_s)
|
|
418
|
+
blk_idx_bat = tl.load(idx_bat + blk_idx_bat_idx, mask=blk_idx_bat_msk).to(tl.int32)
|
|
419
|
+
|
|
420
|
+
# Get position of current sparsity block row
|
|
421
|
+
spa_row_i_idx = (pid_blk * s_lut_i_r_s + 1 * s_lut_i_c_s)
|
|
422
|
+
spa_row_i_msk = (spa_row_i_idx < s_lut_i_r * s_lut_i_r_s)
|
|
423
|
+
spa_row_i = tl.load(s_lut_i + spa_row_i_idx, mask=spa_row_i_msk)
|
|
424
|
+
|
|
425
|
+
blk_i_idx = (pid_blk * i_b_s +
|
|
426
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
427
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
428
|
+
blk_i_msk = (blk_i_idx < i_b * i_b_s)
|
|
429
|
+
blk_i = tl.load(idx_col + blk_i_idx, mask=blk_i_msk)
|
|
430
|
+
|
|
431
|
+
blk_i = blk_i // sparsity_block_size
|
|
432
|
+
blk_v = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), 1, dtype=tl.int32)
|
|
433
|
+
|
|
434
|
+
blk_o_idx = ((blk_idx_bat * o_b_s) +
|
|
435
|
+
(spa_row_i * o_r_s) +
|
|
436
|
+
(blk_i * o_c_s))
|
|
437
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
438
|
+
tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
|
|
@@ -35,8 +35,6 @@ def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
|
|
|
35
35
|
|
|
36
36
|
i_b, i_r, i_c = indices.size()
|
|
37
37
|
i_b_s, i_r_s, i_c_s = indices.stride()
|
|
38
|
-
s_l_i_b, s_l_i_r, s_l_i_c = sparsity_layout_indices.size()
|
|
39
|
-
s_l_i_b_s, s_l_i_r_s, s_l_i_c_s = sparsity_layout_indices.stride()
|
|
40
38
|
s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
|
|
41
39
|
s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
|
|
42
40
|
o_b, o_r, o_c = output.size()
|
|
@@ -54,12 +52,10 @@ def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
|
|
|
54
52
|
(kernel_distribution_layout[triton_grid]
|
|
55
53
|
(indices,
|
|
56
54
|
i_b, i_b_s, i_r_s, i_c_s,
|
|
57
|
-
sparsity_layout_indices,
|
|
58
|
-
s_l_i_b, s_l_i_b_s, s_l_i_r, s_l_i_r_s, s_l_i_c, s_l_i_c_s,
|
|
59
55
|
sparsity_lut_i,
|
|
60
|
-
s_lut_i_r, s_lut_i_r_s,
|
|
56
|
+
s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
|
|
61
57
|
output,
|
|
62
|
-
o_b, o_b_s,
|
|
58
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
63
59
|
sparsity_block_size,
|
|
64
60
|
triton_block_size))
|
|
65
61
|
|
|
@@ -69,12 +65,10 @@ def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
|
|
|
69
65
|
@triton.jit
|
|
70
66
|
def kernel_distribution_layout(i,
|
|
71
67
|
i_b, i_b_s, i_r_s, i_c_s,
|
|
72
|
-
s_l_i,
|
|
73
|
-
s_l_i_b, s_l_i_b_s, s_l_i_r, s_l_i_r_s, s_l_i_c, s_l_i_c_s,
|
|
74
68
|
s_lut_i,
|
|
75
|
-
s_lut_i_r, s_lut_i_r_s,
|
|
69
|
+
s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
|
|
76
70
|
o,
|
|
77
|
-
o_b, o_b_s,
|
|
71
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
78
72
|
sparsity_block_size,
|
|
79
73
|
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
80
74
|
# Get triton block indices
|
|
@@ -105,10 +99,3 @@ def kernel_distribution_layout(i,
|
|
|
105
99
|
(blk_i * o_c_s))
|
|
106
100
|
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
107
101
|
tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
|
|
108
|
-
|
|
109
|
-
# if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
|
|
110
|
-
# blk_o_idx = (pid_bat * o_b_s +
|
|
111
|
-
# (((pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_r_s +
|
|
112
|
-
# ((pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_c_s))
|
|
113
|
-
# blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
114
|
-
# tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
|
|
@@ -188,3 +188,39 @@ def kernel_sparsity_layout_adaption(x,
|
|
|
188
188
|
// sparsity_block_size_to) * o_c_s))
|
|
189
189
|
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
190
190
|
tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def build_sparsity_layout_matmul(sparsity_layout_x: Tensor, sparsity_layout_y: Tensor):
|
|
194
|
+
"""Builds the precise sparsity layout of the result of a matrix multiplication between the two input tensors.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
sparsity_layout_x (Tensor): The sparsity layout of the first block-sparse tensor.
|
|
198
|
+
sparsity_layout_y (Tensor): The sparsity layout of the second block-sparse tensor.
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
Tensor: The precise sparsity layout of the result of a matrix multiplication between the two input tensors.
|
|
202
|
+
|
|
203
|
+
"""
|
|
204
|
+
return torch.matmul(sparsity_layout_x.to(torch.float), sparsity_layout_y.to(torch.float)).to(torch.bool)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def build_sparsity_layout_matmul_fast(sparsity_layout_x: Tensor, sparsity_layout_y: Tensor):
|
|
208
|
+
"""Builds the approximate sparsity layout of the result of a matrix multiplication between the two input tensors.
|
|
209
|
+
|
|
210
|
+
Note:
|
|
211
|
+
This function is faster than the ``build_sparsity_layout_matmul`` function due to the fact that it only checks
|
|
212
|
+
whether at least one of the blocks in either of the vectors participating in the matmul is non-sparse. The
|
|
213
|
+
resulting sparsity layout may thus overestimate the actual sparsity of the result.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
sparsity_layout_x (Tensor): The sparsity layout of the first block-sparse tensor.
|
|
217
|
+
sparsity_layout_y (Tensor): The sparsity layout of the second block-sparse tensor.
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
Tensor: The approximate sparsity layout of the result of a matrix multiplication between the two input tensors.
|
|
221
|
+
|
|
222
|
+
"""
|
|
223
|
+
sparsity_layout_x_slice = torch.max(sparsity_layout_x, dim=-1).values.unsqueeze(-1)
|
|
224
|
+
sparsity_layout_y_slice = torch.max(sparsity_layout_y, dim=-2).values.unsqueeze(1)
|
|
225
|
+
|
|
226
|
+
return torch.logical_or(sparsity_layout_x_slice, sparsity_layout_y_slice)
|
|
@@ -35,7 +35,7 @@ def repeat_interleave(x: Tensor, sparsity_layout: Tensor, repeats: int,
|
|
|
35
35
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
36
36
|
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
37
37
|
|
|
38
|
-
sparsity_layout_output = torch.repeat_interleave(sparsity_layout,
|
|
38
|
+
sparsity_layout_output = torch.repeat_interleave(sparsity_layout, repeats, dim=0).contiguous()
|
|
39
39
|
|
|
40
40
|
sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
|
|
41
41
|
|
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
from sympy.utilities.iterables import partitions
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
from triton import language as tl
|
|
6
|
+
|
|
7
|
+
from blksprs.utils.tools import get_triton_block_size
|
|
8
|
+
|
|
9
|
+
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
|
|
10
|
+
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def split(x: Tensor, sparsity_layout: Tensor, partitions: int,
|
|
14
|
+
sparsity_block_size: int, triton_block_size: int = None) -> (Tensor, Tensor):
|
|
15
|
+
"""Splits a block-sparse tensor in compressed form along the last dimension into partitions.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
x (Tensor): A block-sparse tensor in compressed form.
|
|
19
|
+
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
20
|
+
partitions (int): The number of partitions to split the block-sparse tensor into.
|
|
21
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
22
|
+
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
Tensor: The block-sparse tensor split into partitions in compressed form.
|
|
26
|
+
Tensor: The sparsity layout of the output tensor.
|
|
27
|
+
|
|
28
|
+
"""
|
|
29
|
+
x = x.contiguous()
|
|
30
|
+
|
|
31
|
+
validate_dimensions(x)
|
|
32
|
+
validate_contiguous(x)
|
|
33
|
+
validate_device(x)
|
|
34
|
+
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
35
|
+
validate_sparsity_block_size(sparsity_block_size, x)
|
|
36
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
37
|
+
|
|
38
|
+
sparsity_layout_output = (sparsity_layout
|
|
39
|
+
.reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
|
|
40
|
+
sparsity_layout.size(2) // partitions)
|
|
41
|
+
.permute(0, 2, 1, 3)
|
|
42
|
+
.reshape(sparsity_layout.size(0) * partitions, sparsity_layout.size(1),
|
|
43
|
+
sparsity_layout.size(2) // partitions).contiguous())
|
|
44
|
+
|
|
45
|
+
sparsity_lut = torch.nonzero(sparsity_layout_output).contiguous()
|
|
46
|
+
|
|
47
|
+
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
|
48
|
+
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
49
|
+
(sparsity_layout_flat == 1) -
|
|
50
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
51
|
+
.reshape(sparsity_layout.size())
|
|
52
|
+
.reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
|
|
53
|
+
sparsity_layout.size(2) // partitions)
|
|
54
|
+
.permute(0, 2, 1, 3).reshape(-1).contiguous())
|
|
55
|
+
|
|
56
|
+
n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
|
|
57
|
+
|
|
58
|
+
validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
|
|
59
|
+
|
|
60
|
+
return _BlocksparseSplit.apply(x, sparsity_layout_output, sparsity_lut, sparsity_reverse_lut, partitions,
|
|
61
|
+
sparsity_block_size, n_sparse_blocks, triton_block_size), sparsity_layout_output
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class _BlocksparseSplit(torch.autograd.Function):
|
|
65
|
+
|
|
66
|
+
@staticmethod
|
|
67
|
+
def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
68
|
+
num_partitions: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
69
|
+
ctx.num_partitions = num_partitions
|
|
70
|
+
|
|
71
|
+
return forward_reorder(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
72
|
+
n_sparse_blocks, triton_block_size)
|
|
73
|
+
|
|
74
|
+
@staticmethod
|
|
75
|
+
def backward(ctx, grad_output):
|
|
76
|
+
sparsity_layout = ctx.saved_tensors[0]
|
|
77
|
+
num_partitions = ctx.num_partitions
|
|
78
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
79
|
+
triton_block_size = ctx.triton_block_size
|
|
80
|
+
|
|
81
|
+
return merge(grad_output, sparsity_layout, num_partitions,
|
|
82
|
+
sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def merge(x: Tensor, sparsity_layout: Tensor, partitions: int,
|
|
86
|
+
sparsity_block_size: int, triton_block_size: int = None) -> (Tensor, Tensor):
|
|
87
|
+
"""Merges the specified partitions of a block-sparse tensor in compressed form along the last dimension.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
x (Tensor): A block-sparse tensor in compressed form.
|
|
91
|
+
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
92
|
+
partitions (int): The number of partitions to be merged.
|
|
93
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
94
|
+
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
Tensor: The merged block-sparse tensor in compressed form.
|
|
98
|
+
Tensor: The sparsity layout of the output tensor.
|
|
99
|
+
|
|
100
|
+
"""
|
|
101
|
+
x = x.contiguous()
|
|
102
|
+
|
|
103
|
+
validate_dimensions(x)
|
|
104
|
+
validate_contiguous(x)
|
|
105
|
+
validate_device(x)
|
|
106
|
+
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
107
|
+
validate_sparsity_block_size(sparsity_block_size, x)
|
|
108
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
109
|
+
|
|
110
|
+
sparsity_layout_output = (sparsity_layout.reshape(sparsity_layout.size(0) // partitions, partitions,
|
|
111
|
+
sparsity_layout.size(1), sparsity_layout.size(2))
|
|
112
|
+
.permute(0, 2, 1, 3)
|
|
113
|
+
.reshape(sparsity_layout.size(0) // partitions,
|
|
114
|
+
sparsity_layout.size(1), sparsity_layout.size(2) * partitions).contiguous())
|
|
115
|
+
|
|
116
|
+
sparsity_lut = torch.nonzero(sparsity_layout_output).contiguous()
|
|
117
|
+
|
|
118
|
+
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
|
119
|
+
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
120
|
+
(sparsity_layout_flat == 1) -
|
|
121
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
122
|
+
.reshape(sparsity_layout.size(0) // partitions, partitions,
|
|
123
|
+
sparsity_layout.size(1), sparsity_layout.size(2))
|
|
124
|
+
.permute(0, 2, 1, 3)
|
|
125
|
+
.reshape(sparsity_layout.size(0) // partitions,
|
|
126
|
+
sparsity_layout.size(1), sparsity_layout.size(2) * partitions)
|
|
127
|
+
.reshape(-1).contiguous())
|
|
128
|
+
|
|
129
|
+
n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
|
|
130
|
+
|
|
131
|
+
validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
|
|
132
|
+
|
|
133
|
+
return _BlocksparseMerge.apply(x, sparsity_layout_output, sparsity_lut, sparsity_reverse_lut, partitions,
|
|
134
|
+
sparsity_block_size, n_sparse_blocks, triton_block_size), sparsity_layout_output
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class _BlocksparseMerge(torch.autograd.Function):
|
|
138
|
+
|
|
139
|
+
@staticmethod
|
|
140
|
+
def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
141
|
+
num_partitions: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
142
|
+
ctx.num_partitions = num_partitions
|
|
143
|
+
|
|
144
|
+
return forward_reorder(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
145
|
+
n_sparse_blocks, triton_block_size)
|
|
146
|
+
|
|
147
|
+
@staticmethod
|
|
148
|
+
def backward(ctx, grad_output):
|
|
149
|
+
sparsity_layout = ctx.saved_tensors[0]
|
|
150
|
+
num_partitions = ctx.num_partitions
|
|
151
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
152
|
+
triton_block_size = ctx.triton_block_size
|
|
153
|
+
|
|
154
|
+
return split(grad_output, sparsity_layout, num_partitions,
|
|
155
|
+
sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def forward_reorder(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
159
|
+
sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
160
|
+
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
161
|
+
dtype=x.dtype, device=x.device)
|
|
162
|
+
|
|
163
|
+
x_b, x_r, x_c = x.size()
|
|
164
|
+
x_b_s, x_r_s, x_c_s = x.stride()
|
|
165
|
+
s_l_b, s_l_r, s_l_c = sparsity_layout_o.size()
|
|
166
|
+
s_l_b_s, s_l_r_s, s_l_c_s = sparsity_layout_o.stride()
|
|
167
|
+
s_lut_r, s_lut_c = sparsity_lut.shape
|
|
168
|
+
s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
|
|
169
|
+
o_b, o_r, o_c = output.size()
|
|
170
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
171
|
+
|
|
172
|
+
if triton_block_size is None:
|
|
173
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
174
|
+
|
|
175
|
+
triton_grid = lambda meta: [o_b,
|
|
176
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
177
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
178
|
+
|
|
179
|
+
(kernel_blocksparse_reorder[triton_grid]
|
|
180
|
+
(x,
|
|
181
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
182
|
+
s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
|
|
183
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
184
|
+
sparsity_reverse_lut,
|
|
185
|
+
output,
|
|
186
|
+
o_b, o_b_s,
|
|
187
|
+
triton_block_size))
|
|
188
|
+
|
|
189
|
+
# Save for backward pass
|
|
190
|
+
ctx.save_for_backward(sparsity_layout_o)
|
|
191
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
192
|
+
ctx.triton_block_size = triton_block_size
|
|
193
|
+
|
|
194
|
+
return output
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
@triton.jit
|
|
198
|
+
def kernel_blocksparse_reorder(x,
|
|
199
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
200
|
+
s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
|
|
201
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
202
|
+
r_lut,
|
|
203
|
+
o,
|
|
204
|
+
o_b, o_b_s,
|
|
205
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
206
|
+
# Get triton block indices
|
|
207
|
+
pid_blk = tl.program_id(axis=0)
|
|
208
|
+
pid_row = tl.program_id(axis=1)
|
|
209
|
+
pid_col = tl.program_id(axis=2)
|
|
210
|
+
|
|
211
|
+
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
212
|
+
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
213
|
+
spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
214
|
+
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
215
|
+
|
|
216
|
+
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
217
|
+
spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
|
|
218
|
+
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
219
|
+
|
|
220
|
+
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
221
|
+
spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
|
|
222
|
+
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
223
|
+
|
|
224
|
+
# Get reverse sparsity index
|
|
225
|
+
rev_idx_spa_idx = (spa_bat * s_l_b_s +
|
|
226
|
+
spa_row * s_l_r_s +
|
|
227
|
+
spa_col * s_l_c_s)
|
|
228
|
+
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_b * s_l_b_s)
|
|
229
|
+
rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
230
|
+
|
|
231
|
+
if rev_idx_spa == -1:
|
|
232
|
+
assert False, "Invalid sparsity block"
|
|
233
|
+
|
|
234
|
+
blk_x_idx = (rev_idx_spa * x_b_s +
|
|
235
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
236
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
237
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
238
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
239
|
+
|
|
240
|
+
blk_o_idx = (pid_blk * o_b_s +
|
|
241
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
242
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
243
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
244
|
+
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
@@ -56,16 +56,16 @@ def transpose(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, trit
|
|
|
56
56
|
class _BlocksparseTranspose(torch.autograd.Function):
|
|
57
57
|
|
|
58
58
|
@staticmethod
|
|
59
|
-
def forward(ctx, x: Tensor,
|
|
60
|
-
|
|
61
|
-
n_sparse_blocks: int, triton_block_size: int) ->
|
|
59
|
+
def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
60
|
+
sparsity_block_size: int,
|
|
61
|
+
n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
62
62
|
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
63
63
|
dtype=x.dtype, device=x.device)
|
|
64
64
|
|
|
65
65
|
x_b, x_r, x_c = x.size()
|
|
66
66
|
x_b_s, x_r_s, x_c_s = x.stride()
|
|
67
|
-
s_l_b, s_l_r, s_l_c =
|
|
68
|
-
s_l_b_s, s_l_r_s, s_l_c_s =
|
|
67
|
+
s_l_b, s_l_r, s_l_c = sparsity_layout_o.size()
|
|
68
|
+
s_l_b_s, s_l_r_s, s_l_c_s = sparsity_layout_o.stride()
|
|
69
69
|
s_lut_r, s_lut_c = sparsity_lut.shape
|
|
70
70
|
s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
|
|
71
71
|
o_b, o_r, o_c = output.size()
|
|
@@ -89,8 +89,7 @@ class _BlocksparseTranspose(torch.autograd.Function):
|
|
|
89
89
|
triton_block_size))
|
|
90
90
|
|
|
91
91
|
# Save for backward pass
|
|
92
|
-
ctx.save_for_backward(
|
|
93
|
-
ctx.sparsity_layout = sparsity_layout
|
|
92
|
+
ctx.save_for_backward(sparsity_layout_o)
|
|
94
93
|
ctx.sparsity_block_size = sparsity_block_size
|
|
95
94
|
ctx.triton_block_size = triton_block_size
|
|
96
95
|
|
|
@@ -63,6 +63,8 @@ def validate_sparsity(sparsity_block_size: int, *tensor_sparsity_layout_tuples:
|
|
|
63
63
|
for (tensor, sparsity_layout) in tensor_sparsity_layout_tuples:
|
|
64
64
|
_validate_sparsity_layout_values(sparsity_layout)
|
|
65
65
|
|
|
66
|
+
if not sparsity_layout.dim() == 3:
|
|
67
|
+
raise ValueError("Sparsity layout must have exactly 3 dimensions")
|
|
66
68
|
if not (tensor.size(-1) == tensor.size(-2) == sparsity_block_size):
|
|
67
69
|
raise ValueError("Blocks not conforming to sparsity block size")
|
|
68
70
|
if not tensor.size(0) == torch.sum(sparsity_layout.reshape(-1)):
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.6.1
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -31,9 +31,10 @@ Currently supported operations (includes gradient calculation):
|
|
|
31
31
|
- Sparse matrix multiplication (_supports any combination of sparse and dense matrices due to support
|
|
32
32
|
for `sparse = sparse @ sparse` matmul_)
|
|
33
33
|
- Softmax
|
|
34
|
-
-
|
|
34
|
+
- Transpose
|
|
35
35
|
- Gather
|
|
36
36
|
- Scatter (_supports either no reduction or summation, gradients are only available for summation_)
|
|
37
|
+
- Splitting and merging of matrices along the last dimension
|
|
37
38
|
- Conversion to and from sparse form
|
|
38
39
|
- Conversion to different sparsity layouts and different sparsity block sizes
|
|
39
40
|
|
|
@@ -6,6 +6,7 @@ blksprs.egg-info/SOURCES.txt
|
|
|
6
6
|
blksprs.egg-info/dependency_links.txt
|
|
7
7
|
blksprs.egg-info/requires.txt
|
|
8
8
|
blksprs.egg-info/top_level.txt
|
|
9
|
+
blksprs/experimental/distribution_mdi.py
|
|
9
10
|
blksprs/layouting/distribution_layout.py
|
|
10
11
|
blksprs/layouting/sparsity_layout.py
|
|
11
12
|
blksprs/misc/broadcast_ops.py
|
|
@@ -15,6 +16,7 @@ blksprs/ops/conversion.py
|
|
|
15
16
|
blksprs/ops/distribution.py
|
|
16
17
|
blksprs/ops/exp.py
|
|
17
18
|
blksprs/ops/matmul.py
|
|
19
|
+
blksprs/ops/partitioning.py
|
|
18
20
|
blksprs/ops/softmax.py
|
|
19
21
|
blksprs/ops/transpose.py
|
|
20
22
|
blksprs/utils/benchmarking.py
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|