blksprs 1.11__py3-none-any.whl → 2.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +2 -5
- blksprs/layouting/distribution_layout.py +32 -25
- blksprs/layouting/sparsity_layout.py +65 -52
- blksprs/ops/conversion.py +421 -399
- blksprs/ops/distribution.py +404 -366
- blksprs/ops/flow.py +125 -106
- blksprs/ops/matmul.py +220 -204
- blksprs/ops/misc/broadcast_ops.py +53 -35
- blksprs/ops/misc/row_wise.py +151 -91
- blksprs/ops/partitioning.py +136 -132
- blksprs/ops/repeat.py +115 -120
- blksprs/ops/softmax.py +274 -246
- blksprs/ops/transpose.py +52 -51
- blksprs/utils/benchmarking.py +3 -3
- blksprs/utils/tools.py +31 -4
- blksprs/utils/validation.py +0 -14
- {blksprs-1.11.dist-info → blksprs-2.0rc1.dist-info}/METADATA +42 -36
- blksprs-2.0rc1.dist-info/RECORD +22 -0
- {blksprs-1.11.dist-info → blksprs-2.0rc1.dist-info}/WHEEL +1 -1
- blksprs/utils/layout_utils.py +0 -17
- blksprs-1.11.dist-info/RECORD +0 -23
- {blksprs-1.11.dist-info → blksprs-2.0rc1.dist-info}/top_level.txt +0 -0
blksprs/ops/conversion.py
CHANGED
|
@@ -1,321 +1,333 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import triton
|
|
3
3
|
from torch import Tensor
|
|
4
|
+
from torch._library.triton import wrap_triton, triton_op
|
|
4
5
|
from triton import language as tl
|
|
5
6
|
|
|
6
7
|
from blksprs.layouting.sparsity_layout import build_sparsity_layout_adaption
|
|
7
8
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
8
|
-
from blksprs.utils.tools import
|
|
9
|
+
from blksprs.utils.tools import stride, get_autotune_configs
|
|
9
10
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
10
|
-
validate_sparsity, validate_sparsity_block_size,
|
|
11
|
+
validate_sparsity, validate_sparsity_block_size, validate_sparsity_dense
|
|
11
12
|
|
|
12
13
|
|
|
13
|
-
def
|
|
14
|
-
|
|
15
|
-
"""Wrapper for ``to_dense``.
|
|
14
|
+
def to_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int) -> BlksprsTensor:
|
|
15
|
+
"""Wrapper for ``to_sparse``.
|
|
16
16
|
|
|
17
17
|
"""
|
|
18
|
-
return
|
|
18
|
+
return to_sparse(x, sparsity_layout, sparsity_block_size)
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
def
|
|
22
|
-
|
|
23
|
-
"""Converts a block-sparse tensor in
|
|
24
|
-
|
|
21
|
+
def to_sparse(x: Tensor, sparsity_layout: Tensor,
|
|
22
|
+
sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
|
|
23
|
+
"""Converts a block-sparse tensor in regular form to a block-sparse tensor in compressed form based on the given
|
|
24
|
+
sparsity layout.
|
|
25
25
|
|
|
26
|
-
|
|
27
|
-
x (
|
|
26
|
+
Args:
|
|
27
|
+
x (Tensor): A block-sparse tensor in regular form.
|
|
28
28
|
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
29
29
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
30
|
-
fill_value (float): The value to fill the resulting dense tensor with where the block-sparse tensor is not
|
|
31
|
-
present (default ``0``).
|
|
32
|
-
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
33
30
|
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
34
31
|
|
|
35
32
|
Returns:
|
|
36
|
-
|
|
33
|
+
BlksprsTensor: The block-sparse tensor converted to compressed form.
|
|
37
34
|
|
|
38
35
|
"""
|
|
39
36
|
x = x.contiguous()
|
|
40
37
|
|
|
41
38
|
validate_dimensions(x)
|
|
42
|
-
validate_contiguous(x
|
|
39
|
+
validate_contiguous(x)
|
|
43
40
|
validate_device(x)
|
|
44
|
-
|
|
41
|
+
validate_sparsity_dense(sparsity_block_size, (x, sparsity_layout))
|
|
45
42
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
46
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
47
43
|
|
|
48
|
-
lut =
|
|
44
|
+
lut = to_sparse_build_lut(lut, sparsity_layout)
|
|
49
45
|
|
|
50
46
|
if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
|
|
51
|
-
return x
|
|
47
|
+
return BlksprsTensor(x)
|
|
52
48
|
|
|
53
|
-
return
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
49
|
+
return BlksprsTensor(to_sparse_forward(x, sparsity_layout,
|
|
50
|
+
lut["sparsity_lut"], sparsity_block_size, lut["n_sparse_blocks"]))
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@triton_op("blksprs::to_sparse", mutates_args={})
|
|
54
|
+
def to_sparse_forward(x: Tensor, _: Tensor,
|
|
55
|
+
sparsity_lut: Tensor, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
56
|
+
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
57
|
+
dtype=x.dtype, device=x.device)
|
|
58
|
+
|
|
59
|
+
x_b, x_r, x_c = x.size()
|
|
60
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
61
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
62
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
63
|
+
o_b, o_r, o_c = output.size()
|
|
64
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
65
|
+
|
|
66
|
+
triton_grid = lambda meta: [o_b,
|
|
67
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
68
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
69
|
+
|
|
70
|
+
(wrap_triton(to_sparse_kernel)[triton_grid]
|
|
71
|
+
(x, x_b, x_b_s, x_r_s, x_c_s,
|
|
72
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
73
|
+
output, o_b_s, o_r_s, o_c_s,
|
|
74
|
+
sparsity_block_size))
|
|
75
|
+
|
|
76
|
+
return output
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def to_sparse_backward(ctx, grad_output):
|
|
80
|
+
sparsity_layout = ctx.saved_tensors[0]
|
|
81
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
82
|
+
|
|
83
|
+
return to_dense(grad_output, sparsity_layout, sparsity_block_size), None, None, None, None
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@triton.autotune(
|
|
87
|
+
configs=get_autotune_configs(),
|
|
88
|
+
key=[],
|
|
89
|
+
)
|
|
90
|
+
@triton.jit
|
|
91
|
+
def to_sparse_kernel(x,
|
|
92
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
93
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
94
|
+
o,
|
|
95
|
+
o_b_s, o_r_s, o_c_s,
|
|
96
|
+
sparsity_block_size,
|
|
97
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
98
|
+
# Get triton block indices
|
|
99
|
+
pid_blk = tl.program_id(axis=0)
|
|
100
|
+
pid_row = tl.program_id(axis=1)
|
|
101
|
+
pid_col = tl.program_id(axis=2)
|
|
102
|
+
|
|
103
|
+
# Get valid triton block size
|
|
104
|
+
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
105
|
+
|
|
106
|
+
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
107
|
+
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
108
|
+
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
109
|
+
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
110
|
+
|
|
111
|
+
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
112
|
+
spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
|
|
113
|
+
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
114
|
+
|
|
115
|
+
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
116
|
+
spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
|
|
117
|
+
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
118
|
+
|
|
119
|
+
# Load block from dense tensor
|
|
120
|
+
blk_d_idx = (spa_bat * x_b_s +
|
|
121
|
+
((pid_row * val_tbs + spa_row * sparsity_block_size +
|
|
122
|
+
tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
123
|
+
((pid_col * val_tbs + spa_col * sparsity_block_size +
|
|
124
|
+
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
125
|
+
blk_d_msk = ((blk_d_idx >= 0 and
|
|
126
|
+
blk_d_idx < x_b * x_b_s) and
|
|
127
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
128
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
129
|
+
blk_d = tl.load(x + blk_d_idx, mask=blk_d_msk)
|
|
130
|
+
|
|
131
|
+
# Store block in sparse tensor
|
|
132
|
+
blk_o_idx = ((pid_blk * o_b_s) +
|
|
133
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
134
|
+
((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
|
|
135
|
+
blk_o_msk = ((blk_o_idx >= 0 and
|
|
136
|
+
blk_o_idx < (pid_blk + 1) * o_b_s) and
|
|
137
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
138
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
139
|
+
tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def to_sparse_build_lut(lut: dict, sparsity_layout: Tensor):
|
|
143
|
+
if lut is None:
|
|
144
|
+
lut = dict()
|
|
145
|
+
|
|
146
|
+
if "sparsity_lut" not in lut:
|
|
147
|
+
sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
|
|
148
|
+
lut["sparsity_lut"] = sparsity_lut
|
|
149
|
+
|
|
150
|
+
if "n_sparse_blocks" not in lut:
|
|
151
|
+
n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
|
|
152
|
+
lut["n_sparse_blocks"] = n_sparse_blocks
|
|
153
|
+
|
|
154
|
+
validate_contiguous(sparsity_layout, lut["sparsity_lut"])
|
|
155
|
+
|
|
156
|
+
return lut
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
# noinspection PyUnusedLocal
|
|
160
|
+
def to_sparse_setup_context(ctx, inputs, output):
|
|
161
|
+
(_, sparsity_layout, _, sparsity_block_size, _) = inputs
|
|
162
|
+
|
|
163
|
+
ctx.save_for_backward(sparsity_layout, )
|
|
164
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
to_sparse_forward.register_autograd(to_sparse_backward, setup_context=to_sparse_setup_context)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def from_blksprs(x: BlksprsTensor, sparsity_layout: Tensor,
|
|
171
|
+
sparsity_block_size: int, fill_value: float = 0, lut: dict = None) -> Tensor:
|
|
172
|
+
"""Wrapper for ``to_dense``.
|
|
169
173
|
|
|
170
174
|
"""
|
|
171
|
-
return
|
|
175
|
+
return to_dense(x, sparsity_layout, sparsity_block_size, fill_value=fill_value, lut=lut)
|
|
172
176
|
|
|
173
177
|
|
|
174
|
-
def
|
|
175
|
-
|
|
176
|
-
"""Converts a block-sparse tensor in
|
|
177
|
-
|
|
178
|
+
def to_dense(x: BlksprsTensor, sparsity_layout: Tensor,
|
|
179
|
+
sparsity_block_size: int, fill_value: float = 0, lut: dict = None) -> Tensor:
|
|
180
|
+
"""Converts a block-sparse tensor in compressed form to a block-sparse tensor in regular form based on the given
|
|
181
|
+
sparsity layout.
|
|
178
182
|
|
|
179
|
-
|
|
180
|
-
x (
|
|
183
|
+
Args:
|
|
184
|
+
x (BlksprsTensor): A block-sparse tensor in compressed form.
|
|
181
185
|
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
182
186
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
183
|
-
|
|
187
|
+
fill_value (float): The value to fill the resulting dense tensor with where the block-sparse tensor is not
|
|
188
|
+
present (default ``0``).
|
|
184
189
|
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
185
190
|
|
|
186
191
|
Returns:
|
|
187
|
-
|
|
192
|
+
Tensor: The block-sparse tensor converted to regular form.
|
|
188
193
|
|
|
189
194
|
"""
|
|
190
195
|
x = x.contiguous()
|
|
191
196
|
|
|
192
197
|
validate_dimensions(x)
|
|
193
|
-
validate_contiguous(x)
|
|
198
|
+
validate_contiguous(x, sparsity_layout)
|
|
194
199
|
validate_device(x)
|
|
195
|
-
|
|
200
|
+
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
196
201
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
197
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
198
202
|
|
|
199
|
-
lut =
|
|
203
|
+
lut = to_dense_build_lut(lut, sparsity_layout)
|
|
200
204
|
|
|
201
205
|
if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
|
|
202
|
-
return
|
|
206
|
+
return x
|
|
203
207
|
|
|
204
|
-
return
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
208
|
+
return to_dense_forward(x, sparsity_layout,
|
|
209
|
+
lut["sparsity_reverse_lut"], sparsity_block_size, fill_value)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
@triton_op("blksprs::to_dense", mutates_args={})
|
|
213
|
+
def to_dense_forward(x: Tensor, sparsity_layout: Tensor,
|
|
214
|
+
sparsity_reverse_lut: Tensor,
|
|
215
|
+
sparsity_block_size: int, fill_value: float) -> Tensor:
|
|
216
|
+
output = torch.full(size=(sparsity_layout.size(0), sparsity_layout.size(1) * sparsity_block_size,
|
|
217
|
+
sparsity_layout.size(2) * sparsity_block_size), fill_value=fill_value,
|
|
218
|
+
dtype=x.dtype, device=x.device)
|
|
219
|
+
|
|
220
|
+
x_b, x_r, x_c = x.shape
|
|
221
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
222
|
+
s_l_b, s_l_r, s_l_c = sparsity_layout.size()
|
|
223
|
+
s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout)
|
|
224
|
+
o_b, o_r, o_c = output.size()
|
|
225
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
226
|
+
|
|
227
|
+
triton_grid = lambda meta: [o_b,
|
|
228
|
+
triton.cdiv(o_r, min(meta["sparsity_block_size"], meta["TRITON_BLOCK_SIZE"])),
|
|
229
|
+
triton.cdiv(o_c, min(meta["sparsity_block_size"], meta["TRITON_BLOCK_SIZE"]))]
|
|
230
|
+
|
|
231
|
+
(wrap_triton(to_dense_kernel)[triton_grid]
|
|
232
|
+
(x,
|
|
233
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
234
|
+
s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
|
|
235
|
+
sparsity_reverse_lut,
|
|
236
|
+
output,
|
|
237
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
238
|
+
sparsity_block_size))
|
|
239
|
+
|
|
240
|
+
return output
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def to_dense_backward(ctx, grad_output):
|
|
244
|
+
sparsity_layout = ctx.saved_tensors[0]
|
|
245
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
246
|
+
|
|
247
|
+
return to_sparse(grad_output, sparsity_layout, sparsity_block_size), None, None, None, None
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
@triton.autotune(
|
|
251
|
+
configs=get_autotune_configs(),
|
|
252
|
+
key=[],
|
|
253
|
+
)
|
|
254
|
+
@triton.jit
|
|
255
|
+
def to_dense_kernel(x,
|
|
256
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
257
|
+
s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
|
|
258
|
+
sparsity_reverse_lut,
|
|
259
|
+
o,
|
|
260
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
261
|
+
sparsity_block_size,
|
|
262
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
263
|
+
# Get triton block indices
|
|
264
|
+
pid_blk = tl.program_id(axis=0)
|
|
265
|
+
pid_row = tl.program_id(axis=1)
|
|
266
|
+
pid_col = tl.program_id(axis=2)
|
|
267
|
+
|
|
268
|
+
# Get valid triton block size
|
|
269
|
+
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
270
|
+
|
|
271
|
+
# Get sparsity index of current block
|
|
272
|
+
spa_row = (pid_row * val_tbs) // sparsity_block_size
|
|
273
|
+
spa_col = (pid_col * val_tbs) // sparsity_block_size
|
|
274
|
+
|
|
275
|
+
# Get reverse sparsity index for current block
|
|
276
|
+
rev_idx_spa_idx = (pid_blk * s_l_b_s + spa_row * s_l_r_s + spa_col * s_l_c_s)
|
|
277
|
+
rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_b * s_l_b_s)
|
|
278
|
+
rev_idx_spa = tl.load(sparsity_reverse_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
279
|
+
|
|
280
|
+
# If block is present commence operations
|
|
281
|
+
if rev_idx_spa >= 0:
|
|
282
|
+
blk_idx = (rev_idx_spa * x_b_s +
|
|
283
|
+
(((pid_row % (sparsity_block_size // val_tbs)) * val_tbs +
|
|
284
|
+
tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
285
|
+
(((pid_col % (sparsity_block_size // val_tbs)) * val_tbs +
|
|
286
|
+
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
287
|
+
blk_msk = ((blk_idx >= 0 and
|
|
288
|
+
blk_idx < x_b * x_b_s) and
|
|
289
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
290
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
291
|
+
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
292
|
+
|
|
293
|
+
o_idx = (pid_blk * o_b_s +
|
|
294
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
295
|
+
((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
296
|
+
o_msk = ((o_idx >= 0 and o_idx < o_b * o_b_s) and
|
|
297
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
298
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
299
|
+
tl.store(o + o_idx, blk, o_msk)
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def to_dense_build_lut(lut: dict, sparsity_layout: Tensor):
|
|
303
|
+
if lut is None:
|
|
304
|
+
lut = dict()
|
|
305
|
+
|
|
306
|
+
if "sparsity_reverse_lut" not in lut:
|
|
307
|
+
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
|
308
|
+
sparsity_reverse_lut = ((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
309
|
+
(sparsity_layout_flat == 1) -
|
|
310
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
311
|
+
lut["sparsity_reverse_lut"] = sparsity_reverse_lut
|
|
312
|
+
|
|
313
|
+
validate_contiguous(lut["sparsity_reverse_lut"])
|
|
314
|
+
|
|
315
|
+
return lut
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
# noinspection PyUnusedLocal
|
|
319
|
+
def to_dense_setup_context(ctx, inputs, output):
|
|
320
|
+
(_, sparsity_layout, _, sparsity_block_size, _) = inputs
|
|
321
|
+
|
|
322
|
+
ctx.save_for_backward(sparsity_layout)
|
|
323
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
to_dense_forward.register_autograd(to_dense_backward, setup_context=to_dense_setup_context)
|
|
314
327
|
|
|
315
328
|
|
|
316
329
|
def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_size_from: int,
|
|
317
|
-
sparsity_block_size_to: int, sparsity_layout_to: Tensor = None,
|
|
318
|
-
triton_block_size: int = None) -> (BlksprsTensor, Tensor):
|
|
330
|
+
sparsity_block_size_to: int, sparsity_layout_to: Tensor = None) -> (BlksprsTensor, Tensor):
|
|
319
331
|
"""Adapts the sparsity layout of a block-sparse tensor, resulting in a new block-sparse tensor in compressed form
|
|
320
332
|
conforming to the new sparsity layout (and sparsity block size) definition.
|
|
321
333
|
|
|
@@ -325,7 +337,6 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
|
|
|
325
337
|
sparsity_block_size_from (int): The size of the sparsity blocks of the input sparsity layout.
|
|
326
338
|
sparsity_block_size_to (int): The size of the sparsity blocks of the output sparsity layout.
|
|
327
339
|
sparsity_layout_to (Tensor): The sparsity layout of the output block-sparse tensor (default ``None``).
|
|
328
|
-
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
329
340
|
|
|
330
341
|
Returns:
|
|
331
342
|
BlksprsTensor: The block-sparse tensor in compressed form with the adapted sparsity layout and sparsity block size.
|
|
@@ -340,8 +351,6 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
|
|
|
340
351
|
validate_sparsity(sparsity_block_size_from, (x, sparsity_layout_from))
|
|
341
352
|
validate_sparsity_block_size(sparsity_block_size_from, x)
|
|
342
353
|
validate_sparsity_block_size(sparsity_block_size_to)
|
|
343
|
-
min_sparsity_block_size = min(sparsity_block_size_from, sparsity_block_size_to)
|
|
344
|
-
validate_triton_block_size(triton_block_size, min_sparsity_block_size)
|
|
345
354
|
|
|
346
355
|
sparsity_layout_from_flat = sparsity_layout_from.reshape(-1)
|
|
347
356
|
sparsity_reverse_lut_from = ((torch.cumsum(sparsity_layout_from_flat, dim=-1) - 1) *
|
|
@@ -350,8 +359,7 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
|
|
|
350
359
|
|
|
351
360
|
if sparsity_layout_to is None:
|
|
352
361
|
sparsity_layout_to = build_sparsity_layout_adaption(x, sparsity_layout_from,
|
|
353
|
-
sparsity_block_size_from, sparsity_block_size_to
|
|
354
|
-
triton_block_size)
|
|
362
|
+
sparsity_block_size_from, sparsity_block_size_to)
|
|
355
363
|
|
|
356
364
|
sparsity_lut_to = torch.nonzero(sparsity_layout_to).contiguous()
|
|
357
365
|
|
|
@@ -362,134 +370,148 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
|
|
|
362
370
|
if (sparsity_block_size_from == sparsity_block_size_to) and torch.equal(sparsity_layout_from, sparsity_layout_to):
|
|
363
371
|
return BlksprsTensor(x), sparsity_layout_to
|
|
364
372
|
|
|
365
|
-
return BlksprsTensor(
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
373
|
+
return BlksprsTensor(adapt_layout_forward(x,
|
|
374
|
+
sparsity_layout_from, sparsity_reverse_lut_from,
|
|
375
|
+
sparsity_block_size_from,
|
|
376
|
+
sparsity_layout_to, sparsity_lut_to,
|
|
377
|
+
sparsity_block_size_to,
|
|
378
|
+
n_sparse_blocks_to)), sparsity_layout_to
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
@triton_op("blksprs::adapt_layout", mutates_args={})
|
|
382
|
+
def adapt_layout_forward(x: Tensor,
|
|
383
|
+
sparsity_layout_from: Tensor, sparsity_reverse_lut_from: Tensor,
|
|
384
|
+
sparsity_block_size_from: int,
|
|
385
|
+
_: Tensor, sparsity_lut_to: Tensor,
|
|
386
|
+
sparsity_block_size_to: int,
|
|
387
|
+
n_sparse_blocks_to: int) -> Tensor:
|
|
388
|
+
output = torch.zeros(size=(n_sparse_blocks_to, sparsity_block_size_to, sparsity_block_size_to),
|
|
389
|
+
dtype=x.dtype, device=x.device)
|
|
390
|
+
|
|
391
|
+
x_b, x_r, x_c = x.size()
|
|
392
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
393
|
+
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_from.size()
|
|
394
|
+
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_from)
|
|
395
|
+
o_b, o_r, o_c = output.size()
|
|
396
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
397
|
+
s_lut_o_r, s_lut_o_c = sparsity_lut_to.size()
|
|
398
|
+
s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_to)
|
|
399
|
+
|
|
400
|
+
triton_grid = lambda meta: [o_b,
|
|
401
|
+
triton.cdiv(o_r, min(meta["sparsity_block_size_from"], meta["sparsity_block_size_to"],
|
|
402
|
+
meta["TRITON_BLOCK_SIZE"])),
|
|
403
|
+
triton.cdiv(o_c, min(meta["sparsity_block_size_from"], meta["sparsity_block_size_to"],
|
|
404
|
+
meta["TRITON_BLOCK_SIZE"]))]
|
|
405
|
+
|
|
406
|
+
(wrap_triton(adapt_layout_kernel)[triton_grid]
|
|
407
|
+
(x,
|
|
408
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
409
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
410
|
+
sparsity_reverse_lut_from,
|
|
411
|
+
output,
|
|
412
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
413
|
+
sparsity_lut_to, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
414
|
+
sparsity_block_size_from,
|
|
415
|
+
sparsity_block_size_to))
|
|
416
|
+
|
|
417
|
+
return output
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
def adapt_layout_backward(ctx, grad_output):
|
|
421
|
+
x, sparsity_layout_from, sparsity_layout_to = ctx.saved_tensors
|
|
422
|
+
sparsity_block_size_from = ctx.sparsity_block_size_from
|
|
423
|
+
sparsity_block_size_to = ctx.sparsity_block_size_to
|
|
424
|
+
|
|
425
|
+
return adapt_layout(
|
|
426
|
+
grad_output, sparsity_layout_to, sparsity_block_size_to, sparsity_block_size_from,
|
|
427
|
+
sparsity_layout_to=sparsity_layout_from)[0], None, None, None, None, None, None, None
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
@triton.autotune(
|
|
431
|
+
configs=get_autotune_configs(),
|
|
432
|
+
key=[],
|
|
433
|
+
reset_to_zero=["o"]
|
|
434
|
+
)
|
|
435
|
+
@triton.jit
|
|
436
|
+
def adapt_layout_kernel(x,
|
|
437
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
438
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
439
|
+
r_lut_x,
|
|
440
|
+
o,
|
|
441
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
442
|
+
s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
443
|
+
sparsity_block_size_from,
|
|
444
|
+
sparsity_block_size_to,
|
|
445
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
446
|
+
# Get triton block indices
|
|
447
|
+
pid_blk = tl.program_id(axis=0)
|
|
448
|
+
pid_row = tl.program_id(axis=1)
|
|
449
|
+
pid_col = tl.program_id(axis=2)
|
|
450
|
+
|
|
451
|
+
# Get valid triton block size (Triton can only handle 2-valued min)
|
|
452
|
+
val_tbs = min(min(sparsity_block_size_from, sparsity_block_size_to), TRITON_BLOCK_SIZE)
|
|
453
|
+
|
|
454
|
+
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
455
|
+
spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
|
|
456
|
+
spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
457
|
+
spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
|
|
458
|
+
|
|
459
|
+
spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
|
|
460
|
+
spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
461
|
+
spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
|
|
462
|
+
|
|
463
|
+
spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
|
|
464
|
+
spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
465
|
+
spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
|
|
466
|
+
|
|
467
|
+
# Get equivalent sparsity block in from layout
|
|
468
|
+
spa_bat_x = spa_bat_o
|
|
469
|
+
spa_row_x = (spa_row_o * sparsity_block_size_to + pid_row * val_tbs) // sparsity_block_size_from
|
|
470
|
+
spa_col_x = (spa_col_o * sparsity_block_size_to + pid_col * val_tbs) // sparsity_block_size_from
|
|
471
|
+
|
|
472
|
+
# Get reverse sparsity indices for x
|
|
473
|
+
rev_idx_spa_x_idx = (spa_bat_x * s_l_x_b_s +
|
|
474
|
+
spa_row_x * s_l_x_r_s +
|
|
475
|
+
spa_col_x * s_l_x_c_s)
|
|
476
|
+
rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
|
|
477
|
+
rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
|
|
478
|
+
|
|
479
|
+
# If block is present commence operations
|
|
480
|
+
if rev_idx_spa_x >= 0:
|
|
481
|
+
# Calculate triton block size shifts
|
|
482
|
+
shift_row_x = ((spa_row_o * sparsity_block_size_to + pid_row * val_tbs)
|
|
483
|
+
% sparsity_block_size_from) // val_tbs
|
|
484
|
+
shift_col_x = ((spa_col_o * sparsity_block_size_to + pid_col * val_tbs)
|
|
485
|
+
% sparsity_block_size_from) // val_tbs
|
|
486
|
+
|
|
487
|
+
# Load x values
|
|
488
|
+
blk_x_idx = ((rev_idx_spa_x * x_b_s) +
|
|
489
|
+
((shift_row_x * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
490
|
+
((shift_col_x * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
491
|
+
blk_x_msk = ((blk_x_idx >= 0 and
|
|
492
|
+
blk_x_idx < x_b * x_b_s) and
|
|
493
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
494
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
495
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
496
|
+
|
|
497
|
+
# Store output
|
|
498
|
+
blk_o_idx = ((pid_blk * o_b_s) +
|
|
499
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
500
|
+
((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
501
|
+
blk_o_msk = ((blk_o_idx >= 0 and
|
|
502
|
+
blk_o_idx < o_b * o_b_s) and
|
|
503
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
504
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
505
|
+
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
# noinspection PyUnusedLocal
|
|
509
|
+
def adapt_layout_setup_context(ctx, inputs, output):
|
|
510
|
+
(x, sparsity_layout_from, _, sparsity_block_size_from, sparsity_layout_to, _, sparsity_block_size_to, _) = inputs
|
|
511
|
+
|
|
512
|
+
ctx.save_for_backward(x, sparsity_layout_from, sparsity_layout_to)
|
|
513
|
+
ctx.sparsity_block_size_from = sparsity_block_size_from
|
|
514
|
+
ctx.sparsity_block_size_to = sparsity_block_size_to
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
adapt_layout_forward.register_autograd(adapt_layout_backward, setup_context=adapt_layout_setup_context)
|