blksprs 1.0__py3-none-any.whl → 1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/layouting/distribution_layout.py +114 -0
- blksprs/layouting/sparsity_layout.py +129 -7
- blksprs/misc/broadcast_addition.py +132 -0
- blksprs/ops/conversion.py +237 -17
- blksprs/ops/distribution.py +362 -0
- blksprs/ops/exp.py +18 -8
- blksprs/ops/{matmul_sss.py → matmul.py} +28 -26
- blksprs/ops/row_wise_sum.py +21 -5
- blksprs/ops/softmax.py +23 -12
- blksprs/ops/transpose.py +19 -7
- blksprs/utils/tools.py +1 -28
- blksprs/utils/validation.py +53 -1
- {blksprs-1.0.dist-info → blksprs-1.2.dist-info}/METADATA +39 -14
- blksprs-1.2.dist-info/RECORD +17 -0
- {blksprs-1.0.dist-info → blksprs-1.2.dist-info}/WHEEL +1 -1
- blksprs-1.0.dist-info/RECORD +0 -14
- {blksprs-1.0.dist-info → blksprs-1.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
from triton import language as tl
|
|
5
|
+
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size
|
|
7
|
+
from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
|
|
8
|
+
validate_contiguous
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
|
|
12
|
+
size_target: torch.Size,
|
|
13
|
+
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
14
|
+
"""Builds the sparsity layout of either the source of a gather or the target of a scatter operation.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
indices (Tensor): The block-sparse indices tensor in compressed form used for the gather or scatter operation.
|
|
18
|
+
sparsity_layout_indices (Tensor): The sparsity layout of the indices block-sparse tensor.
|
|
19
|
+
size_target (torch.Size): The size of the block-sparse target tensor in regular form.
|
|
20
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
21
|
+
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
Tensor: The sparsity layout of the source or target tensor.
|
|
25
|
+
|
|
26
|
+
"""
|
|
27
|
+
validate_dimensions(indices)
|
|
28
|
+
validate_contiguous(indices)
|
|
29
|
+
validate_device(indices)
|
|
30
|
+
|
|
31
|
+
sparsity_lut_i = torch.nonzero(sparsity_layout_indices).contiguous()
|
|
32
|
+
|
|
33
|
+
output = torch.zeros(size_target[0], size_target[1] // sparsity_block_size, size_target[2] // sparsity_block_size,
|
|
34
|
+
device=indices.device, dtype=torch.int32)
|
|
35
|
+
|
|
36
|
+
i_b, i_r, i_c = indices.size()
|
|
37
|
+
i_b_s, i_r_s, i_c_s = indices.stride()
|
|
38
|
+
s_l_i_b, s_l_i_r, s_l_i_c = sparsity_layout_indices.size()
|
|
39
|
+
s_l_i_b_s, s_l_i_r_s, s_l_i_c_s = sparsity_layout_indices.stride()
|
|
40
|
+
s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
|
|
41
|
+
s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
|
|
42
|
+
o_b, o_r, o_c = output.size()
|
|
43
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
44
|
+
|
|
45
|
+
if triton_block_size is None:
|
|
46
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
47
|
+
|
|
48
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
49
|
+
|
|
50
|
+
triton_grid = lambda meta: [i_b,
|
|
51
|
+
triton.cdiv(i_r, meta["TRITON_BLOCK_SIZE"]),
|
|
52
|
+
triton.cdiv(i_c, meta["TRITON_BLOCK_SIZE"])]
|
|
53
|
+
|
|
54
|
+
(kernel_distribution_layout[triton_grid]
|
|
55
|
+
(indices,
|
|
56
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
57
|
+
sparsity_layout_indices,
|
|
58
|
+
s_l_i_b, s_l_i_b_s, s_l_i_r, s_l_i_r_s, s_l_i_c, s_l_i_c_s,
|
|
59
|
+
sparsity_lut_i,
|
|
60
|
+
s_lut_i_r, s_lut_i_r_s, s_lut_i_c, s_lut_i_c_s,
|
|
61
|
+
output,
|
|
62
|
+
o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,
|
|
63
|
+
sparsity_block_size,
|
|
64
|
+
triton_block_size))
|
|
65
|
+
|
|
66
|
+
return output
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@triton.jit
|
|
70
|
+
def kernel_distribution_layout(i,
|
|
71
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
72
|
+
s_l_i,
|
|
73
|
+
s_l_i_b, s_l_i_b_s, s_l_i_r, s_l_i_r_s, s_l_i_c, s_l_i_c_s,
|
|
74
|
+
s_lut_i,
|
|
75
|
+
s_lut_i_r, s_lut_i_r_s, s_lut_i_c, s_lut_i_c_s,
|
|
76
|
+
o,
|
|
77
|
+
o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,
|
|
78
|
+
sparsity_block_size,
|
|
79
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
80
|
+
# Get triton block indices
|
|
81
|
+
pid_blk = tl.program_id(axis=0)
|
|
82
|
+
pid_row = tl.program_id(axis=1)
|
|
83
|
+
pid_col = tl.program_id(axis=2)
|
|
84
|
+
|
|
85
|
+
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
86
|
+
spa_bat_i_idx = (pid_blk * s_lut_i_r_s + 0 * s_lut_i_c_s)
|
|
87
|
+
spa_bat_i_msk = (spa_bat_i_idx < s_lut_i_r * s_lut_i_r_s)
|
|
88
|
+
spa_bat_i = tl.load(s_lut_i + spa_bat_i_idx, mask=spa_bat_i_msk)
|
|
89
|
+
|
|
90
|
+
spa_row_i_idx = (pid_blk * s_lut_i_r_s + 1 * s_lut_i_c_s)
|
|
91
|
+
spa_row_i_msk = (spa_row_i_idx < s_lut_i_r * s_lut_i_r_s)
|
|
92
|
+
spa_row_i = tl.load(s_lut_i + spa_row_i_idx, mask=spa_row_i_msk)
|
|
93
|
+
|
|
94
|
+
blk_i_idx = (pid_blk * i_b_s +
|
|
95
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
96
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
97
|
+
blk_i_msk = (blk_i_idx < i_b * i_b_s)
|
|
98
|
+
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk)
|
|
99
|
+
|
|
100
|
+
blk_i = blk_i // sparsity_block_size
|
|
101
|
+
blk_v = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), 1, dtype=tl.int32)
|
|
102
|
+
|
|
103
|
+
blk_o_idx = ((spa_bat_i * o_b_s) +
|
|
104
|
+
(spa_row_i * o_r_s) +
|
|
105
|
+
(blk_i * o_c_s))
|
|
106
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
107
|
+
tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
|
|
108
|
+
|
|
109
|
+
# if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
|
|
110
|
+
# blk_o_idx = (pid_bat * o_b_s +
|
|
111
|
+
# (((pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_r_s +
|
|
112
|
+
# ((pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_c_s))
|
|
113
|
+
# blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
114
|
+
# tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
1
3
|
import torch
|
|
2
4
|
import triton
|
|
3
5
|
from torch import Tensor
|
|
@@ -5,13 +7,23 @@ from triton import language as tl
|
|
|
5
7
|
|
|
6
8
|
from blksprs.utils.tools import get_triton_block_size
|
|
7
9
|
from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
|
|
8
|
-
|
|
10
|
+
validate_contiguous, validate_sparsity, validate_sparsity_block_size
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def build_sparsity_layout(x: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
14
|
+
"""Builds the sparsity layout of a dense tensor in regular form covering its sparse blocks.
|
|
9
15
|
|
|
16
|
+
Args:
|
|
17
|
+
x (Tensor): A block-sparse (or dense) tensor in regular form.
|
|
18
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
19
|
+
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
10
20
|
|
|
11
|
-
|
|
21
|
+
Returns:
|
|
22
|
+
Tensor: The sparsity layout of the input block-sparse (or dense) tensor.
|
|
23
|
+
|
|
24
|
+
"""
|
|
12
25
|
validate_dimensions(x)
|
|
13
26
|
validate_contiguous(x)
|
|
14
|
-
validate_dtype_float(x)
|
|
15
27
|
validate_device(x)
|
|
16
28
|
|
|
17
29
|
output = torch.zeros(x.size(0), x.size(1) // sparsity_block_size, x.size(2) // sparsity_block_size,
|
|
@@ -33,9 +45,9 @@ def create_sparsity_layout(x: Tensor, sparsity_block_size: int, triton_block_siz
|
|
|
33
45
|
|
|
34
46
|
(kernel_sparsity_layout[triton_grid]
|
|
35
47
|
(x,
|
|
36
|
-
x_b, x_b_s,
|
|
48
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
37
49
|
output,
|
|
38
|
-
o_b, o_b_s,
|
|
50
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
39
51
|
sparsity_block_size,
|
|
40
52
|
triton_block_size))
|
|
41
53
|
|
|
@@ -44,9 +56,9 @@ def create_sparsity_layout(x: Tensor, sparsity_block_size: int, triton_block_siz
|
|
|
44
56
|
|
|
45
57
|
@triton.jit
|
|
46
58
|
def kernel_sparsity_layout(x,
|
|
47
|
-
x_b, x_b_s,
|
|
59
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
48
60
|
o,
|
|
49
|
-
o_b, o_b_s,
|
|
61
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
50
62
|
sparsity_block_size,
|
|
51
63
|
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
52
64
|
# Get triton block indices
|
|
@@ -54,15 +66,125 @@ def kernel_sparsity_layout(x,
|
|
|
54
66
|
pid_row = tl.program_id(axis=1)
|
|
55
67
|
pid_col = tl.program_id(axis=2)
|
|
56
68
|
|
|
69
|
+
# Load x values
|
|
57
70
|
blk_x_idx = (pid_bat * x_b_s +
|
|
58
71
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
59
72
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
60
73
|
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
61
74
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
62
75
|
|
|
76
|
+
# Store sparsity layout value
|
|
63
77
|
if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
|
|
64
78
|
blk_o_idx = (pid_bat * o_b_s +
|
|
65
79
|
(((pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_r_s +
|
|
66
80
|
((pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_c_s))
|
|
67
81
|
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
68
82
|
tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def build_sparsity_layout_adaption(x: Tensor, sparsity_layout_from: Tensor,
|
|
86
|
+
sparsity_block_size_from: int, sparsity_block_size_to: int,
|
|
87
|
+
triton_block_size: int = None) -> Tensor:
|
|
88
|
+
"""Builds the sparsity layout of a block-sparse tensor in compressed form if a different sparsity block size were
|
|
89
|
+
used.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
x (Tensor): A block-sparse tensor in compressed form.
|
|
93
|
+
sparsity_layout_from (Tensor): The sparsity layout of the input block-sparse tensor.
|
|
94
|
+
sparsity_block_size_from (int): The size of the sparsity blocks of the input tensor.
|
|
95
|
+
sparsity_block_size_to (int): The desired size of the sparsity blocks for the resulting layout.
|
|
96
|
+
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
Tensor: The sparsity layout in regular form using the new sparsity block size of the input block-sparse tensor
|
|
100
|
+
in compressed form.
|
|
101
|
+
|
|
102
|
+
"""
|
|
103
|
+
validate_dimensions(x)
|
|
104
|
+
validate_contiguous(x, sparsity_layout_from)
|
|
105
|
+
validate_device(x)
|
|
106
|
+
validate_sparsity(sparsity_block_size_from, (x, sparsity_layout_from))
|
|
107
|
+
validate_sparsity_block_size(sparsity_block_size_from, x)
|
|
108
|
+
validate_sparsity_block_size(sparsity_block_size_to)
|
|
109
|
+
min_sparsity_block_size = min(sparsity_block_size_from, sparsity_block_size_to)
|
|
110
|
+
validate_triton_block_size(triton_block_size, min_sparsity_block_size)
|
|
111
|
+
|
|
112
|
+
sparsity_lut = torch.nonzero(sparsity_layout_from).contiguous()
|
|
113
|
+
|
|
114
|
+
validate_contiguous(sparsity_layout_from, sparsity_lut)
|
|
115
|
+
|
|
116
|
+
o_b = sparsity_layout_from.size(0)
|
|
117
|
+
o_r = math.ceil(sparsity_layout_from.size(1) * sparsity_block_size_from // sparsity_block_size_to)
|
|
118
|
+
o_c = math.ceil(sparsity_layout_from.size(2) * sparsity_block_size_from // sparsity_block_size_to)
|
|
119
|
+
|
|
120
|
+
output = torch.zeros(o_b, o_r, o_c, device=x.device, dtype=torch.int32)
|
|
121
|
+
|
|
122
|
+
x_b, x_r, x_c = x.size()
|
|
123
|
+
x_b_s, x_r_s, x_c_s = x.stride()
|
|
124
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
125
|
+
s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
|
|
126
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
127
|
+
|
|
128
|
+
if triton_block_size is None:
|
|
129
|
+
triton_block_size = get_triton_block_size(sparsity_block_size_from)
|
|
130
|
+
|
|
131
|
+
triton_grid = lambda meta: [x_b,
|
|
132
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
133
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
134
|
+
|
|
135
|
+
(kernel_sparsity_layout_adaption[triton_grid]
|
|
136
|
+
(x,
|
|
137
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
138
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
139
|
+
output,
|
|
140
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
141
|
+
sparsity_block_size_from,
|
|
142
|
+
sparsity_block_size_to,
|
|
143
|
+
triton_block_size))
|
|
144
|
+
|
|
145
|
+
return output
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
@triton.jit
|
|
149
|
+
def kernel_sparsity_layout_adaption(x,
|
|
150
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
151
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
152
|
+
o,
|
|
153
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
154
|
+
sparsity_block_size_from,
|
|
155
|
+
sparsity_block_size_to,
|
|
156
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
157
|
+
# Get triton block indices
|
|
158
|
+
pid_blk = tl.program_id(axis=0)
|
|
159
|
+
pid_row = tl.program_id(axis=1)
|
|
160
|
+
pid_col = tl.program_id(axis=2)
|
|
161
|
+
|
|
162
|
+
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
163
|
+
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
164
|
+
spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
165
|
+
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
166
|
+
|
|
167
|
+
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
168
|
+
spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
|
|
169
|
+
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
170
|
+
|
|
171
|
+
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
172
|
+
spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
|
|
173
|
+
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
174
|
+
|
|
175
|
+
# Load x values
|
|
176
|
+
blk_x_idx = ((pid_blk * x_b_s) +
|
|
177
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
178
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
179
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
180
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
181
|
+
|
|
182
|
+
# Store sparsity layout value
|
|
183
|
+
if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
|
|
184
|
+
blk_o_idx = ((spa_bat * o_b_s) +
|
|
185
|
+
(((spa_row * sparsity_block_size_from + pid_row * TRITON_BLOCK_SIZE)
|
|
186
|
+
// sparsity_block_size_to) * o_r_s) +
|
|
187
|
+
(((spa_col * sparsity_block_size_from + pid_col * TRITON_BLOCK_SIZE)
|
|
188
|
+
// sparsity_block_size_to) * o_c_s))
|
|
189
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
190
|
+
tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
from triton import language as tl
|
|
5
|
+
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size
|
|
7
|
+
from blksprs.utils.validation import validate_contiguous, validate_device, \
|
|
8
|
+
validate_sparsity_block_size, validate_triton_block_size
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def broadcast_addition(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
12
|
+
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
13
|
+
"""Performs a broadcast and subsequent addition of two dense tensors x and y. Returns a block-sparse tensor in
|
|
14
|
+
compressed form.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
x (Tensor): A dense input tensor.
|
|
18
|
+
y (Tensor): A dense input tensor.
|
|
19
|
+
sparsity_layout_output (Tensor): The sparsity layout of the output tensor.
|
|
20
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
21
|
+
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
Tensor: The result of the operation as a block-sparse tensor in compressed form. Each element o(i, j) of the
|
|
25
|
+
output tensor corresponds to x(i) + y(j).
|
|
26
|
+
|
|
27
|
+
"""
|
|
28
|
+
validate_device(x, y)
|
|
29
|
+
validate_contiguous(x, y)
|
|
30
|
+
if x.size(-1) != y.size(-1):
|
|
31
|
+
raise ValueError("Dimensions of tensors must match")
|
|
32
|
+
validate_sparsity_block_size(sparsity_block_size)
|
|
33
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
34
|
+
|
|
35
|
+
sparsity_lut_o = torch.nonzero(sparsity_layout_output).contiguous()
|
|
36
|
+
|
|
37
|
+
n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
|
|
38
|
+
|
|
39
|
+
validate_contiguous(sparsity_layout_output, sparsity_lut_o)
|
|
40
|
+
|
|
41
|
+
output = torch.zeros(n_sparse_blocks, sparsity_block_size, sparsity_block_size, device=x.device)
|
|
42
|
+
|
|
43
|
+
x_b, x_c = x.size()
|
|
44
|
+
x_b_s, x_c_s = x.stride()
|
|
45
|
+
y_b, y_c = y.size()
|
|
46
|
+
y_b_s, y_c_s = y.stride()
|
|
47
|
+
o_b, o_r, o_c = output.size()
|
|
48
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
49
|
+
s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
|
|
50
|
+
s_lut_o_r_s, s_lut_o_c_s = sparsity_lut_o.stride()
|
|
51
|
+
|
|
52
|
+
if triton_block_size is None:
|
|
53
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
54
|
+
|
|
55
|
+
triton_grid = lambda meta: [o_b,
|
|
56
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
57
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
58
|
+
|
|
59
|
+
(kernel_broadcast_addition[triton_grid]
|
|
60
|
+
(x,
|
|
61
|
+
x_b, x_b_s, x_c_s,
|
|
62
|
+
y,
|
|
63
|
+
y_b, y_b_s, y_c_s,
|
|
64
|
+
output,
|
|
65
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
66
|
+
sparsity_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
67
|
+
sparsity_block_size,
|
|
68
|
+
triton_block_size))
|
|
69
|
+
|
|
70
|
+
return output
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def broadcast_subtraction(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
74
|
+
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
75
|
+
"""Wrapper for ``broadcast_addition`` with negated y.
|
|
76
|
+
|
|
77
|
+
"""
|
|
78
|
+
return broadcast_addition(x, torch.neg(y), sparsity_layout_output, sparsity_block_size, triton_block_size)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@triton.jit
|
|
82
|
+
def kernel_broadcast_addition(x,
|
|
83
|
+
x_b, x_b_s, x_c_s,
|
|
84
|
+
y,
|
|
85
|
+
y_b, y_b_s, y_c_s,
|
|
86
|
+
o,
|
|
87
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
88
|
+
s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
89
|
+
sparsity_block_size,
|
|
90
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
91
|
+
# Get triton block indices
|
|
92
|
+
pid_blk = tl.program_id(axis=0)
|
|
93
|
+
pid_row = tl.program_id(axis=1)
|
|
94
|
+
pid_col = tl.program_id(axis=2)
|
|
95
|
+
|
|
96
|
+
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
97
|
+
spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
|
|
98
|
+
spa_bat_o_msk = (spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
99
|
+
spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
|
|
100
|
+
|
|
101
|
+
spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
|
|
102
|
+
spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
103
|
+
spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
|
|
104
|
+
|
|
105
|
+
spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
|
|
106
|
+
spa_col_o_msk = (spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
107
|
+
spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
|
|
108
|
+
|
|
109
|
+
# Load x block
|
|
110
|
+
blk_x_idx = (spa_bat_o * x_b_s +
|
|
111
|
+
((spa_row_o * sparsity_block_size + pid_row * TRITON_BLOCK_SIZE +
|
|
112
|
+
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
113
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
114
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
115
|
+
|
|
116
|
+
# Load y block
|
|
117
|
+
blk_y_idx = (spa_bat_o * y_b_s +
|
|
118
|
+
((spa_col_o * sparsity_block_size + pid_col * TRITON_BLOCK_SIZE +
|
|
119
|
+
tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
|
|
120
|
+
blk_y_msk = (blk_y_idx < y_b * y_b_s)
|
|
121
|
+
blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
|
|
122
|
+
|
|
123
|
+
# Compute sum
|
|
124
|
+
blk_x, blk_y = tl.broadcast(tl.trans(blk_x), blk_y)
|
|
125
|
+
buf = blk_x + blk_y
|
|
126
|
+
|
|
127
|
+
# Store result
|
|
128
|
+
blk_o_idx = ((pid_blk * o_b_s) +
|
|
129
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
130
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
131
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
132
|
+
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|