blksprs 1.10.2__py3-none-any.whl → 2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +4 -6
- blksprs/layouting/distribution_layout.py +64 -48
- blksprs/layouting/sparsity_layout.py +96 -72
- blksprs/ops/conversion.py +350 -312
- blksprs/ops/distribution.py +320 -266
- blksprs/ops/flow.py +135 -89
- blksprs/ops/matmul.py +184 -151
- blksprs/ops/misc/broadcast_ops.py +68 -53
- blksprs/ops/misc/row_wise.py +223 -176
- blksprs/ops/partitioning.py +140 -89
- blksprs/ops/repeat.py +118 -108
- blksprs/ops/softmax.py +201 -167
- blksprs/ops/transpose.py +71 -131
- blksprs/utils/autotuning.py +78 -0
- blksprs/utils/benchmarking.py +3 -3
- blksprs/utils/processing.py +2 -1
- blksprs/utils/tools.py +5 -6
- blksprs/utils/validation.py +22 -16
- {blksprs-1.10.2.dist-info → blksprs-2.0.dist-info}/METADATA +55 -36
- blksprs-2.0.dist-info/RECORD +23 -0
- {blksprs-1.10.2.dist-info → blksprs-2.0.dist-info}/WHEEL +1 -1
- blksprs/ops/misc/exp.py +0 -104
- blksprs/utils/layout_utils.py +0 -17
- blksprs-1.10.2.dist-info/RECORD +0 -24
- {blksprs-1.10.2.dist-info → blksprs-2.0.dist-info}/top_level.txt +0 -0
blksprs/__init__.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
2
2
|
|
|
3
|
+
__version__ = "2.0"
|
|
4
|
+
|
|
3
5
|
|
|
4
6
|
class ops:
|
|
5
7
|
from blksprs.ops.conversion import to_dense, to_sparse, from_blksprs, to_blksprs, adapt_layout
|
|
@@ -13,25 +15,21 @@ class ops:
|
|
|
13
15
|
class misc:
|
|
14
16
|
from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_add, row_wise_sub
|
|
15
17
|
from blksprs.ops.misc.broadcast_ops import broadcast_add, broadcast_sub
|
|
16
|
-
from blksprs.ops.misc.exp import exp
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
class layouting:
|
|
20
21
|
from blksprs.layouting.distribution_layout import build_distribution_layout
|
|
21
22
|
from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption, \
|
|
22
|
-
build_sparsity_layout_matmul, build_sparsity_layout_matmul_fast
|
|
23
|
-
from blksprs.utils.layout_utils import build_full_sparsity_layout
|
|
23
|
+
build_sparsity_layout_matmul, build_sparsity_layout_matmul_fast, build_sparsity_layout_full
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
class utils:
|
|
27
27
|
from blksprs.utils.processing import apply_torch_linear, apply_torch_normalisation, apply_torch_dropout, \
|
|
28
28
|
apply_function_applicable_row_wise
|
|
29
29
|
from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse
|
|
30
|
-
from blksprs.utils.validation import disable_validation
|
|
31
30
|
|
|
32
31
|
class validation:
|
|
33
32
|
from blksprs.utils.validation import disable_validation
|
|
34
33
|
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_dtype_float, \
|
|
35
34
|
validate_dtype_int, validate_device, validate_sparsity, validate_sparsity_dense, \
|
|
36
|
-
validate_sparsity_block_size
|
|
37
|
-
validate_triton_block_size
|
|
35
|
+
validate_sparsity_block_size
|
|
@@ -1,17 +1,23 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
|
|
1
3
|
import torch
|
|
2
4
|
import triton
|
|
3
5
|
from torch import Tensor
|
|
6
|
+
from torch._library import triton_op
|
|
7
|
+
from torch._library.triton import wrap_triton
|
|
4
8
|
from triton import language as tl
|
|
5
9
|
|
|
6
10
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
7
|
-
from blksprs.utils.tools import
|
|
8
|
-
from blksprs.utils.
|
|
11
|
+
from blksprs.utils.tools import stride
|
|
12
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
13
|
+
from blksprs.utils.validation import validate_dimensions, validate_device, \
|
|
9
14
|
validate_contiguous
|
|
10
15
|
|
|
11
16
|
|
|
17
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
12
18
|
def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: Tensor,
|
|
13
19
|
dim: int, size_target: torch.Size,
|
|
14
|
-
sparsity_block_size: int
|
|
20
|
+
sparsity_block_size: int) -> Tensor:
|
|
15
21
|
"""Builds the sparsity layout of either the source of a gather or the target of a scatter operation.
|
|
16
22
|
|
|
17
23
|
Args:
|
|
@@ -20,7 +26,6 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
|
|
|
20
26
|
dim (int): The dimension along which the operation is conducted.
|
|
21
27
|
size_target (torch.Size): The size of the block-sparse target tensor in regular form.
|
|
22
28
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
23
|
-
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
24
29
|
|
|
25
30
|
Returns:
|
|
26
31
|
Tensor: The sparsity layout of the source or target tensor.
|
|
@@ -34,49 +39,58 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
|
|
|
34
39
|
|
|
35
40
|
adjusted_dim = dim % 3
|
|
36
41
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
42
|
+
return build_distribution_layout_operation(indices, sparsity_lut_i, adjusted_dim, size_target, sparsity_block_size)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@triton_op("blksprs::build_distribution_layout", mutates_args={})
|
|
46
|
+
def build_distribution_layout_operation(indices: Tensor, sparsity_lut_i: Tensor,
|
|
47
|
+
adjusted_dim: int, size_target: typing.List[int],
|
|
48
|
+
sparsity_block_size: int) -> Tensor:
|
|
49
|
+
with torch.no_grad():
|
|
50
|
+
output = torch.zeros(size_target[0], size_target[1] // sparsity_block_size,
|
|
51
|
+
size_target[2] // sparsity_block_size,
|
|
52
|
+
dtype=torch.bool, device=indices.device)
|
|
53
|
+
|
|
54
|
+
i_b, i_r, i_c = indices.size()
|
|
55
|
+
i_b_s, i_r_s, i_c_s = stride(indices)
|
|
56
|
+
s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
|
|
57
|
+
s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
|
|
58
|
+
o_b, o_r, o_c = output.size()
|
|
59
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
60
|
+
|
|
61
|
+
triton_grid = lambda meta: [i_b,
|
|
62
|
+
triton.cdiv(i_r, meta["TRITON_BLOCK_SIZE"]),
|
|
63
|
+
triton.cdiv(i_c, meta["TRITON_BLOCK_SIZE"])]
|
|
64
|
+
|
|
65
|
+
(wrap_triton(build_distribution_layout_kernel)[triton_grid]
|
|
66
|
+
(indices,
|
|
67
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
68
|
+
sparsity_lut_i,
|
|
69
|
+
s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
|
|
70
|
+
adjusted_dim,
|
|
71
|
+
output,
|
|
72
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
73
|
+
sparsity_block_size))
|
|
74
|
+
|
|
75
|
+
return output
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@triton.autotune(
|
|
79
|
+
configs=get_autotune_configs(),
|
|
80
|
+
key=["sparsity_block_size"],
|
|
81
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
82
|
+
reset_to_zero=["o"]
|
|
83
|
+
)
|
|
70
84
|
@triton.jit
|
|
71
|
-
def
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
85
|
+
def build_distribution_layout_kernel(i,
|
|
86
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
87
|
+
s_lut_i,
|
|
88
|
+
s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
|
|
89
|
+
dim,
|
|
90
|
+
o,
|
|
91
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
92
|
+
sparsity_block_size,
|
|
93
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
80
94
|
# Get triton block indices
|
|
81
95
|
pid_blk = tl.program_id(axis=0)
|
|
82
96
|
pid_row = tl.program_id(axis=1)
|
|
@@ -98,7 +112,8 @@ def kernel_distribution_layout(i,
|
|
|
98
112
|
blk_i_idx = (pid_blk * i_b_s +
|
|
99
113
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
100
114
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
101
|
-
blk_i_msk = (blk_i_idx >= 0 and
|
|
115
|
+
blk_i_msk = (blk_i_idx >= 0 and
|
|
116
|
+
blk_i_idx < i_b * i_b_s)
|
|
102
117
|
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk)
|
|
103
118
|
|
|
104
119
|
dst_bat_idx = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_i, dtype=tl.int32)
|
|
@@ -116,5 +131,6 @@ def kernel_distribution_layout(i,
|
|
|
116
131
|
blk_o_idx = ((dst_bat_idx * o_b_s) +
|
|
117
132
|
(dst_row_idx * o_r_s) +
|
|
118
133
|
(dst_col_idx * o_c_s))
|
|
119
|
-
blk_o_msk = (blk_o_idx >= 0 and
|
|
134
|
+
blk_o_msk = (blk_o_idx >= 0 and
|
|
135
|
+
blk_o_idx < o_b * o_b_s)
|
|
120
136
|
tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
|
|
@@ -3,21 +3,23 @@ import math
|
|
|
3
3
|
import torch
|
|
4
4
|
import triton
|
|
5
5
|
from torch import Tensor
|
|
6
|
+
from torch._library.triton import wrap_triton, triton_op
|
|
6
7
|
from triton import language as tl
|
|
7
8
|
|
|
8
9
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
9
|
-
from blksprs.utils.tools import
|
|
10
|
-
from blksprs.utils.
|
|
10
|
+
from blksprs.utils.tools import stride
|
|
11
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs, prune_autotune_configs_conversion
|
|
12
|
+
from blksprs.utils.validation import validate_dimensions, validate_device, \
|
|
11
13
|
validate_contiguous, validate_sparsity, validate_sparsity_block_size
|
|
12
14
|
|
|
13
15
|
|
|
14
|
-
|
|
16
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
17
|
+
def build_sparsity_layout(x: Tensor, sparsity_block_size: int) -> Tensor:
|
|
15
18
|
"""Builds the sparsity layout of a dense tensor in regular form covering its sparse blocks.
|
|
16
19
|
|
|
17
20
|
Args:
|
|
18
21
|
x (Tensor): A block-sparse (or dense) tensor in regular form.
|
|
19
22
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
20
|
-
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
21
23
|
|
|
22
24
|
Returns:
|
|
23
25
|
Tensor: The sparsity layout of the input block-sparse (or dense) tensor.
|
|
@@ -27,41 +29,47 @@ def build_sparsity_layout(x: Tensor, sparsity_block_size: int, triton_block_size
|
|
|
27
29
|
validate_contiguous(x)
|
|
28
30
|
validate_device(x)
|
|
29
31
|
|
|
30
|
-
|
|
31
|
-
dtype=torch.bool, device=x.device)
|
|
32
|
+
return build_sparsity_layout_operation(x, sparsity_block_size)
|
|
32
33
|
|
|
33
|
-
x_b, x_r, x_c = x.size()
|
|
34
|
-
x_b_s, x_r_s, x_c_s = stride(x)
|
|
35
|
-
o_b, o_r, o_c = output.size()
|
|
36
|
-
o_b_s, o_r_s, o_c_s = stride(output)
|
|
37
34
|
|
|
38
|
-
|
|
39
|
-
|
|
35
|
+
@triton_op("blksprs::build_sparsity_layout", mutates_args={})
|
|
36
|
+
def build_sparsity_layout_operation(x: Tensor, sparsity_block_size: int) -> Tensor:
|
|
37
|
+
with torch.no_grad():
|
|
38
|
+
output = torch.zeros(x.size(0), x.size(1) // sparsity_block_size, x.size(2) // sparsity_block_size,
|
|
39
|
+
dtype=torch.bool, device=x.device)
|
|
40
40
|
|
|
41
|
-
|
|
41
|
+
x_b, x_r, x_c = x.size()
|
|
42
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
43
|
+
o_b, o_r, o_c = output.size()
|
|
44
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
42
45
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
+
triton_grid = lambda meta: [x_b,
|
|
47
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
48
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
46
49
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
triton_block_size))
|
|
50
|
+
(wrap_triton(build_sparsity_layout_kernel)[triton_grid]
|
|
51
|
+
(x,
|
|
52
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
53
|
+
output,
|
|
54
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
55
|
+
sparsity_block_size))
|
|
54
56
|
|
|
55
|
-
|
|
57
|
+
return output
|
|
56
58
|
|
|
57
59
|
|
|
60
|
+
@triton.autotune(
|
|
61
|
+
configs=get_autotune_configs(),
|
|
62
|
+
key=["sparsity_block_size"],
|
|
63
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
64
|
+
reset_to_zero=["o"]
|
|
65
|
+
)
|
|
58
66
|
@triton.jit
|
|
59
|
-
def
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
67
|
+
def build_sparsity_layout_kernel(x,
|
|
68
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
69
|
+
o,
|
|
70
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
71
|
+
sparsity_block_size,
|
|
72
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
65
73
|
# Get triton block indices
|
|
66
74
|
pid_bat = tl.program_id(axis=0)
|
|
67
75
|
pid_row = tl.program_id(axis=1)
|
|
@@ -71,7 +79,8 @@ def kernel_sparsity_layout(x,
|
|
|
71
79
|
blk_x_idx = (pid_bat * x_b_s +
|
|
72
80
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
73
81
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
74
|
-
blk_x_msk = (blk_x_idx >= 0 and
|
|
82
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
83
|
+
blk_x_idx < x_b * x_b_s)
|
|
75
84
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
76
85
|
|
|
77
86
|
# Store sparsity layout value
|
|
@@ -83,9 +92,9 @@ def kernel_sparsity_layout(x,
|
|
|
83
92
|
tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
|
|
84
93
|
|
|
85
94
|
|
|
95
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
86
96
|
def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tensor,
|
|
87
|
-
sparsity_block_size_from: int, sparsity_block_size_to: int
|
|
88
|
-
triton_block_size: int = None) -> Tensor:
|
|
97
|
+
sparsity_block_size_from: int, sparsity_block_size_to: int) -> Tensor:
|
|
89
98
|
"""Builds the sparsity layout of a block-sparse tensor in compressed form if a different sparsity block size were
|
|
90
99
|
used.
|
|
91
100
|
|
|
@@ -94,7 +103,6 @@ def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tenso
|
|
|
94
103
|
sparsity_layout_from (Tensor): The sparsity layout of the input block-sparse tensor.
|
|
95
104
|
sparsity_block_size_from (int): The size of the sparsity blocks of the input tensor.
|
|
96
105
|
sparsity_block_size_to (int): The desired size of the sparsity blocks for the resulting layout.
|
|
97
|
-
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
98
106
|
|
|
99
107
|
Returns:
|
|
100
108
|
Tensor: The sparsity layout in regular form using the new sparsity block size of the input block-sparse tensor
|
|
@@ -107,54 +115,62 @@ def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tenso
|
|
|
107
115
|
validate_sparsity(sparsity_block_size_from, (x, sparsity_layout_from))
|
|
108
116
|
validate_sparsity_block_size(sparsity_block_size_from, x)
|
|
109
117
|
validate_sparsity_block_size(sparsity_block_size_to)
|
|
110
|
-
min_sparsity_block_size = min(sparsity_block_size_from, sparsity_block_size_to)
|
|
111
|
-
validate_triton_block_size(triton_block_size, min_sparsity_block_size)
|
|
112
118
|
|
|
113
119
|
sparsity_lut = torch.nonzero(sparsity_layout_from).contiguous()
|
|
114
120
|
|
|
115
121
|
validate_contiguous(sparsity_layout_from, sparsity_lut)
|
|
116
122
|
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
o_c = math.ceil(sparsity_layout_from.size(2) * sparsity_block_size_from // sparsity_block_size_to)
|
|
123
|
+
return build_sparsity_layout_adaption_operation(x, sparsity_layout_from, sparsity_lut,
|
|
124
|
+
sparsity_block_size_from, sparsity_block_size_to)
|
|
120
125
|
|
|
121
|
-
output = torch.zeros(o_b, o_r, o_c, dtype=torch.bool, device=x.device)
|
|
122
126
|
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
127
|
+
@triton_op("blksprs::build_sparsity_layout_adaption", mutates_args={})
|
|
128
|
+
def build_sparsity_layout_adaption_operation(x: Tensor, sparsity_layout_from: Tensor, sparsity_lut: Tensor,
|
|
129
|
+
sparsity_block_size_from: int, sparsity_block_size_to: int) -> Tensor:
|
|
130
|
+
with torch.no_grad():
|
|
131
|
+
o_b = sparsity_layout_from.size(0)
|
|
132
|
+
o_r = math.ceil(sparsity_layout_from.size(1) * sparsity_block_size_from // sparsity_block_size_to)
|
|
133
|
+
o_c = math.ceil(sparsity_layout_from.size(2) * sparsity_block_size_from // sparsity_block_size_to)
|
|
128
134
|
|
|
129
|
-
|
|
130
|
-
triton_block_size = get_triton_block_size(sparsity_block_size_from)
|
|
135
|
+
output = torch.zeros(o_b, o_r, o_c, dtype=torch.bool, device=x.device)
|
|
131
136
|
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
137
|
+
x_b, x_r, x_c = x.size()
|
|
138
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
139
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
140
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
141
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
135
142
|
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
140
|
-
output,
|
|
141
|
-
o_b, o_b_s, o_r_s, o_c_s,
|
|
142
|
-
sparsity_block_size_from,
|
|
143
|
-
sparsity_block_size_to,
|
|
144
|
-
triton_block_size))
|
|
143
|
+
triton_grid = lambda meta: [x_b,
|
|
144
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
145
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
145
146
|
|
|
146
|
-
|
|
147
|
+
(wrap_triton(build_sparsity_layout_adaption_kernel)[triton_grid]
|
|
148
|
+
(x,
|
|
149
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
150
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
151
|
+
output,
|
|
152
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
153
|
+
sparsity_block_size_from,
|
|
154
|
+
sparsity_block_size_to))
|
|
147
155
|
|
|
156
|
+
return output
|
|
148
157
|
|
|
158
|
+
|
|
159
|
+
@triton.autotune(
|
|
160
|
+
configs=get_autotune_configs(),
|
|
161
|
+
key=["sparsity_block_size_from", "sparsity_block_size_to"],
|
|
162
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs_conversion},
|
|
163
|
+
reset_to_zero=["o"]
|
|
164
|
+
)
|
|
149
165
|
@triton.jit
|
|
150
|
-
def
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
166
|
+
def build_sparsity_layout_adaption_kernel(x,
|
|
167
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
168
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
169
|
+
o,
|
|
170
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
171
|
+
sparsity_block_size_from,
|
|
172
|
+
sparsity_block_size_to,
|
|
173
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
158
174
|
# Get triton block indices
|
|
159
175
|
pid_blk = tl.program_id(axis=0)
|
|
160
176
|
pid_row = tl.program_id(axis=1)
|
|
@@ -177,21 +193,23 @@ def kernel_sparsity_layout_adaption(x,
|
|
|
177
193
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
178
194
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
179
195
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
180
|
-
blk_x_msk = (blk_x_idx >= 0 and
|
|
196
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
197
|
+
blk_x_idx < x_b * x_b_s)
|
|
181
198
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
182
199
|
|
|
183
200
|
# Store sparsity layout value
|
|
184
201
|
if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
|
|
185
202
|
blk_o_idx = ((spa_bat * o_b_s) +
|
|
186
|
-
(((
|
|
203
|
+
(((pid_row * TRITON_BLOCK_SIZE + spa_row * sparsity_block_size_from)
|
|
187
204
|
// sparsity_block_size_to) * o_r_s) +
|
|
188
|
-
(((
|
|
205
|
+
(((pid_col * TRITON_BLOCK_SIZE + spa_col * sparsity_block_size_from)
|
|
189
206
|
// sparsity_block_size_to) * o_c_s))
|
|
190
207
|
blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
|
|
191
208
|
tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
|
|
192
209
|
|
|
193
210
|
|
|
194
|
-
|
|
211
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
212
|
+
def build_sparsity_layout_matmul(sparsity_layout_x: Tensor, sparsity_layout_y: Tensor) -> Tensor:
|
|
195
213
|
"""Builds the precise sparsity layout of the result of a matrix multiplication between the two input tensors.
|
|
196
214
|
|
|
197
215
|
Args:
|
|
@@ -205,6 +223,7 @@ def build_sparsity_layout_matmul(sparsity_layout_x: Tensor, sparsity_layout_y: T
|
|
|
205
223
|
return torch.matmul(sparsity_layout_x.to(torch.float), sparsity_layout_y.to(torch.float)).to(torch.bool)
|
|
206
224
|
|
|
207
225
|
|
|
226
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
208
227
|
def build_sparsity_layout_matmul_fast(sparsity_layout_x: Tensor, sparsity_layout_y: Tensor):
|
|
209
228
|
"""Builds the approximate sparsity layout of the result of a matrix multiplication between the two input tensors.
|
|
210
229
|
|
|
@@ -225,3 +244,8 @@ def build_sparsity_layout_matmul_fast(sparsity_layout_x: Tensor, sparsity_layout
|
|
|
225
244
|
sparsity_layout_y_slice = torch.max(sparsity_layout_y, dim=-2).values.unsqueeze(1)
|
|
226
245
|
|
|
227
246
|
return torch.logical_or(sparsity_layout_x_slice, sparsity_layout_y_slice)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def build_sparsity_layout_full(x: Tensor, sparsity_block_size: int) -> Tensor:
|
|
250
|
+
return torch.ones(size=(x.size(0), x.size(1) // sparsity_block_size, x.size(2) // sparsity_block_size),
|
|
251
|
+
dtype=torch.bool, device=x.device)
|