blksprs 1.4.2__tar.gz → 1.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {blksprs-1.4.2 → blksprs-1.5}/PKG-INFO +1 -1
- {blksprs-1.4.2 → blksprs-1.5}/blksprs/__init__.py +4 -1
- blksprs-1.5/blksprs/experimental/distribution_mdi.py +438 -0
- {blksprs-1.4.2 → blksprs-1.5}/blksprs/layouting/distribution_layout.py +4 -17
- {blksprs-1.4.2 → blksprs-1.5}/blksprs.egg-info/PKG-INFO +1 -1
- {blksprs-1.4.2 → blksprs-1.5}/blksprs.egg-info/SOURCES.txt +1 -0
- {blksprs-1.4.2 → blksprs-1.5}/pyproject.toml +1 -1
- {blksprs-1.4.2 → blksprs-1.5}/README.md +0 -0
- {blksprs-1.4.2 → blksprs-1.5}/blksprs/layouting/sparsity_layout.py +0 -0
- {blksprs-1.4.2 → blksprs-1.5}/blksprs/misc/broadcast_ops.py +0 -0
- {blksprs-1.4.2 → blksprs-1.5}/blksprs/misc/repeat_interleave.py +0 -0
- {blksprs-1.4.2 → blksprs-1.5}/blksprs/misc/row_wise.py +0 -0
- {blksprs-1.4.2 → blksprs-1.5}/blksprs/ops/conversion.py +0 -0
- {blksprs-1.4.2 → blksprs-1.5}/blksprs/ops/distribution.py +0 -0
- {blksprs-1.4.2 → blksprs-1.5}/blksprs/ops/exp.py +0 -0
- {blksprs-1.4.2 → blksprs-1.5}/blksprs/ops/matmul.py +0 -0
- {blksprs-1.4.2 → blksprs-1.5}/blksprs/ops/softmax.py +0 -0
- {blksprs-1.4.2 → blksprs-1.5}/blksprs/ops/transpose.py +0 -0
- {blksprs-1.4.2 → blksprs-1.5}/blksprs/utils/benchmarking.py +0 -0
- {blksprs-1.4.2 → blksprs-1.5}/blksprs/utils/tools.py +0 -0
- {blksprs-1.4.2 → blksprs-1.5}/blksprs/utils/validation.py +0 -0
- {blksprs-1.4.2 → blksprs-1.5}/blksprs.egg-info/dependency_links.txt +0 -0
- {blksprs-1.4.2 → blksprs-1.5}/blksprs.egg-info/requires.txt +0 -0
- {blksprs-1.4.2 → blksprs-1.5}/blksprs.egg-info/top_level.txt +0 -0
- {blksprs-1.4.2 → blksprs-1.5}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.5
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -15,4 +15,7 @@ class misc:
|
|
|
15
15
|
from blksprs.misc.row_wise import row_wise_sum, row_wise_max, row_wise_add, row_wise_sub
|
|
16
16
|
|
|
17
17
|
class util:
|
|
18
|
-
from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse, disable_validation
|
|
18
|
+
from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse, disable_validation
|
|
19
|
+
|
|
20
|
+
class experimental:
|
|
21
|
+
from blksprs.experimental.distribution_mdi import gather_mdi
|
|
@@ -0,0 +1,438 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
from triton import language as tl
|
|
5
|
+
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size
|
|
7
|
+
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
8
|
+
validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def gather_mdi(src: Tensor, sparsity_layout_src: Tensor,
|
|
12
|
+
idx_bat: Tensor,
|
|
13
|
+
idx_row: Tensor,
|
|
14
|
+
idx_col: Tensor,
|
|
15
|
+
sparsity_layout_idx: Tensor,
|
|
16
|
+
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
17
|
+
src = src.contiguous()
|
|
18
|
+
idx_bat = idx_bat.contiguous()
|
|
19
|
+
idx_col = idx_col.contiguous()
|
|
20
|
+
|
|
21
|
+
validate_dimensions(src, idx_bat, idx_col)
|
|
22
|
+
validate_contiguous(src, idx_bat, idx_col)
|
|
23
|
+
validate_dtype_int(idx_bat, idx_col)
|
|
24
|
+
validate_device(src, idx_bat, idx_col)
|
|
25
|
+
validate_sparsity(sparsity_block_size, (src, sparsity_layout_src),
|
|
26
|
+
(idx_bat, sparsity_layout_idx), (idx_col, sparsity_layout_idx))
|
|
27
|
+
validate_sparsity_block_size(sparsity_block_size, src, idx_bat, idx_col)
|
|
28
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
29
|
+
|
|
30
|
+
sparsity_layout_x_flat = sparsity_layout_src.reshape(-1)
|
|
31
|
+
sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
|
|
32
|
+
(sparsity_layout_x_flat == 1) -
|
|
33
|
+
(1 * (sparsity_layout_x_flat == 0)))
|
|
34
|
+
|
|
35
|
+
sparsity_lut_i = torch.nonzero(sparsity_layout_idx).contiguous()
|
|
36
|
+
|
|
37
|
+
validate_contiguous(sparsity_layout_src, sparsity_reverse_lut_x,
|
|
38
|
+
sparsity_layout_idx, sparsity_lut_i)
|
|
39
|
+
|
|
40
|
+
return _BlocksparseGatherMDI.apply(src, sparsity_layout_src, sparsity_reverse_lut_x,
|
|
41
|
+
idx_bat, idx_col, sparsity_layout_idx, sparsity_lut_i,
|
|
42
|
+
sparsity_block_size, triton_block_size)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class _BlocksparseGatherMDI(torch.autograd.Function):
|
|
46
|
+
|
|
47
|
+
@staticmethod
|
|
48
|
+
def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
|
|
49
|
+
idx_bat: Tensor, idx_col: Tensor, sparsity_layout_i: Tensor, sparsity_lut_i: Tensor,
|
|
50
|
+
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
51
|
+
output = torch.empty_like(idx_col, dtype=x.dtype)
|
|
52
|
+
|
|
53
|
+
x_b, x_r, x_c = x.size()
|
|
54
|
+
x_b_s, x_r_s, x_c_s = x.stride()
|
|
55
|
+
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
|
|
56
|
+
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout_x.stride()
|
|
57
|
+
i_b, i_r, i_c = idx_col.size()
|
|
58
|
+
i_b_s, i_r_s, i_c_s = idx_col.stride()
|
|
59
|
+
s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
|
|
60
|
+
s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
|
|
61
|
+
o_b, o_r, o_c = output.size()
|
|
62
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
63
|
+
|
|
64
|
+
if triton_block_size is None:
|
|
65
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
66
|
+
|
|
67
|
+
triton_grid = lambda meta: [o_b,
|
|
68
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
69
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
70
|
+
|
|
71
|
+
(_BlocksparseGatherMDI.kernel_blocksparse_gather_mdi[triton_grid]
|
|
72
|
+
(x,
|
|
73
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
74
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
75
|
+
sparsity_reverse_lut_x,
|
|
76
|
+
idx_bat,
|
|
77
|
+
idx_col,
|
|
78
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
79
|
+
output,
|
|
80
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
81
|
+
sparsity_lut_i, s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
|
|
82
|
+
sparsity_block_size,
|
|
83
|
+
triton_block_size))
|
|
84
|
+
|
|
85
|
+
ctx.save_for_backward(sparsity_layout_x, idx_bat, idx_col, sparsity_layout_i)
|
|
86
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
87
|
+
ctx.triton_block_size = triton_block_size
|
|
88
|
+
|
|
89
|
+
return output
|
|
90
|
+
|
|
91
|
+
@staticmethod
|
|
92
|
+
def backward(ctx, grad_output):
|
|
93
|
+
sparsity_layout_x, idx_bat, idx_col, sparsity_layout_i = ctx.saved_tensors
|
|
94
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
95
|
+
triton_block_size = ctx.triton_block_size
|
|
96
|
+
|
|
97
|
+
return scatter_reduce_mdi(grad_output, sparsity_layout_i,
|
|
98
|
+
idx_bat,
|
|
99
|
+
None,
|
|
100
|
+
idx_col,
|
|
101
|
+
sparsity_layout_x,
|
|
102
|
+
sparsity_block_size,
|
|
103
|
+
reduce_op="sum",
|
|
104
|
+
triton_block_size=triton_block_size), None, None, None, None, None, None, None, None
|
|
105
|
+
|
|
106
|
+
@staticmethod
|
|
107
|
+
@triton.jit
|
|
108
|
+
def kernel_blocksparse_gather_mdi(x,
|
|
109
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
110
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
111
|
+
r_lut_x,
|
|
112
|
+
idx_bat,
|
|
113
|
+
idx_col,
|
|
114
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
115
|
+
o,
|
|
116
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
117
|
+
s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
118
|
+
sparsity_block_size,
|
|
119
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
120
|
+
# Get triton block indices
|
|
121
|
+
pid_blk = tl.program_id(axis=0)
|
|
122
|
+
pid_row = tl.program_id(axis=1)
|
|
123
|
+
pid_col = tl.program_id(axis=2)
|
|
124
|
+
|
|
125
|
+
# Load batch index values
|
|
126
|
+
blk_idx_bat_idx = ((pid_blk * i_b_s) +
|
|
127
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
128
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
129
|
+
blk_idx_bat_msk = (blk_idx_bat_idx < i_b * i_b_s)
|
|
130
|
+
blk_idx_bat = tl.load(idx_bat + blk_idx_bat_idx, mask=blk_idx_bat_msk).to(tl.int32)
|
|
131
|
+
|
|
132
|
+
# Get position of current sparsity block row
|
|
133
|
+
spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
|
|
134
|
+
spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
135
|
+
spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
|
|
136
|
+
|
|
137
|
+
# Load column index values
|
|
138
|
+
blk_idx_col_idx = ((pid_blk * i_b_s) +
|
|
139
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
140
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
141
|
+
blk_idx_col_msk = (blk_idx_col_idx < i_b * i_b_s)
|
|
142
|
+
blk_idx_col = tl.load(idx_col + blk_idx_col_idx, mask=blk_idx_col_msk).to(tl.int32)
|
|
143
|
+
|
|
144
|
+
# Get positions of sparsity blocks
|
|
145
|
+
pos_spa_blk_x = blk_idx_col // sparsity_block_size
|
|
146
|
+
pos_spa_col_x = blk_idx_col % sparsity_block_size
|
|
147
|
+
|
|
148
|
+
# Load reverse sparsity indices for x
|
|
149
|
+
rev_idx_spa_x_idx = ((blk_idx_bat * s_l_x_b_s) +
|
|
150
|
+
(spa_row_o * s_l_x_r_s) +
|
|
151
|
+
(pos_spa_blk_x * s_l_x_c_s))
|
|
152
|
+
rev_idx_spa_x_msk = (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
|
|
153
|
+
rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
|
|
154
|
+
|
|
155
|
+
# Load x values
|
|
156
|
+
blk_x_idx = ((rev_idx_spa_x * x_b_s) +
|
|
157
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
158
|
+
(pos_spa_col_x * x_c_s))
|
|
159
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
160
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
161
|
+
|
|
162
|
+
# Store output
|
|
163
|
+
blk_o_idx = ((pid_blk * o_b_s) +
|
|
164
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
165
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
166
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
167
|
+
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def scatter_reduce_mdi(src: Tensor, sparsity_layout_src: Tensor,
|
|
171
|
+
idx_bat: Tensor,
|
|
172
|
+
idx_row: Tensor,
|
|
173
|
+
idx_col: Tensor,
|
|
174
|
+
sparsity_layout_tgt: Tensor,
|
|
175
|
+
sparsity_block_size: int,
|
|
176
|
+
reduce_op: str = "sum", triton_block_size: int = None) -> Tensor:
|
|
177
|
+
src = src.contiguous()
|
|
178
|
+
idx_bat = idx_bat.contiguous()
|
|
179
|
+
idx_col = idx_col.contiguous()
|
|
180
|
+
|
|
181
|
+
validate_dimensions(src, idx_bat, idx_col)
|
|
182
|
+
validate_contiguous(src, idx_bat, idx_col)
|
|
183
|
+
validate_dtype_int(idx_bat, idx_col)
|
|
184
|
+
validate_device(src, idx_bat, idx_col)
|
|
185
|
+
validate_sparsity(sparsity_block_size, (src, sparsity_layout_src),
|
|
186
|
+
(idx_bat, sparsity_layout_src),
|
|
187
|
+
(idx_col, sparsity_layout_src))
|
|
188
|
+
validate_sparsity_block_size(sparsity_block_size, src, idx_bat, idx_col)
|
|
189
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
190
|
+
|
|
191
|
+
if reduce_op not in ["none", "sum"]:
|
|
192
|
+
raise ValueError(f"Reduction operation '{reduce_op}' is not supported")
|
|
193
|
+
|
|
194
|
+
sparsity_lut_x = torch.nonzero(sparsity_layout_src).contiguous()
|
|
195
|
+
|
|
196
|
+
sparsity_layout_o_flat = sparsity_layout_tgt.reshape(-1)
|
|
197
|
+
sparsity_reverse_lut_o = ((torch.cumsum(sparsity_layout_o_flat, dim=-1) - 1) *
|
|
198
|
+
(sparsity_layout_o_flat == 1) -
|
|
199
|
+
(1 * (sparsity_layout_o_flat == 0)))
|
|
200
|
+
|
|
201
|
+
n_sparse_blocks = torch.sum(sparsity_layout_tgt.to(torch.int)).item()
|
|
202
|
+
|
|
203
|
+
validate_contiguous(sparsity_layout_src, sparsity_lut_x,
|
|
204
|
+
sparsity_layout_tgt, sparsity_reverse_lut_o)
|
|
205
|
+
|
|
206
|
+
return _BlocksparseScatterReduceMDI.apply(src, sparsity_layout_src, sparsity_lut_x,
|
|
207
|
+
idx_bat,
|
|
208
|
+
idx_col,
|
|
209
|
+
sparsity_layout_tgt, sparsity_reverse_lut_o,
|
|
210
|
+
sparsity_block_size, n_sparse_blocks,
|
|
211
|
+
reduce_op, triton_block_size)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
class _BlocksparseScatterReduceMDI(torch.autograd.Function):
|
|
215
|
+
|
|
216
|
+
@staticmethod
|
|
217
|
+
def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_lut_x: Tensor,
|
|
218
|
+
idx_bat: Tensor,
|
|
219
|
+
idx_col: Tensor,
|
|
220
|
+
sparsity_layout_o: Tensor, sparsity_reverse_lut_o: Tensor,
|
|
221
|
+
sparsity_block_size: int, n_sparse_blocks: int,
|
|
222
|
+
reduce_op: str, triton_block_size: int) -> Tensor:
|
|
223
|
+
output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
224
|
+
dtype=x.dtype, device=x.device)
|
|
225
|
+
|
|
226
|
+
x_b, x_r, x_c = x.size()
|
|
227
|
+
x_b_s, x_r_s, x_c_s = x.stride()
|
|
228
|
+
s_lut_x_r, s_lut_x_c = sparsity_lut_x.size()
|
|
229
|
+
s_lut_x_r_s, s_lut_x_c_s = sparsity_lut_x.stride()
|
|
230
|
+
i_b, i_r, i_c = idx_col.size()
|
|
231
|
+
i_b_s, i_r_s, i_c_s = idx_col.stride()
|
|
232
|
+
o_b, o_r, o_c = output.size()
|
|
233
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
234
|
+
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
|
|
235
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_o.stride()
|
|
236
|
+
|
|
237
|
+
if triton_block_size is None:
|
|
238
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
239
|
+
|
|
240
|
+
triton_grid = lambda meta: [x_b,
|
|
241
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
242
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
243
|
+
|
|
244
|
+
reduce_op_ind = 0
|
|
245
|
+
if reduce_op == "sum":
|
|
246
|
+
reduce_op_ind = 1
|
|
247
|
+
|
|
248
|
+
(_BlocksparseScatterReduceMDI.kernel_blocksparse_scatter_mdi[triton_grid]
|
|
249
|
+
(x,
|
|
250
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
251
|
+
sparsity_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
252
|
+
idx_bat,
|
|
253
|
+
idx_col,
|
|
254
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
255
|
+
output,
|
|
256
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
257
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
258
|
+
sparsity_reverse_lut_o,
|
|
259
|
+
reduce_op_ind,
|
|
260
|
+
sparsity_block_size,
|
|
261
|
+
triton_block_size))
|
|
262
|
+
|
|
263
|
+
ctx.save_for_backward(sparsity_layout_x, idx_bat, idx_col, sparsity_layout_o)
|
|
264
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
265
|
+
ctx.reduce_op = reduce_op
|
|
266
|
+
ctx.triton_block_size = triton_block_size
|
|
267
|
+
|
|
268
|
+
return output
|
|
269
|
+
|
|
270
|
+
@staticmethod
|
|
271
|
+
def backward(ctx, grad_output):
|
|
272
|
+
sparsity_layout_x, idx_bat, idx_col, sparsity_layout_o = ctx.saved_tensors
|
|
273
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
274
|
+
reduce_op = ctx.reduce_op
|
|
275
|
+
triton_block_size = ctx.triton_block_size
|
|
276
|
+
|
|
277
|
+
if reduce_op == "sum":
|
|
278
|
+
return gather_mdi(grad_output, sparsity_layout_o,
|
|
279
|
+
idx_bat,
|
|
280
|
+
None,
|
|
281
|
+
idx_col,
|
|
282
|
+
sparsity_layout_x, sparsity_block_size,
|
|
283
|
+
triton_block_size=triton_block_size), None, None, None, None, None, None, None, None, None, None
|
|
284
|
+
else:
|
|
285
|
+
raise ValueError(f"Reduction operation '{reduce_op}' does not support backward pass")
|
|
286
|
+
|
|
287
|
+
@staticmethod
|
|
288
|
+
@triton.jit
|
|
289
|
+
def kernel_blocksparse_scatter_mdi(x,
|
|
290
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
291
|
+
s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
292
|
+
idx_bat,
|
|
293
|
+
idx_col,
|
|
294
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
295
|
+
o,
|
|
296
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
297
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
298
|
+
r_lut_o,
|
|
299
|
+
reduce_op_ind,
|
|
300
|
+
sparsity_block_size,
|
|
301
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
302
|
+
# Get triton block indices
|
|
303
|
+
pid_blk = tl.program_id(axis=0)
|
|
304
|
+
pid_row = tl.program_id(axis=1)
|
|
305
|
+
pid_col = tl.program_id(axis=2)
|
|
306
|
+
|
|
307
|
+
# Load batch index values
|
|
308
|
+
blk_idx_bat_idx = ((pid_blk * i_b_s) +
|
|
309
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
310
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
311
|
+
blk_idx_bat_msk = (blk_idx_bat_idx < i_b * i_b_s)
|
|
312
|
+
blk_idx_bat = tl.load(idx_bat + blk_idx_bat_idx, mask=blk_idx_bat_msk).to(tl.int32)
|
|
313
|
+
|
|
314
|
+
# Get position of current sparsity block row
|
|
315
|
+
spa_row_x_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
|
|
316
|
+
spa_row_x_msk = (spa_row_x_idx < s_lut_x_r * s_lut_x_r_s)
|
|
317
|
+
spa_row_x = tl.load(s_lut_x + spa_row_x_idx, mask=spa_row_x_msk)
|
|
318
|
+
|
|
319
|
+
# Load x values
|
|
320
|
+
blk_x_idx = ((pid_blk * x_b_s) +
|
|
321
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
322
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
323
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
324
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
325
|
+
|
|
326
|
+
# Load column index values
|
|
327
|
+
blk_idx_col_idx = ((pid_blk * i_b_s) +
|
|
328
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
329
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
330
|
+
blk_idx_col_msk = (blk_idx_col_idx < i_b * i_b_s)
|
|
331
|
+
blk_idx_col = tl.load(idx_col + blk_idx_col_idx, mask=blk_idx_col_msk).to(tl.int32)
|
|
332
|
+
|
|
333
|
+
# Get positions of sparsity blocks
|
|
334
|
+
pos_spa_blk_o = blk_idx_col // sparsity_block_size
|
|
335
|
+
pos_spa_col_o = blk_idx_col % sparsity_block_size
|
|
336
|
+
|
|
337
|
+
# Load reverse sparsity indices for o
|
|
338
|
+
rev_idx_spa_o_idx = ((blk_idx_bat * s_l_o_b_s) +
|
|
339
|
+
(spa_row_x * s_l_o_r_s) +
|
|
340
|
+
(pos_spa_blk_o * s_l_o_c_s))
|
|
341
|
+
rev_idx_spa_o_msk = (rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
|
|
342
|
+
rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
|
|
343
|
+
|
|
344
|
+
# Store output
|
|
345
|
+
blk_o_idx = ((rev_idx_spa_o * o_b_s) +
|
|
346
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
347
|
+
(pos_spa_col_o * o_c_s))
|
|
348
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
349
|
+
|
|
350
|
+
if reduce_op_ind == 0:
|
|
351
|
+
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
352
|
+
elif reduce_op_ind == 1:
|
|
353
|
+
tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def build_distribution_layout_mdi(idx_bat: Tensor, idx_row: Tensor, idx_col: Tensor, sparsity_layout_idx: Tensor,
|
|
357
|
+
size_target: torch.Size,
|
|
358
|
+
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
359
|
+
validate_dimensions(idx_bat, idx_col)
|
|
360
|
+
validate_contiguous(idx_bat, idx_col)
|
|
361
|
+
validate_device(idx_bat, idx_col)
|
|
362
|
+
|
|
363
|
+
sparsity_lut_i = torch.nonzero(sparsity_layout_idx).contiguous()
|
|
364
|
+
|
|
365
|
+
output = torch.zeros(size_target[0], size_target[1] // sparsity_block_size, size_target[2] // sparsity_block_size,
|
|
366
|
+
dtype=torch.bool, device=idx_col.device)
|
|
367
|
+
|
|
368
|
+
i_b, i_r, i_c = idx_col.size()
|
|
369
|
+
i_b_s, i_r_s, i_c_s = idx_col.stride()
|
|
370
|
+
s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
|
|
371
|
+
s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
|
|
372
|
+
o_b, o_r, o_c = output.size()
|
|
373
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
374
|
+
|
|
375
|
+
if triton_block_size is None:
|
|
376
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
377
|
+
|
|
378
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
379
|
+
|
|
380
|
+
triton_grid = lambda meta: [i_b,
|
|
381
|
+
triton.cdiv(i_r, meta["TRITON_BLOCK_SIZE"]),
|
|
382
|
+
triton.cdiv(i_c, meta["TRITON_BLOCK_SIZE"])]
|
|
383
|
+
|
|
384
|
+
(kernel_distribution_layout_mdi[triton_grid]
|
|
385
|
+
(idx_bat,
|
|
386
|
+
idx_col,
|
|
387
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
388
|
+
sparsity_lut_i,
|
|
389
|
+
s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
|
|
390
|
+
output,
|
|
391
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
392
|
+
sparsity_block_size,
|
|
393
|
+
triton_block_size))
|
|
394
|
+
|
|
395
|
+
return output
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
@triton.jit
|
|
399
|
+
def kernel_distribution_layout_mdi(idx_bat,
|
|
400
|
+
idx_col,
|
|
401
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
402
|
+
s_lut_i,
|
|
403
|
+
s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
|
|
404
|
+
o,
|
|
405
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
406
|
+
sparsity_block_size,
|
|
407
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
408
|
+
# Get triton block indices
|
|
409
|
+
pid_blk = tl.program_id(axis=0)
|
|
410
|
+
pid_row = tl.program_id(axis=1)
|
|
411
|
+
pid_col = tl.program_id(axis=2)
|
|
412
|
+
|
|
413
|
+
# Load batch index values
|
|
414
|
+
blk_idx_bat_idx = ((pid_blk * i_b_s) +
|
|
415
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
416
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
417
|
+
blk_idx_bat_msk = (blk_idx_bat_idx < i_b * i_b_s)
|
|
418
|
+
blk_idx_bat = tl.load(idx_bat + blk_idx_bat_idx, mask=blk_idx_bat_msk).to(tl.int32)
|
|
419
|
+
|
|
420
|
+
# Get position of current sparsity block row
|
|
421
|
+
spa_row_i_idx = (pid_blk * s_lut_i_r_s + 1 * s_lut_i_c_s)
|
|
422
|
+
spa_row_i_msk = (spa_row_i_idx < s_lut_i_r * s_lut_i_r_s)
|
|
423
|
+
spa_row_i = tl.load(s_lut_i + spa_row_i_idx, mask=spa_row_i_msk)
|
|
424
|
+
|
|
425
|
+
blk_i_idx = (pid_blk * i_b_s +
|
|
426
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
427
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
428
|
+
blk_i_msk = (blk_i_idx < i_b * i_b_s)
|
|
429
|
+
blk_i = tl.load(idx_col + blk_i_idx, mask=blk_i_msk)
|
|
430
|
+
|
|
431
|
+
blk_i = blk_i // sparsity_block_size
|
|
432
|
+
blk_v = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), 1, dtype=tl.int32)
|
|
433
|
+
|
|
434
|
+
blk_o_idx = ((blk_idx_bat * o_b_s) +
|
|
435
|
+
(spa_row_i * o_r_s) +
|
|
436
|
+
(blk_i * o_c_s))
|
|
437
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
438
|
+
tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
|
|
@@ -35,8 +35,6 @@ def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
|
|
|
35
35
|
|
|
36
36
|
i_b, i_r, i_c = indices.size()
|
|
37
37
|
i_b_s, i_r_s, i_c_s = indices.stride()
|
|
38
|
-
s_l_i_b, s_l_i_r, s_l_i_c = sparsity_layout_indices.size()
|
|
39
|
-
s_l_i_b_s, s_l_i_r_s, s_l_i_c_s = sparsity_layout_indices.stride()
|
|
40
38
|
s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
|
|
41
39
|
s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
|
|
42
40
|
o_b, o_r, o_c = output.size()
|
|
@@ -54,12 +52,10 @@ def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
|
|
|
54
52
|
(kernel_distribution_layout[triton_grid]
|
|
55
53
|
(indices,
|
|
56
54
|
i_b, i_b_s, i_r_s, i_c_s,
|
|
57
|
-
sparsity_layout_indices,
|
|
58
|
-
s_l_i_b, s_l_i_b_s, s_l_i_r, s_l_i_r_s, s_l_i_c, s_l_i_c_s,
|
|
59
55
|
sparsity_lut_i,
|
|
60
|
-
s_lut_i_r, s_lut_i_r_s,
|
|
56
|
+
s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
|
|
61
57
|
output,
|
|
62
|
-
o_b, o_b_s,
|
|
58
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
63
59
|
sparsity_block_size,
|
|
64
60
|
triton_block_size))
|
|
65
61
|
|
|
@@ -69,12 +65,10 @@ def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
|
|
|
69
65
|
@triton.jit
|
|
70
66
|
def kernel_distribution_layout(i,
|
|
71
67
|
i_b, i_b_s, i_r_s, i_c_s,
|
|
72
|
-
s_l_i,
|
|
73
|
-
s_l_i_b, s_l_i_b_s, s_l_i_r, s_l_i_r_s, s_l_i_c, s_l_i_c_s,
|
|
74
68
|
s_lut_i,
|
|
75
|
-
s_lut_i_r, s_lut_i_r_s,
|
|
69
|
+
s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
|
|
76
70
|
o,
|
|
77
|
-
o_b, o_b_s,
|
|
71
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
78
72
|
sparsity_block_size,
|
|
79
73
|
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
80
74
|
# Get triton block indices
|
|
@@ -105,10 +99,3 @@ def kernel_distribution_layout(i,
|
|
|
105
99
|
(blk_i * o_c_s))
|
|
106
100
|
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
107
101
|
tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
|
|
108
|
-
|
|
109
|
-
# if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
|
|
110
|
-
# blk_o_idx = (pid_bat * o_b_s +
|
|
111
|
-
# (((pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_r_s +
|
|
112
|
-
# ((pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_c_s))
|
|
113
|
-
# blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
114
|
-
# tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.5
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -6,6 +6,7 @@ blksprs.egg-info/SOURCES.txt
|
|
|
6
6
|
blksprs.egg-info/dependency_links.txt
|
|
7
7
|
blksprs.egg-info/requires.txt
|
|
8
8
|
blksprs.egg-info/top_level.txt
|
|
9
|
+
blksprs/experimental/distribution_mdi.py
|
|
9
10
|
blksprs/layouting/distribution_layout.py
|
|
10
11
|
blksprs/layouting/sparsity_layout.py
|
|
11
12
|
blksprs/misc/broadcast_ops.py
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|