blksprs 1.4.1__tar.gz → 1.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {blksprs-1.4.1 → blksprs-1.5}/PKG-INFO +1 -1
- {blksprs-1.4.1 → blksprs-1.5}/blksprs/__init__.py +4 -1
- blksprs-1.5/blksprs/experimental/distribution_mdi.py +438 -0
- {blksprs-1.4.1 → blksprs-1.5}/blksprs/layouting/distribution_layout.py +4 -17
- {blksprs-1.4.1 → blksprs-1.5}/blksprs/misc/broadcast_ops.py +1 -1
- {blksprs-1.4.1 → blksprs-1.5}/blksprs/misc/row_wise.py +1 -0
- {blksprs-1.4.1 → blksprs-1.5}/blksprs/ops/conversion.py +2 -2
- {blksprs-1.4.1 → blksprs-1.5}/blksprs/ops/matmul.py +2 -1
- {blksprs-1.4.1 → blksprs-1.5}/blksprs/ops/softmax.py +1 -1
- {blksprs-1.4.1 → blksprs-1.5}/blksprs/ops/transpose.py +4 -2
- {blksprs-1.4.1 → blksprs-1.5}/blksprs/utils/tools.py +1 -2
- {blksprs-1.4.1 → blksprs-1.5}/blksprs/utils/validation.py +6 -3
- {blksprs-1.4.1 → blksprs-1.5}/blksprs.egg-info/PKG-INFO +1 -1
- {blksprs-1.4.1 → blksprs-1.5}/blksprs.egg-info/SOURCES.txt +1 -0
- {blksprs-1.4.1 → blksprs-1.5}/pyproject.toml +1 -1
- {blksprs-1.4.1 → blksprs-1.5}/README.md +0 -0
- {blksprs-1.4.1 → blksprs-1.5}/blksprs/layouting/sparsity_layout.py +0 -0
- {blksprs-1.4.1 → blksprs-1.5}/blksprs/misc/repeat_interleave.py +0 -0
- {blksprs-1.4.1 → blksprs-1.5}/blksprs/ops/distribution.py +0 -0
- {blksprs-1.4.1 → blksprs-1.5}/blksprs/ops/exp.py +0 -0
- {blksprs-1.4.1 → blksprs-1.5}/blksprs/utils/benchmarking.py +0 -0
- {blksprs-1.4.1 → blksprs-1.5}/blksprs.egg-info/dependency_links.txt +0 -0
- {blksprs-1.4.1 → blksprs-1.5}/blksprs.egg-info/requires.txt +0 -0
- {blksprs-1.4.1 → blksprs-1.5}/blksprs.egg-info/top_level.txt +0 -0
- {blksprs-1.4.1 → blksprs-1.5}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.5
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -15,4 +15,7 @@ class misc:
|
|
|
15
15
|
from blksprs.misc.row_wise import row_wise_sum, row_wise_max, row_wise_add, row_wise_sub
|
|
16
16
|
|
|
17
17
|
class util:
|
|
18
|
-
from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse, disable_validation
|
|
18
|
+
from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse, disable_validation
|
|
19
|
+
|
|
20
|
+
class experimental:
|
|
21
|
+
from blksprs.experimental.distribution_mdi import gather_mdi
|
|
@@ -0,0 +1,438 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
from triton import language as tl
|
|
5
|
+
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size
|
|
7
|
+
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
8
|
+
validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def gather_mdi(src: Tensor, sparsity_layout_src: Tensor,
|
|
12
|
+
idx_bat: Tensor,
|
|
13
|
+
idx_row: Tensor,
|
|
14
|
+
idx_col: Tensor,
|
|
15
|
+
sparsity_layout_idx: Tensor,
|
|
16
|
+
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
17
|
+
src = src.contiguous()
|
|
18
|
+
idx_bat = idx_bat.contiguous()
|
|
19
|
+
idx_col = idx_col.contiguous()
|
|
20
|
+
|
|
21
|
+
validate_dimensions(src, idx_bat, idx_col)
|
|
22
|
+
validate_contiguous(src, idx_bat, idx_col)
|
|
23
|
+
validate_dtype_int(idx_bat, idx_col)
|
|
24
|
+
validate_device(src, idx_bat, idx_col)
|
|
25
|
+
validate_sparsity(sparsity_block_size, (src, sparsity_layout_src),
|
|
26
|
+
(idx_bat, sparsity_layout_idx), (idx_col, sparsity_layout_idx))
|
|
27
|
+
validate_sparsity_block_size(sparsity_block_size, src, idx_bat, idx_col)
|
|
28
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
29
|
+
|
|
30
|
+
sparsity_layout_x_flat = sparsity_layout_src.reshape(-1)
|
|
31
|
+
sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
|
|
32
|
+
(sparsity_layout_x_flat == 1) -
|
|
33
|
+
(1 * (sparsity_layout_x_flat == 0)))
|
|
34
|
+
|
|
35
|
+
sparsity_lut_i = torch.nonzero(sparsity_layout_idx).contiguous()
|
|
36
|
+
|
|
37
|
+
validate_contiguous(sparsity_layout_src, sparsity_reverse_lut_x,
|
|
38
|
+
sparsity_layout_idx, sparsity_lut_i)
|
|
39
|
+
|
|
40
|
+
return _BlocksparseGatherMDI.apply(src, sparsity_layout_src, sparsity_reverse_lut_x,
|
|
41
|
+
idx_bat, idx_col, sparsity_layout_idx, sparsity_lut_i,
|
|
42
|
+
sparsity_block_size, triton_block_size)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class _BlocksparseGatherMDI(torch.autograd.Function):
|
|
46
|
+
|
|
47
|
+
@staticmethod
|
|
48
|
+
def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
|
|
49
|
+
idx_bat: Tensor, idx_col: Tensor, sparsity_layout_i: Tensor, sparsity_lut_i: Tensor,
|
|
50
|
+
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
51
|
+
output = torch.empty_like(idx_col, dtype=x.dtype)
|
|
52
|
+
|
|
53
|
+
x_b, x_r, x_c = x.size()
|
|
54
|
+
x_b_s, x_r_s, x_c_s = x.stride()
|
|
55
|
+
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
|
|
56
|
+
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout_x.stride()
|
|
57
|
+
i_b, i_r, i_c = idx_col.size()
|
|
58
|
+
i_b_s, i_r_s, i_c_s = idx_col.stride()
|
|
59
|
+
s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
|
|
60
|
+
s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
|
|
61
|
+
o_b, o_r, o_c = output.size()
|
|
62
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
63
|
+
|
|
64
|
+
if triton_block_size is None:
|
|
65
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
66
|
+
|
|
67
|
+
triton_grid = lambda meta: [o_b,
|
|
68
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
69
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
70
|
+
|
|
71
|
+
(_BlocksparseGatherMDI.kernel_blocksparse_gather_mdi[triton_grid]
|
|
72
|
+
(x,
|
|
73
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
74
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
75
|
+
sparsity_reverse_lut_x,
|
|
76
|
+
idx_bat,
|
|
77
|
+
idx_col,
|
|
78
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
79
|
+
output,
|
|
80
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
81
|
+
sparsity_lut_i, s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
|
|
82
|
+
sparsity_block_size,
|
|
83
|
+
triton_block_size))
|
|
84
|
+
|
|
85
|
+
ctx.save_for_backward(sparsity_layout_x, idx_bat, idx_col, sparsity_layout_i)
|
|
86
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
87
|
+
ctx.triton_block_size = triton_block_size
|
|
88
|
+
|
|
89
|
+
return output
|
|
90
|
+
|
|
91
|
+
@staticmethod
|
|
92
|
+
def backward(ctx, grad_output):
|
|
93
|
+
sparsity_layout_x, idx_bat, idx_col, sparsity_layout_i = ctx.saved_tensors
|
|
94
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
95
|
+
triton_block_size = ctx.triton_block_size
|
|
96
|
+
|
|
97
|
+
return scatter_reduce_mdi(grad_output, sparsity_layout_i,
|
|
98
|
+
idx_bat,
|
|
99
|
+
None,
|
|
100
|
+
idx_col,
|
|
101
|
+
sparsity_layout_x,
|
|
102
|
+
sparsity_block_size,
|
|
103
|
+
reduce_op="sum",
|
|
104
|
+
triton_block_size=triton_block_size), None, None, None, None, None, None, None, None
|
|
105
|
+
|
|
106
|
+
@staticmethod
|
|
107
|
+
@triton.jit
|
|
108
|
+
def kernel_blocksparse_gather_mdi(x,
|
|
109
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
110
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
111
|
+
r_lut_x,
|
|
112
|
+
idx_bat,
|
|
113
|
+
idx_col,
|
|
114
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
115
|
+
o,
|
|
116
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
117
|
+
s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
118
|
+
sparsity_block_size,
|
|
119
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
120
|
+
# Get triton block indices
|
|
121
|
+
pid_blk = tl.program_id(axis=0)
|
|
122
|
+
pid_row = tl.program_id(axis=1)
|
|
123
|
+
pid_col = tl.program_id(axis=2)
|
|
124
|
+
|
|
125
|
+
# Load batch index values
|
|
126
|
+
blk_idx_bat_idx = ((pid_blk * i_b_s) +
|
|
127
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
128
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
129
|
+
blk_idx_bat_msk = (blk_idx_bat_idx < i_b * i_b_s)
|
|
130
|
+
blk_idx_bat = tl.load(idx_bat + blk_idx_bat_idx, mask=blk_idx_bat_msk).to(tl.int32)
|
|
131
|
+
|
|
132
|
+
# Get position of current sparsity block row
|
|
133
|
+
spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
|
|
134
|
+
spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
135
|
+
spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
|
|
136
|
+
|
|
137
|
+
# Load column index values
|
|
138
|
+
blk_idx_col_idx = ((pid_blk * i_b_s) +
|
|
139
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
140
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
141
|
+
blk_idx_col_msk = (blk_idx_col_idx < i_b * i_b_s)
|
|
142
|
+
blk_idx_col = tl.load(idx_col + blk_idx_col_idx, mask=blk_idx_col_msk).to(tl.int32)
|
|
143
|
+
|
|
144
|
+
# Get positions of sparsity blocks
|
|
145
|
+
pos_spa_blk_x = blk_idx_col // sparsity_block_size
|
|
146
|
+
pos_spa_col_x = blk_idx_col % sparsity_block_size
|
|
147
|
+
|
|
148
|
+
# Load reverse sparsity indices for x
|
|
149
|
+
rev_idx_spa_x_idx = ((blk_idx_bat * s_l_x_b_s) +
|
|
150
|
+
(spa_row_o * s_l_x_r_s) +
|
|
151
|
+
(pos_spa_blk_x * s_l_x_c_s))
|
|
152
|
+
rev_idx_spa_x_msk = (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
|
|
153
|
+
rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
|
|
154
|
+
|
|
155
|
+
# Load x values
|
|
156
|
+
blk_x_idx = ((rev_idx_spa_x * x_b_s) +
|
|
157
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
158
|
+
(pos_spa_col_x * x_c_s))
|
|
159
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
160
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
161
|
+
|
|
162
|
+
# Store output
|
|
163
|
+
blk_o_idx = ((pid_blk * o_b_s) +
|
|
164
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
165
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
166
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
167
|
+
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def scatter_reduce_mdi(src: Tensor, sparsity_layout_src: Tensor,
|
|
171
|
+
idx_bat: Tensor,
|
|
172
|
+
idx_row: Tensor,
|
|
173
|
+
idx_col: Tensor,
|
|
174
|
+
sparsity_layout_tgt: Tensor,
|
|
175
|
+
sparsity_block_size: int,
|
|
176
|
+
reduce_op: str = "sum", triton_block_size: int = None) -> Tensor:
|
|
177
|
+
src = src.contiguous()
|
|
178
|
+
idx_bat = idx_bat.contiguous()
|
|
179
|
+
idx_col = idx_col.contiguous()
|
|
180
|
+
|
|
181
|
+
validate_dimensions(src, idx_bat, idx_col)
|
|
182
|
+
validate_contiguous(src, idx_bat, idx_col)
|
|
183
|
+
validate_dtype_int(idx_bat, idx_col)
|
|
184
|
+
validate_device(src, idx_bat, idx_col)
|
|
185
|
+
validate_sparsity(sparsity_block_size, (src, sparsity_layout_src),
|
|
186
|
+
(idx_bat, sparsity_layout_src),
|
|
187
|
+
(idx_col, sparsity_layout_src))
|
|
188
|
+
validate_sparsity_block_size(sparsity_block_size, src, idx_bat, idx_col)
|
|
189
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
190
|
+
|
|
191
|
+
if reduce_op not in ["none", "sum"]:
|
|
192
|
+
raise ValueError(f"Reduction operation '{reduce_op}' is not supported")
|
|
193
|
+
|
|
194
|
+
sparsity_lut_x = torch.nonzero(sparsity_layout_src).contiguous()
|
|
195
|
+
|
|
196
|
+
sparsity_layout_o_flat = sparsity_layout_tgt.reshape(-1)
|
|
197
|
+
sparsity_reverse_lut_o = ((torch.cumsum(sparsity_layout_o_flat, dim=-1) - 1) *
|
|
198
|
+
(sparsity_layout_o_flat == 1) -
|
|
199
|
+
(1 * (sparsity_layout_o_flat == 0)))
|
|
200
|
+
|
|
201
|
+
n_sparse_blocks = torch.sum(sparsity_layout_tgt.to(torch.int)).item()
|
|
202
|
+
|
|
203
|
+
validate_contiguous(sparsity_layout_src, sparsity_lut_x,
|
|
204
|
+
sparsity_layout_tgt, sparsity_reverse_lut_o)
|
|
205
|
+
|
|
206
|
+
return _BlocksparseScatterReduceMDI.apply(src, sparsity_layout_src, sparsity_lut_x,
|
|
207
|
+
idx_bat,
|
|
208
|
+
idx_col,
|
|
209
|
+
sparsity_layout_tgt, sparsity_reverse_lut_o,
|
|
210
|
+
sparsity_block_size, n_sparse_blocks,
|
|
211
|
+
reduce_op, triton_block_size)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
class _BlocksparseScatterReduceMDI(torch.autograd.Function):
|
|
215
|
+
|
|
216
|
+
@staticmethod
|
|
217
|
+
def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_lut_x: Tensor,
|
|
218
|
+
idx_bat: Tensor,
|
|
219
|
+
idx_col: Tensor,
|
|
220
|
+
sparsity_layout_o: Tensor, sparsity_reverse_lut_o: Tensor,
|
|
221
|
+
sparsity_block_size: int, n_sparse_blocks: int,
|
|
222
|
+
reduce_op: str, triton_block_size: int) -> Tensor:
|
|
223
|
+
output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
224
|
+
dtype=x.dtype, device=x.device)
|
|
225
|
+
|
|
226
|
+
x_b, x_r, x_c = x.size()
|
|
227
|
+
x_b_s, x_r_s, x_c_s = x.stride()
|
|
228
|
+
s_lut_x_r, s_lut_x_c = sparsity_lut_x.size()
|
|
229
|
+
s_lut_x_r_s, s_lut_x_c_s = sparsity_lut_x.stride()
|
|
230
|
+
i_b, i_r, i_c = idx_col.size()
|
|
231
|
+
i_b_s, i_r_s, i_c_s = idx_col.stride()
|
|
232
|
+
o_b, o_r, o_c = output.size()
|
|
233
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
234
|
+
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
|
|
235
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_o.stride()
|
|
236
|
+
|
|
237
|
+
if triton_block_size is None:
|
|
238
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
239
|
+
|
|
240
|
+
triton_grid = lambda meta: [x_b,
|
|
241
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
242
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
243
|
+
|
|
244
|
+
reduce_op_ind = 0
|
|
245
|
+
if reduce_op == "sum":
|
|
246
|
+
reduce_op_ind = 1
|
|
247
|
+
|
|
248
|
+
(_BlocksparseScatterReduceMDI.kernel_blocksparse_scatter_mdi[triton_grid]
|
|
249
|
+
(x,
|
|
250
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
251
|
+
sparsity_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
252
|
+
idx_bat,
|
|
253
|
+
idx_col,
|
|
254
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
255
|
+
output,
|
|
256
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
257
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
258
|
+
sparsity_reverse_lut_o,
|
|
259
|
+
reduce_op_ind,
|
|
260
|
+
sparsity_block_size,
|
|
261
|
+
triton_block_size))
|
|
262
|
+
|
|
263
|
+
ctx.save_for_backward(sparsity_layout_x, idx_bat, idx_col, sparsity_layout_o)
|
|
264
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
265
|
+
ctx.reduce_op = reduce_op
|
|
266
|
+
ctx.triton_block_size = triton_block_size
|
|
267
|
+
|
|
268
|
+
return output
|
|
269
|
+
|
|
270
|
+
@staticmethod
|
|
271
|
+
def backward(ctx, grad_output):
|
|
272
|
+
sparsity_layout_x, idx_bat, idx_col, sparsity_layout_o = ctx.saved_tensors
|
|
273
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
274
|
+
reduce_op = ctx.reduce_op
|
|
275
|
+
triton_block_size = ctx.triton_block_size
|
|
276
|
+
|
|
277
|
+
if reduce_op == "sum":
|
|
278
|
+
return gather_mdi(grad_output, sparsity_layout_o,
|
|
279
|
+
idx_bat,
|
|
280
|
+
None,
|
|
281
|
+
idx_col,
|
|
282
|
+
sparsity_layout_x, sparsity_block_size,
|
|
283
|
+
triton_block_size=triton_block_size), None, None, None, None, None, None, None, None, None, None
|
|
284
|
+
else:
|
|
285
|
+
raise ValueError(f"Reduction operation '{reduce_op}' does not support backward pass")
|
|
286
|
+
|
|
287
|
+
@staticmethod
|
|
288
|
+
@triton.jit
|
|
289
|
+
def kernel_blocksparse_scatter_mdi(x,
|
|
290
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
291
|
+
s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
292
|
+
idx_bat,
|
|
293
|
+
idx_col,
|
|
294
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
295
|
+
o,
|
|
296
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
297
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
298
|
+
r_lut_o,
|
|
299
|
+
reduce_op_ind,
|
|
300
|
+
sparsity_block_size,
|
|
301
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
302
|
+
# Get triton block indices
|
|
303
|
+
pid_blk = tl.program_id(axis=0)
|
|
304
|
+
pid_row = tl.program_id(axis=1)
|
|
305
|
+
pid_col = tl.program_id(axis=2)
|
|
306
|
+
|
|
307
|
+
# Load batch index values
|
|
308
|
+
blk_idx_bat_idx = ((pid_blk * i_b_s) +
|
|
309
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
310
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
311
|
+
blk_idx_bat_msk = (blk_idx_bat_idx < i_b * i_b_s)
|
|
312
|
+
blk_idx_bat = tl.load(idx_bat + blk_idx_bat_idx, mask=blk_idx_bat_msk).to(tl.int32)
|
|
313
|
+
|
|
314
|
+
# Get position of current sparsity block row
|
|
315
|
+
spa_row_x_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
|
|
316
|
+
spa_row_x_msk = (spa_row_x_idx < s_lut_x_r * s_lut_x_r_s)
|
|
317
|
+
spa_row_x = tl.load(s_lut_x + spa_row_x_idx, mask=spa_row_x_msk)
|
|
318
|
+
|
|
319
|
+
# Load x values
|
|
320
|
+
blk_x_idx = ((pid_blk * x_b_s) +
|
|
321
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
322
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
323
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
324
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
325
|
+
|
|
326
|
+
# Load column index values
|
|
327
|
+
blk_idx_col_idx = ((pid_blk * i_b_s) +
|
|
328
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
329
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
330
|
+
blk_idx_col_msk = (blk_idx_col_idx < i_b * i_b_s)
|
|
331
|
+
blk_idx_col = tl.load(idx_col + blk_idx_col_idx, mask=blk_idx_col_msk).to(tl.int32)
|
|
332
|
+
|
|
333
|
+
# Get positions of sparsity blocks
|
|
334
|
+
pos_spa_blk_o = blk_idx_col // sparsity_block_size
|
|
335
|
+
pos_spa_col_o = blk_idx_col % sparsity_block_size
|
|
336
|
+
|
|
337
|
+
# Load reverse sparsity indices for o
|
|
338
|
+
rev_idx_spa_o_idx = ((blk_idx_bat * s_l_o_b_s) +
|
|
339
|
+
(spa_row_x * s_l_o_r_s) +
|
|
340
|
+
(pos_spa_blk_o * s_l_o_c_s))
|
|
341
|
+
rev_idx_spa_o_msk = (rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
|
|
342
|
+
rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
|
|
343
|
+
|
|
344
|
+
# Store output
|
|
345
|
+
blk_o_idx = ((rev_idx_spa_o * o_b_s) +
|
|
346
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
347
|
+
(pos_spa_col_o * o_c_s))
|
|
348
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
349
|
+
|
|
350
|
+
if reduce_op_ind == 0:
|
|
351
|
+
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
352
|
+
elif reduce_op_ind == 1:
|
|
353
|
+
tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def build_distribution_layout_mdi(idx_bat: Tensor, idx_row: Tensor, idx_col: Tensor, sparsity_layout_idx: Tensor,
|
|
357
|
+
size_target: torch.Size,
|
|
358
|
+
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
359
|
+
validate_dimensions(idx_bat, idx_col)
|
|
360
|
+
validate_contiguous(idx_bat, idx_col)
|
|
361
|
+
validate_device(idx_bat, idx_col)
|
|
362
|
+
|
|
363
|
+
sparsity_lut_i = torch.nonzero(sparsity_layout_idx).contiguous()
|
|
364
|
+
|
|
365
|
+
output = torch.zeros(size_target[0], size_target[1] // sparsity_block_size, size_target[2] // sparsity_block_size,
|
|
366
|
+
dtype=torch.bool, device=idx_col.device)
|
|
367
|
+
|
|
368
|
+
i_b, i_r, i_c = idx_col.size()
|
|
369
|
+
i_b_s, i_r_s, i_c_s = idx_col.stride()
|
|
370
|
+
s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
|
|
371
|
+
s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
|
|
372
|
+
o_b, o_r, o_c = output.size()
|
|
373
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
374
|
+
|
|
375
|
+
if triton_block_size is None:
|
|
376
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
377
|
+
|
|
378
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
379
|
+
|
|
380
|
+
triton_grid = lambda meta: [i_b,
|
|
381
|
+
triton.cdiv(i_r, meta["TRITON_BLOCK_SIZE"]),
|
|
382
|
+
triton.cdiv(i_c, meta["TRITON_BLOCK_SIZE"])]
|
|
383
|
+
|
|
384
|
+
(kernel_distribution_layout_mdi[triton_grid]
|
|
385
|
+
(idx_bat,
|
|
386
|
+
idx_col,
|
|
387
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
388
|
+
sparsity_lut_i,
|
|
389
|
+
s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
|
|
390
|
+
output,
|
|
391
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
392
|
+
sparsity_block_size,
|
|
393
|
+
triton_block_size))
|
|
394
|
+
|
|
395
|
+
return output
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
@triton.jit
|
|
399
|
+
def kernel_distribution_layout_mdi(idx_bat,
|
|
400
|
+
idx_col,
|
|
401
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
402
|
+
s_lut_i,
|
|
403
|
+
s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
|
|
404
|
+
o,
|
|
405
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
406
|
+
sparsity_block_size,
|
|
407
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
408
|
+
# Get triton block indices
|
|
409
|
+
pid_blk = tl.program_id(axis=0)
|
|
410
|
+
pid_row = tl.program_id(axis=1)
|
|
411
|
+
pid_col = tl.program_id(axis=2)
|
|
412
|
+
|
|
413
|
+
# Load batch index values
|
|
414
|
+
blk_idx_bat_idx = ((pid_blk * i_b_s) +
|
|
415
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
416
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
417
|
+
blk_idx_bat_msk = (blk_idx_bat_idx < i_b * i_b_s)
|
|
418
|
+
blk_idx_bat = tl.load(idx_bat + blk_idx_bat_idx, mask=blk_idx_bat_msk).to(tl.int32)
|
|
419
|
+
|
|
420
|
+
# Get position of current sparsity block row
|
|
421
|
+
spa_row_i_idx = (pid_blk * s_lut_i_r_s + 1 * s_lut_i_c_s)
|
|
422
|
+
spa_row_i_msk = (spa_row_i_idx < s_lut_i_r * s_lut_i_r_s)
|
|
423
|
+
spa_row_i = tl.load(s_lut_i + spa_row_i_idx, mask=spa_row_i_msk)
|
|
424
|
+
|
|
425
|
+
blk_i_idx = (pid_blk * i_b_s +
|
|
426
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
427
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
428
|
+
blk_i_msk = (blk_i_idx < i_b * i_b_s)
|
|
429
|
+
blk_i = tl.load(idx_col + blk_i_idx, mask=blk_i_msk)
|
|
430
|
+
|
|
431
|
+
blk_i = blk_i // sparsity_block_size
|
|
432
|
+
blk_v = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), 1, dtype=tl.int32)
|
|
433
|
+
|
|
434
|
+
blk_o_idx = ((blk_idx_bat * o_b_s) +
|
|
435
|
+
(spa_row_i * o_r_s) +
|
|
436
|
+
(blk_i * o_c_s))
|
|
437
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
438
|
+
tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
|
|
@@ -35,8 +35,6 @@ def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
|
|
|
35
35
|
|
|
36
36
|
i_b, i_r, i_c = indices.size()
|
|
37
37
|
i_b_s, i_r_s, i_c_s = indices.stride()
|
|
38
|
-
s_l_i_b, s_l_i_r, s_l_i_c = sparsity_layout_indices.size()
|
|
39
|
-
s_l_i_b_s, s_l_i_r_s, s_l_i_c_s = sparsity_layout_indices.stride()
|
|
40
38
|
s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
|
|
41
39
|
s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
|
|
42
40
|
o_b, o_r, o_c = output.size()
|
|
@@ -54,12 +52,10 @@ def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
|
|
|
54
52
|
(kernel_distribution_layout[triton_grid]
|
|
55
53
|
(indices,
|
|
56
54
|
i_b, i_b_s, i_r_s, i_c_s,
|
|
57
|
-
sparsity_layout_indices,
|
|
58
|
-
s_l_i_b, s_l_i_b_s, s_l_i_r, s_l_i_r_s, s_l_i_c, s_l_i_c_s,
|
|
59
55
|
sparsity_lut_i,
|
|
60
|
-
s_lut_i_r, s_lut_i_r_s,
|
|
56
|
+
s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
|
|
61
57
|
output,
|
|
62
|
-
o_b, o_b_s,
|
|
58
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
63
59
|
sparsity_block_size,
|
|
64
60
|
triton_block_size))
|
|
65
61
|
|
|
@@ -69,12 +65,10 @@ def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
|
|
|
69
65
|
@triton.jit
|
|
70
66
|
def kernel_distribution_layout(i,
|
|
71
67
|
i_b, i_b_s, i_r_s, i_c_s,
|
|
72
|
-
s_l_i,
|
|
73
|
-
s_l_i_b, s_l_i_b_s, s_l_i_r, s_l_i_r_s, s_l_i_c, s_l_i_c_s,
|
|
74
68
|
s_lut_i,
|
|
75
|
-
s_lut_i_r, s_lut_i_r_s,
|
|
69
|
+
s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
|
|
76
70
|
o,
|
|
77
|
-
o_b, o_b_s,
|
|
71
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
78
72
|
sparsity_block_size,
|
|
79
73
|
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
80
74
|
# Get triton block indices
|
|
@@ -105,10 +99,3 @@ def kernel_distribution_layout(i,
|
|
|
105
99
|
(blk_i * o_c_s))
|
|
106
100
|
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
107
101
|
tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
|
|
108
|
-
|
|
109
|
-
# if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
|
|
110
|
-
# blk_o_idx = (pid_bat * o_b_s +
|
|
111
|
-
# (((pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_r_s +
|
|
112
|
-
# ((pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_c_s))
|
|
113
|
-
# blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
114
|
-
# tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
|
|
@@ -41,7 +41,7 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
|
41
41
|
|
|
42
42
|
validate_contiguous(sparsity_layout_output, sparsity_lut_o)
|
|
43
43
|
|
|
44
|
-
output = torch.zeros(n_sparse_blocks, sparsity_block_size, sparsity_block_size, device=x.device)
|
|
44
|
+
output = torch.zeros(n_sparse_blocks, sparsity_block_size, sparsity_block_size, dtype=x.dtype, device=x.device)
|
|
45
45
|
|
|
46
46
|
x_b, x_c = x.size()
|
|
47
47
|
x_b_s, x_c_s = x.stride()
|
|
@@ -56,6 +56,7 @@ def row_wise_sum(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
|
56
56
|
output = torch.zeros(size=(n_sparse_blocks_output,
|
|
57
57
|
sparsity_block_size,
|
|
58
58
|
1 if flag_slice_only else sparsity_block_size),
|
|
59
|
+
dtype=x.dtype,
|
|
59
60
|
device=x.device)
|
|
60
61
|
|
|
61
62
|
x_b, x_r, x_c = x.size()
|
|
@@ -186,8 +186,8 @@ class _BlocksparseToSparse(torch.autograd.Function):
|
|
|
186
186
|
def forward(ctx, x: Tensor,
|
|
187
187
|
sparsity_layout: Tensor, sparsity_lut: Tensor,
|
|
188
188
|
sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
189
|
-
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
190
|
-
device=x.device)
|
|
189
|
+
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
190
|
+
dtype=x.dtype, device=x.device)
|
|
191
191
|
|
|
192
192
|
x_b, x_r, x_c = x.size()
|
|
193
193
|
x_b_s, x_r_s, x_c_s = x.stride()
|
|
@@ -78,7 +78,8 @@ class _BlocksparseMatmulSSS(torch.autograd.Function):
|
|
|
78
78
|
sparsity_layout_y: Tensor, sparsity_reverse_lut_y: Tensor,
|
|
79
79
|
sparsity_layout_o: Tensor, sparsity_lut_o: Tensor,
|
|
80
80
|
sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
81
|
-
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
81
|
+
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
82
|
+
dtype=x.dtype, device=x.device)
|
|
82
83
|
|
|
83
84
|
x_b, x_r, x_c = x.size()
|
|
84
85
|
x_b_s, x_r_s, x_c_s = x.stride()
|
|
@@ -127,7 +127,7 @@ class _BlocksparseSoftmax(torch.autograd.Function):
|
|
|
127
127
|
s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
|
|
128
128
|
s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = sparsity_layout_s.stride()
|
|
129
129
|
|
|
130
|
-
grad_x = torch.empty_like(o)
|
|
130
|
+
grad_x = torch.empty_like(o, dtype=torch.float)
|
|
131
131
|
|
|
132
132
|
triton_grid = lambda meta: [o_b,
|
|
133
133
|
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
@@ -59,7 +59,8 @@ class _BlocksparseTranspose(torch.autograd.Function):
|
|
|
59
59
|
def forward(ctx, x: Tensor,
|
|
60
60
|
sparsity_layout: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor, sparsity_block_size: int,
|
|
61
61
|
n_sparse_blocks: int, triton_block_size: int) -> (Tensor, Tensor):
|
|
62
|
-
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
62
|
+
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
63
|
+
dtype=x.dtype, device=x.device)
|
|
63
64
|
|
|
64
65
|
x_b, x_r, x_c = x.size()
|
|
65
66
|
x_b_s, x_r_s, x_c_s = x.stride()
|
|
@@ -101,7 +102,8 @@ class _BlocksparseTranspose(torch.autograd.Function):
|
|
|
101
102
|
sparsity_block_size = ctx.sparsity_block_size
|
|
102
103
|
triton_block_size = ctx.triton_block_size
|
|
103
104
|
|
|
104
|
-
return transpose(grad_output, sparsity_layout, sparsity_block_size, triton_block_size)[
|
|
105
|
+
return transpose(grad_output, sparsity_layout, sparsity_block_size, triton_block_size)[
|
|
106
|
+
0], None, None, None, None, None, None
|
|
105
107
|
|
|
106
108
|
@staticmethod
|
|
107
109
|
@triton.jit
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import torch
|
|
2
1
|
from torch import Tensor, Size
|
|
3
2
|
|
|
4
3
|
from blksprs.utils.validation import _set_skip_validation
|
|
@@ -8,7 +7,7 @@ def do_shape_blocksparse(x: Tensor):
|
|
|
8
7
|
if x.dim() == 3:
|
|
9
8
|
return x.contiguous(), x.size()
|
|
10
9
|
|
|
11
|
-
return x.reshape(-1, x.size(-2), x.size(-1)), x.size()
|
|
10
|
+
return x.reshape(-1, x.size(-2), x.size(-1)).contiguous(), x.size()
|
|
12
11
|
|
|
13
12
|
|
|
14
13
|
def undo_shape_blocksparse(x: Tensor, shape: Size):
|
|
@@ -3,13 +3,13 @@ from torch import Tensor
|
|
|
3
3
|
|
|
4
4
|
VALIDATION = True
|
|
5
5
|
|
|
6
|
-
def validate_dimensions(*tensors: Tensor) -> None:
|
|
6
|
+
def validate_dimensions(*tensors: Tensor, dims=3) -> None:
|
|
7
7
|
if _check_skip_validation():
|
|
8
8
|
return
|
|
9
9
|
|
|
10
10
|
for tensor in tensors:
|
|
11
|
-
if tensor.dim() !=
|
|
12
|
-
raise ValueError("Tensor must have
|
|
11
|
+
if tensor.dim() != dims:
|
|
12
|
+
raise ValueError(f"Tensor must have {dims} dimensions")
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
def validate_contiguous(*tensors: Tensor) -> None:
|
|
@@ -91,6 +91,9 @@ def validate_triton_block_size(triton_block_size: int, sparsity_block_size: int)
|
|
|
91
91
|
if triton_block_size is None:
|
|
92
92
|
return
|
|
93
93
|
|
|
94
|
+
if not (triton_block_size & (triton_block_size - 1)) == 0:
|
|
95
|
+
raise ValueError("Triton block size must be a power of 2")
|
|
96
|
+
|
|
94
97
|
if triton_block_size > sparsity_block_size:
|
|
95
98
|
raise ValueError("Triton block size cannot be larger than sparsity block size")
|
|
96
99
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.5
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -6,6 +6,7 @@ blksprs.egg-info/SOURCES.txt
|
|
|
6
6
|
blksprs.egg-info/dependency_links.txt
|
|
7
7
|
blksprs.egg-info/requires.txt
|
|
8
8
|
blksprs.egg-info/top_level.txt
|
|
9
|
+
blksprs/experimental/distribution_mdi.py
|
|
9
10
|
blksprs/layouting/distribution_layout.py
|
|
10
11
|
blksprs/layouting/sparsity_layout.py
|
|
11
12
|
blksprs/misc/broadcast_ops.py
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|