blksprs 1.11__py3-none-any.whl → 2.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +2 -5
- blksprs/layouting/distribution_layout.py +32 -25
- blksprs/layouting/sparsity_layout.py +65 -52
- blksprs/ops/conversion.py +421 -399
- blksprs/ops/distribution.py +404 -366
- blksprs/ops/flow.py +125 -106
- blksprs/ops/matmul.py +220 -204
- blksprs/ops/misc/broadcast_ops.py +53 -35
- blksprs/ops/misc/row_wise.py +151 -91
- blksprs/ops/partitioning.py +136 -132
- blksprs/ops/repeat.py +115 -120
- blksprs/ops/softmax.py +274 -246
- blksprs/ops/transpose.py +52 -51
- blksprs/utils/benchmarking.py +3 -3
- blksprs/utils/tools.py +31 -4
- blksprs/utils/validation.py +0 -14
- {blksprs-1.11.dist-info → blksprs-2.0rc1.dist-info}/METADATA +42 -36
- blksprs-2.0rc1.dist-info/RECORD +22 -0
- {blksprs-1.11.dist-info → blksprs-2.0rc1.dist-info}/WHEEL +1 -1
- blksprs/utils/layout_utils.py +0 -17
- blksprs-1.11.dist-info/RECORD +0 -23
- {blksprs-1.11.dist-info → blksprs-2.0rc1.dist-info}/top_level.txt +0 -0
blksprs/__init__.py
CHANGED
|
@@ -18,19 +18,16 @@ class ops:
|
|
|
18
18
|
class layouting:
|
|
19
19
|
from blksprs.layouting.distribution_layout import build_distribution_layout
|
|
20
20
|
from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption, \
|
|
21
|
-
build_sparsity_layout_matmul, build_sparsity_layout_matmul_fast
|
|
22
|
-
from blksprs.utils.layout_utils import build_full_sparsity_layout
|
|
21
|
+
build_sparsity_layout_matmul, build_sparsity_layout_matmul_fast, build_sparsity_layout_full
|
|
23
22
|
|
|
24
23
|
|
|
25
24
|
class utils:
|
|
26
25
|
from blksprs.utils.processing import apply_torch_linear, apply_torch_normalisation, apply_torch_dropout, \
|
|
27
26
|
apply_function_applicable_row_wise
|
|
28
27
|
from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse
|
|
29
|
-
from blksprs.utils.validation import disable_validation
|
|
30
28
|
|
|
31
29
|
class validation:
|
|
32
30
|
from blksprs.utils.validation import disable_validation
|
|
33
31
|
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_dtype_float, \
|
|
34
32
|
validate_dtype_int, validate_device, validate_sparsity, validate_sparsity_dense, \
|
|
35
|
-
validate_sparsity_block_size
|
|
36
|
-
validate_triton_block_size
|
|
33
|
+
validate_sparsity_block_size
|
|
@@ -4,14 +4,14 @@ from torch import Tensor
|
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
6
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
7
|
-
from blksprs.utils.tools import
|
|
8
|
-
from blksprs.utils.validation import
|
|
7
|
+
from blksprs.utils.tools import stride, get_autotune_configs
|
|
8
|
+
from blksprs.utils.validation import validate_dimensions, validate_device, \
|
|
9
9
|
validate_contiguous
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: Tensor,
|
|
13
13
|
dim: int, size_target: torch.Size,
|
|
14
|
-
sparsity_block_size: int
|
|
14
|
+
sparsity_block_size: int) -> Tensor:
|
|
15
15
|
"""Builds the sparsity layout of either the source of a gather or the target of a scatter operation.
|
|
16
16
|
|
|
17
17
|
Args:
|
|
@@ -20,7 +20,6 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
|
|
|
20
20
|
dim (int): The dimension along which the operation is conducted.
|
|
21
21
|
size_target (torch.Size): The size of the block-sparse target tensor in regular form.
|
|
22
22
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
23
|
-
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
24
23
|
|
|
25
24
|
Returns:
|
|
26
25
|
Tensor: The sparsity layout of the source or target tensor.
|
|
@@ -44,16 +43,11 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
|
|
|
44
43
|
o_b, o_r, o_c = output.size()
|
|
45
44
|
o_b_s, o_r_s, o_c_s = stride(output)
|
|
46
45
|
|
|
47
|
-
if triton_block_size is None:
|
|
48
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
49
|
-
|
|
50
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
51
|
-
|
|
52
46
|
triton_grid = lambda meta: [i_b,
|
|
53
47
|
triton.cdiv(i_r, meta["TRITON_BLOCK_SIZE"]),
|
|
54
48
|
triton.cdiv(i_c, meta["TRITON_BLOCK_SIZE"])]
|
|
55
49
|
|
|
56
|
-
(
|
|
50
|
+
(build_distribution_layout_kernel[triton_grid]
|
|
57
51
|
(indices,
|
|
58
52
|
i_b, i_b_s, i_r_s, i_c_s,
|
|
59
53
|
sparsity_lut_i,
|
|
@@ -61,27 +55,34 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
|
|
|
61
55
|
adjusted_dim,
|
|
62
56
|
output,
|
|
63
57
|
o_b, o_b_s, o_r_s, o_c_s,
|
|
64
|
-
sparsity_block_size
|
|
65
|
-
triton_block_size))
|
|
58
|
+
sparsity_block_size))
|
|
66
59
|
|
|
67
60
|
return output
|
|
68
61
|
|
|
69
62
|
|
|
63
|
+
@triton.autotune(
|
|
64
|
+
configs=get_autotune_configs(),
|
|
65
|
+
key=[],
|
|
66
|
+
reset_to_zero=["o"]
|
|
67
|
+
)
|
|
70
68
|
@triton.jit
|
|
71
|
-
def
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
69
|
+
def build_distribution_layout_kernel(i,
|
|
70
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
71
|
+
s_lut_i,
|
|
72
|
+
s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
|
|
73
|
+
dim,
|
|
74
|
+
o,
|
|
75
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
76
|
+
sparsity_block_size,
|
|
77
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
80
78
|
# Get triton block indices
|
|
81
79
|
pid_blk = tl.program_id(axis=0)
|
|
82
80
|
pid_row = tl.program_id(axis=1)
|
|
83
81
|
pid_col = tl.program_id(axis=2)
|
|
84
82
|
|
|
83
|
+
# Get valid triton block size
|
|
84
|
+
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
85
|
+
|
|
85
86
|
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
86
87
|
spa_bat_i_idx = (pid_blk * s_lut_i_r_s + 0 * s_lut_i_c_s)
|
|
87
88
|
spa_bat_i_msk = (spa_bat_i_idx >= 0 and spa_bat_i_idx < s_lut_i_r * s_lut_i_r_s)
|
|
@@ -96,9 +97,12 @@ def kernel_distribution_layout(i,
|
|
|
96
97
|
spa_col_i = tl.load(s_lut_i + spa_col_i_idx, mask=spa_col_i_msk)
|
|
97
98
|
|
|
98
99
|
blk_i_idx = (pid_blk * i_b_s +
|
|
99
|
-
((pid_row *
|
|
100
|
-
((pid_col *
|
|
101
|
-
blk_i_msk = (blk_i_idx >= 0 and
|
|
100
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
101
|
+
((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
102
|
+
blk_i_msk = ((blk_i_idx >= 0 and
|
|
103
|
+
blk_i_idx < i_b * i_b_s) and
|
|
104
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
105
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
102
106
|
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk)
|
|
103
107
|
|
|
104
108
|
dst_bat_idx = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_i, dtype=tl.int32)
|
|
@@ -116,5 +120,8 @@ def kernel_distribution_layout(i,
|
|
|
116
120
|
blk_o_idx = ((dst_bat_idx * o_b_s) +
|
|
117
121
|
(dst_row_idx * o_r_s) +
|
|
118
122
|
(dst_col_idx * o_c_s))
|
|
119
|
-
blk_o_msk = (blk_o_idx >= 0 and
|
|
123
|
+
blk_o_msk = ((blk_o_idx >= 0 and
|
|
124
|
+
blk_o_idx < o_b * o_b_s) and
|
|
125
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
126
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
120
127
|
tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
|
|
@@ -3,21 +3,21 @@ import math
|
|
|
3
3
|
import torch
|
|
4
4
|
import triton
|
|
5
5
|
from torch import Tensor
|
|
6
|
+
from torch._library.triton import wrap_triton
|
|
6
7
|
from triton import language as tl
|
|
7
8
|
|
|
8
9
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
9
|
-
from blksprs.utils.tools import
|
|
10
|
-
from blksprs.utils.validation import
|
|
10
|
+
from blksprs.utils.tools import stride, get_autotune_configs
|
|
11
|
+
from blksprs.utils.validation import validate_dimensions, validate_device, \
|
|
11
12
|
validate_contiguous, validate_sparsity, validate_sparsity_block_size
|
|
12
13
|
|
|
13
14
|
|
|
14
|
-
def build_sparsity_layout(x: Tensor, sparsity_block_size: int
|
|
15
|
+
def build_sparsity_layout(x: Tensor, sparsity_block_size: int) -> Tensor:
|
|
15
16
|
"""Builds the sparsity layout of a dense tensor in regular form covering its sparse blocks.
|
|
16
17
|
|
|
17
18
|
Args:
|
|
18
19
|
x (Tensor): A block-sparse (or dense) tensor in regular form.
|
|
19
20
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
20
|
-
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
21
21
|
|
|
22
22
|
Returns:
|
|
23
23
|
Tensor: The sparsity layout of the input block-sparse (or dense) tensor.
|
|
@@ -35,57 +35,61 @@ def build_sparsity_layout(x: Tensor, sparsity_block_size: int, triton_block_size
|
|
|
35
35
|
o_b, o_r, o_c = output.size()
|
|
36
36
|
o_b_s, o_r_s, o_c_s = stride(output)
|
|
37
37
|
|
|
38
|
-
if triton_block_size is None:
|
|
39
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
40
|
-
|
|
41
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
42
|
-
|
|
43
38
|
triton_grid = lambda meta: [x_b,
|
|
44
|
-
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
45
|
-
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
39
|
+
triton.cdiv(x_r, min(meta["sparsity_block_size"], meta["TRITON_BLOCK_SIZE"])),
|
|
40
|
+
triton.cdiv(x_c, min(meta["sparsity_block_size"], meta["TRITON_BLOCK_SIZE"]))]
|
|
46
41
|
|
|
47
|
-
(
|
|
42
|
+
(wrap_triton(build_sparsity_layout_kernel)[triton_grid]
|
|
48
43
|
(x,
|
|
49
44
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
50
45
|
output,
|
|
51
46
|
o_b, o_b_s, o_r_s, o_c_s,
|
|
52
|
-
sparsity_block_size
|
|
53
|
-
triton_block_size))
|
|
47
|
+
sparsity_block_size))
|
|
54
48
|
|
|
55
49
|
return output
|
|
56
50
|
|
|
57
51
|
|
|
52
|
+
@triton.autotune(
|
|
53
|
+
configs=get_autotune_configs(),
|
|
54
|
+
key=[],
|
|
55
|
+
reset_to_zero=["o"]
|
|
56
|
+
)
|
|
58
57
|
@triton.jit
|
|
59
|
-
def
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
58
|
+
def build_sparsity_layout_kernel(x,
|
|
59
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
60
|
+
o,
|
|
61
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
62
|
+
sparsity_block_size,
|
|
63
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
65
64
|
# Get triton block indices
|
|
66
65
|
pid_bat = tl.program_id(axis=0)
|
|
67
66
|
pid_row = tl.program_id(axis=1)
|
|
68
67
|
pid_col = tl.program_id(axis=2)
|
|
69
68
|
|
|
69
|
+
# Get valid triton block size
|
|
70
|
+
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
71
|
+
|
|
70
72
|
# Load x values
|
|
71
73
|
blk_x_idx = (pid_bat * x_b_s +
|
|
72
|
-
((pid_row *
|
|
73
|
-
((pid_col *
|
|
74
|
-
blk_x_msk = (blk_x_idx >= 0 and
|
|
74
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
75
|
+
((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
76
|
+
blk_x_msk = ((blk_x_idx >= 0 and
|
|
77
|
+
blk_x_idx < x_b * x_b_s) and
|
|
78
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
79
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
75
80
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
76
81
|
|
|
77
82
|
# Store sparsity layout value
|
|
78
83
|
if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
|
|
79
84
|
blk_o_idx = (pid_bat * o_b_s +
|
|
80
|
-
(((pid_row *
|
|
81
|
-
((pid_col *
|
|
85
|
+
(((pid_row * val_tbs) // sparsity_block_size) * o_r_s +
|
|
86
|
+
((pid_col * val_tbs) // sparsity_block_size) * o_c_s))
|
|
82
87
|
blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
|
|
83
88
|
tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
|
|
84
89
|
|
|
85
90
|
|
|
86
91
|
def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tensor,
|
|
87
|
-
sparsity_block_size_from: int, sparsity_block_size_to: int
|
|
88
|
-
triton_block_size: int = None) -> Tensor:
|
|
92
|
+
sparsity_block_size_from: int, sparsity_block_size_to: int) -> Tensor:
|
|
89
93
|
"""Builds the sparsity layout of a block-sparse tensor in compressed form if a different sparsity block size were
|
|
90
94
|
used.
|
|
91
95
|
|
|
@@ -94,7 +98,6 @@ def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tenso
|
|
|
94
98
|
sparsity_layout_from (Tensor): The sparsity layout of the input block-sparse tensor.
|
|
95
99
|
sparsity_block_size_from (int): The size of the sparsity blocks of the input tensor.
|
|
96
100
|
sparsity_block_size_to (int): The desired size of the sparsity blocks for the resulting layout.
|
|
97
|
-
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
98
101
|
|
|
99
102
|
Returns:
|
|
100
103
|
Tensor: The sparsity layout in regular form using the new sparsity block size of the input block-sparse tensor
|
|
@@ -107,8 +110,6 @@ def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tenso
|
|
|
107
110
|
validate_sparsity(sparsity_block_size_from, (x, sparsity_layout_from))
|
|
108
111
|
validate_sparsity_block_size(sparsity_block_size_from, x)
|
|
109
112
|
validate_sparsity_block_size(sparsity_block_size_to)
|
|
110
|
-
min_sparsity_block_size = min(sparsity_block_size_from, sparsity_block_size_to)
|
|
111
|
-
validate_triton_block_size(triton_block_size, min_sparsity_block_size)
|
|
112
113
|
|
|
113
114
|
sparsity_lut = torch.nonzero(sparsity_layout_from).contiguous()
|
|
114
115
|
|
|
@@ -126,40 +127,44 @@ def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tenso
|
|
|
126
127
|
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
127
128
|
o_b_s, o_r_s, o_c_s = stride(output)
|
|
128
129
|
|
|
129
|
-
if triton_block_size is None:
|
|
130
|
-
triton_block_size = get_triton_block_size(sparsity_block_size_from)
|
|
131
|
-
|
|
132
130
|
triton_grid = lambda meta: [x_b,
|
|
133
|
-
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
134
|
-
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
131
|
+
triton.cdiv(x_r, min(meta["sparsity_block_size_to"], meta["TRITON_BLOCK_SIZE"])),
|
|
132
|
+
triton.cdiv(x_c, min(meta["sparsity_block_size_to"], meta["TRITON_BLOCK_SIZE"]))]
|
|
135
133
|
|
|
136
|
-
(
|
|
134
|
+
(wrap_triton(build_sparsity_layout_adaption_kernel)[triton_grid]
|
|
137
135
|
(x,
|
|
138
136
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
139
137
|
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
140
138
|
output,
|
|
141
139
|
o_b, o_b_s, o_r_s, o_c_s,
|
|
142
140
|
sparsity_block_size_from,
|
|
143
|
-
sparsity_block_size_to
|
|
144
|
-
triton_block_size))
|
|
141
|
+
sparsity_block_size_to))
|
|
145
142
|
|
|
146
143
|
return output
|
|
147
144
|
|
|
148
145
|
|
|
146
|
+
@triton.autotune(
|
|
147
|
+
configs=get_autotune_configs(),
|
|
148
|
+
key=[],
|
|
149
|
+
reset_to_zero=["o"]
|
|
150
|
+
)
|
|
149
151
|
@triton.jit
|
|
150
|
-
def
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
152
|
+
def build_sparsity_layout_adaption_kernel(x,
|
|
153
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
154
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
155
|
+
o,
|
|
156
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
157
|
+
sparsity_block_size_from,
|
|
158
|
+
sparsity_block_size_to,
|
|
159
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
158
160
|
# Get triton block indices
|
|
159
161
|
pid_blk = tl.program_id(axis=0)
|
|
160
162
|
pid_row = tl.program_id(axis=1)
|
|
161
163
|
pid_col = tl.program_id(axis=2)
|
|
162
164
|
|
|
165
|
+
# Get valid triton block size
|
|
166
|
+
val_tbs = min(sparsity_block_size_to, TRITON_BLOCK_SIZE)
|
|
167
|
+
|
|
163
168
|
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
164
169
|
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
165
170
|
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
@@ -175,23 +180,26 @@ def kernel_sparsity_layout_adaption(x,
|
|
|
175
180
|
|
|
176
181
|
# Load x values
|
|
177
182
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
178
|
-
((pid_row *
|
|
179
|
-
((pid_col *
|
|
180
|
-
blk_x_msk = (blk_x_idx >= 0 and
|
|
183
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
184
|
+
((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
185
|
+
blk_x_msk = ((blk_x_idx >= 0 and
|
|
186
|
+
blk_x_idx < x_b * x_b_s) and
|
|
187
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < sparsity_block_size_from and
|
|
188
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < sparsity_block_size_from))
|
|
181
189
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
182
190
|
|
|
183
191
|
# Store sparsity layout value
|
|
184
192
|
if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
|
|
185
193
|
blk_o_idx = ((spa_bat * o_b_s) +
|
|
186
|
-
(((
|
|
194
|
+
(((pid_row * val_tbs + spa_row * sparsity_block_size_from)
|
|
187
195
|
// sparsity_block_size_to) * o_r_s) +
|
|
188
|
-
(((
|
|
196
|
+
(((pid_col * val_tbs + spa_col * sparsity_block_size_from)
|
|
189
197
|
// sparsity_block_size_to) * o_c_s))
|
|
190
198
|
blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
|
|
191
199
|
tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
|
|
192
200
|
|
|
193
201
|
|
|
194
|
-
def build_sparsity_layout_matmul(sparsity_layout_x: Tensor, sparsity_layout_y: Tensor):
|
|
202
|
+
def build_sparsity_layout_matmul(sparsity_layout_x: Tensor, sparsity_layout_y: Tensor) -> Tensor:
|
|
195
203
|
"""Builds the precise sparsity layout of the result of a matrix multiplication between the two input tensors.
|
|
196
204
|
|
|
197
205
|
Args:
|
|
@@ -225,3 +233,8 @@ def build_sparsity_layout_matmul_fast(sparsity_layout_x: Tensor, sparsity_layout
|
|
|
225
233
|
sparsity_layout_y_slice = torch.max(sparsity_layout_y, dim=-2).values.unsqueeze(1)
|
|
226
234
|
|
|
227
235
|
return torch.logical_or(sparsity_layout_x_slice, sparsity_layout_y_slice)
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def build_sparsity_layout_full(x: Tensor, sparsity_block_size: int) -> Tensor:
|
|
239
|
+
return torch.ones(size=(x.size(0), x.size(1) // sparsity_block_size, x.size(2) // sparsity_block_size),
|
|
240
|
+
dtype=torch.bool, device=x.device)
|