blksprs 1.10.2__py3-none-any.whl → 2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +4 -6
- blksprs/layouting/distribution_layout.py +64 -48
- blksprs/layouting/sparsity_layout.py +96 -72
- blksprs/ops/conversion.py +350 -312
- blksprs/ops/distribution.py +320 -266
- blksprs/ops/flow.py +135 -89
- blksprs/ops/matmul.py +184 -151
- blksprs/ops/misc/broadcast_ops.py +68 -53
- blksprs/ops/misc/row_wise.py +223 -176
- blksprs/ops/partitioning.py +140 -89
- blksprs/ops/repeat.py +118 -108
- blksprs/ops/softmax.py +201 -167
- blksprs/ops/transpose.py +71 -131
- blksprs/utils/autotuning.py +78 -0
- blksprs/utils/benchmarking.py +3 -3
- blksprs/utils/processing.py +2 -1
- blksprs/utils/tools.py +5 -6
- blksprs/utils/validation.py +22 -16
- {blksprs-1.10.2.dist-info → blksprs-2.0.dist-info}/METADATA +55 -36
- blksprs-2.0.dist-info/RECORD +23 -0
- {blksprs-1.10.2.dist-info → blksprs-2.0.dist-info}/WHEEL +1 -1
- blksprs/ops/misc/exp.py +0 -104
- blksprs/utils/layout_utils.py +0 -17
- blksprs-1.10.2.dist-info/RECORD +0 -24
- {blksprs-1.10.2.dist-info → blksprs-2.0.dist-info}/top_level.txt +0 -0
blksprs/ops/conversion.py
CHANGED
|
@@ -1,25 +1,181 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import triton
|
|
3
3
|
from torch import Tensor
|
|
4
|
+
from torch._library.triton import wrap_triton, triton_op
|
|
4
5
|
from triton import language as tl
|
|
5
6
|
|
|
6
7
|
from blksprs.layouting.sparsity_layout import build_sparsity_layout_adaption
|
|
7
8
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
8
|
-
from blksprs.utils.tools import
|
|
9
|
+
from blksprs.utils.tools import stride
|
|
10
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs, prune_autotune_configs_conversion
|
|
9
11
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
10
|
-
validate_sparsity, validate_sparsity_block_size,
|
|
12
|
+
validate_sparsity, validate_sparsity_block_size, validate_sparsity_dense
|
|
11
13
|
|
|
12
14
|
|
|
13
|
-
def
|
|
14
|
-
|
|
15
|
+
def to_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int) -> BlksprsTensor:
|
|
16
|
+
"""Wrapper for ``to_sparse``.
|
|
17
|
+
|
|
18
|
+
"""
|
|
19
|
+
return to_sparse(x, sparsity_layout, sparsity_block_size)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
23
|
+
def to_sparse(x: Tensor, sparsity_layout: Tensor,
|
|
24
|
+
sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
|
|
25
|
+
"""Converts a block-sparse tensor in regular form to a block-sparse tensor in compressed form based on the given
|
|
26
|
+
sparsity layout.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
x (Tensor): A block-sparse tensor in regular form.
|
|
30
|
+
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
31
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
32
|
+
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
BlksprsTensor: The block-sparse tensor converted to compressed form.
|
|
36
|
+
|
|
37
|
+
"""
|
|
38
|
+
x = x.contiguous()
|
|
39
|
+
|
|
40
|
+
validate_dimensions(x)
|
|
41
|
+
validate_contiguous(x)
|
|
42
|
+
validate_device(x)
|
|
43
|
+
validate_sparsity_dense(sparsity_block_size, (x, sparsity_layout))
|
|
44
|
+
validate_sparsity_block_size(sparsity_block_size, x)
|
|
45
|
+
|
|
46
|
+
lut = to_sparse_build_lut(lut, sparsity_layout)
|
|
47
|
+
|
|
48
|
+
if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
|
|
49
|
+
return BlksprsTensor(x)
|
|
50
|
+
|
|
51
|
+
return BlksprsTensor(to_sparse_forward(x, sparsity_layout,
|
|
52
|
+
lut["sparsity_lut"], sparsity_block_size, lut["n_sparse_blocks"]))
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@triton_op("blksprs::to_sparse_forward", mutates_args={})
|
|
56
|
+
def to_sparse_forward(x: Tensor, _: Tensor,
|
|
57
|
+
sparsity_lut: Tensor, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
58
|
+
with torch.no_grad():
|
|
59
|
+
output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
60
|
+
dtype=x.dtype, device=x.device)
|
|
61
|
+
|
|
62
|
+
x_b, x_r, x_c = x.size()
|
|
63
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
64
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
65
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
66
|
+
o_b, o_r, o_c = output.size()
|
|
67
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
68
|
+
|
|
69
|
+
triton_grid = lambda meta: [o_b,
|
|
70
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
71
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
72
|
+
|
|
73
|
+
(wrap_triton(to_sparse_kernel)[triton_grid]
|
|
74
|
+
(x, x_b, x_b_s, x_r_s, x_c_s,
|
|
75
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
76
|
+
output, o_b_s, o_r_s, o_c_s,
|
|
77
|
+
sparsity_block_size))
|
|
78
|
+
|
|
79
|
+
return output
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def to_sparse_wrapper_backward(ctx, grad_output):
|
|
83
|
+
sparsity_layout = ctx.saved_tensors[0]
|
|
84
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
85
|
+
|
|
86
|
+
return to_dense(grad_output, sparsity_layout, sparsity_block_size), None, None, None, None
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@triton.autotune(
|
|
90
|
+
configs=get_autotune_configs(),
|
|
91
|
+
key=["sparsity_block_size"],
|
|
92
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
93
|
+
reset_to_zero=["o"]
|
|
94
|
+
)
|
|
95
|
+
@triton.jit
|
|
96
|
+
def to_sparse_kernel(x,
|
|
97
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
98
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
99
|
+
o,
|
|
100
|
+
o_b_s, o_r_s, o_c_s,
|
|
101
|
+
sparsity_block_size,
|
|
102
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
103
|
+
# Get triton block indices
|
|
104
|
+
pid_blk = tl.program_id(axis=0)
|
|
105
|
+
pid_row = tl.program_id(axis=1)
|
|
106
|
+
pid_col = tl.program_id(axis=2)
|
|
107
|
+
|
|
108
|
+
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
109
|
+
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
110
|
+
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
111
|
+
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
112
|
+
|
|
113
|
+
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
114
|
+
spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
|
|
115
|
+
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
116
|
+
|
|
117
|
+
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
118
|
+
spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
|
|
119
|
+
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
120
|
+
|
|
121
|
+
# Load block from dense tensor
|
|
122
|
+
blk_d_idx = (spa_bat * x_b_s +
|
|
123
|
+
((pid_row * TRITON_BLOCK_SIZE + spa_row * sparsity_block_size +
|
|
124
|
+
tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
125
|
+
((pid_col * TRITON_BLOCK_SIZE + spa_col * sparsity_block_size +
|
|
126
|
+
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
127
|
+
blk_d_msk = (blk_d_idx >= 0 and
|
|
128
|
+
blk_d_idx < x_b * x_b_s)
|
|
129
|
+
blk_d = tl.load(x + blk_d_idx, mask=blk_d_msk)
|
|
130
|
+
|
|
131
|
+
# Store block in sparse tensor
|
|
132
|
+
blk_o_idx = ((pid_blk * o_b_s) +
|
|
133
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
134
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
|
|
135
|
+
blk_o_msk = (blk_o_idx >= 0 and
|
|
136
|
+
blk_o_idx < (pid_blk + 1) * o_b_s)
|
|
137
|
+
tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def to_sparse_build_lut(lut: dict, sparsity_layout: Tensor):
|
|
141
|
+
if lut is None:
|
|
142
|
+
lut = dict()
|
|
143
|
+
|
|
144
|
+
if "sparsity_lut" not in lut:
|
|
145
|
+
sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
|
|
146
|
+
lut["sparsity_lut"] = sparsity_lut
|
|
147
|
+
|
|
148
|
+
if "n_sparse_blocks" not in lut:
|
|
149
|
+
n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
|
|
150
|
+
lut["n_sparse_blocks"] = n_sparse_blocks
|
|
151
|
+
|
|
152
|
+
validate_contiguous(sparsity_layout, lut["sparsity_lut"])
|
|
153
|
+
|
|
154
|
+
return lut
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
# noinspection PyUnusedLocal
|
|
158
|
+
def to_sparse_setup_context(ctx, inputs, output):
|
|
159
|
+
(_, sparsity_layout, _, sparsity_block_size, _) = inputs
|
|
160
|
+
|
|
161
|
+
ctx.save_for_backward(sparsity_layout, )
|
|
162
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
to_sparse_forward.register_autograd(to_sparse_wrapper_backward, setup_context=to_sparse_setup_context)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def from_blksprs(x: BlksprsTensor, sparsity_layout: Tensor,
|
|
169
|
+
sparsity_block_size: int, fill_value: float = 0, lut: dict = None) -> Tensor:
|
|
15
170
|
"""Wrapper for ``to_dense``.
|
|
16
171
|
|
|
17
172
|
"""
|
|
18
|
-
return to_dense(x, sparsity_layout, sparsity_block_size, fill_value,
|
|
173
|
+
return to_dense(x, sparsity_layout, sparsity_block_size, fill_value=fill_value, lut=lut)
|
|
19
174
|
|
|
20
175
|
|
|
21
|
-
|
|
22
|
-
|
|
176
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
177
|
+
def to_dense(x: BlksprsTensor, sparsity_layout: Tensor,
|
|
178
|
+
sparsity_block_size: int, fill_value: float = 0, lut: dict = None) -> Tensor:
|
|
23
179
|
"""Converts a block-sparse tensor in compressed form to a block-sparse tensor in regular form based on the given
|
|
24
180
|
sparsity layout.
|
|
25
181
|
|
|
@@ -29,7 +185,7 @@ def to_dense(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int
|
|
|
29
185
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
30
186
|
fill_value (float): The value to fill the resulting dense tensor with where the block-sparse tensor is not
|
|
31
187
|
present (default ``0``).
|
|
32
|
-
|
|
188
|
+
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
33
189
|
|
|
34
190
|
Returns:
|
|
35
191
|
Tensor: The block-sparse tensor converted to regular form.
|
|
@@ -42,31 +198,21 @@ def to_dense(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int
|
|
|
42
198
|
validate_device(x)
|
|
43
199
|
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
44
200
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
45
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
46
|
-
|
|
47
|
-
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
|
48
|
-
sparsity_reverse_lut = ((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
49
|
-
(sparsity_layout_flat == 1) -
|
|
50
|
-
(1 * (sparsity_layout_flat == 0)))
|
|
51
201
|
|
|
52
|
-
|
|
202
|
+
lut = to_dense_build_lut(lut, sparsity_layout)
|
|
53
203
|
|
|
54
204
|
if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
|
|
55
205
|
return x
|
|
56
206
|
|
|
57
|
-
return
|
|
58
|
-
|
|
59
|
-
sparsity_block_size, fill_value,
|
|
60
|
-
triton_block_size)
|
|
207
|
+
return Tensor(to_dense_forward(x, sparsity_layout,
|
|
208
|
+
lut["sparsity_reverse_lut"], sparsity_block_size, fill_value))
|
|
61
209
|
|
|
62
210
|
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
sparsity_block_size: int, fill_value: float,
|
|
69
|
-
triton_block_size: int) -> Tensor:
|
|
211
|
+
@triton_op("blksprs::to_dense_forward", mutates_args={})
|
|
212
|
+
def to_dense_forward(x: Tensor, sparsity_layout: Tensor,
|
|
213
|
+
sparsity_reverse_lut: Tensor,
|
|
214
|
+
sparsity_block_size: int, fill_value: float) -> Tensor:
|
|
215
|
+
with torch.no_grad():
|
|
70
216
|
output = torch.full(size=(sparsity_layout.size(0), sparsity_layout.size(1) * sparsity_block_size,
|
|
71
217
|
sparsity_layout.size(2) * sparsity_block_size), fill_value=fill_value,
|
|
72
218
|
dtype=x.dtype, device=x.device)
|
|
@@ -78,217 +224,106 @@ class _BlocksparseToDense(torch.autograd.Function):
|
|
|
78
224
|
o_b, o_r, o_c = output.size()
|
|
79
225
|
o_b_s, o_r_s, o_c_s = stride(output)
|
|
80
226
|
|
|
81
|
-
if triton_block_size is None:
|
|
82
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
83
|
-
|
|
84
227
|
triton_grid = lambda meta: [o_b,
|
|
85
228
|
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
86
229
|
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
87
230
|
|
|
88
|
-
(
|
|
231
|
+
(wrap_triton(to_dense_kernel)[triton_grid]
|
|
89
232
|
(x,
|
|
90
233
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
91
234
|
s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
|
|
92
235
|
sparsity_reverse_lut,
|
|
93
236
|
output,
|
|
94
237
|
o_b, o_b_s, o_r_s, o_c_s,
|
|
95
|
-
sparsity_block_size
|
|
96
|
-
triton_block_size))
|
|
97
|
-
|
|
98
|
-
ctx.save_for_backward(sparsity_layout)
|
|
99
|
-
ctx.sparsity_block_size = sparsity_block_size
|
|
100
|
-
ctx.triton_block_size = triton_block_size
|
|
238
|
+
sparsity_block_size))
|
|
101
239
|
|
|
102
240
|
return output
|
|
103
241
|
|
|
104
|
-
@staticmethod
|
|
105
|
-
def backward(ctx, grad_output):
|
|
106
|
-
sparsity_layout = ctx.saved_tensors[0]
|
|
107
|
-
sparsity_block_size = ctx.sparsity_block_size
|
|
108
|
-
triton_block_size = ctx.triton_block_size
|
|
109
|
-
|
|
110
|
-
return to_sparse(grad_output, sparsity_layout, sparsity_block_size,
|
|
111
|
-
triton_block_size), None, None, None, None, None
|
|
112
|
-
|
|
113
|
-
@staticmethod
|
|
114
|
-
@triton.jit
|
|
115
|
-
def kernel_blocksparse_to_dense(x,
|
|
116
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
117
|
-
s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
|
|
118
|
-
sparsity_reverse_lut,
|
|
119
|
-
o,
|
|
120
|
-
o_b, o_b_s, o_r_s, o_c_s,
|
|
121
|
-
sparsity_block_size,
|
|
122
|
-
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
123
|
-
# Get triton block indices
|
|
124
|
-
pid_blk = tl.program_id(axis=0)
|
|
125
|
-
pid_row = tl.program_id(axis=1)
|
|
126
|
-
pid_col = tl.program_id(axis=2)
|
|
127
|
-
|
|
128
|
-
# Get sparsity index of current block
|
|
129
|
-
spa_row = (pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size
|
|
130
|
-
spa_col = (pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size
|
|
131
|
-
|
|
132
|
-
# Get reverse sparsity index for current block
|
|
133
|
-
rev_idx_spa_idx = (pid_blk * s_l_b_s + spa_row * s_l_r_s + spa_col * s_l_c_s)
|
|
134
|
-
rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_b * s_l_b_s)
|
|
135
|
-
rev_idx_spa = tl.load(sparsity_reverse_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
136
|
-
|
|
137
|
-
# If block is present commence operations
|
|
138
|
-
if rev_idx_spa >= 0:
|
|
139
|
-
blk_idx = (rev_idx_spa * x_b_s +
|
|
140
|
-
(((pid_row % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
|
|
141
|
-
tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
142
|
-
(((pid_col % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
|
|
143
|
-
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
144
|
-
blk_msk = (blk_idx >= 0 and blk_idx < x_b * x_b_s)
|
|
145
|
-
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
146
|
-
|
|
147
|
-
o_idx = (pid_blk * o_b_s +
|
|
148
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
149
|
-
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
150
|
-
o_msk = (o_idx >= 0 and o_idx < o_b * o_b_s)
|
|
151
|
-
tl.store(o + o_idx, blk, o_msk)
|
|
152
|
-
|
|
153
242
|
|
|
154
|
-
def
|
|
155
|
-
|
|
156
|
-
|
|
243
|
+
def to_dense_wrapper_backward(ctx, grad_output):
|
|
244
|
+
sparsity_layout = ctx.saved_tensors[0]
|
|
245
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
157
246
|
|
|
158
|
-
|
|
159
|
-
return to_sparse(x, sparsity_layout, sparsity_block_size, triton_block_size)
|
|
247
|
+
return to_sparse(grad_output, sparsity_layout, sparsity_block_size), None, None, None, None
|
|
160
248
|
|
|
161
249
|
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
""
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
validate_device(x)
|
|
182
|
-
validate_sparsity_dense(sparsity_block_size, (x, sparsity_layout))
|
|
183
|
-
validate_sparsity_block_size(sparsity_block_size, x)
|
|
184
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
250
|
+
@triton.autotune(
|
|
251
|
+
configs=get_autotune_configs(),
|
|
252
|
+
key=["sparsity_block_size"],
|
|
253
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
254
|
+
restore_value=["o"]
|
|
255
|
+
)
|
|
256
|
+
@triton.jit
|
|
257
|
+
def to_dense_kernel(x,
|
|
258
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
259
|
+
s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
|
|
260
|
+
sparsity_reverse_lut,
|
|
261
|
+
o,
|
|
262
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
263
|
+
sparsity_block_size,
|
|
264
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
265
|
+
# Get triton block indices
|
|
266
|
+
pid_blk = tl.program_id(axis=0)
|
|
267
|
+
pid_row = tl.program_id(axis=1)
|
|
268
|
+
pid_col = tl.program_id(axis=2)
|
|
185
269
|
|
|
186
|
-
|
|
187
|
-
|
|
270
|
+
# Get sparsity index of current block
|
|
271
|
+
spa_row = (pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size
|
|
272
|
+
spa_col = (pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size
|
|
188
273
|
|
|
189
|
-
|
|
274
|
+
# Get reverse sparsity index for current block
|
|
275
|
+
rev_idx_spa_idx = (pid_blk * s_l_b_s + spa_row * s_l_r_s + spa_col * s_l_c_s)
|
|
276
|
+
rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_b * s_l_b_s)
|
|
277
|
+
rev_idx_spa = tl.load(sparsity_reverse_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
190
278
|
|
|
191
|
-
|
|
192
|
-
|
|
279
|
+
# If block is present commence operations
|
|
280
|
+
if rev_idx_spa >= 0:
|
|
281
|
+
blk_idx = (rev_idx_spa * x_b_s +
|
|
282
|
+
(((pid_row % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
|
|
283
|
+
tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
284
|
+
(((pid_col % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
|
|
285
|
+
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
286
|
+
blk_msk = (blk_idx >= 0 and
|
|
287
|
+
blk_idx < x_b * x_b_s)
|
|
288
|
+
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
193
289
|
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
290
|
+
o_idx = (pid_blk * o_b_s +
|
|
291
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
292
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
293
|
+
o_msk = (o_idx >= 0 and o_idx < o_b * o_b_s)
|
|
294
|
+
tl.store(o + o_idx, blk, o_msk)
|
|
198
295
|
|
|
199
296
|
|
|
200
|
-
|
|
297
|
+
def to_dense_build_lut(lut: dict, sparsity_layout: Tensor):
|
|
298
|
+
if lut is None:
|
|
299
|
+
lut = dict()
|
|
201
300
|
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
301
|
+
if "sparsity_reverse_lut" not in lut:
|
|
302
|
+
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
|
303
|
+
sparsity_reverse_lut = ((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
304
|
+
(sparsity_layout_flat == 1) -
|
|
305
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
306
|
+
lut["sparsity_reverse_lut"] = sparsity_reverse_lut
|
|
208
307
|
|
|
209
|
-
|
|
210
|
-
x_b_s, x_r_s, x_c_s = stride(x)
|
|
211
|
-
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
212
|
-
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
213
|
-
o_b, o_r, o_c = output.size()
|
|
214
|
-
o_b_s, o_r_s, o_c_s = stride(output)
|
|
308
|
+
validate_contiguous(lut["sparsity_reverse_lut"])
|
|
215
309
|
|
|
216
|
-
|
|
217
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
310
|
+
return lut
|
|
218
311
|
|
|
219
|
-
triton_grid = lambda meta: [o_b,
|
|
220
|
-
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
221
|
-
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
222
312
|
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
output, o_b_s, o_r_s, o_c_s,
|
|
227
|
-
sparsity_block_size,
|
|
228
|
-
triton_block_size))
|
|
313
|
+
# noinspection PyUnusedLocal
|
|
314
|
+
def to_dense_setup_context(ctx, inputs, output):
|
|
315
|
+
(_, sparsity_layout, _, sparsity_block_size, _) = inputs
|
|
229
316
|
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
ctx.triton_block_size = triton_block_size
|
|
317
|
+
ctx.save_for_backward(sparsity_layout)
|
|
318
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
233
319
|
|
|
234
|
-
return output
|
|
235
320
|
|
|
236
|
-
|
|
237
|
-
def backward(ctx, grad_output):
|
|
238
|
-
sparsity_layout = ctx.saved_tensors[0]
|
|
239
|
-
sparsity_block_size = ctx.sparsity_block_size
|
|
240
|
-
triton_block_size = ctx.triton_block_size
|
|
241
|
-
|
|
242
|
-
return to_dense(grad_output, sparsity_layout, sparsity_block_size,
|
|
243
|
-
triton_block_size=triton_block_size), None, None, None, None, None
|
|
244
|
-
|
|
245
|
-
@staticmethod
|
|
246
|
-
@triton.jit
|
|
247
|
-
def kernel_blocksparse_to_sparse(x,
|
|
248
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
249
|
-
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
250
|
-
o,
|
|
251
|
-
o_b_s, o_r_s, o_c_s,
|
|
252
|
-
sparsity_block_size,
|
|
253
|
-
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
254
|
-
# Get triton block indices
|
|
255
|
-
pid_blk = tl.program_id(axis=0)
|
|
256
|
-
pid_row = tl.program_id(axis=1)
|
|
257
|
-
pid_col = tl.program_id(axis=2)
|
|
258
|
-
|
|
259
|
-
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
260
|
-
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
261
|
-
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
262
|
-
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
263
|
-
|
|
264
|
-
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
265
|
-
spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
|
|
266
|
-
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
267
|
-
|
|
268
|
-
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
269
|
-
spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
|
|
270
|
-
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
271
|
-
|
|
272
|
-
# Load block from dense tensor
|
|
273
|
-
blk_d_idx = (spa_bat * x_b_s +
|
|
274
|
-
((spa_row * sparsity_block_size + pid_row * TRITON_BLOCK_SIZE +
|
|
275
|
-
tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
276
|
-
((spa_col * sparsity_block_size + pid_col * TRITON_BLOCK_SIZE +
|
|
277
|
-
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
278
|
-
blk_d_msk = (blk_d_idx >= 0 and blk_d_idx < x_b * x_b_s)
|
|
279
|
-
blk_d = tl.load(x + blk_d_idx, mask=blk_d_msk)
|
|
280
|
-
|
|
281
|
-
# Store block in sparse tensor
|
|
282
|
-
blk_o_idx = ((pid_blk * o_b_s) +
|
|
283
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
284
|
-
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
|
|
285
|
-
blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < (pid_blk + 1) * o_b_s)
|
|
286
|
-
tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
|
|
321
|
+
to_dense_forward.register_autograd(to_dense_wrapper_backward, setup_context=to_dense_setup_context)
|
|
287
322
|
|
|
288
323
|
|
|
324
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
289
325
|
def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_size_from: int,
|
|
290
|
-
sparsity_block_size_to: int, sparsity_layout_to: Tensor = None,
|
|
291
|
-
triton_block_size: int = None) -> (BlksprsTensor, Tensor):
|
|
326
|
+
sparsity_block_size_to: int, sparsity_layout_to: Tensor = None) -> (BlksprsTensor, Tensor):
|
|
292
327
|
"""Adapts the sparsity layout of a block-sparse tensor, resulting in a new block-sparse tensor in compressed form
|
|
293
328
|
conforming to the new sparsity layout (and sparsity block size) definition.
|
|
294
329
|
|
|
@@ -298,7 +333,6 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
|
|
|
298
333
|
sparsity_block_size_from (int): The size of the sparsity blocks of the input sparsity layout.
|
|
299
334
|
sparsity_block_size_to (int): The size of the sparsity blocks of the output sparsity layout.
|
|
300
335
|
sparsity_layout_to (Tensor): The sparsity layout of the output block-sparse tensor (default ``None``).
|
|
301
|
-
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
302
336
|
|
|
303
337
|
Returns:
|
|
304
338
|
BlksprsTensor: The block-sparse tensor in compressed form with the adapted sparsity layout and sparsity block size.
|
|
@@ -313,8 +347,6 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
|
|
|
313
347
|
validate_sparsity(sparsity_block_size_from, (x, sparsity_layout_from))
|
|
314
348
|
validate_sparsity_block_size(sparsity_block_size_from, x)
|
|
315
349
|
validate_sparsity_block_size(sparsity_block_size_to)
|
|
316
|
-
min_sparsity_block_size = min(sparsity_block_size_from, sparsity_block_size_to)
|
|
317
|
-
validate_triton_block_size(triton_block_size, min_sparsity_block_size)
|
|
318
350
|
|
|
319
351
|
sparsity_layout_from_flat = sparsity_layout_from.reshape(-1)
|
|
320
352
|
sparsity_reverse_lut_from = ((torch.cumsum(sparsity_layout_from_flat, dim=-1) - 1) *
|
|
@@ -323,8 +355,7 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
|
|
|
323
355
|
|
|
324
356
|
if sparsity_layout_to is None:
|
|
325
357
|
sparsity_layout_to = build_sparsity_layout_adaption(x, sparsity_layout_from,
|
|
326
|
-
sparsity_block_size_from, sparsity_block_size_to
|
|
327
|
-
triton_block_size)
|
|
358
|
+
sparsity_block_size_from, sparsity_block_size_to)
|
|
328
359
|
|
|
329
360
|
sparsity_lut_to = torch.nonzero(sparsity_layout_to).contiguous()
|
|
330
361
|
|
|
@@ -335,24 +366,22 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
|
|
|
335
366
|
if (sparsity_block_size_from == sparsity_block_size_to) and torch.equal(sparsity_layout_from, sparsity_layout_to):
|
|
336
367
|
return BlksprsTensor(x), sparsity_layout_to
|
|
337
368
|
|
|
338
|
-
return BlksprsTensor(
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
sparsity_block_size_to: int,
|
|
355
|
-
n_sparse_blocks_to: int, min_sparsity_block_size: int, triton_block_size: int) -> Tensor:
|
|
369
|
+
return BlksprsTensor(adapt_layout_forward(x,
|
|
370
|
+
sparsity_layout_from, sparsity_reverse_lut_from,
|
|
371
|
+
sparsity_block_size_from,
|
|
372
|
+
sparsity_layout_to, sparsity_lut_to,
|
|
373
|
+
sparsity_block_size_to,
|
|
374
|
+
n_sparse_blocks_to)), sparsity_layout_to
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
@triton_op("blksprs::adapt_layout_forward", mutates_args={})
|
|
378
|
+
def adapt_layout_forward(x: Tensor,
|
|
379
|
+
sparsity_layout_from: Tensor, sparsity_reverse_lut_from: Tensor,
|
|
380
|
+
sparsity_block_size_from: int,
|
|
381
|
+
_: Tensor, sparsity_lut_to: Tensor,
|
|
382
|
+
sparsity_block_size_to: int,
|
|
383
|
+
n_sparse_blocks_to: int) -> Tensor:
|
|
384
|
+
with torch.no_grad():
|
|
356
385
|
output = torch.zeros(size=(n_sparse_blocks_to, sparsity_block_size_to, sparsity_block_size_to),
|
|
357
386
|
dtype=x.dtype, device=x.device)
|
|
358
387
|
|
|
@@ -365,14 +394,11 @@ class _BlocksparseAdaptLayout(torch.autograd.Function):
|
|
|
365
394
|
s_lut_o_r, s_lut_o_c = sparsity_lut_to.size()
|
|
366
395
|
s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_to)
|
|
367
396
|
|
|
368
|
-
if triton_block_size is None:
|
|
369
|
-
triton_block_size = get_triton_block_size(min_sparsity_block_size)
|
|
370
|
-
|
|
371
397
|
triton_grid = lambda meta: [o_b,
|
|
372
398
|
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
373
399
|
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
374
400
|
|
|
375
|
-
(
|
|
401
|
+
(wrap_triton(adapt_layout_kernel)[triton_grid]
|
|
376
402
|
(x,
|
|
377
403
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
378
404
|
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
@@ -381,88 +407,100 @@ class _BlocksparseAdaptLayout(torch.autograd.Function):
|
|
|
381
407
|
o_b, o_b_s, o_r_s, o_c_s,
|
|
382
408
|
sparsity_lut_to, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
383
409
|
sparsity_block_size_from,
|
|
384
|
-
sparsity_block_size_to
|
|
385
|
-
triton_block_size))
|
|
386
|
-
|
|
387
|
-
ctx.save_for_backward(x, sparsity_layout_from, sparsity_layout_to)
|
|
388
|
-
ctx.sparsity_block_size_from = sparsity_block_size_from
|
|
389
|
-
ctx.sparsity_block_size_to = sparsity_block_size_to
|
|
390
|
-
ctx.triton_block_size = triton_block_size
|
|
410
|
+
sparsity_block_size_to))
|
|
391
411
|
|
|
392
412
|
return output
|
|
393
413
|
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
414
|
+
|
|
415
|
+
def adapt_layout_wrapper_backward(ctx, grad_output):
|
|
416
|
+
x, sparsity_layout_from, sparsity_layout_to = ctx.saved_tensors
|
|
417
|
+
sparsity_block_size_from = ctx.sparsity_block_size_from
|
|
418
|
+
sparsity_block_size_to = ctx.sparsity_block_size_to
|
|
419
|
+
|
|
420
|
+
return adapt_layout(
|
|
421
|
+
grad_output, sparsity_layout_to, sparsity_block_size_to, sparsity_block_size_from,
|
|
422
|
+
sparsity_layout_to=sparsity_layout_from)[0], None, None, None, None, None, None, None
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
@triton.autotune(
|
|
426
|
+
configs=get_autotune_configs(),
|
|
427
|
+
key=["sparsity_block_size_from", "sparsity_block_size_to"],
|
|
428
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs_conversion},
|
|
429
|
+
reset_to_zero=["o"]
|
|
430
|
+
)
|
|
431
|
+
@triton.jit
|
|
432
|
+
def adapt_layout_kernel(x,
|
|
433
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
434
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
435
|
+
r_lut_x,
|
|
436
|
+
o,
|
|
437
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
438
|
+
s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
439
|
+
sparsity_block_size_from,
|
|
440
|
+
sparsity_block_size_to,
|
|
441
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
442
|
+
# Get triton block indices
|
|
443
|
+
pid_blk = tl.program_id(axis=0)
|
|
444
|
+
pid_row = tl.program_id(axis=1)
|
|
445
|
+
pid_col = tl.program_id(axis=2)
|
|
446
|
+
|
|
447
|
+
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
448
|
+
spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
|
|
449
|
+
spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
450
|
+
spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
|
|
451
|
+
|
|
452
|
+
spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
|
|
453
|
+
spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
454
|
+
spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
|
|
455
|
+
|
|
456
|
+
spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
|
|
457
|
+
spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
458
|
+
spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
|
|
459
|
+
|
|
460
|
+
# Get equivalent sparsity block in from layout
|
|
461
|
+
spa_bat_x = spa_bat_o
|
|
462
|
+
spa_row_x = (spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size_from
|
|
463
|
+
spa_col_x = (spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size_from
|
|
464
|
+
|
|
465
|
+
# Get reverse sparsity indices for x
|
|
466
|
+
rev_idx_spa_x_idx = (spa_bat_x * s_l_x_b_s +
|
|
467
|
+
spa_row_x * s_l_x_r_s +
|
|
468
|
+
spa_col_x * s_l_x_c_s)
|
|
469
|
+
rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
|
|
470
|
+
rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
|
|
471
|
+
|
|
472
|
+
# If block is present commence operations
|
|
473
|
+
if rev_idx_spa_x >= 0:
|
|
474
|
+
# Calculate triton block size shifts
|
|
475
|
+
shift_row_x = ((spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE)
|
|
476
|
+
% sparsity_block_size_from) // TRITON_BLOCK_SIZE
|
|
477
|
+
shift_col_x = ((spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE)
|
|
478
|
+
% sparsity_block_size_from) // TRITON_BLOCK_SIZE
|
|
479
|
+
|
|
480
|
+
# Load x values
|
|
481
|
+
blk_x_idx = ((rev_idx_spa_x * x_b_s) +
|
|
482
|
+
((shift_row_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
483
|
+
((shift_col_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
484
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
485
|
+
blk_x_idx < x_b * x_b_s)
|
|
486
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
487
|
+
|
|
488
|
+
# Store output
|
|
489
|
+
blk_o_idx = ((pid_blk * o_b_s) +
|
|
490
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
491
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
492
|
+
blk_o_msk = (blk_o_idx >= 0 and
|
|
493
|
+
blk_o_idx < o_b * o_b_s)
|
|
494
|
+
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
# noinspection PyUnusedLocal
|
|
498
|
+
def adapt_layout_setup_context(ctx, inputs, output):
|
|
499
|
+
(x, sparsity_layout_from, _, sparsity_block_size_from, sparsity_layout_to, _, sparsity_block_size_to, _) = inputs
|
|
500
|
+
|
|
501
|
+
ctx.save_for_backward(x, sparsity_layout_from, sparsity_layout_to)
|
|
502
|
+
ctx.sparsity_block_size_from = sparsity_block_size_from
|
|
503
|
+
ctx.sparsity_block_size_to = sparsity_block_size_to
|
|
504
|
+
|
|
505
|
+
|
|
506
|
+
adapt_layout_forward.register_autograd(adapt_layout_wrapper_backward, setup_context=adapt_layout_setup_context)
|