blksprs 0.2b4__py3-none-any.whl → 1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/layouting/distribution_layout.py +114 -0
- blksprs/layouting/sparsity_layout.py +78 -0
- blksprs/misc/broadcast_addition.py +132 -0
- blksprs/ops/conversion.py +256 -0
- blksprs/ops/distribution.py +362 -0
- blksprs/ops/exp.py +101 -0
- blksprs/ops/matmul.py +221 -0
- blksprs/ops/row_wise_sum.py +231 -0
- blksprs/ops/softmax.py +263 -0
- blksprs/ops/transpose.py +154 -0
- blksprs/utils/tools.py +20 -0
- blksprs/utils/validation.py +97 -0
- blksprs-1.1.dist-info/METADATA +164 -0
- blksprs-1.1.dist-info/RECORD +17 -0
- {blksprs-0.2b4.dist-info → blksprs-1.1.dist-info}/WHEEL +1 -1
- blksprs/ops/blocksparse.py +0 -589
- blksprs-0.2b4.dist-info/METADATA +0 -26
- blksprs-0.2b4.dist-info/RECORD +0 -6
- {blksprs-0.2b4.dist-info → blksprs-1.1.dist-info}/top_level.txt +0 -0
blksprs/ops/blocksparse.py
DELETED
|
@@ -1,589 +0,0 @@
|
|
|
1
|
-
from abc import ABC
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
import triton
|
|
5
|
-
import triton.language as tl
|
|
6
|
-
from torch import Tensor, Size
|
|
7
|
-
from torch.nn import Module
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class BaseBlocksparse(Module, ABC):
|
|
11
|
-
_validate = None
|
|
12
|
-
|
|
13
|
-
def __init__(self, sparsity_block_size: int, device: torch.device, triton_block_size: int = None) -> None:
|
|
14
|
-
super().__init__()
|
|
15
|
-
|
|
16
|
-
self.sparsity_block_size = sparsity_block_size
|
|
17
|
-
self.device = device
|
|
18
|
-
|
|
19
|
-
self.triton_block_size = triton_block_size
|
|
20
|
-
|
|
21
|
-
if BaseBlocksparse._validate is None:
|
|
22
|
-
BaseBlocksparse._validate = True
|
|
23
|
-
# print(
|
|
24
|
-
# f"{'\033[93m'}Blocksparse validation is activated. Consider deactivating for production use.{'\033[0m'}")
|
|
25
|
-
|
|
26
|
-
def validate_tensors(self, *tensors: Tensor, flag_dim: bool = True, flag_contiguous: bool = True,
|
|
27
|
-
flag_dtype: bool = True,
|
|
28
|
-
flag_device: bool = True) -> None:
|
|
29
|
-
if not BaseBlocksparse._validate:
|
|
30
|
-
return
|
|
31
|
-
|
|
32
|
-
for tensor in tensors:
|
|
33
|
-
if flag_dim:
|
|
34
|
-
assert tensor.dim() == 3, "Input tensors must have 3 dimensions"
|
|
35
|
-
if flag_contiguous:
|
|
36
|
-
assert tensor.is_contiguous(), "Input tensors must be contiguous"
|
|
37
|
-
if flag_dtype:
|
|
38
|
-
assert tensor.dtype == torch.float32, "Input tensors must be of type float32"
|
|
39
|
-
if flag_device:
|
|
40
|
-
assert tensor.device == self.device, "Input tensors must be on the same device"
|
|
41
|
-
|
|
42
|
-
def validate_sparsity(self, *tensor_sparsity_layout_tuples: tuple[Tensor, Tensor]) -> None:
|
|
43
|
-
if not BaseBlocksparse._validate:
|
|
44
|
-
return
|
|
45
|
-
|
|
46
|
-
for tensor_sparsity_layout_tuple in tensor_sparsity_layout_tuples:
|
|
47
|
-
tensor, sparsity_layout = tensor_sparsity_layout_tuple
|
|
48
|
-
|
|
49
|
-
assert tensor.size(-1) == tensor.size(-2) == self.sparsity_block_size, \
|
|
50
|
-
"Tensor not conforming to sparsity specification"
|
|
51
|
-
assert tensor.size(0) == torch.sum(sparsity_layout.reshape(-1))
|
|
52
|
-
|
|
53
|
-
@staticmethod
|
|
54
|
-
def get_triton_block_size(sparsity_block_size: int, limit: int = 128):
|
|
55
|
-
return min(sparsity_block_size, limit)
|
|
56
|
-
|
|
57
|
-
@staticmethod
|
|
58
|
-
def disable_validation():
|
|
59
|
-
BaseBlocksparse._validate = False
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
# --- Matmul SSS ---
|
|
63
|
-
|
|
64
|
-
class BlocksparseMatmulSSS(BaseBlocksparse):
|
|
65
|
-
|
|
66
|
-
def __init__(self, sparsity_block_size: int, device: torch.device, triton_block_size: int = None) -> None:
|
|
67
|
-
super().__init__(sparsity_block_size, device, triton_block_size=triton_block_size)
|
|
68
|
-
|
|
69
|
-
def forward(self, x: Tensor, y: Tensor,
|
|
70
|
-
sparsity_layout_x: Tensor, sparsity_layout_y: Tensor, sparsity_layout_output: Tensor) -> Tensor:
|
|
71
|
-
self.validate_tensors(x, y)
|
|
72
|
-
self.validate_sparsity((x, sparsity_layout_x), (y, sparsity_layout_y))
|
|
73
|
-
assert x.size(2) == y.size(1), "Inner dimensions must match"
|
|
74
|
-
|
|
75
|
-
o_n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
|
|
76
|
-
|
|
77
|
-
sparsity_layout_x_flat = sparsity_layout_x.reshape(-1)
|
|
78
|
-
sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
|
|
79
|
-
(sparsity_layout_x_flat == 1) -
|
|
80
|
-
(1 * (sparsity_layout_x_flat == 0)))
|
|
81
|
-
|
|
82
|
-
sparsity_layout_y_flat = sparsity_layout_y.reshape(-1)
|
|
83
|
-
sparsity_reverse_lut_y = ((torch.cumsum(sparsity_layout_y_flat, dim=-1) - 1) *
|
|
84
|
-
(sparsity_layout_y_flat == 1) -
|
|
85
|
-
(1 * (sparsity_layout_y_flat == 0)))
|
|
86
|
-
|
|
87
|
-
sparsity_lut_o = torch.nonzero(sparsity_layout_output)
|
|
88
|
-
|
|
89
|
-
return _BlocksparseMatmulSSS.apply(x, y,
|
|
90
|
-
sparsity_layout_x, sparsity_reverse_lut_x,
|
|
91
|
-
sparsity_layout_y, sparsity_reverse_lut_y,
|
|
92
|
-
sparsity_layout_output, sparsity_lut_o,
|
|
93
|
-
self.sparsity_block_size,
|
|
94
|
-
o_n_sparse_blocks,
|
|
95
|
-
self.triton_block_size,
|
|
96
|
-
self.device)
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
class _BlocksparseMatmulSSS(torch.autograd.Function):
|
|
100
|
-
|
|
101
|
-
@staticmethod
|
|
102
|
-
def forward(ctx, x: Tensor, y: Tensor,
|
|
103
|
-
sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
|
|
104
|
-
sparsity_layout_y: Tensor, sparsity_reverse_lut_y: Tensor,
|
|
105
|
-
sparsity_layout_o: Tensor, sparsity_lut_o: Tensor,
|
|
106
|
-
sparsity_block_size: int, o_n_sparse_blocks: int, triton_block_size: int,
|
|
107
|
-
device: torch.device) -> Tensor:
|
|
108
|
-
output = torch.zeros(size=(o_n_sparse_blocks, sparsity_block_size, sparsity_block_size), device=device)
|
|
109
|
-
|
|
110
|
-
x_b, x_r, x_c = x.size()
|
|
111
|
-
x_b_s, x_r_s, x_c_s = x.stride()
|
|
112
|
-
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
|
|
113
|
-
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout_x.stride()
|
|
114
|
-
y_b, y_r, y_c = y.size()
|
|
115
|
-
y_b_s, y_r_s, y_c_s = y.stride()
|
|
116
|
-
s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_y.size()
|
|
117
|
-
s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = sparsity_layout_y.stride()
|
|
118
|
-
o_b, o_r, o_c = output.size()
|
|
119
|
-
o_b_s, o_r_s, o_c_s = output.stride()
|
|
120
|
-
s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
|
|
121
|
-
s_lut_o_r_s, s_lut_o_c_s = sparsity_lut_o.stride()
|
|
122
|
-
|
|
123
|
-
if triton_block_size is None:
|
|
124
|
-
triton_block_size = BaseBlocksparse.get_triton_block_size(sparsity_block_size)
|
|
125
|
-
|
|
126
|
-
triton_grid = lambda meta: [o_b,
|
|
127
|
-
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
128
|
-
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
129
|
-
|
|
130
|
-
(_BlocksparseMatmulSSS.kernel_blocksparse_matmul_sss[triton_grid]
|
|
131
|
-
(x,
|
|
132
|
-
x_b, x_b_s, x_r, x_r_s, x_c, x_c_s,
|
|
133
|
-
s_l_x_b, s_l_x_b_s, s_l_x_r, s_l_x_r_s, s_l_x_c, s_l_x_c_s,
|
|
134
|
-
sparsity_reverse_lut_x,
|
|
135
|
-
y,
|
|
136
|
-
y_b, y_b_s, y_r, y_r_s, y_c, y_c_s,
|
|
137
|
-
s_l_y_b, s_l_y_b_s, s_l_y_r, s_l_y_r_s, s_l_y_c, s_l_y_c_s,
|
|
138
|
-
sparsity_reverse_lut_y,
|
|
139
|
-
output,
|
|
140
|
-
o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,
|
|
141
|
-
sparsity_lut_o,
|
|
142
|
-
s_lut_o_r, s_lut_o_r_s,
|
|
143
|
-
s_lut_o_c, s_lut_o_c_s,
|
|
144
|
-
sparsity_block_size,
|
|
145
|
-
triton_block_size))
|
|
146
|
-
|
|
147
|
-
ctx.save_for_backward(x, y)
|
|
148
|
-
ctx.sparsity_layout_x = sparsity_layout_x
|
|
149
|
-
ctx.sparsity_layout_y = sparsity_layout_y
|
|
150
|
-
ctx.sparsity_layout_o = sparsity_layout_o
|
|
151
|
-
ctx.sparsity_block_size = sparsity_block_size
|
|
152
|
-
ctx.triton_block_size = triton_block_size
|
|
153
|
-
ctx.device = device
|
|
154
|
-
|
|
155
|
-
return output
|
|
156
|
-
|
|
157
|
-
@staticmethod
|
|
158
|
-
def backward(ctx, grad_output):
|
|
159
|
-
x, y = ctx.saved_tensors
|
|
160
|
-
sparsity_layout_x = ctx.sparsity_layout_x
|
|
161
|
-
sparsity_layout_y = ctx.sparsity_layout_y
|
|
162
|
-
sparsity_layout_o = ctx.sparsity_layout_o
|
|
163
|
-
sparsity_block_size = ctx.sparsity_block_size
|
|
164
|
-
triton_block_size = ctx.triton_block_size
|
|
165
|
-
device = ctx.device
|
|
166
|
-
|
|
167
|
-
blksprs_transpose = BlocksparseTranspose(sparsity_block_size, device, triton_block_size)
|
|
168
|
-
|
|
169
|
-
x_t, sparsity_layout_x_t = blksprs_transpose(x, sparsity_layout_x)
|
|
170
|
-
y_t, sparsity_layout_y_t = blksprs_transpose(y, sparsity_layout_y)
|
|
171
|
-
|
|
172
|
-
grad_x = BlocksparseMatmulSSS(sparsity_block_size, device, triton_block_size)(grad_output, y_t,
|
|
173
|
-
sparsity_layout_o,
|
|
174
|
-
sparsity_layout_y_t,
|
|
175
|
-
sparsity_layout_x)
|
|
176
|
-
grad_y = BlocksparseMatmulSSS(sparsity_block_size, device, triton_block_size)(x_t, grad_output,
|
|
177
|
-
sparsity_layout_x_t,
|
|
178
|
-
sparsity_layout_o,
|
|
179
|
-
sparsity_layout_y)
|
|
180
|
-
|
|
181
|
-
return grad_x, grad_y, None, None, None, None, None, None, None, None, None, None
|
|
182
|
-
|
|
183
|
-
@staticmethod
|
|
184
|
-
@triton.jit
|
|
185
|
-
def kernel_blocksparse_matmul_sss(x,
|
|
186
|
-
x_b, x_b_s, x_r, x_r_s, x_c, x_c_s,
|
|
187
|
-
s_l_x_b, s_l_x_b_s, s_l_x_r, s_l_x_r_s, s_l_x_c, s_l_x_c_s,
|
|
188
|
-
r_lut_x,
|
|
189
|
-
y,
|
|
190
|
-
y_b, y_b_s, y_r, y_r_s, y_c, y_c_s,
|
|
191
|
-
s_l_y_b, s_l_y_b_s, s_l_y_r, s_l_y_r_s, s_l_y_c, s_l_y_c_s,
|
|
192
|
-
r_lut_y,
|
|
193
|
-
o,
|
|
194
|
-
o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,
|
|
195
|
-
s_lut_o,
|
|
196
|
-
s_lut_o_r, s_lut_o_r_s,
|
|
197
|
-
s_lut_o_c, s_lut_o_c_s,
|
|
198
|
-
sparsity_block_size,
|
|
199
|
-
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
200
|
-
# Get triton block indices
|
|
201
|
-
pid_blk = tl.program_id(axis=0)
|
|
202
|
-
pid_row = tl.program_id(axis=1)
|
|
203
|
-
pid_col = tl.program_id(axis=2)
|
|
204
|
-
|
|
205
|
-
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
206
|
-
spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
|
|
207
|
-
spa_bat_o_msk = (spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s + s_lut_o_c * s_lut_o_c_s)
|
|
208
|
-
spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
|
|
209
|
-
|
|
210
|
-
spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
|
|
211
|
-
spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s + s_lut_o_c * s_lut_o_c_s)
|
|
212
|
-
spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
|
|
213
|
-
|
|
214
|
-
spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
|
|
215
|
-
spa_col_o_msk = (spa_col_o_idx < s_lut_o_r * s_lut_o_r_s + s_lut_o_c * s_lut_o_c_s)
|
|
216
|
-
spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
|
|
217
|
-
|
|
218
|
-
# Setup buffer
|
|
219
|
-
buf = tl.zeros(shape=(TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), dtype=tl.float32)
|
|
220
|
-
|
|
221
|
-
# Slide over triton block sized segments of input tensors
|
|
222
|
-
for i_seg_tri in range(0, tl.cdiv(s_l_x_c * sparsity_block_size, TRITON_BLOCK_SIZE)):
|
|
223
|
-
# Convert to segment index of sparsity layout
|
|
224
|
-
i_seg_spa = (i_seg_tri * TRITON_BLOCK_SIZE) // sparsity_block_size
|
|
225
|
-
# Calculate the triton segment index within a block
|
|
226
|
-
i_seg_tri_mod = i_seg_tri % (sparsity_block_size // TRITON_BLOCK_SIZE)
|
|
227
|
-
|
|
228
|
-
# Get reverse sparsity indices for input tensors x and y
|
|
229
|
-
# These are either -1 if the block is empty or equal to the index of the block in the sparse tensor
|
|
230
|
-
|
|
231
|
-
# Get reverse sparsity indices for x
|
|
232
|
-
rev_idx_spa_x_idx = (spa_bat_o * s_l_x_b_s +
|
|
233
|
-
spa_row_o * s_l_x_r_s +
|
|
234
|
-
i_seg_spa * s_l_x_c_s)
|
|
235
|
-
rev_idx_spa_x_msk = (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s + s_l_x_r * s_l_x_r_s + s_l_x_c * s_l_x_c_s)
|
|
236
|
-
rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
|
|
237
|
-
|
|
238
|
-
# Get reverse sparsity indices for y
|
|
239
|
-
rev_idx_spa_y_idx = (spa_bat_o * s_l_y_b_s + i_seg_spa * s_l_y_r_s + spa_col_o * s_l_y_c_s)
|
|
240
|
-
rev_idx_spa_y_msk = (rev_idx_spa_y_idx < s_l_y_b * s_l_y_b_s + s_l_y_r * s_l_y_r_s + s_l_y_c * s_l_y_c_s)
|
|
241
|
-
rev_idx_spa_y = tl.load(r_lut_y + rev_idx_spa_y_idx, mask=rev_idx_spa_y_msk).to(tl.int32)
|
|
242
|
-
|
|
243
|
-
# If both blocks are present commence calculation
|
|
244
|
-
if rev_idx_spa_x >= 0 and rev_idx_spa_y >= 0:
|
|
245
|
-
blk_x_idx = ((rev_idx_spa_x * x_b_s) +
|
|
246
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
247
|
-
((i_seg_tri_mod * TRITON_BLOCK_SIZE +
|
|
248
|
-
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
249
|
-
blk_x_msk = (blk_x_idx < x_b * x_b_s + x_r * x_r_s + x_c * x_c_s)
|
|
250
|
-
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
251
|
-
|
|
252
|
-
blk_y_idx = ((rev_idx_spa_y * y_b_s) +
|
|
253
|
-
((i_seg_tri_mod * TRITON_BLOCK_SIZE +
|
|
254
|
-
tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
|
|
255
|
-
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
|
|
256
|
-
blk_y_msk = (blk_y_idx < y_b * y_b_s + y_r * y_r_s + y_c * y_c_s)
|
|
257
|
-
blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
|
|
258
|
-
|
|
259
|
-
# Perform matrix multiplication
|
|
260
|
-
buf += tl.dot(blk_x, blk_y)
|
|
261
|
-
|
|
262
|
-
# Store output
|
|
263
|
-
blk_o_idx = ((pid_blk * o_b_s) +
|
|
264
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
265
|
-
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
266
|
-
blk_o_msk = (blk_o_idx < o_b * o_b_s + o_r * o_r_s + o_c * o_c_s)
|
|
267
|
-
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
# --- Softmax ---
|
|
271
|
-
|
|
272
|
-
class BlocksparseSoftmax(BaseBlocksparse):
|
|
273
|
-
# TODO At the moment uses standard softmax instead of blocksparse improvements
|
|
274
|
-
|
|
275
|
-
def __init__(self, sparsity_block_size: int, device: torch.device, triton_block_size: int = None) -> None:
|
|
276
|
-
super().__init__(sparsity_block_size, device, triton_block_size=triton_block_size)
|
|
277
|
-
|
|
278
|
-
self.blksprs_to_dense = BlocksparseToDense(sparsity_block_size, device)
|
|
279
|
-
self.blksprs_to_sparse = BlocksparseToSparse(sparsity_block_size, device)
|
|
280
|
-
|
|
281
|
-
def forward(self, x: Tensor, sparsity_layout: Tensor, fill_value:float=float("-inf")) -> Tensor:
|
|
282
|
-
self.validate_tensors(x)
|
|
283
|
-
|
|
284
|
-
x_dense = self.blksprs_to_dense(x, sparsity_layout, fill_value=fill_value)
|
|
285
|
-
x_softmax = torch.softmax(x_dense, dim=-1)
|
|
286
|
-
x_sparse = self.blksprs_to_sparse(x_softmax, sparsity_layout)
|
|
287
|
-
|
|
288
|
-
return x_sparse
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
# --- Transpose ---
|
|
292
|
-
|
|
293
|
-
class BlocksparseTranspose(BaseBlocksparse):
|
|
294
|
-
|
|
295
|
-
def __init__(self, sparsity_block_size: int, device: torch.device, triton_block_size: int = None) -> None:
|
|
296
|
-
super().__init__(sparsity_block_size, device, triton_block_size=triton_block_size)
|
|
297
|
-
|
|
298
|
-
def forward(self, x: Tensor, sparsity_layout: Tensor, shuffle_blocks: bool = True) -> (Tensor, Tensor):
|
|
299
|
-
self.validate_tensors(x)
|
|
300
|
-
|
|
301
|
-
x_t = x.transpose(1, 2).contiguous()
|
|
302
|
-
sparsity_layout_t = sparsity_layout.transpose(-1, -2).contiguous()
|
|
303
|
-
|
|
304
|
-
if shuffle_blocks:
|
|
305
|
-
sparsity_layout_t_flat = sparsity_layout.reshape(-1)
|
|
306
|
-
shuffle_layout = ((torch.cumsum(sparsity_layout_t_flat, dim=-1) - 1) *
|
|
307
|
-
(sparsity_layout_t_flat == 1) -
|
|
308
|
-
(1 * (sparsity_layout_t_flat == 0)))
|
|
309
|
-
shuffle_layout = (shuffle_layout.reshape(sparsity_layout.size()).transpose(-1, -2).contiguous()
|
|
310
|
-
.reshape(-1).to(torch.int))
|
|
311
|
-
shuffle_layout = shuffle_layout[shuffle_layout >= 0]
|
|
312
|
-
x_t = x_t[shuffle_layout, :, :]
|
|
313
|
-
|
|
314
|
-
return x_t, sparsity_layout_t
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
# --- To Dense ---
|
|
318
|
-
|
|
319
|
-
class BlocksparseToDense(BaseBlocksparse):
|
|
320
|
-
|
|
321
|
-
def __init__(self, sparsity_block_size: int, device: torch.device, triton_block_size: int = None) -> None:
|
|
322
|
-
super().__init__(sparsity_block_size, device, triton_block_size=triton_block_size)
|
|
323
|
-
|
|
324
|
-
def forward(self, x: Tensor, sparsity_layout: Tensor, fill_value: int = 0) -> Tensor:
|
|
325
|
-
self.validate_tensors(x)
|
|
326
|
-
|
|
327
|
-
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
|
328
|
-
sparsity_reverse_lut = ((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
329
|
-
(sparsity_layout_flat == 1) -
|
|
330
|
-
(1 * (sparsity_layout_flat == 0)))
|
|
331
|
-
|
|
332
|
-
return _BlocksparseToDense.apply(x,
|
|
333
|
-
sparsity_layout, sparsity_reverse_lut,
|
|
334
|
-
self.sparsity_block_size, fill_value,
|
|
335
|
-
self.triton_block_size, self.device)
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
class _BlocksparseToDense(torch.autograd.Function):
|
|
339
|
-
|
|
340
|
-
@staticmethod
|
|
341
|
-
def forward(ctx, x: Tensor,
|
|
342
|
-
sparsity_layout: Tensor, sparsity_reverse_lut: Tensor,
|
|
343
|
-
sparsity_block_size: int, fill_value: int,
|
|
344
|
-
triton_block_size: int, device: torch.device) -> Tensor:
|
|
345
|
-
output = torch.full(size=(sparsity_layout.size(0), sparsity_layout.size(1) * sparsity_block_size,
|
|
346
|
-
sparsity_layout.size(2) * sparsity_block_size), fill_value=fill_value,
|
|
347
|
-
dtype=x.dtype, device=device)
|
|
348
|
-
|
|
349
|
-
x_b, x_r, x_c = x.shape
|
|
350
|
-
x_b_s, x_r_s, x_c_s = x.stride()
|
|
351
|
-
s_l_b, s_l_r, s_l_c = sparsity_layout.size()
|
|
352
|
-
s_l_b_s, s_l_r_s, s_l_c_s = sparsity_layout.stride()
|
|
353
|
-
o_b, o_r, o_c = output.size()
|
|
354
|
-
o_b_s, o_r_s, o_c_s = output.stride()
|
|
355
|
-
|
|
356
|
-
if triton_block_size is None:
|
|
357
|
-
triton_block_size = BaseBlocksparse.get_triton_block_size(sparsity_block_size)
|
|
358
|
-
|
|
359
|
-
triton_grid = lambda meta: [o_b,
|
|
360
|
-
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
361
|
-
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
362
|
-
|
|
363
|
-
(_BlocksparseToDense.kernel_blocksparse_to_dense[triton_grid]
|
|
364
|
-
(x,
|
|
365
|
-
x_b, x_b_s, x_r, x_r_s, x_c, x_c_s,
|
|
366
|
-
s_l_b, s_l_b_s, s_l_r, s_l_r_s, s_l_c, s_l_c_s,
|
|
367
|
-
sparsity_reverse_lut,
|
|
368
|
-
output,
|
|
369
|
-
o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,
|
|
370
|
-
sparsity_block_size,
|
|
371
|
-
triton_block_size))
|
|
372
|
-
|
|
373
|
-
ctx.sparsity_layout = sparsity_layout
|
|
374
|
-
ctx.sparsity_block_size = sparsity_block_size
|
|
375
|
-
ctx.triton_block_size = triton_block_size
|
|
376
|
-
ctx.device = device
|
|
377
|
-
|
|
378
|
-
return output
|
|
379
|
-
|
|
380
|
-
@staticmethod
|
|
381
|
-
def backward(ctx, grad_output):
|
|
382
|
-
sparsity_layout = ctx.sparsity_layout
|
|
383
|
-
sparsity_block_size = ctx.sparsity_block_size
|
|
384
|
-
triton_block_size = ctx.triton_block_size
|
|
385
|
-
device = ctx.device
|
|
386
|
-
|
|
387
|
-
return BlocksparseToSparse(sparsity_block_size, device, triton_block_size)(grad_output,
|
|
388
|
-
sparsity_layout), None, None, None, None, None, None
|
|
389
|
-
|
|
390
|
-
@staticmethod
|
|
391
|
-
@triton.jit
|
|
392
|
-
def kernel_blocksparse_to_dense(x,
|
|
393
|
-
x_b, x_b_s, x_r, x_r_s, x_c, x_c_s,
|
|
394
|
-
s_l_b, s_l_b_s, s_l_r, s_l_r_s, s_l_c, s_l_c_s,
|
|
395
|
-
sparsity_reverse_lut,
|
|
396
|
-
o,
|
|
397
|
-
o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,
|
|
398
|
-
sparsity_block_size,
|
|
399
|
-
TRITON_BLOCK_SIZE: tl.constexpr):
|
|
400
|
-
# Get triton block indices
|
|
401
|
-
pid_bat = tl.program_id(axis=0)
|
|
402
|
-
pid_row = tl.program_id(axis=1)
|
|
403
|
-
pid_col = tl.program_id(axis=2)
|
|
404
|
-
|
|
405
|
-
# Get sparsity index of current block
|
|
406
|
-
spa_row = (pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size
|
|
407
|
-
spa_col = (pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size
|
|
408
|
-
|
|
409
|
-
# Get reverse sparsity index for current block
|
|
410
|
-
rev_idx_spa_idx = (pid_bat * s_l_b_s + spa_row * s_l_r_s + spa_col * s_l_c_s)
|
|
411
|
-
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_b * s_l_b_s + s_l_r * s_l_r_s + s_l_c * s_l_c_s)
|
|
412
|
-
rev_idx_spa = tl.load(sparsity_reverse_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
413
|
-
|
|
414
|
-
# If block is present commence operations
|
|
415
|
-
if rev_idx_spa >= 0:
|
|
416
|
-
blk_idx = (rev_idx_spa * x_b_s +
|
|
417
|
-
(((pid_row % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
|
|
418
|
-
tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
419
|
-
(((pid_col % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
|
|
420
|
-
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
421
|
-
blk_msk = (blk_idx < x_b * x_b_s + x_r * x_r_s + x_c * x_c_s)
|
|
422
|
-
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
423
|
-
|
|
424
|
-
o_idx = (pid_bat * o_b_s +
|
|
425
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
426
|
-
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
427
|
-
o_msk = (o_idx < o_b * o_b_s + o_r * o_r_s + o_c * o_c_s)
|
|
428
|
-
tl.store(o + o_idx, blk, o_msk)
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
# --- To Sparse ---
|
|
432
|
-
|
|
433
|
-
class BlocksparseToSparse(BaseBlocksparse):
|
|
434
|
-
|
|
435
|
-
def __init__(self, sparsity_block_size: int, device: torch.device, triton_block_size: int = None) -> None:
|
|
436
|
-
super().__init__(sparsity_block_size, device, triton_block_size=triton_block_size)
|
|
437
|
-
|
|
438
|
-
def forward(self, x: Tensor, sparsity_layout: Tensor) -> Tensor:
|
|
439
|
-
self.validate_tensors(x)
|
|
440
|
-
|
|
441
|
-
sparsity_lut = torch.nonzero(sparsity_layout)
|
|
442
|
-
n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
|
|
443
|
-
|
|
444
|
-
return _BlocksparseToSparse.apply(x,
|
|
445
|
-
sparsity_layout, sparsity_lut,
|
|
446
|
-
self.sparsity_block_size, n_sparse_blocks,
|
|
447
|
-
self.triton_block_size, self.device)
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
class _BlocksparseToSparse(torch.autograd.Function):
|
|
451
|
-
|
|
452
|
-
@staticmethod
|
|
453
|
-
def forward(ctx, x: Tensor,
|
|
454
|
-
sparsity_layout: Tensor, sparsity_lut: Tensor,
|
|
455
|
-
sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int, device: torch.device) -> Tensor:
|
|
456
|
-
output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size), device=device)
|
|
457
|
-
|
|
458
|
-
x_b, x_r, x_c = x.size()
|
|
459
|
-
x_b_s, x_r_s, x_c_s = x.stride()
|
|
460
|
-
o_b, o_r, o_c = output.size()
|
|
461
|
-
o_b_s, o_r_s, o_c_s = output.stride()
|
|
462
|
-
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
463
|
-
s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
|
|
464
|
-
|
|
465
|
-
if triton_block_size is None:
|
|
466
|
-
triton_block_size = BaseBlocksparse.get_triton_block_size(sparsity_block_size)
|
|
467
|
-
|
|
468
|
-
triton_grid = lambda meta: [o_b,
|
|
469
|
-
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
470
|
-
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
471
|
-
|
|
472
|
-
(_BlocksparseToSparse.kernel_blocksparse_to_sparse[triton_grid]
|
|
473
|
-
(x, x_b, x_b_s, x_r, x_r_s, x_c, x_c_s,
|
|
474
|
-
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c,
|
|
475
|
-
s_lut_c_s,
|
|
476
|
-
output, o_b_s, o_r_s, o_c_s,
|
|
477
|
-
sparsity_block_size,
|
|
478
|
-
triton_block_size))
|
|
479
|
-
|
|
480
|
-
ctx.sparsity_layout = sparsity_layout
|
|
481
|
-
ctx.sparsity_block_size = sparsity_block_size
|
|
482
|
-
ctx.triton_block_size = triton_block_size
|
|
483
|
-
ctx.device = device
|
|
484
|
-
|
|
485
|
-
return output
|
|
486
|
-
|
|
487
|
-
@staticmethod
|
|
488
|
-
def backward(ctx, grad_output):
|
|
489
|
-
sparsity_layout = ctx.sparsity_layout
|
|
490
|
-
sparsity_block_size = ctx.sparsity_block_size
|
|
491
|
-
triton_block_size = ctx.triton_block_size
|
|
492
|
-
device = ctx.device
|
|
493
|
-
|
|
494
|
-
# return _BlocksparseToDense.apply(grad_output,
|
|
495
|
-
# sparsity_layout, sparsity_lut,
|
|
496
|
-
# sparsity_block_size, 0,
|
|
497
|
-
# triton_block_size, device), None, None, None, None, None, None
|
|
498
|
-
return BlocksparseToDense(sparsity_block_size, device, triton_block_size)(grad_output,
|
|
499
|
-
sparsity_layout), None, None, None, None, None, None
|
|
500
|
-
|
|
501
|
-
@staticmethod
|
|
502
|
-
@triton.jit
|
|
503
|
-
def kernel_blocksparse_to_sparse(x,
|
|
504
|
-
x_b, x_b_s, x_r, x_r_s, x_c: tl.constexpr, x_c_s,
|
|
505
|
-
s_lut, s_lut_r, s_lut_r_s, s_lut_c, s_lut_c_s,
|
|
506
|
-
o,
|
|
507
|
-
o_b_s, o_r_s, o_c_s,
|
|
508
|
-
sparsity_block_size,
|
|
509
|
-
TRITON_BLOCK_SIZE: tl.constexpr):
|
|
510
|
-
# Get triton block indices
|
|
511
|
-
pid_blk = tl.program_id(axis=0)
|
|
512
|
-
pid_row = tl.program_id(axis=1)
|
|
513
|
-
pid_col = tl.program_id(axis=2)
|
|
514
|
-
|
|
515
|
-
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
516
|
-
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
517
|
-
spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s + s_lut_c * s_lut_c_s)
|
|
518
|
-
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
519
|
-
|
|
520
|
-
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
521
|
-
spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s + s_lut_c * s_lut_c_s)
|
|
522
|
-
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
523
|
-
|
|
524
|
-
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
525
|
-
spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s + s_lut_c * s_lut_c_s)
|
|
526
|
-
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
527
|
-
|
|
528
|
-
# Load block from dense tensor
|
|
529
|
-
blk_d_idx = (spa_bat * x_b_s +
|
|
530
|
-
((spa_row * sparsity_block_size + pid_row * TRITON_BLOCK_SIZE +
|
|
531
|
-
tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
532
|
-
((spa_col * sparsity_block_size + pid_col * TRITON_BLOCK_SIZE +
|
|
533
|
-
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
534
|
-
blk_d_msk = (blk_d_idx < x_b * x_b_s + x_r * x_r_s + x_c * x_c_s)
|
|
535
|
-
blk_d = tl.load(x + blk_d_idx, mask=blk_d_msk)
|
|
536
|
-
|
|
537
|
-
# Store block in sparse tensor
|
|
538
|
-
blk_o_idx = ((pid_blk * o_b_s) +
|
|
539
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
540
|
-
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
|
|
541
|
-
blk_o_msk = (blk_o_idx < (pid_blk + 1) * o_b_s)
|
|
542
|
-
tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
class BlocksparseTools:
|
|
546
|
-
|
|
547
|
-
@staticmethod
|
|
548
|
-
def do_shape_blocksparse(x: Tensor):
|
|
549
|
-
if x.dim() == 3:
|
|
550
|
-
return x
|
|
551
|
-
|
|
552
|
-
return x.reshape(-1, x.size(-2), x.size(-1))
|
|
553
|
-
|
|
554
|
-
@staticmethod
|
|
555
|
-
def undo_shape_blocksparse(x: Tensor, shape: Size):
|
|
556
|
-
if x.dim() == 3:
|
|
557
|
-
return x
|
|
558
|
-
|
|
559
|
-
return x.reshape(shape)
|
|
560
|
-
|
|
561
|
-
# Methods used for verification
|
|
562
|
-
|
|
563
|
-
@staticmethod
|
|
564
|
-
def slow_to_dense(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int):
|
|
565
|
-
output = torch.zeros(size=(sparsity_layout.size(0), sparsity_layout.size(1) * sparsity_block_size,
|
|
566
|
-
sparsity_layout.size(2) * sparsity_block_size), device=x.device)
|
|
567
|
-
indices_sparse_blocks = sparsity_layout.nonzero(as_tuple=True)
|
|
568
|
-
|
|
569
|
-
for idx, (b, r, c) in enumerate(zip(*indices_sparse_blocks)):
|
|
570
|
-
t_r = r * sparsity_block_size
|
|
571
|
-
t_c = c * sparsity_block_size
|
|
572
|
-
to_insert = x[idx]
|
|
573
|
-
output[b, t_r:t_r + sparsity_block_size, t_c:t_c + sparsity_block_size] = to_insert
|
|
574
|
-
|
|
575
|
-
return output
|
|
576
|
-
|
|
577
|
-
@staticmethod
|
|
578
|
-
def slow_to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int):
|
|
579
|
-
indices_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
|
|
580
|
-
output = torch.zeros(size=(indices_sparse_blocks, sparsity_block_size, sparsity_block_size), device=x.device)
|
|
581
|
-
indices_sparse_blocks = sparsity_layout.nonzero(as_tuple=True)
|
|
582
|
-
|
|
583
|
-
for idx, (b, r, c) in enumerate(zip(*indices_sparse_blocks)):
|
|
584
|
-
t_r = r * sparsity_block_size
|
|
585
|
-
t_c = c * sparsity_block_size
|
|
586
|
-
to_insert = x[b, t_r:t_r + sparsity_block_size, t_c:t_c + sparsity_block_size]
|
|
587
|
-
output[idx] = to_insert
|
|
588
|
-
|
|
589
|
-
return output
|
blksprs-0.2b4.dist-info/METADATA
DELETED
|
@@ -1,26 +0,0 @@
|
|
|
1
|
-
Metadata-Version: 2.1
|
|
2
|
-
Name: blksprs
|
|
3
|
-
Version: 0.2b4
|
|
4
|
-
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
|
-
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
|
-
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
7
|
-
Project-URL: Bugtracker, https://github.com/FelixSchoen/blksprs/issues
|
|
8
|
-
Requires-Python: >=3.12
|
|
9
|
-
Description-Content-Type: text/markdown
|
|
10
|
-
Requires-Dist: torch
|
|
11
|
-
Provides-Extra: deploy
|
|
12
|
-
Requires-Dist: build; extra == "deploy"
|
|
13
|
-
Requires-Dist: twine; extra == "deploy"
|
|
14
|
-
Requires-Dist: pdoc3; extra == "deploy"
|
|
15
|
-
Provides-Extra: test
|
|
16
|
-
Requires-Dist: pytest; extra == "test"
|
|
17
|
-
Requires-Dist: pytest-xdist; extra == "test"
|
|
18
|
-
Requires-Dist: pytest-cov; extra == "test"
|
|
19
|
-
Requires-Dist: coverage; extra == "test"
|
|
20
|
-
Requires-Dist: matplotlib; extra == "test"
|
|
21
|
-
|
|
22
|
-
# blksprs
|
|
23
|
-
|
|
24
|
-
## Overview
|
|
25
|
-
|
|
26
|
-
A lightweight library for operations on blocksparse matrices in PyTorch.
|
blksprs-0.2b4.dist-info/RECORD
DELETED
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
blksprs/ops/blocksparse.py,sha256=4vATdQicjMgEmULct-955vyJ4rRoIqk572tIGu5RjPU,27630
|
|
2
|
-
blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
|
|
3
|
-
blksprs-0.2b4.dist-info/METADATA,sha256=PGuf_WUjS7KT7dvkPoiApvSasWGKVZtH-EF_XX_Ffos,876
|
|
4
|
-
blksprs-0.2b4.dist-info/WHEEL,sha256=UvcQYKBHoFqaQd6LKyqHw9fxEolWLQnlzP0h_LgJAfI,91
|
|
5
|
-
blksprs-0.2b4.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
|
|
6
|
-
blksprs-0.2b4.dist-info/RECORD,,
|
|
File without changes
|