blksprs 1.10.2__py3-none-any.whl → 2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +4 -6
- blksprs/layouting/distribution_layout.py +64 -48
- blksprs/layouting/sparsity_layout.py +96 -72
- blksprs/ops/conversion.py +350 -312
- blksprs/ops/distribution.py +320 -266
- blksprs/ops/flow.py +135 -89
- blksprs/ops/matmul.py +184 -151
- blksprs/ops/misc/broadcast_ops.py +68 -53
- blksprs/ops/misc/row_wise.py +223 -176
- blksprs/ops/partitioning.py +140 -89
- blksprs/ops/repeat.py +118 -108
- blksprs/ops/softmax.py +201 -167
- blksprs/ops/transpose.py +71 -131
- blksprs/utils/autotuning.py +78 -0
- blksprs/utils/benchmarking.py +3 -3
- blksprs/utils/processing.py +2 -1
- blksprs/utils/tools.py +5 -6
- blksprs/utils/validation.py +22 -16
- {blksprs-1.10.2.dist-info → blksprs-2.0.dist-info}/METADATA +55 -36
- blksprs-2.0.dist-info/RECORD +23 -0
- {blksprs-1.10.2.dist-info → blksprs-2.0.dist-info}/WHEEL +1 -1
- blksprs/ops/misc/exp.py +0 -104
- blksprs/utils/layout_utils.py +0 -17
- blksprs-1.10.2.dist-info/RECORD +0 -24
- {blksprs-1.10.2.dist-info → blksprs-2.0.dist-info}/top_level.txt +0 -0
blksprs/ops/flow.py
CHANGED
|
@@ -1,20 +1,65 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import triton
|
|
3
3
|
from torch import Tensor
|
|
4
|
+
from torch._library import triton_op
|
|
5
|
+
from torch._library.triton import wrap_triton
|
|
4
6
|
from triton import language as tl
|
|
5
7
|
|
|
6
|
-
from blksprs.utils.tools import stride
|
|
7
|
-
|
|
8
|
-
|
|
8
|
+
from blksprs.utils.tools import stride
|
|
9
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@triton_op("blksprs::flow_pull_forward", mutates_args={})
|
|
13
|
+
def flow_pull_forward(x: Tensor, sparsity_layout_o: Tensor,
|
|
14
|
+
sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
15
|
+
sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
16
|
+
with torch.no_grad():
|
|
17
|
+
output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
18
|
+
dtype=x.dtype, device=x.device)
|
|
19
|
+
|
|
20
|
+
x_b, x_r, x_c = x.size()
|
|
21
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
22
|
+
o_b, o_r, o_c = output.size()
|
|
23
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
24
|
+
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
|
|
25
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
|
|
26
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
27
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
28
|
+
|
|
29
|
+
triton_grid = lambda meta: [o_b,
|
|
30
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
31
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
32
|
+
|
|
33
|
+
(wrap_triton(flow_pull_kernel)[triton_grid]
|
|
34
|
+
(x,
|
|
35
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
36
|
+
output,
|
|
37
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
38
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
39
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
40
|
+
sparsity_reverse_lut,
|
|
41
|
+
sparsity_block_size))
|
|
42
|
+
|
|
43
|
+
return output
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# noinspection PyUnusedLocal
|
|
47
|
+
@triton.autotune(
|
|
48
|
+
configs=get_autotune_configs(),
|
|
49
|
+
key=["sparsity_block_size"],
|
|
50
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
51
|
+
reset_to_zero=["o"]
|
|
52
|
+
)
|
|
9
53
|
@triton.jit
|
|
10
|
-
def
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
54
|
+
def flow_pull_kernel(x,
|
|
55
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
56
|
+
o,
|
|
57
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
58
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
59
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
60
|
+
r_lut,
|
|
61
|
+
sparsity_block_size,
|
|
62
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
18
63
|
# Get triton block indices
|
|
19
64
|
pid_blk = tl.program_id(axis=0)
|
|
20
65
|
pid_row = tl.program_id(axis=1)
|
|
@@ -40,32 +85,72 @@ def kernel_blocksparse_flow_pull(x,
|
|
|
40
85
|
rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
|
|
41
86
|
rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
42
87
|
|
|
43
|
-
if rev_idx_spa
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
88
|
+
if rev_idx_spa >= 0:
|
|
89
|
+
blk_x_idx = (rev_idx_spa * x_b_s +
|
|
90
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
91
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
92
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
93
|
+
blk_x_idx < x_b * x_b_s)
|
|
94
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
95
|
+
|
|
96
|
+
blk_o_idx = (pid_blk * o_b_s +
|
|
97
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
98
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
99
|
+
blk_o_msk = (blk_o_idx >= 0 and
|
|
100
|
+
blk_o_idx < o_b * o_b_s)
|
|
101
|
+
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@triton_op("blksprs::flow_push_forward", mutates_args={})
|
|
105
|
+
def flow_push_forward(x: Tensor, sparsity_layout_x: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
106
|
+
sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
107
|
+
with torch.no_grad():
|
|
108
|
+
output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
109
|
+
dtype=x.dtype, device=x.device)
|
|
110
|
+
|
|
111
|
+
x_b, x_r, x_c = x.size()
|
|
112
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
113
|
+
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
|
|
114
|
+
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
|
|
115
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
116
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
117
|
+
o_b, o_r, o_c = output.size()
|
|
118
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
119
|
+
|
|
120
|
+
triton_grid = lambda meta: [x_b,
|
|
121
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
122
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
123
|
+
|
|
124
|
+
(wrap_triton(flow_push_kernel)[triton_grid]
|
|
125
|
+
(x,
|
|
126
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
127
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
128
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
129
|
+
sparsity_reverse_lut,
|
|
130
|
+
output,
|
|
131
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
132
|
+
sparsity_block_size))
|
|
133
|
+
|
|
134
|
+
return output
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
# noinspection PyUnusedLocal
|
|
138
|
+
@triton.autotune(
|
|
139
|
+
configs=get_autotune_configs(),
|
|
140
|
+
key=["sparsity_block_size"],
|
|
141
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
142
|
+
reset_to_zero=["o"]
|
|
143
|
+
)
|
|
60
144
|
@triton.jit
|
|
61
|
-
def
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
145
|
+
def flow_push_kernel(x,
|
|
146
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
147
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
148
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
149
|
+
r_lut,
|
|
150
|
+
o,
|
|
151
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
152
|
+
sparsity_block_size,
|
|
153
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
69
154
|
# Get triton block indices
|
|
70
155
|
pid_blk = tl.program_id(axis=0)
|
|
71
156
|
pid_row = tl.program_id(axis=1)
|
|
@@ -91,56 +176,17 @@ def kernel_blocksparse_flow_push(x,
|
|
|
91
176
|
rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_x_b * s_l_x_b_s)
|
|
92
177
|
rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
93
178
|
|
|
94
|
-
if rev_idx_spa
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
def flow_forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
112
|
-
sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
113
|
-
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
114
|
-
dtype=x.dtype, device=x.device)
|
|
115
|
-
|
|
116
|
-
x_b, x_r, x_c = x.size()
|
|
117
|
-
x_b_s, x_r_s, x_c_s = stride(x)
|
|
118
|
-
o_b, o_r, o_c = output.size()
|
|
119
|
-
o_b_s, o_r_s, o_c_s = stride(output)
|
|
120
|
-
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
|
|
121
|
-
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
|
|
122
|
-
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
123
|
-
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
124
|
-
|
|
125
|
-
if triton_block_size is None:
|
|
126
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
127
|
-
|
|
128
|
-
triton_grid = lambda meta: [o_b,
|
|
129
|
-
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
130
|
-
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
131
|
-
|
|
132
|
-
(kernel_blocksparse_flow_pull[triton_grid]
|
|
133
|
-
(x,
|
|
134
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
135
|
-
output,
|
|
136
|
-
o_b, o_b_s, o_r_s, o_c_s,
|
|
137
|
-
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
138
|
-
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
139
|
-
sparsity_reverse_lut,
|
|
140
|
-
triton_block_size))
|
|
141
|
-
|
|
142
|
-
# Save for backward pass
|
|
143
|
-
ctx.sparsity_block_size = sparsity_block_size
|
|
144
|
-
ctx.triton_block_size = triton_block_size
|
|
145
|
-
|
|
146
|
-
return output
|
|
179
|
+
if rev_idx_spa >= 0:
|
|
180
|
+
blk_x_idx = (pid_blk * x_b_s +
|
|
181
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
182
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
183
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
184
|
+
blk_x_idx < x_b * x_b_s)
|
|
185
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
186
|
+
|
|
187
|
+
blk_o_idx = (rev_idx_spa * o_b_s +
|
|
188
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
189
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
190
|
+
blk_o_msk = (blk_o_idx >= 0 and
|
|
191
|
+
blk_o_idx < o_b * o_b_s)
|
|
192
|
+
tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
blksprs/ops/matmul.py
CHANGED
|
@@ -1,19 +1,22 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import triton
|
|
3
3
|
from torch import Tensor
|
|
4
|
+
from torch.library import triton_op, wrap_triton
|
|
4
5
|
from triton import language as tl
|
|
5
6
|
|
|
6
7
|
from blksprs.ops.transpose import transpose
|
|
7
8
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
8
|
-
from blksprs.utils.tools import
|
|
9
|
+
from blksprs.utils.tools import stride
|
|
10
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
9
11
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
10
|
-
validate_sparsity, validate_sparsity_block_size,
|
|
12
|
+
validate_sparsity, validate_sparsity_block_size, validate_dtype_float
|
|
11
13
|
|
|
12
14
|
|
|
15
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
13
16
|
def matmul(x: BlksprsTensor, sparsity_layout_x: Tensor,
|
|
14
17
|
y: BlksprsTensor, sparsity_layout_y: Tensor,
|
|
15
18
|
sparsity_layout_output: Tensor,
|
|
16
|
-
sparsity_block_size: int,
|
|
19
|
+
sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
|
|
17
20
|
"""Performs matrix multiplication between two block-sparse tensors.
|
|
18
21
|
|
|
19
22
|
The sparsity layout of the output tensor is used to only calculate blocks that will be present in the output.
|
|
@@ -25,7 +28,7 @@ def matmul(x: BlksprsTensor, sparsity_layout_x: Tensor,
|
|
|
25
28
|
sparsity_layout_y (Tensor): The sparsity layout of the second block-sparse tensor.
|
|
26
29
|
sparsity_layout_output (Tensor): The sparsity layout of the output tensor.
|
|
27
30
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
28
|
-
|
|
31
|
+
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
29
32
|
|
|
30
33
|
Returns:
|
|
31
34
|
BlksprsTensor: The result of the matrix multiplication as a block-sparse tensor in compressed form.
|
|
@@ -42,44 +45,24 @@ def matmul(x: BlksprsTensor, sparsity_layout_x: Tensor,
|
|
|
42
45
|
if sparsity_layout_x.size(-1) != sparsity_layout_y.size(-2):
|
|
43
46
|
raise ValueError("Inner dimensions of tensors must match")
|
|
44
47
|
validate_sparsity_block_size(sparsity_block_size, x, y)
|
|
45
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
46
48
|
|
|
47
|
-
|
|
48
|
-
sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
|
|
49
|
-
(sparsity_layout_x_flat == 1) -
|
|
50
|
-
(1 * (sparsity_layout_x_flat == 0)))
|
|
49
|
+
lut = matmul_build_lut(lut, sparsity_layout_x, sparsity_layout_y, sparsity_layout_output)
|
|
51
50
|
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
51
|
+
return BlksprsTensor(matmul_forward(x, y,
|
|
52
|
+
sparsity_layout_x, lut["sparsity_reverse_lut_x"],
|
|
53
|
+
sparsity_layout_y, lut["sparsity_reverse_lut_y"],
|
|
54
|
+
sparsity_layout_output, lut["sparsity_lut_o"],
|
|
55
|
+
sparsity_block_size, lut["n_sparse_blocks"]))
|
|
56
56
|
|
|
57
|
-
sparsity_lut_o = torch.nonzero(sparsity_layout_output).contiguous()
|
|
58
57
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
sparsity_layout_y, sparsity_reverse_lut_y,
|
|
68
|
-
sparsity_layout_output, sparsity_lut_o,
|
|
69
|
-
sparsity_block_size,
|
|
70
|
-
n_sparse_blocks,
|
|
71
|
-
triton_block_size))
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
class _BlocksparseMatmulSSS(torch.autograd.Function):
|
|
75
|
-
|
|
76
|
-
@staticmethod
|
|
77
|
-
def forward(ctx, x: Tensor, y: Tensor,
|
|
78
|
-
sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
|
|
79
|
-
sparsity_layout_y: Tensor, sparsity_reverse_lut_y: Tensor,
|
|
80
|
-
sparsity_layout_o: Tensor, sparsity_lut_o: Tensor,
|
|
81
|
-
sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
82
|
-
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
58
|
+
@triton_op("blksprs::matmul_forward", mutates_args={})
|
|
59
|
+
def matmul_forward(x: Tensor, y: Tensor,
|
|
60
|
+
sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
|
|
61
|
+
sparsity_layout_y: Tensor, sparsity_reverse_lut_y: Tensor,
|
|
62
|
+
_: Tensor, sparsity_lut_o: Tensor,
|
|
63
|
+
sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
64
|
+
with torch.no_grad():
|
|
65
|
+
output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
83
66
|
dtype=x.dtype, device=x.device)
|
|
84
67
|
|
|
85
68
|
x_b, x_r, x_c = x.size()
|
|
@@ -95,133 +78,183 @@ class _BlocksparseMatmulSSS(torch.autograd.Function):
|
|
|
95
78
|
s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
|
|
96
79
|
s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_o)
|
|
97
80
|
|
|
98
|
-
if triton_block_size is None:
|
|
99
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
100
|
-
|
|
101
81
|
triton_grid = lambda meta: [o_b,
|
|
102
82
|
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
103
83
|
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
104
84
|
|
|
105
|
-
(
|
|
85
|
+
(wrap_triton(matmul_kernel)[triton_grid]
|
|
106
86
|
(x,
|
|
107
87
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
108
|
-
s_l_x_b, s_l_x_b_s, s_l_x_r_s,
|
|
88
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s,
|
|
89
|
+
s_l_x_c, s_l_x_c_s,
|
|
109
90
|
sparsity_reverse_lut_x,
|
|
110
91
|
y,
|
|
111
92
|
y_b, y_b_s, y_r_s, y_c_s,
|
|
112
|
-
s_l_y_b, s_l_y_b_s, s_l_y_r_s,
|
|
93
|
+
s_l_y_b, s_l_y_b_s, s_l_y_r_s,
|
|
94
|
+
s_l_y_c_s,
|
|
113
95
|
sparsity_reverse_lut_y,
|
|
114
96
|
output,
|
|
115
97
|
o_b, o_b_s, o_r_s, o_c_s,
|
|
116
98
|
sparsity_lut_o,
|
|
117
99
|
s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
118
|
-
sparsity_block_size
|
|
119
|
-
triton_block_size))
|
|
120
|
-
|
|
121
|
-
ctx.save_for_backward(x, sparsity_layout_x, y, sparsity_layout_y, sparsity_layout_o)
|
|
122
|
-
ctx.sparsity_block_size = sparsity_block_size
|
|
123
|
-
ctx.triton_block_size = triton_block_size
|
|
100
|
+
sparsity_block_size))
|
|
124
101
|
|
|
125
102
|
return output
|
|
126
103
|
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
104
|
+
|
|
105
|
+
def matmul_wrapper_backward(ctx, grad_output):
|
|
106
|
+
x, sparsity_layout_x, y, sparsity_layout_y, sparsity_layout_o = ctx.saved_tensors
|
|
107
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
108
|
+
|
|
109
|
+
x_t, sparsity_layout_x_t = transpose(x, sparsity_layout_x, sparsity_block_size)
|
|
110
|
+
y_t, sparsity_layout_y_t = transpose(y, sparsity_layout_y, sparsity_block_size)
|
|
111
|
+
|
|
112
|
+
grad_x = matmul(grad_output, sparsity_layout_o, y_t, sparsity_layout_y_t, sparsity_layout_x,
|
|
113
|
+
sparsity_block_size)
|
|
114
|
+
grad_y = matmul(x_t, sparsity_layout_x_t, grad_output, sparsity_layout_o, sparsity_layout_y,
|
|
115
|
+
sparsity_block_size)
|
|
116
|
+
|
|
117
|
+
return grad_x, grad_y, None, None, None, None, None, None, None, None
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@triton.autotune(
|
|
121
|
+
configs=get_autotune_configs(),
|
|
122
|
+
key=["sparsity_block_size"],
|
|
123
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
124
|
+
reset_to_zero=["o"]
|
|
125
|
+
)
|
|
126
|
+
@triton.jit
|
|
127
|
+
def matmul_kernel(x,
|
|
128
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
129
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c, s_l_x_c_s,
|
|
130
|
+
r_lut_x,
|
|
131
|
+
y,
|
|
132
|
+
y_b, y_b_s, y_r_s, y_c_s,
|
|
133
|
+
s_l_y_b, s_l_y_b_s, s_l_y_r_s, s_l_y_c_s,
|
|
134
|
+
r_lut_y,
|
|
135
|
+
o,
|
|
136
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
137
|
+
s_lut_o,
|
|
138
|
+
s_lut_o_r, s_lut_o_r_s,
|
|
139
|
+
s_lut_o_c_s,
|
|
140
|
+
sparsity_block_size,
|
|
141
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
142
|
+
# Get triton block indices
|
|
143
|
+
pid_blk = tl.program_id(axis=0)
|
|
144
|
+
pid_row = tl.program_id(axis=1)
|
|
145
|
+
pid_col = tl.program_id(axis=2)
|
|
146
|
+
|
|
147
|
+
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
148
|
+
spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
|
|
149
|
+
spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
150
|
+
spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
|
|
151
|
+
|
|
152
|
+
spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
|
|
153
|
+
spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
154
|
+
spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
|
|
155
|
+
|
|
156
|
+
spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
|
|
157
|
+
spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
158
|
+
spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
|
|
159
|
+
|
|
160
|
+
# Setup buffer
|
|
161
|
+
buf = tl.zeros(shape=(TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), dtype=tl.float32)
|
|
162
|
+
|
|
163
|
+
# Slide over triton block sized segments of input tensors
|
|
164
|
+
for i_seg_tri in range(0, tl.cdiv(s_l_x_c * sparsity_block_size, TRITON_BLOCK_SIZE)):
|
|
165
|
+
# Convert to segment index of sparsity layout
|
|
166
|
+
i_seg_spa = (i_seg_tri * TRITON_BLOCK_SIZE) // sparsity_block_size
|
|
167
|
+
# Calculate the triton segment index within a block
|
|
168
|
+
i_seg_tri_mod = i_seg_tri % (sparsity_block_size // TRITON_BLOCK_SIZE)
|
|
169
|
+
|
|
170
|
+
# Get reverse sparsity indices for input tensors x and y
|
|
171
|
+
# These are either -1 if the block is empty or equal to the index of the block in the sparse tensor
|
|
172
|
+
|
|
173
|
+
# Get reverse sparsity indices for x
|
|
174
|
+
rev_idx_spa_x_idx = (spa_bat_o * s_l_x_b_s +
|
|
175
|
+
spa_row_o * s_l_x_r_s +
|
|
176
|
+
i_seg_spa * s_l_x_c_s)
|
|
177
|
+
rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
|
|
178
|
+
rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
|
|
179
|
+
|
|
180
|
+
# Get reverse sparsity indices for y
|
|
181
|
+
rev_idx_spa_y_idx = (spa_bat_o * s_l_y_b_s + i_seg_spa * s_l_y_r_s + spa_col_o * s_l_y_c_s)
|
|
182
|
+
rev_idx_spa_y_msk = (rev_idx_spa_y_idx >= 0 and rev_idx_spa_y_idx < s_l_y_b * s_l_y_b_s)
|
|
183
|
+
rev_idx_spa_y = tl.load(r_lut_y + rev_idx_spa_y_idx, mask=rev_idx_spa_y_msk).to(tl.int32)
|
|
184
|
+
|
|
185
|
+
# If both blocks are present commence calculation
|
|
186
|
+
if rev_idx_spa_x >= 0 and rev_idx_spa_y >= 0:
|
|
187
|
+
blk_x_idx = ((rev_idx_spa_x * x_b_s) +
|
|
188
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
189
|
+
((i_seg_tri_mod * TRITON_BLOCK_SIZE +
|
|
190
|
+
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
191
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
192
|
+
blk_x_idx < x_b * x_b_s)
|
|
193
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
194
|
+
|
|
195
|
+
blk_y_idx = ((rev_idx_spa_y * y_b_s) +
|
|
196
|
+
((i_seg_tri_mod * TRITON_BLOCK_SIZE +
|
|
197
|
+
tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
|
|
198
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
|
|
199
|
+
blk_y_msk = (blk_y_idx >= 0 and
|
|
200
|
+
blk_y_idx < y_b * y_b_s)
|
|
201
|
+
blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
|
|
202
|
+
|
|
203
|
+
# Perform matrix multiplication
|
|
204
|
+
buf += tl.dot(blk_x, blk_y)
|
|
205
|
+
|
|
206
|
+
# Cast buffer
|
|
207
|
+
buf = buf.to(o.dtype.element_ty)
|
|
208
|
+
|
|
209
|
+
# Store output
|
|
210
|
+
blk_o_idx = ((pid_blk * o_b_s) +
|
|
211
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
212
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
213
|
+
blk_o_msk = (blk_o_idx >= 0 and
|
|
214
|
+
blk_o_idx < o_b * o_b_s)
|
|
215
|
+
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def matmul_build_lut(lut: dict, sparsity_layout_x: Tensor, sparsity_layout_y: Tensor, sparsity_layout_output: Tensor):
|
|
219
|
+
if lut is None:
|
|
220
|
+
lut = dict()
|
|
221
|
+
|
|
222
|
+
if "sparsity_reverse_lut_x" not in lut:
|
|
223
|
+
sparsity_layout_x_flat = sparsity_layout_x.reshape(-1)
|
|
224
|
+
sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
|
|
225
|
+
(sparsity_layout_x_flat == 1) -
|
|
226
|
+
(1 * (sparsity_layout_x_flat == 0)))
|
|
227
|
+
lut["sparsity_reverse_lut_x"] = sparsity_reverse_lut_x
|
|
228
|
+
|
|
229
|
+
if "sparsity_reverse_lut_y" not in lut:
|
|
230
|
+
sparsity_layout_y_flat = sparsity_layout_y.reshape(-1)
|
|
231
|
+
sparsity_reverse_lut_y = ((torch.cumsum(sparsity_layout_y_flat, dim=-1) - 1) *
|
|
232
|
+
(sparsity_layout_y_flat == 1) -
|
|
233
|
+
(1 * (sparsity_layout_y_flat == 0)))
|
|
234
|
+
lut["sparsity_reverse_lut_y"] = sparsity_reverse_lut_y
|
|
235
|
+
|
|
236
|
+
if "sparsity_lut_o" not in lut:
|
|
237
|
+
sparsity_lut_o = torch.nonzero(sparsity_layout_output).contiguous()
|
|
238
|
+
lut["sparsity_lut_o"] = sparsity_lut_o
|
|
239
|
+
|
|
240
|
+
if "n_sparse_blocks" not in lut:
|
|
241
|
+
n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
|
|
242
|
+
lut["n_sparse_blocks"] = n_sparse_blocks
|
|
243
|
+
|
|
244
|
+
validate_contiguous(sparsity_layout_x, lut["sparsity_reverse_lut_x"],
|
|
245
|
+
sparsity_layout_y, lut["sparsity_reverse_lut_y"],
|
|
246
|
+
sparsity_layout_output, lut["sparsity_lut_o"])
|
|
247
|
+
|
|
248
|
+
return lut
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
# noinspection PyUnusedLocal
|
|
252
|
+
def matmul_setup_context(ctx, inputs, output):
|
|
253
|
+
(x, y, sparsity_layout_x, _, sparsity_layout_y, _,
|
|
254
|
+
sparsity_layout_o, _, sparsity_block_size, _) = inputs
|
|
255
|
+
|
|
256
|
+
ctx.save_for_backward(x, sparsity_layout_x, y, sparsity_layout_y, sparsity_layout_o)
|
|
257
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
matmul_forward.register_autograd(matmul_wrapper_backward, setup_context=matmul_setup_context)
|