blksprs 0.2b4__py3-none-any.whl → 1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/layouting/distribution_layout.py +114 -0
- blksprs/layouting/sparsity_layout.py +78 -0
- blksprs/misc/broadcast_addition.py +132 -0
- blksprs/ops/conversion.py +256 -0
- blksprs/ops/distribution.py +362 -0
- blksprs/ops/exp.py +101 -0
- blksprs/ops/matmul.py +221 -0
- blksprs/ops/row_wise_sum.py +231 -0
- blksprs/ops/softmax.py +263 -0
- blksprs/ops/transpose.py +154 -0
- blksprs/utils/tools.py +20 -0
- blksprs/utils/validation.py +97 -0
- blksprs-1.1.dist-info/METADATA +164 -0
- blksprs-1.1.dist-info/RECORD +17 -0
- {blksprs-0.2b4.dist-info → blksprs-1.1.dist-info}/WHEEL +1 -1
- blksprs/ops/blocksparse.py +0 -589
- blksprs-0.2b4.dist-info/METADATA +0 -26
- blksprs-0.2b4.dist-info/RECORD +0 -6
- {blksprs-0.2b4.dist-info → blksprs-1.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
from triton import language as tl
|
|
5
|
+
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size
|
|
7
|
+
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
8
|
+
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def row_wise_sum(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
12
|
+
flag_slice_only: bool = False, triton_block_size: int = None) -> tuple[Tensor, Tensor]:
|
|
13
|
+
"""Computes the row-wise sum of a block-sparse tensor.
|
|
14
|
+
|
|
15
|
+
Returns a block-sparse tensor in compressed form with only one block per row, where the first entry contains the sum
|
|
16
|
+
of the corresponding row.
|
|
17
|
+
|
|
18
|
+
Note:
|
|
19
|
+
If ``flag_slice_only`` is set the output will be of shape ``[x.size(0), x.size(1), 1]``.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
x (Tensor): A block-sparse tensor in compressed form.
|
|
23
|
+
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
24
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
25
|
+
flag_slice_only (bool, optional): If set the output will be of shape ``[x.size(0), x.size(1), 1]``
|
|
26
|
+
(default ``False``).
|
|
27
|
+
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
tuple[Tensor, Tensor]: A tuple containing a block-sparse tensor in compressed form containing the row-wise sum
|
|
31
|
+
of the input and the sparsity layout of the output tensor.
|
|
32
|
+
|
|
33
|
+
"""
|
|
34
|
+
validate_dimensions(x)
|
|
35
|
+
validate_contiguous(x)
|
|
36
|
+
validate_device(x)
|
|
37
|
+
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
38
|
+
validate_sparsity_block_size(sparsity_block_size, x)
|
|
39
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
40
|
+
|
|
41
|
+
sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
|
|
42
|
+
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
|
43
|
+
sparsity_reverse_lut = ((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
44
|
+
(sparsity_layout_flat == 1) -
|
|
45
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
46
|
+
|
|
47
|
+
sparsity_layout_output, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)
|
|
48
|
+
sparsity_lut_output = torch.nonzero(sparsity_layout_output).contiguous()
|
|
49
|
+
sparsity_layout_output_flat = sparsity_layout_output.reshape(-1)
|
|
50
|
+
sparsity_reverse_lut_output = ((torch.cumsum(sparsity_layout_output_flat, dim=-1) - 1) *
|
|
51
|
+
(sparsity_layout_output_flat == 1) -
|
|
52
|
+
(1 * (sparsity_layout_output_flat == 0)))
|
|
53
|
+
|
|
54
|
+
n_sparse_blocks_output = torch.sum(sparsity_layout_output.to(torch.int)).item()
|
|
55
|
+
|
|
56
|
+
validate_contiguous(sparsity_layout, sparsity_lut, sparsity_reverse_lut,
|
|
57
|
+
sparsity_layout_output, sparsity_lut_output, sparsity_reverse_lut_output)
|
|
58
|
+
|
|
59
|
+
return (_BlocksparseRowWiseSum.apply(x,
|
|
60
|
+
sparsity_layout, sparsity_lut, sparsity_reverse_lut,
|
|
61
|
+
sparsity_layout_output, sparsity_lut_output, sparsity_reverse_lut_output,
|
|
62
|
+
n_sparse_blocks_output,
|
|
63
|
+
flag_slice_only,
|
|
64
|
+
sparsity_block_size, triton_block_size),
|
|
65
|
+
sparsity_layout_output)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class _BlocksparseRowWiseSum(torch.autograd.Function):
|
|
69
|
+
IMPLEMENTATION = "atomic_add"
|
|
70
|
+
|
|
71
|
+
@staticmethod
|
|
72
|
+
def forward(ctx, x: Tensor,
|
|
73
|
+
sparsity_layout: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
74
|
+
sparsity_layout_output: Tensor, sparsity_lut_output: Tensor, sparsity_reverse_lut_output: Tensor,
|
|
75
|
+
n_sparse_blocks_output: int,
|
|
76
|
+
flag_slice_only: bool,
|
|
77
|
+
sparsity_block_size: int, triton_block_size: int) -> Tensor:
|
|
78
|
+
output = torch.zeros(size=(n_sparse_blocks_output,
|
|
79
|
+
sparsity_block_size,
|
|
80
|
+
1 if flag_slice_only else sparsity_block_size),
|
|
81
|
+
device=x.device)
|
|
82
|
+
|
|
83
|
+
x_b, x_r, x_c = x.size()
|
|
84
|
+
x_b_s, x_r_s, x_c_s = x.stride()
|
|
85
|
+
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout.size()
|
|
86
|
+
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout.stride()
|
|
87
|
+
s_lut_x_r, s_lut_x_c = sparsity_lut.size()
|
|
88
|
+
s_lut_x_r_s, s_lut_x_c_s = sparsity_lut.stride()
|
|
89
|
+
o_b, o_r, o_c = output.size()
|
|
90
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
91
|
+
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
|
|
92
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_output.stride()
|
|
93
|
+
s_lut_o_r, s_lut_o_c = sparsity_lut_output.size()
|
|
94
|
+
s_lut_o_r_s, s_lut_o_c_s = sparsity_lut_output.stride()
|
|
95
|
+
|
|
96
|
+
if triton_block_size is None:
|
|
97
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
98
|
+
|
|
99
|
+
if _BlocksparseRowWiseSum.IMPLEMENTATION == "basic":
|
|
100
|
+
triton_grid = lambda meta: [o_b,
|
|
101
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"])]
|
|
102
|
+
|
|
103
|
+
(_BlocksparseRowWiseSum.kernel_blocksparse_row_wise_sum[triton_grid]
|
|
104
|
+
(x,
|
|
105
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
106
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c, s_l_x_c_s,
|
|
107
|
+
sparsity_reverse_lut,
|
|
108
|
+
output,
|
|
109
|
+
o_b, o_b_s, o_r_s,
|
|
110
|
+
sparsity_lut_output, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
111
|
+
sparsity_block_size,
|
|
112
|
+
triton_block_size))
|
|
113
|
+
elif _BlocksparseRowWiseSum.IMPLEMENTATION == "atomic_add":
|
|
114
|
+
triton_grid = lambda meta: [x_b,
|
|
115
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
116
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
117
|
+
|
|
118
|
+
(_BlocksparseRowWiseSum.kernel_blocksparse_row_wise_sum_atomic_add[triton_grid]
|
|
119
|
+
(x,
|
|
120
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
121
|
+
sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
122
|
+
output,
|
|
123
|
+
o_b, o_b_s, o_r_s,
|
|
124
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s,
|
|
125
|
+
sparsity_reverse_lut_output,
|
|
126
|
+
triton_block_size))
|
|
127
|
+
|
|
128
|
+
return output
|
|
129
|
+
|
|
130
|
+
@staticmethod
|
|
131
|
+
def backward(ctx, grad_output):
|
|
132
|
+
raise NotImplementedError
|
|
133
|
+
|
|
134
|
+
@staticmethod
|
|
135
|
+
@triton.jit
|
|
136
|
+
def kernel_blocksparse_row_wise_sum(x,
|
|
137
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
138
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c, s_l_x_c_s,
|
|
139
|
+
r_lut_x,
|
|
140
|
+
o,
|
|
141
|
+
o_b, o_b_s, o_r_s,
|
|
142
|
+
s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
143
|
+
sparsity_block_size,
|
|
144
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
145
|
+
pid_blk = tl.program_id(axis=0)
|
|
146
|
+
pid_row = tl.program_id(axis=1)
|
|
147
|
+
|
|
148
|
+
# Get position of current sparsity block consisting of its batch and row index
|
|
149
|
+
spa_bat_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
|
|
150
|
+
spa_bat_msk = (spa_bat_idx < s_lut_o_r * s_lut_o_r_s)
|
|
151
|
+
spa_bat = tl.load(s_lut_o + spa_bat_idx, mask=spa_bat_msk)
|
|
152
|
+
|
|
153
|
+
spa_row_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
|
|
154
|
+
spa_row_msk = (spa_row_idx < s_lut_o_r * s_lut_o_r_s)
|
|
155
|
+
spa_row = tl.load(s_lut_o + spa_row_idx, mask=spa_row_msk)
|
|
156
|
+
|
|
157
|
+
buf = tl.zeros(shape=(TRITON_BLOCK_SIZE, 1), dtype=tl.float32)
|
|
158
|
+
|
|
159
|
+
# Slide over triton block sized segments of input tensor
|
|
160
|
+
for i_seg_tri in range(0, tl.cdiv(s_l_x_c * sparsity_block_size, TRITON_BLOCK_SIZE)):
|
|
161
|
+
# Convert to segment index of sparsity layout
|
|
162
|
+
i_seg_spa = (i_seg_tri * TRITON_BLOCK_SIZE) // sparsity_block_size
|
|
163
|
+
# Calculate the triton segment index within a block
|
|
164
|
+
i_seg_tri_mod = i_seg_tri % (sparsity_block_size // TRITON_BLOCK_SIZE)
|
|
165
|
+
|
|
166
|
+
# Load reverse sparsity index for current block
|
|
167
|
+
rev_idx_spa_idx = (spa_bat * s_l_x_b_s +
|
|
168
|
+
spa_row * s_l_x_r_s +
|
|
169
|
+
i_seg_spa * s_l_x_c_s)
|
|
170
|
+
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_x_b * s_l_x_b_s)
|
|
171
|
+
rev_idx_spa = tl.load(r_lut_x + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
172
|
+
|
|
173
|
+
# If block is present commence operations
|
|
174
|
+
if rev_idx_spa >= 0:
|
|
175
|
+
blk_idx = ((rev_idx_spa * x_b_s) +
|
|
176
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
177
|
+
((i_seg_tri_mod * TRITON_BLOCK_SIZE +
|
|
178
|
+
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
179
|
+
blk_msk = (blk_idx < x_b * x_b_s)
|
|
180
|
+
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
181
|
+
|
|
182
|
+
buf = buf + tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
|
|
183
|
+
|
|
184
|
+
o_idx = (pid_blk * o_b_s +
|
|
185
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
186
|
+
(tl.arange(0, 1))[None, :])
|
|
187
|
+
o_msk = (o_idx < o_b * o_b_s)
|
|
188
|
+
tl.store(o + o_idx, buf, o_msk)
|
|
189
|
+
|
|
190
|
+
@staticmethod
|
|
191
|
+
@triton.jit
|
|
192
|
+
def kernel_blocksparse_row_wise_sum_atomic_add(x,
|
|
193
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
194
|
+
s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
195
|
+
o,
|
|
196
|
+
o_b, o_b_s, o_r_s,
|
|
197
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s,
|
|
198
|
+
r_lut_o,
|
|
199
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
200
|
+
pid_blk = tl.program_id(axis=0)
|
|
201
|
+
pid_row = tl.program_id(axis=1)
|
|
202
|
+
pid_col = tl.program_id(axis=2)
|
|
203
|
+
|
|
204
|
+
# Get position of current sparsity block consisting of its batch and row index
|
|
205
|
+
spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
|
|
206
|
+
spa_bat_msk = (spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
|
|
207
|
+
spa_bat = tl.load(s_lut_x + spa_bat_idx, mask=spa_bat_msk)
|
|
208
|
+
|
|
209
|
+
spa_row_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
|
|
210
|
+
spa_row_msk = (spa_row_idx < s_lut_x_r * s_lut_x_r_s)
|
|
211
|
+
spa_row = tl.load(s_lut_x + spa_row_idx, mask=spa_row_msk)
|
|
212
|
+
|
|
213
|
+
# Load reverse sparsity index for current block
|
|
214
|
+
rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
|
|
215
|
+
spa_row * s_l_o_r_s)
|
|
216
|
+
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
|
|
217
|
+
rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
218
|
+
|
|
219
|
+
blk_idx = ((pid_blk * x_b_s) +
|
|
220
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
221
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
222
|
+
blk_msk = (blk_idx < x_b * x_b_s)
|
|
223
|
+
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
224
|
+
|
|
225
|
+
buf = tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
|
|
226
|
+
|
|
227
|
+
o_idx = (rev_idx_spa * o_b_s +
|
|
228
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
229
|
+
(tl.arange(0, 1))[None, :])
|
|
230
|
+
o_msk = (o_idx < o_b * o_b_s)
|
|
231
|
+
tl.atomic_add(o + o_idx, buf, o_msk)
|
blksprs/ops/softmax.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
from triton import language as tl
|
|
5
|
+
|
|
6
|
+
from blksprs.ops.exp import exp
|
|
7
|
+
from blksprs.ops.row_wise_sum import row_wise_sum
|
|
8
|
+
from blksprs.utils.tools import get_triton_block_size
|
|
9
|
+
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
10
|
+
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def softmax(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
14
|
+
"""Computes the softmax of a block-sparse tensor in compressed form.
|
|
15
|
+
|
|
16
|
+
Note:
|
|
17
|
+
Sparse blocks are not considered for the calculation of the softmax, i.e., all values are assumed to be ``-inf``.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
x (Tensor): A block-sparse tensor in compressed form.
|
|
21
|
+
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
22
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
23
|
+
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
Tensor: The result of the softmax operation as a block-sparse tensor in compressed form.
|
|
27
|
+
|
|
28
|
+
"""
|
|
29
|
+
validate_dimensions(x)
|
|
30
|
+
validate_contiguous(x)
|
|
31
|
+
validate_device(x)
|
|
32
|
+
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
33
|
+
validate_sparsity_block_size(sparsity_block_size, x)
|
|
34
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
35
|
+
|
|
36
|
+
if x.size(0) != 0:
|
|
37
|
+
max_val = torch.max(x).item()
|
|
38
|
+
else:
|
|
39
|
+
max_val = 0
|
|
40
|
+
x_scaled = x - max_val
|
|
41
|
+
|
|
42
|
+
sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
|
|
43
|
+
|
|
44
|
+
sparsity_layout_rws, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)
|
|
45
|
+
sparsity_layout_rws_flat = sparsity_layout_rws.reshape(-1)
|
|
46
|
+
sparsity_reverse_lut_rws = ((torch.cumsum(sparsity_layout_rws_flat, dim=-1) - 1) *
|
|
47
|
+
(sparsity_layout_rws_flat == 1) -
|
|
48
|
+
(1 * (sparsity_layout_rws_flat == 0)))
|
|
49
|
+
|
|
50
|
+
validate_contiguous(sparsity_layout, sparsity_lut, sparsity_reverse_lut_rws)
|
|
51
|
+
|
|
52
|
+
return _BlocksparseSoftmax.apply(x_scaled, sparsity_layout,
|
|
53
|
+
sparsity_lut,
|
|
54
|
+
sparsity_reverse_lut_rws,
|
|
55
|
+
sparsity_block_size, triton_block_size)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class _BlocksparseSoftmax(torch.autograd.Function):
|
|
59
|
+
|
|
60
|
+
@staticmethod
|
|
61
|
+
def forward(ctx, x: Tensor, sparsity_layout: Tensor,
|
|
62
|
+
sparsity_lut: Tensor,
|
|
63
|
+
sparsity_reverse_lut_rws: Tensor,
|
|
64
|
+
sparsity_block_size: int, triton_block_size: int) -> Tensor:
|
|
65
|
+
output = torch.empty_like(x)
|
|
66
|
+
|
|
67
|
+
x_b, x_r, x_c = x.shape
|
|
68
|
+
x_b_s, x_r_s, x_c_s = x.stride()
|
|
69
|
+
s_lut_r, s_lut_c = sparsity_lut.shape
|
|
70
|
+
s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
|
|
71
|
+
o_b, o_r, o_c = output.shape
|
|
72
|
+
|
|
73
|
+
x_exp = exp(x, sparsity_block_size, triton_block_size=triton_block_size)
|
|
74
|
+
x_exp_row_wise_sum, sparsity_layout_rws = row_wise_sum(x_exp, sparsity_layout, sparsity_block_size,
|
|
75
|
+
flag_slice_only=True,
|
|
76
|
+
triton_block_size=triton_block_size)
|
|
77
|
+
|
|
78
|
+
s_b, s_r, s_c = x_exp_row_wise_sum.shape
|
|
79
|
+
s_b_s, s_r_s, s_c_s = x_exp_row_wise_sum.stride()
|
|
80
|
+
s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_rws.shape
|
|
81
|
+
s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = sparsity_layout_rws.stride()
|
|
82
|
+
|
|
83
|
+
if triton_block_size is None:
|
|
84
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
85
|
+
|
|
86
|
+
triton_grid = lambda meta: [o_b,
|
|
87
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
88
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
89
|
+
|
|
90
|
+
(_BlocksparseSoftmax.kernel_blocksparse_softmax[triton_grid]
|
|
91
|
+
(x_exp,
|
|
92
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
93
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
94
|
+
x_exp_row_wise_sum, s_b, s_b_s, s_r_s, s_c_s,
|
|
95
|
+
s_l_s_b, s_l_s_b_s, s_l_s_r_s,
|
|
96
|
+
sparsity_reverse_lut_rws,
|
|
97
|
+
output,
|
|
98
|
+
triton_block_size))
|
|
99
|
+
|
|
100
|
+
# Save for backward pass
|
|
101
|
+
ctx.save_for_backward(output, sparsity_layout, sparsity_lut)
|
|
102
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
103
|
+
ctx.triton_block_size = triton_block_size
|
|
104
|
+
|
|
105
|
+
return output
|
|
106
|
+
|
|
107
|
+
@staticmethod
|
|
108
|
+
def backward(ctx, grad_output):
|
|
109
|
+
o, sparsity_layout, sparsity_lut = ctx.saved_tensors
|
|
110
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
111
|
+
triton_block_size = ctx.triton_block_size
|
|
112
|
+
|
|
113
|
+
s, sparsity_layout_s = row_wise_sum(grad_output * o, sparsity_layout, sparsity_block_size, flag_slice_only=True,
|
|
114
|
+
triton_block_size=triton_block_size)
|
|
115
|
+
|
|
116
|
+
sparsity_layout_s_flat = sparsity_layout_s.reshape(-1)
|
|
117
|
+
sparsity_reverse_lut_s = ((torch.cumsum(sparsity_layout_s_flat, dim=-1) - 1) *
|
|
118
|
+
(sparsity_layout_s_flat == 1) -
|
|
119
|
+
(1 * (sparsity_layout_s_flat == 0)))
|
|
120
|
+
|
|
121
|
+
o_b, o_r, o_c = o.size()
|
|
122
|
+
o_b_s, o_r_s, o_c_s = o.stride()
|
|
123
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
124
|
+
s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
|
|
125
|
+
s_b, s_r, s_c = s.size()
|
|
126
|
+
s_b_s, s_r_s, s_c_s = s.stride()
|
|
127
|
+
s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
|
|
128
|
+
s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = sparsity_layout_s.stride()
|
|
129
|
+
|
|
130
|
+
grad_x = torch.empty_like(o)
|
|
131
|
+
|
|
132
|
+
triton_grid = lambda meta: [o_b,
|
|
133
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
134
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
135
|
+
|
|
136
|
+
(_BlocksparseSoftmax.kernel_blocksparse_softmax_grad_x[triton_grid]
|
|
137
|
+
(grad_output,
|
|
138
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
139
|
+
o,
|
|
140
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
141
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
142
|
+
s,
|
|
143
|
+
s_b, s_b_s, s_r_s, s_c_s,
|
|
144
|
+
s_l_s_b, s_l_s_b_s, s_l_s_r_s,
|
|
145
|
+
sparsity_reverse_lut_s,
|
|
146
|
+
grad_x,
|
|
147
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
148
|
+
triton_block_size
|
|
149
|
+
))
|
|
150
|
+
|
|
151
|
+
return grad_x, None, None, None, None, None
|
|
152
|
+
|
|
153
|
+
@staticmethod
|
|
154
|
+
@triton.jit
|
|
155
|
+
def kernel_blocksparse_softmax(x,
|
|
156
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
157
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
158
|
+
s, s_b, s_b_s, s_r_s, s_c_s,
|
|
159
|
+
s_l_s_b, s_l_s_b_s, s_l_s_r_s,
|
|
160
|
+
r_lut_s,
|
|
161
|
+
o,
|
|
162
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
163
|
+
# Get triton block indices
|
|
164
|
+
pid_blk = tl.program_id(axis=0)
|
|
165
|
+
pid_row = tl.program_id(axis=1)
|
|
166
|
+
pid_col = tl.program_id(axis=2)
|
|
167
|
+
|
|
168
|
+
# Get position of current sparsity block consisting of its batch and row index
|
|
169
|
+
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
170
|
+
spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
171
|
+
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
172
|
+
|
|
173
|
+
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
174
|
+
spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
|
|
175
|
+
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
176
|
+
|
|
177
|
+
# Get reverse sparsity indices for x
|
|
178
|
+
rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
|
|
179
|
+
spa_row * s_l_s_r_s)
|
|
180
|
+
rev_idx_spa_s_msk = (rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
|
|
181
|
+
rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
|
|
182
|
+
|
|
183
|
+
if rev_idx_spa_s == -1:
|
|
184
|
+
assert False, "Invalid sparsity block"
|
|
185
|
+
|
|
186
|
+
# Load x block
|
|
187
|
+
blk_x_idx = ((pid_blk * x_b_s) +
|
|
188
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
189
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
190
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
191
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
192
|
+
|
|
193
|
+
# Load sum block
|
|
194
|
+
blk_s_idx = (rev_idx_spa_s * s_b_s +
|
|
195
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
|
|
196
|
+
(tl.arange(0, 1) * s_c_s)[None, :])
|
|
197
|
+
blk_s_msk = (blk_s_idx < s_b * s_b_s)
|
|
198
|
+
blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
|
|
199
|
+
|
|
200
|
+
# Compute softmax
|
|
201
|
+
buf = tl.div_rn(blk_x, blk_s)
|
|
202
|
+
|
|
203
|
+
# Store output
|
|
204
|
+
tl.store(o + blk_x_idx, buf, mask=blk_x_msk)
|
|
205
|
+
|
|
206
|
+
@staticmethod
|
|
207
|
+
@triton.jit
|
|
208
|
+
def kernel_blocksparse_softmax_grad_x(g,
|
|
209
|
+
g_b, g_b_s, g_r_s, g_c_s,
|
|
210
|
+
x,
|
|
211
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
212
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
213
|
+
s,
|
|
214
|
+
s_b, s_b_s, s_r_s, s_c_s,
|
|
215
|
+
s_l_s_b, s_l_s_b_s, s_l_s_r_s,
|
|
216
|
+
r_lut_s,
|
|
217
|
+
o,
|
|
218
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
219
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
220
|
+
# Get triton block indices
|
|
221
|
+
pid_blk = tl.program_id(axis=0)
|
|
222
|
+
pid_row = tl.program_id(axis=1)
|
|
223
|
+
pid_col = tl.program_id(axis=2)
|
|
224
|
+
|
|
225
|
+
# Get position of current sparsity block consisting of its batch and row index
|
|
226
|
+
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
227
|
+
spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
228
|
+
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
229
|
+
|
|
230
|
+
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
231
|
+
spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
|
|
232
|
+
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
233
|
+
|
|
234
|
+
rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
|
|
235
|
+
spa_row * s_l_s_r_s)
|
|
236
|
+
rev_idx_spa_s_msk = (rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
|
|
237
|
+
rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
|
|
238
|
+
|
|
239
|
+
blk_s_idx = (rev_idx_spa_s * s_b_s +
|
|
240
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
|
|
241
|
+
(tl.arange(0, 1) * s_c_s)[None, :])
|
|
242
|
+
blk_s_msk = (blk_s_idx < s_b * s_b_s)
|
|
243
|
+
blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
|
|
244
|
+
|
|
245
|
+
blk_g_idx = ((pid_blk * g_b_s) +
|
|
246
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_r_s)[:, None] +
|
|
247
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_c_s)[None, :])
|
|
248
|
+
blk_g_msk = (blk_g_idx < g_b * g_b_s)
|
|
249
|
+
blk_g = tl.load(g + blk_g_idx, mask=blk_g_msk)
|
|
250
|
+
|
|
251
|
+
blk_x_idx = ((pid_blk * x_b_s) +
|
|
252
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
253
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
254
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
255
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
256
|
+
|
|
257
|
+
buf = blk_x * (blk_g - blk_s)
|
|
258
|
+
|
|
259
|
+
blk_o_idx = ((pid_blk * o_b_s) +
|
|
260
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
261
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
262
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
263
|
+
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
blksprs/ops/transpose.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
from triton import language as tl
|
|
5
|
+
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size
|
|
7
|
+
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
|
|
8
|
+
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def transpose(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> (
|
|
12
|
+
Tensor, Tensor):
|
|
13
|
+
"""Transposes a block-sparse tensor in compressed form.
|
|
14
|
+
|
|
15
|
+
Note:
|
|
16
|
+
Returns the transposed tensor and the sparsity layout of the transposed tensor.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
x (Tensor): A block-sparse tensor in compressed form.
|
|
20
|
+
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
21
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
22
|
+
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
Tensor: The transposed block-sparse tensor in compressed form.
|
|
26
|
+
Tensor: The sparsity layout of the transposed tensor.
|
|
27
|
+
|
|
28
|
+
"""
|
|
29
|
+
validate_dimensions(x)
|
|
30
|
+
validate_contiguous(x)
|
|
31
|
+
validate_device(x)
|
|
32
|
+
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
33
|
+
validate_sparsity_block_size(sparsity_block_size, x)
|
|
34
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
35
|
+
|
|
36
|
+
sparsity_layout_t = sparsity_layout.transpose(-1, -2).contiguous()
|
|
37
|
+
|
|
38
|
+
sparsity_lut = torch.nonzero(sparsity_layout_t).contiguous()
|
|
39
|
+
|
|
40
|
+
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
|
41
|
+
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
42
|
+
(sparsity_layout_flat == 1) -
|
|
43
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
44
|
+
.reshape(sparsity_layout.size()).transpose(-1, -2).contiguous().reshape(-1))
|
|
45
|
+
|
|
46
|
+
n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
|
|
47
|
+
|
|
48
|
+
validate_contiguous(sparsity_layout_t, sparsity_lut, sparsity_reverse_lut)
|
|
49
|
+
|
|
50
|
+
return _BlocksparseTranspose.apply(x, sparsity_layout_t, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
51
|
+
n_sparse_blocks, triton_block_size), sparsity_layout_t
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class _BlocksparseTranspose(torch.autograd.Function):
|
|
55
|
+
|
|
56
|
+
@staticmethod
|
|
57
|
+
def forward(ctx, x: Tensor,
|
|
58
|
+
sparsity_layout: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor, sparsity_block_size: int,
|
|
59
|
+
n_sparse_blocks: int, triton_block_size: int) -> (Tensor, Tensor):
|
|
60
|
+
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size), device=x.device)
|
|
61
|
+
|
|
62
|
+
x_b, x_r, x_c = x.size()
|
|
63
|
+
x_b_s, x_r_s, x_c_s = x.stride()
|
|
64
|
+
s_l_b, s_l_r, s_l_c = sparsity_layout.size()
|
|
65
|
+
s_l_b_s, s_l_r_s, s_l_c_s = sparsity_layout.stride()
|
|
66
|
+
s_lut_r, s_lut_c = sparsity_lut.shape
|
|
67
|
+
s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
|
|
68
|
+
o_b, o_r, o_c = output.size()
|
|
69
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
70
|
+
|
|
71
|
+
if triton_block_size is None:
|
|
72
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
73
|
+
|
|
74
|
+
triton_grid = lambda meta: [o_b,
|
|
75
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
76
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
77
|
+
|
|
78
|
+
(_BlocksparseTranspose.kernel_blocksparse_transpose[triton_grid]
|
|
79
|
+
(x,
|
|
80
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
81
|
+
s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
|
|
82
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
83
|
+
sparsity_reverse_lut,
|
|
84
|
+
output,
|
|
85
|
+
o_b, o_b_s,
|
|
86
|
+
triton_block_size))
|
|
87
|
+
|
|
88
|
+
# Save for backward pass
|
|
89
|
+
ctx.save_for_backward(sparsity_layout)
|
|
90
|
+
ctx.sparsity_layout = sparsity_layout
|
|
91
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
92
|
+
ctx.triton_block_size = triton_block_size
|
|
93
|
+
|
|
94
|
+
return output
|
|
95
|
+
|
|
96
|
+
@staticmethod
|
|
97
|
+
def backward(ctx, grad_output):
|
|
98
|
+
sparsity_layout = ctx.saved_tensors[0]
|
|
99
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
100
|
+
triton_block_size = ctx.triton_block_size
|
|
101
|
+
|
|
102
|
+
return transpose(grad_output, sparsity_layout, sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None
|
|
103
|
+
|
|
104
|
+
@staticmethod
|
|
105
|
+
@triton.jit
|
|
106
|
+
def kernel_blocksparse_transpose(x,
|
|
107
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
108
|
+
s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
|
|
109
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
110
|
+
r_lut,
|
|
111
|
+
o,
|
|
112
|
+
o_b, o_b_s,
|
|
113
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
114
|
+
# Get triton block indices
|
|
115
|
+
pid_blk = tl.program_id(axis=0)
|
|
116
|
+
pid_row = tl.program_id(axis=1)
|
|
117
|
+
pid_col = tl.program_id(axis=2)
|
|
118
|
+
|
|
119
|
+
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
120
|
+
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
121
|
+
spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
122
|
+
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
123
|
+
|
|
124
|
+
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
125
|
+
spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
|
|
126
|
+
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
127
|
+
|
|
128
|
+
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
129
|
+
spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
|
|
130
|
+
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
131
|
+
|
|
132
|
+
# Get reverse sparsity indices
|
|
133
|
+
rev_idx_spa_idx = (spa_bat * s_l_b_s +
|
|
134
|
+
spa_row * s_l_r_s +
|
|
135
|
+
spa_col * s_l_c_s)
|
|
136
|
+
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_b * s_l_b_s)
|
|
137
|
+
rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
138
|
+
|
|
139
|
+
if rev_idx_spa == -1:
|
|
140
|
+
assert False, "Invalid sparsity block"
|
|
141
|
+
|
|
142
|
+
blk_x_idx = (rev_idx_spa * x_b_s +
|
|
143
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
144
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
145
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
146
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
147
|
+
|
|
148
|
+
blk_x_t = tl.trans(blk_x)
|
|
149
|
+
|
|
150
|
+
blk_o_idx = (pid_blk * o_b_s +
|
|
151
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
152
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
153
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
154
|
+
tl.store(o + blk_o_idx, blk_x_t, mask=blk_o_msk)
|
blksprs/utils/tools.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import Tensor, Size
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def do_shape_blocksparse(x: Tensor):
|
|
6
|
+
if x.dim() == 3:
|
|
7
|
+
return x, x.size()
|
|
8
|
+
|
|
9
|
+
return x.reshape(-1, x.size(-2), x.size(-1)), x.size()
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def undo_shape_blocksparse(x: Tensor, shape: Size):
|
|
13
|
+
if x.shape[-2:] == shape[-2:]:
|
|
14
|
+
return x
|
|
15
|
+
|
|
16
|
+
return x.reshape((*shape[:-2], *x.shape[-2:]))
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_triton_block_size(sparsity_block_size: int, limit: int = 128):
|
|
20
|
+
return min(sparsity_block_size, limit)
|