blksprs 1.11__py3-none-any.whl → 2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +4 -5
- blksprs/layouting/distribution_layout.py +64 -48
- blksprs/layouting/sparsity_layout.py +96 -72
- blksprs/ops/conversion.py +349 -338
- blksprs/ops/distribution.py +318 -294
- blksprs/ops/flow.py +113 -100
- blksprs/ops/matmul.py +187 -172
- blksprs/ops/misc/broadcast_ops.py +68 -53
- blksprs/ops/misc/row_wise.py +223 -176
- blksprs/ops/partitioning.py +140 -132
- blksprs/ops/repeat.py +118 -120
- blksprs/ops/softmax.py +240 -214
- blksprs/ops/transpose.py +55 -52
- blksprs/utils/autotuning.py +78 -0
- blksprs/utils/benchmarking.py +3 -3
- blksprs/utils/processing.py +2 -1
- blksprs/utils/tools.py +5 -6
- blksprs/utils/validation.py +22 -16
- {blksprs-1.11.dist-info → blksprs-2.0.dist-info}/METADATA +55 -36
- blksprs-2.0.dist-info/RECORD +23 -0
- {blksprs-1.11.dist-info → blksprs-2.0.dist-info}/WHEEL +1 -1
- blksprs/utils/layout_utils.py +0 -17
- blksprs-1.11.dist-info/RECORD +0 -23
- {blksprs-1.11.dist-info → blksprs-2.0.dist-info}/top_level.txt +0 -0
blksprs/ops/conversion.py
CHANGED
|
@@ -1,25 +1,181 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import triton
|
|
3
3
|
from torch import Tensor
|
|
4
|
+
from torch._library.triton import wrap_triton, triton_op
|
|
4
5
|
from triton import language as tl
|
|
5
6
|
|
|
6
7
|
from blksprs.layouting.sparsity_layout import build_sparsity_layout_adaption
|
|
7
8
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
8
|
-
from blksprs.utils.tools import
|
|
9
|
+
from blksprs.utils.tools import stride
|
|
10
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs, prune_autotune_configs_conversion
|
|
9
11
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
10
|
-
validate_sparsity, validate_sparsity_block_size,
|
|
12
|
+
validate_sparsity, validate_sparsity_block_size, validate_sparsity_dense
|
|
11
13
|
|
|
12
14
|
|
|
13
|
-
def
|
|
14
|
-
|
|
15
|
+
def to_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int) -> BlksprsTensor:
|
|
16
|
+
"""Wrapper for ``to_sparse``.
|
|
17
|
+
|
|
18
|
+
"""
|
|
19
|
+
return to_sparse(x, sparsity_layout, sparsity_block_size)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
23
|
+
def to_sparse(x: Tensor, sparsity_layout: Tensor,
|
|
24
|
+
sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
|
|
25
|
+
"""Converts a block-sparse tensor in regular form to a block-sparse tensor in compressed form based on the given
|
|
26
|
+
sparsity layout.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
x (Tensor): A block-sparse tensor in regular form.
|
|
30
|
+
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
31
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
32
|
+
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
BlksprsTensor: The block-sparse tensor converted to compressed form.
|
|
36
|
+
|
|
37
|
+
"""
|
|
38
|
+
x = x.contiguous()
|
|
39
|
+
|
|
40
|
+
validate_dimensions(x)
|
|
41
|
+
validate_contiguous(x)
|
|
42
|
+
validate_device(x)
|
|
43
|
+
validate_sparsity_dense(sparsity_block_size, (x, sparsity_layout))
|
|
44
|
+
validate_sparsity_block_size(sparsity_block_size, x)
|
|
45
|
+
|
|
46
|
+
lut = to_sparse_build_lut(lut, sparsity_layout)
|
|
47
|
+
|
|
48
|
+
if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
|
|
49
|
+
return BlksprsTensor(x)
|
|
50
|
+
|
|
51
|
+
return BlksprsTensor(to_sparse_forward(x, sparsity_layout,
|
|
52
|
+
lut["sparsity_lut"], sparsity_block_size, lut["n_sparse_blocks"]))
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@triton_op("blksprs::to_sparse_forward", mutates_args={})
|
|
56
|
+
def to_sparse_forward(x: Tensor, _: Tensor,
|
|
57
|
+
sparsity_lut: Tensor, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
58
|
+
with torch.no_grad():
|
|
59
|
+
output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
60
|
+
dtype=x.dtype, device=x.device)
|
|
61
|
+
|
|
62
|
+
x_b, x_r, x_c = x.size()
|
|
63
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
64
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
65
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
66
|
+
o_b, o_r, o_c = output.size()
|
|
67
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
68
|
+
|
|
69
|
+
triton_grid = lambda meta: [o_b,
|
|
70
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
71
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
72
|
+
|
|
73
|
+
(wrap_triton(to_sparse_kernel)[triton_grid]
|
|
74
|
+
(x, x_b, x_b_s, x_r_s, x_c_s,
|
|
75
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
76
|
+
output, o_b_s, o_r_s, o_c_s,
|
|
77
|
+
sparsity_block_size))
|
|
78
|
+
|
|
79
|
+
return output
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def to_sparse_wrapper_backward(ctx, grad_output):
|
|
83
|
+
sparsity_layout = ctx.saved_tensors[0]
|
|
84
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
85
|
+
|
|
86
|
+
return to_dense(grad_output, sparsity_layout, sparsity_block_size), None, None, None, None
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@triton.autotune(
|
|
90
|
+
configs=get_autotune_configs(),
|
|
91
|
+
key=["sparsity_block_size"],
|
|
92
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
93
|
+
reset_to_zero=["o"]
|
|
94
|
+
)
|
|
95
|
+
@triton.jit
|
|
96
|
+
def to_sparse_kernel(x,
|
|
97
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
98
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
99
|
+
o,
|
|
100
|
+
o_b_s, o_r_s, o_c_s,
|
|
101
|
+
sparsity_block_size,
|
|
102
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
103
|
+
# Get triton block indices
|
|
104
|
+
pid_blk = tl.program_id(axis=0)
|
|
105
|
+
pid_row = tl.program_id(axis=1)
|
|
106
|
+
pid_col = tl.program_id(axis=2)
|
|
107
|
+
|
|
108
|
+
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
109
|
+
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
110
|
+
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
111
|
+
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
112
|
+
|
|
113
|
+
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
114
|
+
spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
|
|
115
|
+
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
116
|
+
|
|
117
|
+
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
118
|
+
spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
|
|
119
|
+
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
120
|
+
|
|
121
|
+
# Load block from dense tensor
|
|
122
|
+
blk_d_idx = (spa_bat * x_b_s +
|
|
123
|
+
((pid_row * TRITON_BLOCK_SIZE + spa_row * sparsity_block_size +
|
|
124
|
+
tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
125
|
+
((pid_col * TRITON_BLOCK_SIZE + spa_col * sparsity_block_size +
|
|
126
|
+
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
127
|
+
blk_d_msk = (blk_d_idx >= 0 and
|
|
128
|
+
blk_d_idx < x_b * x_b_s)
|
|
129
|
+
blk_d = tl.load(x + blk_d_idx, mask=blk_d_msk)
|
|
130
|
+
|
|
131
|
+
# Store block in sparse tensor
|
|
132
|
+
blk_o_idx = ((pid_blk * o_b_s) +
|
|
133
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
134
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
|
|
135
|
+
blk_o_msk = (blk_o_idx >= 0 and
|
|
136
|
+
blk_o_idx < (pid_blk + 1) * o_b_s)
|
|
137
|
+
tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def to_sparse_build_lut(lut: dict, sparsity_layout: Tensor):
|
|
141
|
+
if lut is None:
|
|
142
|
+
lut = dict()
|
|
143
|
+
|
|
144
|
+
if "sparsity_lut" not in lut:
|
|
145
|
+
sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
|
|
146
|
+
lut["sparsity_lut"] = sparsity_lut
|
|
147
|
+
|
|
148
|
+
if "n_sparse_blocks" not in lut:
|
|
149
|
+
n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
|
|
150
|
+
lut["n_sparse_blocks"] = n_sparse_blocks
|
|
151
|
+
|
|
152
|
+
validate_contiguous(sparsity_layout, lut["sparsity_lut"])
|
|
153
|
+
|
|
154
|
+
return lut
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
# noinspection PyUnusedLocal
|
|
158
|
+
def to_sparse_setup_context(ctx, inputs, output):
|
|
159
|
+
(_, sparsity_layout, _, sparsity_block_size, _) = inputs
|
|
160
|
+
|
|
161
|
+
ctx.save_for_backward(sparsity_layout, )
|
|
162
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
to_sparse_forward.register_autograd(to_sparse_wrapper_backward, setup_context=to_sparse_setup_context)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def from_blksprs(x: BlksprsTensor, sparsity_layout: Tensor,
|
|
169
|
+
sparsity_block_size: int, fill_value: float = 0, lut: dict = None) -> Tensor:
|
|
15
170
|
"""Wrapper for ``to_dense``.
|
|
16
171
|
|
|
17
172
|
"""
|
|
18
|
-
return to_dense(x, sparsity_layout, sparsity_block_size, fill_value,
|
|
173
|
+
return to_dense(x, sparsity_layout, sparsity_block_size, fill_value=fill_value, lut=lut)
|
|
19
174
|
|
|
20
175
|
|
|
21
|
-
|
|
22
|
-
|
|
176
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
177
|
+
def to_dense(x: BlksprsTensor, sparsity_layout: Tensor,
|
|
178
|
+
sparsity_block_size: int, fill_value: float = 0, lut: dict = None) -> Tensor:
|
|
23
179
|
"""Converts a block-sparse tensor in compressed form to a block-sparse tensor in regular form based on the given
|
|
24
180
|
sparsity layout.
|
|
25
181
|
|
|
@@ -29,7 +185,6 @@ def to_dense(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int
|
|
|
29
185
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
30
186
|
fill_value (float): The value to fill the resulting dense tensor with where the block-sparse tensor is not
|
|
31
187
|
present (default ``0``).
|
|
32
|
-
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
33
188
|
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
34
189
|
|
|
35
190
|
Returns:
|
|
@@ -43,42 +198,21 @@ def to_dense(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int
|
|
|
43
198
|
validate_device(x)
|
|
44
199
|
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
45
200
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
46
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
47
201
|
|
|
48
|
-
lut =
|
|
202
|
+
lut = to_dense_build_lut(lut, sparsity_layout)
|
|
49
203
|
|
|
50
204
|
if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
|
|
51
205
|
return x
|
|
52
206
|
|
|
53
|
-
return
|
|
54
|
-
|
|
55
|
-
sparsity_block_size, fill_value,
|
|
56
|
-
triton_block_size)
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
class _BlocksparseToDense(torch.autograd.Function):
|
|
60
|
-
|
|
61
|
-
@staticmethod
|
|
62
|
-
def build_lut(lut: dict, sparsity_layout: Tensor):
|
|
63
|
-
if lut is None:
|
|
64
|
-
lut = dict()
|
|
207
|
+
return Tensor(to_dense_forward(x, sparsity_layout,
|
|
208
|
+
lut["sparsity_reverse_lut"], sparsity_block_size, fill_value))
|
|
65
209
|
|
|
66
|
-
if "sparsity_reverse_lut" not in lut:
|
|
67
|
-
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
|
68
|
-
sparsity_reverse_lut = ((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
69
|
-
(sparsity_layout_flat == 1) -
|
|
70
|
-
(1 * (sparsity_layout_flat == 0)))
|
|
71
|
-
lut["sparsity_reverse_lut"] = sparsity_reverse_lut
|
|
72
210
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
def forward(ctx, x: Tensor,
|
|
79
|
-
sparsity_layout: Tensor, sparsity_reverse_lut: Tensor,
|
|
80
|
-
sparsity_block_size: int, fill_value: float,
|
|
81
|
-
triton_block_size: int) -> Tensor:
|
|
211
|
+
@triton_op("blksprs::to_dense_forward", mutates_args={})
|
|
212
|
+
def to_dense_forward(x: Tensor, sparsity_layout: Tensor,
|
|
213
|
+
sparsity_reverse_lut: Tensor,
|
|
214
|
+
sparsity_block_size: int, fill_value: float) -> Tensor:
|
|
215
|
+
with torch.no_grad():
|
|
82
216
|
output = torch.full(size=(sparsity_layout.size(0), sparsity_layout.size(1) * sparsity_block_size,
|
|
83
217
|
sparsity_layout.size(2) * sparsity_block_size), fill_value=fill_value,
|
|
84
218
|
dtype=x.dtype, device=x.device)
|
|
@@ -90,232 +224,106 @@ class _BlocksparseToDense(torch.autograd.Function):
|
|
|
90
224
|
o_b, o_r, o_c = output.size()
|
|
91
225
|
o_b_s, o_r_s, o_c_s = stride(output)
|
|
92
226
|
|
|
93
|
-
if triton_block_size is None:
|
|
94
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
95
|
-
|
|
96
227
|
triton_grid = lambda meta: [o_b,
|
|
97
228
|
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
98
229
|
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
99
230
|
|
|
100
|
-
(
|
|
231
|
+
(wrap_triton(to_dense_kernel)[triton_grid]
|
|
101
232
|
(x,
|
|
102
233
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
103
234
|
s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
|
|
104
235
|
sparsity_reverse_lut,
|
|
105
236
|
output,
|
|
106
237
|
o_b, o_b_s, o_r_s, o_c_s,
|
|
107
|
-
sparsity_block_size
|
|
108
|
-
triton_block_size))
|
|
109
|
-
|
|
110
|
-
ctx.save_for_backward(sparsity_layout)
|
|
111
|
-
ctx.sparsity_block_size = sparsity_block_size
|
|
112
|
-
ctx.triton_block_size = triton_block_size
|
|
238
|
+
sparsity_block_size))
|
|
113
239
|
|
|
114
240
|
return output
|
|
115
241
|
|
|
116
|
-
@staticmethod
|
|
117
|
-
def backward(ctx, grad_output):
|
|
118
|
-
sparsity_layout = ctx.saved_tensors[0]
|
|
119
|
-
sparsity_block_size = ctx.sparsity_block_size
|
|
120
|
-
triton_block_size = ctx.triton_block_size
|
|
121
|
-
|
|
122
|
-
return to_sparse(grad_output, sparsity_layout, sparsity_block_size,
|
|
123
|
-
triton_block_size), None, None, None, None, None
|
|
124
|
-
|
|
125
|
-
@staticmethod
|
|
126
|
-
@triton.jit
|
|
127
|
-
def kernel_blocksparse_to_dense(x,
|
|
128
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
129
|
-
s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
|
|
130
|
-
sparsity_reverse_lut,
|
|
131
|
-
o,
|
|
132
|
-
o_b, o_b_s, o_r_s, o_c_s,
|
|
133
|
-
sparsity_block_size,
|
|
134
|
-
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
135
|
-
# Get triton block indices
|
|
136
|
-
pid_blk = tl.program_id(axis=0)
|
|
137
|
-
pid_row = tl.program_id(axis=1)
|
|
138
|
-
pid_col = tl.program_id(axis=2)
|
|
139
|
-
|
|
140
|
-
# Get sparsity index of current block
|
|
141
|
-
spa_row = (pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size
|
|
142
|
-
spa_col = (pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size
|
|
143
|
-
|
|
144
|
-
# Get reverse sparsity index for current block
|
|
145
|
-
rev_idx_spa_idx = (pid_blk * s_l_b_s + spa_row * s_l_r_s + spa_col * s_l_c_s)
|
|
146
|
-
rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_b * s_l_b_s)
|
|
147
|
-
rev_idx_spa = tl.load(sparsity_reverse_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
148
|
-
|
|
149
|
-
# If block is present commence operations
|
|
150
|
-
if rev_idx_spa >= 0:
|
|
151
|
-
blk_idx = (rev_idx_spa * x_b_s +
|
|
152
|
-
(((pid_row % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
|
|
153
|
-
tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
154
|
-
(((pid_col % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
|
|
155
|
-
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
156
|
-
blk_msk = (blk_idx >= 0 and blk_idx < x_b * x_b_s)
|
|
157
|
-
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
158
|
-
|
|
159
|
-
o_idx = (pid_blk * o_b_s +
|
|
160
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
161
|
-
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
162
|
-
o_msk = (o_idx >= 0 and o_idx < o_b * o_b_s)
|
|
163
|
-
tl.store(o + o_idx, blk, o_msk)
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
def to_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
167
|
-
triton_block_size: int = None) -> BlksprsTensor:
|
|
168
|
-
"""Wrapper for ``to_sparse``.
|
|
169
|
-
|
|
170
|
-
"""
|
|
171
|
-
return to_sparse(x, sparsity_layout, sparsity_block_size, triton_block_size)
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
175
|
-
triton_block_size: int = None, lut: dict = None) -> BlksprsTensor:
|
|
176
|
-
"""Converts a block-sparse tensor in regular form to a block-sparse tensor in compressed form based on the given
|
|
177
|
-
sparsity layout.
|
|
178
|
-
|
|
179
|
-
Args:
|
|
180
|
-
x (Tensor): A block-sparse tensor in regular form.
|
|
181
|
-
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
182
|
-
sparsity_block_size (int): The size of the sparsity blocks.
|
|
183
|
-
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
184
|
-
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
185
|
-
|
|
186
|
-
Returns:
|
|
187
|
-
BlksprsTensor: The block-sparse tensor converted to compressed form.
|
|
188
|
-
|
|
189
|
-
"""
|
|
190
|
-
x = x.contiguous()
|
|
191
|
-
|
|
192
|
-
validate_dimensions(x)
|
|
193
|
-
validate_contiguous(x)
|
|
194
|
-
validate_device(x)
|
|
195
|
-
validate_sparsity_dense(sparsity_block_size, (x, sparsity_layout))
|
|
196
|
-
validate_sparsity_block_size(sparsity_block_size, x)
|
|
197
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
198
242
|
|
|
199
|
-
|
|
243
|
+
def to_dense_wrapper_backward(ctx, grad_output):
|
|
244
|
+
sparsity_layout = ctx.saved_tensors[0]
|
|
245
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
200
246
|
|
|
201
|
-
|
|
202
|
-
return BlksprsTensor(x)
|
|
247
|
+
return to_sparse(grad_output, sparsity_layout, sparsity_block_size), None, None, None, None
|
|
203
248
|
|
|
204
|
-
return BlksprsTensor(_BlocksparseToSparse.apply(x,
|
|
205
|
-
sparsity_layout, lut["sparsity_lut"],
|
|
206
|
-
sparsity_block_size, lut["n_sparse_blocks"],
|
|
207
|
-
triton_block_size))
|
|
208
249
|
|
|
250
|
+
@triton.autotune(
|
|
251
|
+
configs=get_autotune_configs(),
|
|
252
|
+
key=["sparsity_block_size"],
|
|
253
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
254
|
+
restore_value=["o"]
|
|
255
|
+
)
|
|
256
|
+
@triton.jit
|
|
257
|
+
def to_dense_kernel(x,
|
|
258
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
259
|
+
s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
|
|
260
|
+
sparsity_reverse_lut,
|
|
261
|
+
o,
|
|
262
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
263
|
+
sparsity_block_size,
|
|
264
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
265
|
+
# Get triton block indices
|
|
266
|
+
pid_blk = tl.program_id(axis=0)
|
|
267
|
+
pid_row = tl.program_id(axis=1)
|
|
268
|
+
pid_col = tl.program_id(axis=2)
|
|
209
269
|
|
|
210
|
-
|
|
270
|
+
# Get sparsity index of current block
|
|
271
|
+
spa_row = (pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size
|
|
272
|
+
spa_col = (pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size
|
|
211
273
|
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
274
|
+
# Get reverse sparsity index for current block
|
|
275
|
+
rev_idx_spa_idx = (pid_blk * s_l_b_s + spa_row * s_l_r_s + spa_col * s_l_c_s)
|
|
276
|
+
rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_b * s_l_b_s)
|
|
277
|
+
rev_idx_spa = tl.load(sparsity_reverse_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
216
278
|
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
279
|
+
# If block is present commence operations
|
|
280
|
+
if rev_idx_spa >= 0:
|
|
281
|
+
blk_idx = (rev_idx_spa * x_b_s +
|
|
282
|
+
(((pid_row % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
|
|
283
|
+
tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
284
|
+
(((pid_col % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
|
|
285
|
+
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
286
|
+
blk_msk = (blk_idx >= 0 and
|
|
287
|
+
blk_idx < x_b * x_b_s)
|
|
288
|
+
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
220
289
|
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
290
|
+
o_idx = (pid_blk * o_b_s +
|
|
291
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
292
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
293
|
+
o_msk = (o_idx >= 0 and o_idx < o_b * o_b_s)
|
|
294
|
+
tl.store(o + o_idx, blk, o_msk)
|
|
224
295
|
|
|
225
|
-
validate_contiguous(sparsity_layout, lut["sparsity_lut"])
|
|
226
296
|
|
|
227
|
-
|
|
297
|
+
def to_dense_build_lut(lut: dict, sparsity_layout: Tensor):
|
|
298
|
+
if lut is None:
|
|
299
|
+
lut = dict()
|
|
228
300
|
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
301
|
+
if "sparsity_reverse_lut" not in lut:
|
|
302
|
+
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
|
303
|
+
sparsity_reverse_lut = ((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
304
|
+
(sparsity_layout_flat == 1) -
|
|
305
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
306
|
+
lut["sparsity_reverse_lut"] = sparsity_reverse_lut
|
|
235
307
|
|
|
236
|
-
|
|
237
|
-
x_b_s, x_r_s, x_c_s = stride(x)
|
|
238
|
-
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
239
|
-
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
240
|
-
o_b, o_r, o_c = output.size()
|
|
241
|
-
o_b_s, o_r_s, o_c_s = stride(output)
|
|
308
|
+
validate_contiguous(lut["sparsity_reverse_lut"])
|
|
242
309
|
|
|
243
|
-
|
|
244
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
310
|
+
return lut
|
|
245
311
|
|
|
246
|
-
triton_grid = lambda meta: [o_b,
|
|
247
|
-
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
248
|
-
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
249
312
|
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
output, o_b_s, o_r_s, o_c_s,
|
|
254
|
-
sparsity_block_size,
|
|
255
|
-
triton_block_size))
|
|
313
|
+
# noinspection PyUnusedLocal
|
|
314
|
+
def to_dense_setup_context(ctx, inputs, output):
|
|
315
|
+
(_, sparsity_layout, _, sparsity_block_size, _) = inputs
|
|
256
316
|
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
ctx.triton_block_size = triton_block_size
|
|
317
|
+
ctx.save_for_backward(sparsity_layout)
|
|
318
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
260
319
|
|
|
261
|
-
return output
|
|
262
320
|
|
|
263
|
-
|
|
264
|
-
def backward(ctx, grad_output):
|
|
265
|
-
sparsity_layout = ctx.saved_tensors[0]
|
|
266
|
-
sparsity_block_size = ctx.sparsity_block_size
|
|
267
|
-
triton_block_size = ctx.triton_block_size
|
|
268
|
-
|
|
269
|
-
return to_dense(grad_output, sparsity_layout, sparsity_block_size,
|
|
270
|
-
triton_block_size=triton_block_size), None, None, None, None, None
|
|
271
|
-
|
|
272
|
-
@staticmethod
|
|
273
|
-
@triton.jit
|
|
274
|
-
def kernel_blocksparse_to_sparse(x,
|
|
275
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
276
|
-
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
277
|
-
o,
|
|
278
|
-
o_b_s, o_r_s, o_c_s,
|
|
279
|
-
sparsity_block_size,
|
|
280
|
-
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
281
|
-
# Get triton block indices
|
|
282
|
-
pid_blk = tl.program_id(axis=0)
|
|
283
|
-
pid_row = tl.program_id(axis=1)
|
|
284
|
-
pid_col = tl.program_id(axis=2)
|
|
285
|
-
|
|
286
|
-
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
287
|
-
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
288
|
-
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
289
|
-
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
290
|
-
|
|
291
|
-
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
292
|
-
spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
|
|
293
|
-
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
294
|
-
|
|
295
|
-
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
296
|
-
spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
|
|
297
|
-
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
298
|
-
|
|
299
|
-
# Load block from dense tensor
|
|
300
|
-
blk_d_idx = (spa_bat * x_b_s +
|
|
301
|
-
((spa_row * sparsity_block_size + pid_row * TRITON_BLOCK_SIZE +
|
|
302
|
-
tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
303
|
-
((spa_col * sparsity_block_size + pid_col * TRITON_BLOCK_SIZE +
|
|
304
|
-
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
305
|
-
blk_d_msk = (blk_d_idx >= 0 and blk_d_idx < x_b * x_b_s)
|
|
306
|
-
blk_d = tl.load(x + blk_d_idx, mask=blk_d_msk)
|
|
307
|
-
|
|
308
|
-
# Store block in sparse tensor
|
|
309
|
-
blk_o_idx = ((pid_blk * o_b_s) +
|
|
310
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
311
|
-
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
|
|
312
|
-
blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < (pid_blk + 1) * o_b_s)
|
|
313
|
-
tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
|
|
321
|
+
to_dense_forward.register_autograd(to_dense_wrapper_backward, setup_context=to_dense_setup_context)
|
|
314
322
|
|
|
315
323
|
|
|
324
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
316
325
|
def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_size_from: int,
|
|
317
|
-
sparsity_block_size_to: int, sparsity_layout_to: Tensor = None,
|
|
318
|
-
triton_block_size: int = None) -> (BlksprsTensor, Tensor):
|
|
326
|
+
sparsity_block_size_to: int, sparsity_layout_to: Tensor = None) -> (BlksprsTensor, Tensor):
|
|
319
327
|
"""Adapts the sparsity layout of a block-sparse tensor, resulting in a new block-sparse tensor in compressed form
|
|
320
328
|
conforming to the new sparsity layout (and sparsity block size) definition.
|
|
321
329
|
|
|
@@ -325,7 +333,6 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
|
|
|
325
333
|
sparsity_block_size_from (int): The size of the sparsity blocks of the input sparsity layout.
|
|
326
334
|
sparsity_block_size_to (int): The size of the sparsity blocks of the output sparsity layout.
|
|
327
335
|
sparsity_layout_to (Tensor): The sparsity layout of the output block-sparse tensor (default ``None``).
|
|
328
|
-
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
329
336
|
|
|
330
337
|
Returns:
|
|
331
338
|
BlksprsTensor: The block-sparse tensor in compressed form with the adapted sparsity layout and sparsity block size.
|
|
@@ -340,8 +347,6 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
|
|
|
340
347
|
validate_sparsity(sparsity_block_size_from, (x, sparsity_layout_from))
|
|
341
348
|
validate_sparsity_block_size(sparsity_block_size_from, x)
|
|
342
349
|
validate_sparsity_block_size(sparsity_block_size_to)
|
|
343
|
-
min_sparsity_block_size = min(sparsity_block_size_from, sparsity_block_size_to)
|
|
344
|
-
validate_triton_block_size(triton_block_size, min_sparsity_block_size)
|
|
345
350
|
|
|
346
351
|
sparsity_layout_from_flat = sparsity_layout_from.reshape(-1)
|
|
347
352
|
sparsity_reverse_lut_from = ((torch.cumsum(sparsity_layout_from_flat, dim=-1) - 1) *
|
|
@@ -350,8 +355,7 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
|
|
|
350
355
|
|
|
351
356
|
if sparsity_layout_to is None:
|
|
352
357
|
sparsity_layout_to = build_sparsity_layout_adaption(x, sparsity_layout_from,
|
|
353
|
-
sparsity_block_size_from, sparsity_block_size_to
|
|
354
|
-
triton_block_size)
|
|
358
|
+
sparsity_block_size_from, sparsity_block_size_to)
|
|
355
359
|
|
|
356
360
|
sparsity_lut_to = torch.nonzero(sparsity_layout_to).contiguous()
|
|
357
361
|
|
|
@@ -362,24 +366,22 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
|
|
|
362
366
|
if (sparsity_block_size_from == sparsity_block_size_to) and torch.equal(sparsity_layout_from, sparsity_layout_to):
|
|
363
367
|
return BlksprsTensor(x), sparsity_layout_to
|
|
364
368
|
|
|
365
|
-
return BlksprsTensor(
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
sparsity_block_size_to: int,
|
|
382
|
-
n_sparse_blocks_to: int, min_sparsity_block_size: int, triton_block_size: int) -> Tensor:
|
|
369
|
+
return BlksprsTensor(adapt_layout_forward(x,
|
|
370
|
+
sparsity_layout_from, sparsity_reverse_lut_from,
|
|
371
|
+
sparsity_block_size_from,
|
|
372
|
+
sparsity_layout_to, sparsity_lut_to,
|
|
373
|
+
sparsity_block_size_to,
|
|
374
|
+
n_sparse_blocks_to)), sparsity_layout_to
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
@triton_op("blksprs::adapt_layout_forward", mutates_args={})
|
|
378
|
+
def adapt_layout_forward(x: Tensor,
|
|
379
|
+
sparsity_layout_from: Tensor, sparsity_reverse_lut_from: Tensor,
|
|
380
|
+
sparsity_block_size_from: int,
|
|
381
|
+
_: Tensor, sparsity_lut_to: Tensor,
|
|
382
|
+
sparsity_block_size_to: int,
|
|
383
|
+
n_sparse_blocks_to: int) -> Tensor:
|
|
384
|
+
with torch.no_grad():
|
|
383
385
|
output = torch.zeros(size=(n_sparse_blocks_to, sparsity_block_size_to, sparsity_block_size_to),
|
|
384
386
|
dtype=x.dtype, device=x.device)
|
|
385
387
|
|
|
@@ -392,14 +394,11 @@ class _BlocksparseAdaptLayout(torch.autograd.Function):
|
|
|
392
394
|
s_lut_o_r, s_lut_o_c = sparsity_lut_to.size()
|
|
393
395
|
s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_to)
|
|
394
396
|
|
|
395
|
-
if triton_block_size is None:
|
|
396
|
-
triton_block_size = get_triton_block_size(min_sparsity_block_size)
|
|
397
|
-
|
|
398
397
|
triton_grid = lambda meta: [o_b,
|
|
399
398
|
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
400
399
|
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
401
400
|
|
|
402
|
-
(
|
|
401
|
+
(wrap_triton(adapt_layout_kernel)[triton_grid]
|
|
403
402
|
(x,
|
|
404
403
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
405
404
|
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
@@ -408,88 +407,100 @@ class _BlocksparseAdaptLayout(torch.autograd.Function):
|
|
|
408
407
|
o_b, o_b_s, o_r_s, o_c_s,
|
|
409
408
|
sparsity_lut_to, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
410
409
|
sparsity_block_size_from,
|
|
411
|
-
sparsity_block_size_to
|
|
412
|
-
triton_block_size))
|
|
413
|
-
|
|
414
|
-
ctx.save_for_backward(x, sparsity_layout_from, sparsity_layout_to)
|
|
415
|
-
ctx.sparsity_block_size_from = sparsity_block_size_from
|
|
416
|
-
ctx.sparsity_block_size_to = sparsity_block_size_to
|
|
417
|
-
ctx.triton_block_size = triton_block_size
|
|
410
|
+
sparsity_block_size_to))
|
|
418
411
|
|
|
419
412
|
return output
|
|
420
413
|
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
414
|
+
|
|
415
|
+
def adapt_layout_wrapper_backward(ctx, grad_output):
|
|
416
|
+
x, sparsity_layout_from, sparsity_layout_to = ctx.saved_tensors
|
|
417
|
+
sparsity_block_size_from = ctx.sparsity_block_size_from
|
|
418
|
+
sparsity_block_size_to = ctx.sparsity_block_size_to
|
|
419
|
+
|
|
420
|
+
return adapt_layout(
|
|
421
|
+
grad_output, sparsity_layout_to, sparsity_block_size_to, sparsity_block_size_from,
|
|
422
|
+
sparsity_layout_to=sparsity_layout_from)[0], None, None, None, None, None, None, None
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
@triton.autotune(
|
|
426
|
+
configs=get_autotune_configs(),
|
|
427
|
+
key=["sparsity_block_size_from", "sparsity_block_size_to"],
|
|
428
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs_conversion},
|
|
429
|
+
reset_to_zero=["o"]
|
|
430
|
+
)
|
|
431
|
+
@triton.jit
|
|
432
|
+
def adapt_layout_kernel(x,
|
|
433
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
434
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
435
|
+
r_lut_x,
|
|
436
|
+
o,
|
|
437
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
438
|
+
s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
439
|
+
sparsity_block_size_from,
|
|
440
|
+
sparsity_block_size_to,
|
|
441
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
442
|
+
# Get triton block indices
|
|
443
|
+
pid_blk = tl.program_id(axis=0)
|
|
444
|
+
pid_row = tl.program_id(axis=1)
|
|
445
|
+
pid_col = tl.program_id(axis=2)
|
|
446
|
+
|
|
447
|
+
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
448
|
+
spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
|
|
449
|
+
spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
450
|
+
spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
|
|
451
|
+
|
|
452
|
+
spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
|
|
453
|
+
spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
454
|
+
spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
|
|
455
|
+
|
|
456
|
+
spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
|
|
457
|
+
spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
458
|
+
spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
|
|
459
|
+
|
|
460
|
+
# Get equivalent sparsity block in from layout
|
|
461
|
+
spa_bat_x = spa_bat_o
|
|
462
|
+
spa_row_x = (spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size_from
|
|
463
|
+
spa_col_x = (spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size_from
|
|
464
|
+
|
|
465
|
+
# Get reverse sparsity indices for x
|
|
466
|
+
rev_idx_spa_x_idx = (spa_bat_x * s_l_x_b_s +
|
|
467
|
+
spa_row_x * s_l_x_r_s +
|
|
468
|
+
spa_col_x * s_l_x_c_s)
|
|
469
|
+
rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
|
|
470
|
+
rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
|
|
471
|
+
|
|
472
|
+
# If block is present commence operations
|
|
473
|
+
if rev_idx_spa_x >= 0:
|
|
474
|
+
# Calculate triton block size shifts
|
|
475
|
+
shift_row_x = ((spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE)
|
|
476
|
+
% sparsity_block_size_from) // TRITON_BLOCK_SIZE
|
|
477
|
+
shift_col_x = ((spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE)
|
|
478
|
+
% sparsity_block_size_from) // TRITON_BLOCK_SIZE
|
|
479
|
+
|
|
480
|
+
# Load x values
|
|
481
|
+
blk_x_idx = ((rev_idx_spa_x * x_b_s) +
|
|
482
|
+
((shift_row_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
483
|
+
((shift_col_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
484
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
485
|
+
blk_x_idx < x_b * x_b_s)
|
|
486
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
487
|
+
|
|
488
|
+
# Store output
|
|
489
|
+
blk_o_idx = ((pid_blk * o_b_s) +
|
|
490
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
491
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
492
|
+
blk_o_msk = (blk_o_idx >= 0 and
|
|
493
|
+
blk_o_idx < o_b * o_b_s)
|
|
494
|
+
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
# noinspection PyUnusedLocal
|
|
498
|
+
def adapt_layout_setup_context(ctx, inputs, output):
|
|
499
|
+
(x, sparsity_layout_from, _, sparsity_block_size_from, sparsity_layout_to, _, sparsity_block_size_to, _) = inputs
|
|
500
|
+
|
|
501
|
+
ctx.save_for_backward(x, sparsity_layout_from, sparsity_layout_to)
|
|
502
|
+
ctx.sparsity_block_size_from = sparsity_block_size_from
|
|
503
|
+
ctx.sparsity_block_size_to = sparsity_block_size_to
|
|
504
|
+
|
|
505
|
+
|
|
506
|
+
adapt_layout_forward.register_autograd(adapt_layout_wrapper_backward, setup_context=adapt_layout_setup_context)
|