blksprs 1.0__py3-none-any.whl → 1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/layouting/distribution_layout.py +114 -0
- blksprs/layouting/sparsity_layout.py +129 -7
- blksprs/misc/broadcast_addition.py +132 -0
- blksprs/ops/conversion.py +237 -17
- blksprs/ops/distribution.py +362 -0
- blksprs/ops/exp.py +18 -8
- blksprs/ops/{matmul_sss.py → matmul.py} +28 -26
- blksprs/ops/row_wise_sum.py +21 -5
- blksprs/ops/softmax.py +23 -12
- blksprs/ops/transpose.py +19 -7
- blksprs/utils/tools.py +1 -28
- blksprs/utils/validation.py +53 -1
- {blksprs-1.0.dist-info → blksprs-1.2.dist-info}/METADATA +39 -14
- blksprs-1.2.dist-info/RECORD +17 -0
- {blksprs-1.0.dist-info → blksprs-1.2.dist-info}/WHEEL +1 -1
- blksprs-1.0.dist-info/RECORD +0 -14
- {blksprs-1.0.dist-info → blksprs-1.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,362 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
from triton import language as tl
|
|
5
|
+
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size
|
|
7
|
+
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
8
|
+
validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def gather(src: Tensor, sparsity_layout_src: Tensor, idx: Tensor, sparsity_layout_idx: Tensor,
|
|
12
|
+
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
13
|
+
"""Applies a gather operation on a block-sparse tensor in compressed form.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
src (Tensor): The source block-sparse tensor in compressed form to gather from.
|
|
17
|
+
sparsity_layout_src (Tensor): The sparsity layout of the source block-sparse tensor.
|
|
18
|
+
idx (Tensor): The block-sparse indices tensor in compressed form specifying how to gather from the source tensor.
|
|
19
|
+
sparsity_layout_idx (Tensor): The sparsity layout of the indices block-sparse tensor.
|
|
20
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
21
|
+
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
Tensor: The result of the gather operation as a block-sparse tensor in compressed form.
|
|
25
|
+
|
|
26
|
+
"""
|
|
27
|
+
validate_dimensions(src, idx)
|
|
28
|
+
validate_contiguous(src, idx)
|
|
29
|
+
validate_dtype_int(idx)
|
|
30
|
+
validate_device(src, idx)
|
|
31
|
+
validate_sparsity(sparsity_block_size, (src, sparsity_layout_src), (idx, sparsity_layout_idx))
|
|
32
|
+
validate_sparsity_block_size(sparsity_block_size, src, idx)
|
|
33
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
34
|
+
|
|
35
|
+
sparsity_layout_x_flat = sparsity_layout_src.reshape(-1)
|
|
36
|
+
sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
|
|
37
|
+
(sparsity_layout_x_flat == 1) -
|
|
38
|
+
(1 * (sparsity_layout_x_flat == 0)))
|
|
39
|
+
|
|
40
|
+
sparsity_lut_i = torch.nonzero(sparsity_layout_idx).contiguous()
|
|
41
|
+
|
|
42
|
+
validate_contiguous(sparsity_layout_src, sparsity_reverse_lut_x,
|
|
43
|
+
sparsity_layout_idx, sparsity_lut_i)
|
|
44
|
+
|
|
45
|
+
return _BlocksparseGather.apply(src, sparsity_layout_src, sparsity_reverse_lut_x,
|
|
46
|
+
idx, sparsity_layout_idx, sparsity_lut_i,
|
|
47
|
+
sparsity_block_size, triton_block_size)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class _BlocksparseGather(torch.autograd.Function):
|
|
51
|
+
|
|
52
|
+
@staticmethod
|
|
53
|
+
def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
|
|
54
|
+
i: Tensor, sparsity_layout_i: Tensor, sparsity_lut_i: Tensor,
|
|
55
|
+
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
56
|
+
output = torch.empty_like(i, dtype=x.dtype)
|
|
57
|
+
|
|
58
|
+
x_b, x_r, x_c = x.size()
|
|
59
|
+
x_b_s, x_r_s, x_c_s = x.stride()
|
|
60
|
+
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
|
|
61
|
+
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout_x.stride()
|
|
62
|
+
i_b, i_r, i_c = i.size()
|
|
63
|
+
i_b_s, i_r_s, i_c_s = i.stride()
|
|
64
|
+
s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
|
|
65
|
+
s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
|
|
66
|
+
o_b, o_r, o_c = output.size()
|
|
67
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
68
|
+
|
|
69
|
+
if triton_block_size is None:
|
|
70
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
71
|
+
|
|
72
|
+
triton_grid = lambda meta: [o_b,
|
|
73
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
74
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
75
|
+
|
|
76
|
+
(_BlocksparseGather.kernel_blocksparse_gather[triton_grid]
|
|
77
|
+
(x,
|
|
78
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
79
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
80
|
+
sparsity_reverse_lut_x,
|
|
81
|
+
i,
|
|
82
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
83
|
+
output,
|
|
84
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
85
|
+
sparsity_lut_i, s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
|
|
86
|
+
sparsity_block_size,
|
|
87
|
+
triton_block_size))
|
|
88
|
+
|
|
89
|
+
ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_i)
|
|
90
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
91
|
+
ctx.triton_block_size = triton_block_size
|
|
92
|
+
|
|
93
|
+
return output
|
|
94
|
+
|
|
95
|
+
@staticmethod
|
|
96
|
+
def backward(ctx, grad_output):
|
|
97
|
+
sparsity_layout_x, i, sparsity_layout_i = ctx.saved_tensors
|
|
98
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
99
|
+
triton_block_size = ctx.triton_block_size
|
|
100
|
+
|
|
101
|
+
return scatter_reduce(grad_output, sparsity_layout_i,
|
|
102
|
+
i,
|
|
103
|
+
sparsity_layout_x,
|
|
104
|
+
sparsity_block_size,
|
|
105
|
+
reduce_op="sum",
|
|
106
|
+
triton_block_size=triton_block_size), None, None, None, None, None, None, None
|
|
107
|
+
|
|
108
|
+
@staticmethod
|
|
109
|
+
@triton.jit
|
|
110
|
+
def kernel_blocksparse_gather(x,
|
|
111
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
112
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
113
|
+
r_lut_x,
|
|
114
|
+
i,
|
|
115
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
116
|
+
o,
|
|
117
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
118
|
+
s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
119
|
+
sparsity_block_size,
|
|
120
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
121
|
+
# Get triton block indices
|
|
122
|
+
pid_blk = tl.program_id(axis=0)
|
|
123
|
+
pid_row = tl.program_id(axis=1)
|
|
124
|
+
pid_col = tl.program_id(axis=2)
|
|
125
|
+
|
|
126
|
+
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
127
|
+
spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
|
|
128
|
+
spa_bat_o_msk = (spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
129
|
+
spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
|
|
130
|
+
|
|
131
|
+
spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
|
|
132
|
+
spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
133
|
+
spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
|
|
134
|
+
|
|
135
|
+
# Load index values
|
|
136
|
+
blk_i_idx = ((pid_blk * i_b_s) +
|
|
137
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
138
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
139
|
+
blk_i_msk = (blk_i_idx < i_b * i_b_s)
|
|
140
|
+
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
|
|
141
|
+
|
|
142
|
+
# Get positions of sparsity blocks
|
|
143
|
+
pos_spa_blk_x = blk_i // sparsity_block_size
|
|
144
|
+
pos_spa_col_x = blk_i % sparsity_block_size
|
|
145
|
+
|
|
146
|
+
# Load reverse sparsity indices for x
|
|
147
|
+
rev_idx_spa_x_idx = ((spa_bat_o * s_l_x_b_s) +
|
|
148
|
+
(spa_row_o * s_l_x_r_s) +
|
|
149
|
+
(pos_spa_blk_x * s_l_x_c_s))
|
|
150
|
+
rev_idx_spa_x_msk = (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
|
|
151
|
+
rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
|
|
152
|
+
|
|
153
|
+
# Load x values
|
|
154
|
+
blk_x_idx = ((rev_idx_spa_x * x_b_s) +
|
|
155
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
156
|
+
(pos_spa_col_x * x_c_s))
|
|
157
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
158
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
159
|
+
|
|
160
|
+
# Store output
|
|
161
|
+
blk_o_idx = ((pid_blk * o_b_s) +
|
|
162
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
163
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
164
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
165
|
+
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def scatter(src: Tensor, sparsity_layout_src: Tensor,
|
|
169
|
+
idx: Tensor,
|
|
170
|
+
sparsity_layout_tgt: Tensor,
|
|
171
|
+
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
172
|
+
"""Wrapper for ``scatter_reduce`` with ``reduce_op="none"``.
|
|
173
|
+
|
|
174
|
+
"""
|
|
175
|
+
return scatter_reduce(src, sparsity_layout_src,
|
|
176
|
+
idx,
|
|
177
|
+
sparsity_layout_tgt,
|
|
178
|
+
sparsity_block_size,
|
|
179
|
+
reduce_op="none", triton_block_size=triton_block_size)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def scatter_reduce(src: Tensor, sparsity_layout_src: Tensor,
|
|
183
|
+
idx: Tensor,
|
|
184
|
+
sparsity_layout_tgt: Tensor,
|
|
185
|
+
sparsity_block_size: int,
|
|
186
|
+
reduce_op: str = "sum", triton_block_size: int = None) -> Tensor:
|
|
187
|
+
"""Applies a scatter operation on a block-sparse tensor in compressed form.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
src (Tensor): The source block-sparse tensor in compressed form to scatter from.
|
|
191
|
+
sparsity_layout_src (Tensor): The sparsity layout of the source block-sparse tensor.
|
|
192
|
+
idx (Tensor): The block-sparse indices tensor in compressed form specifying how to scatter to the target tensor.
|
|
193
|
+
sparsity_layout_tgt (Tensor): The sparsity layout of the target block-sparse tensor.
|
|
194
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
195
|
+
reduce_op (str, optional): The reduction operation to apply during the scatter operation (default ``"sum"``).
|
|
196
|
+
Supported operations are ``"none"`` and ``"sum"``.
|
|
197
|
+
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
Tensor: The result of the scatter operation as a block-sparse tensor in compressed form.
|
|
201
|
+
|
|
202
|
+
"""
|
|
203
|
+
validate_dimensions(src, idx)
|
|
204
|
+
validate_contiguous(src, idx)
|
|
205
|
+
validate_dtype_int(idx)
|
|
206
|
+
validate_device(src, idx)
|
|
207
|
+
validate_sparsity(sparsity_block_size, (src, sparsity_layout_src), (idx, sparsity_layout_src))
|
|
208
|
+
validate_sparsity_block_size(sparsity_block_size, src, idx)
|
|
209
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
210
|
+
|
|
211
|
+
if reduce_op not in ["none", "sum"]:
|
|
212
|
+
raise ValueError(f"Reduction operation '{reduce_op}' is not supported")
|
|
213
|
+
|
|
214
|
+
sparsity_lut_x = torch.nonzero(sparsity_layout_src).contiguous()
|
|
215
|
+
|
|
216
|
+
sparsity_layout_o_flat = sparsity_layout_tgt.reshape(-1)
|
|
217
|
+
sparsity_reverse_lut_o = ((torch.cumsum(sparsity_layout_o_flat, dim=-1) - 1) *
|
|
218
|
+
(sparsity_layout_o_flat == 1) -
|
|
219
|
+
(1 * (sparsity_layout_o_flat == 0)))
|
|
220
|
+
|
|
221
|
+
n_sparse_blocks = torch.sum(sparsity_layout_tgt.to(torch.int)).item()
|
|
222
|
+
|
|
223
|
+
validate_contiguous(sparsity_layout_src, sparsity_lut_x,
|
|
224
|
+
sparsity_layout_tgt, sparsity_reverse_lut_o)
|
|
225
|
+
|
|
226
|
+
return _BlocksparseScatterReduce.apply(src, sparsity_layout_src, sparsity_lut_x,
|
|
227
|
+
idx,
|
|
228
|
+
sparsity_layout_tgt, sparsity_reverse_lut_o,
|
|
229
|
+
sparsity_block_size, n_sparse_blocks,
|
|
230
|
+
reduce_op, triton_block_size)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
class _BlocksparseScatterReduce(torch.autograd.Function):
|
|
234
|
+
|
|
235
|
+
@staticmethod
|
|
236
|
+
def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_lut_x: Tensor,
|
|
237
|
+
i: Tensor,
|
|
238
|
+
sparsity_layout_o: Tensor, sparsity_reverse_lut_o: Tensor,
|
|
239
|
+
sparsity_block_size: int, n_sparse_blocks: int,
|
|
240
|
+
reduce_op: str, triton_block_size: int) -> Tensor:
|
|
241
|
+
output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
242
|
+
dtype=x.dtype, device=x.device)
|
|
243
|
+
|
|
244
|
+
x_b, x_r, x_c = x.size()
|
|
245
|
+
x_b_s, x_r_s, x_c_s = x.stride()
|
|
246
|
+
s_lut_x_r, s_lut_x_c = sparsity_lut_x.size()
|
|
247
|
+
s_lut_x_r_s, s_lut_x_c_s = sparsity_lut_x.stride()
|
|
248
|
+
i_b, i_r, i_c = i.size()
|
|
249
|
+
i_b_s, i_r_s, i_c_s = i.stride()
|
|
250
|
+
o_b, o_r, o_c = output.size()
|
|
251
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
252
|
+
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
|
|
253
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_o.stride()
|
|
254
|
+
|
|
255
|
+
if triton_block_size is None:
|
|
256
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
257
|
+
|
|
258
|
+
triton_grid = lambda meta: [x_b,
|
|
259
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
260
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
261
|
+
|
|
262
|
+
reduce_op_ind = 0
|
|
263
|
+
if reduce_op == "sum":
|
|
264
|
+
reduce_op_ind = 1
|
|
265
|
+
|
|
266
|
+
(_BlocksparseScatterReduce.kernel_blocksparse_scatter[triton_grid]
|
|
267
|
+
(x,
|
|
268
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
269
|
+
sparsity_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
270
|
+
i,
|
|
271
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
272
|
+
output,
|
|
273
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
274
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
275
|
+
sparsity_reverse_lut_o,
|
|
276
|
+
reduce_op_ind,
|
|
277
|
+
sparsity_block_size,
|
|
278
|
+
triton_block_size))
|
|
279
|
+
|
|
280
|
+
ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_o)
|
|
281
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
282
|
+
ctx.reduce_op = reduce_op
|
|
283
|
+
ctx.triton_block_size = triton_block_size
|
|
284
|
+
|
|
285
|
+
return output
|
|
286
|
+
|
|
287
|
+
@staticmethod
|
|
288
|
+
def backward(ctx, grad_output):
|
|
289
|
+
sparsity_layout_x, i, sparsity_layout_o = ctx.saved_tensors
|
|
290
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
291
|
+
reduce_op = ctx.reduce_op
|
|
292
|
+
triton_block_size = ctx.triton_block_size
|
|
293
|
+
|
|
294
|
+
if reduce_op == "sum":
|
|
295
|
+
return gather(grad_output, sparsity_layout_o, i, sparsity_layout_x, sparsity_block_size,
|
|
296
|
+
triton_block_size=triton_block_size), None, None, None, None, None, None, None, None, None
|
|
297
|
+
else:
|
|
298
|
+
raise ValueError(f"Reduction operation '{reduce_op}' does not support backward pass")
|
|
299
|
+
|
|
300
|
+
@staticmethod
|
|
301
|
+
@triton.jit
|
|
302
|
+
def kernel_blocksparse_scatter(x,
|
|
303
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
304
|
+
s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
305
|
+
i,
|
|
306
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
307
|
+
o,
|
|
308
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
309
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
310
|
+
r_lut_o,
|
|
311
|
+
reduce_op_ind,
|
|
312
|
+
sparsity_block_size,
|
|
313
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
314
|
+
# Get triton block indices
|
|
315
|
+
pid_blk = tl.program_id(axis=0)
|
|
316
|
+
pid_row = tl.program_id(axis=1)
|
|
317
|
+
pid_col = tl.program_id(axis=2)
|
|
318
|
+
|
|
319
|
+
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
320
|
+
spa_bat_x_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
|
|
321
|
+
spa_bat_x_msk = (spa_bat_x_idx < s_lut_x_r * s_lut_x_r_s)
|
|
322
|
+
spa_bat_x = tl.load(s_lut_x + spa_bat_x_idx, mask=spa_bat_x_msk)
|
|
323
|
+
|
|
324
|
+
spa_row_x_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
|
|
325
|
+
spa_row_x_msk = (spa_row_x_idx < s_lut_x_r * s_lut_x_r_s)
|
|
326
|
+
spa_row_x = tl.load(s_lut_x + spa_row_x_idx, mask=spa_row_x_msk)
|
|
327
|
+
|
|
328
|
+
# Load x values
|
|
329
|
+
blk_x_idx = ((pid_blk * x_b_s) +
|
|
330
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
331
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
332
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
333
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
334
|
+
|
|
335
|
+
# Load index values
|
|
336
|
+
blk_i_idx = ((pid_blk * i_b_s) +
|
|
337
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
338
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
339
|
+
blk_i_msk = (blk_i_idx < i_b * i_b_s)
|
|
340
|
+
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
|
|
341
|
+
|
|
342
|
+
# Get positions of sparsity blocks
|
|
343
|
+
pos_spa_blk_o = blk_i // sparsity_block_size
|
|
344
|
+
pos_spa_col_o = blk_i % sparsity_block_size
|
|
345
|
+
|
|
346
|
+
# Load reverse sparsity indices for o
|
|
347
|
+
rev_idx_spa_o_idx = ((spa_bat_x * s_l_o_b_s) +
|
|
348
|
+
(spa_row_x * s_l_o_r_s) +
|
|
349
|
+
(pos_spa_blk_o * s_l_o_c_s))
|
|
350
|
+
rev_idx_spa_o_msk = (rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
|
|
351
|
+
rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
|
|
352
|
+
|
|
353
|
+
# Store output
|
|
354
|
+
blk_o_idx = ((rev_idx_spa_o * o_b_s) +
|
|
355
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
356
|
+
(pos_spa_col_o * o_c_s))
|
|
357
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
358
|
+
|
|
359
|
+
if reduce_op_ind == 0:
|
|
360
|
+
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
361
|
+
elif reduce_op_ind == 1:
|
|
362
|
+
tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
blksprs/ops/exp.py
CHANGED
|
@@ -1,25 +1,35 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import triton
|
|
3
|
-
from triton import language as tl
|
|
4
3
|
from torch import Tensor
|
|
4
|
+
from triton import language as tl
|
|
5
5
|
|
|
6
6
|
from blksprs.utils.tools import get_triton_block_size
|
|
7
|
-
from blksprs.utils.validation import validate_contiguous, validate_dimensions,
|
|
7
|
+
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
8
|
+
validate_sparsity_block_size, validate_triton_block_size
|
|
8
9
|
|
|
9
10
|
|
|
10
11
|
def exp(x: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
11
|
-
"""Applies the element-wise exponential function to
|
|
12
|
+
"""Applies the element-wise exponential function to a block-sparse tensor.
|
|
13
|
+
|
|
14
|
+
Note:
|
|
15
|
+
This operation does not consider sparse blocks, i.e., these will not be set to ``e^0``.
|
|
16
|
+
Consider this when converting back to tensors in regular form.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
x (Tensor): A block-sparse tensor in compressed form.
|
|
20
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
21
|
+
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
12
22
|
|
|
13
|
-
|
|
23
|
+
Returns:
|
|
24
|
+
Tensor: The exponential function applied to all elements of the input tensor as a block-sparse tensor in
|
|
25
|
+
compressed form.
|
|
14
26
|
|
|
15
|
-
Note:
|
|
16
|
-
This operation does not consider sparse blocks, i.e., these will not be set to ``e^0``.
|
|
17
|
-
Consider this when converting back to dense tensors.
|
|
18
27
|
"""
|
|
19
28
|
validate_dimensions(x)
|
|
20
29
|
validate_contiguous(x)
|
|
21
|
-
validate_dtype_float(x)
|
|
22
30
|
validate_device(x)
|
|
31
|
+
validate_sparsity_block_size(sparsity_block_size, x)
|
|
32
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
23
33
|
|
|
24
34
|
return _BlocksparseExp.apply(x, sparsity_block_size, triton_block_size)
|
|
25
35
|
|
|
@@ -5,25 +5,39 @@ from triton import language as tl
|
|
|
5
5
|
|
|
6
6
|
from blksprs.ops.transpose import transpose
|
|
7
7
|
from blksprs.utils.tools import get_triton_block_size
|
|
8
|
-
from blksprs.utils.validation import validate_contiguous, validate_dimensions,
|
|
9
|
-
validate_sparsity
|
|
8
|
+
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
9
|
+
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
def
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
12
|
+
def matmul(x: Tensor, sparsity_layout_x: Tensor,
|
|
13
|
+
y: Tensor, sparsity_layout_y: Tensor,
|
|
14
|
+
sparsity_layout_output: Tensor,
|
|
15
|
+
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
16
|
+
"""Performs matrix multiplication between two block-sparse tensors.
|
|
16
17
|
|
|
17
|
-
The
|
|
18
|
+
The sparsity layout of the output tensor is used to only calculate blocks that will be present in the output.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
x (Tensor): A block-sparse tensor in compressed form.
|
|
22
|
+
y (Tensor): A block-sparse tensor in compressed form.
|
|
23
|
+
sparsity_layout_x (Tensor): The sparsity layout of the first block-sparse tensor.
|
|
24
|
+
sparsity_layout_y (Tensor): The sparsity layout of the second block-sparse tensor.
|
|
25
|
+
sparsity_layout_output (Tensor): The sparsity layout of the output tensor.
|
|
26
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
27
|
+
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
Tensor: The result of the matrix multiplication as a block-sparse tensor in compressed form.
|
|
18
31
|
|
|
19
32
|
"""
|
|
20
33
|
validate_dimensions(x, y)
|
|
21
34
|
validate_contiguous(x, y)
|
|
22
|
-
validate_dtype_float(x, y)
|
|
23
35
|
validate_device(x, y)
|
|
24
36
|
validate_sparsity(sparsity_block_size, (x, sparsity_layout_x), (y, sparsity_layout_y))
|
|
25
37
|
if sparsity_layout_x.size(-1) != sparsity_layout_y.size(-2):
|
|
26
38
|
raise ValueError("Inner dimensions of tensors must match")
|
|
39
|
+
validate_sparsity_block_size(sparsity_block_size, x, y)
|
|
40
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
27
41
|
|
|
28
42
|
sparsity_layout_x_flat = sparsity_layout_x.reshape(-1)
|
|
29
43
|
sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
|
|
@@ -98,10 +112,7 @@ class _BlocksparseMatmulSSS(torch.autograd.Function):
|
|
|
98
112
|
sparsity_block_size,
|
|
99
113
|
triton_block_size))
|
|
100
114
|
|
|
101
|
-
ctx.save_for_backward(x, y)
|
|
102
|
-
ctx.sparsity_layout_x = sparsity_layout_x
|
|
103
|
-
ctx.sparsity_layout_y = sparsity_layout_y
|
|
104
|
-
ctx.sparsity_layout_o = sparsity_layout_o
|
|
115
|
+
ctx.save_for_backward(x, sparsity_layout_x, y, sparsity_layout_y, sparsity_layout_o)
|
|
105
116
|
ctx.sparsity_block_size = sparsity_block_size
|
|
106
117
|
ctx.triton_block_size = triton_block_size
|
|
107
118
|
|
|
@@ -109,26 +120,17 @@ class _BlocksparseMatmulSSS(torch.autograd.Function):
|
|
|
109
120
|
|
|
110
121
|
@staticmethod
|
|
111
122
|
def backward(ctx, grad_output):
|
|
112
|
-
x, y = ctx.saved_tensors
|
|
113
|
-
sparsity_layout_x = ctx.sparsity_layout_x
|
|
114
|
-
sparsity_layout_y = ctx.sparsity_layout_y
|
|
115
|
-
sparsity_layout_o = ctx.sparsity_layout_o
|
|
123
|
+
x, sparsity_layout_x, y, sparsity_layout_y, sparsity_layout_o = ctx.saved_tensors
|
|
116
124
|
sparsity_block_size = ctx.sparsity_block_size
|
|
117
125
|
triton_block_size = ctx.triton_block_size
|
|
118
126
|
|
|
119
127
|
x_t, sparsity_layout_x_t = transpose(x, sparsity_layout_x, sparsity_block_size, triton_block_size)
|
|
120
128
|
y_t, sparsity_layout_y_t = transpose(y, sparsity_layout_y, sparsity_block_size, triton_block_size)
|
|
121
129
|
|
|
122
|
-
grad_x =
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
sparsity_block_size, triton_block_size)
|
|
127
|
-
grad_y = matmul_sss(x_t, grad_output,
|
|
128
|
-
sparsity_layout_x_t,
|
|
129
|
-
sparsity_layout_o,
|
|
130
|
-
sparsity_layout_y,
|
|
131
|
-
sparsity_block_size, triton_block_size)
|
|
130
|
+
grad_x = matmul(grad_output, sparsity_layout_o, y_t, sparsity_layout_y_t, sparsity_layout_x,
|
|
131
|
+
sparsity_block_size, triton_block_size)
|
|
132
|
+
grad_y = matmul(x_t, sparsity_layout_x_t, grad_output, sparsity_layout_o, sparsity_layout_y,
|
|
133
|
+
sparsity_block_size, triton_block_size)
|
|
132
134
|
|
|
133
135
|
return grad_x, grad_y, None, None, None, None, None, None, None, None, None
|
|
134
136
|
|
blksprs/ops/row_wise_sum.py
CHANGED
|
@@ -4,23 +4,39 @@ from torch import Tensor
|
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
6
|
from blksprs.utils.tools import get_triton_block_size
|
|
7
|
-
from blksprs.utils.validation import validate_contiguous, validate_dimensions,
|
|
7
|
+
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
8
|
+
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
8
9
|
|
|
9
10
|
|
|
10
11
|
def row_wise_sum(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
11
12
|
flag_slice_only: bool = False, triton_block_size: int = None) -> tuple[Tensor, Tensor]:
|
|
12
|
-
"""Computes the row-wise sum of a
|
|
13
|
+
"""Computes the row-wise sum of a block-sparse tensor.
|
|
13
14
|
|
|
14
|
-
Returns a
|
|
15
|
+
Returns a block-sparse tensor in compressed form with only one block per row, where the first entry contains the sum
|
|
16
|
+
of the corresponding row.
|
|
15
17
|
|
|
16
18
|
Note:
|
|
17
|
-
If ``flag_slice_only`` is set the output will be of shape ``[
|
|
19
|
+
If ``flag_slice_only`` is set the output will be of shape ``[x.size(0), x.size(1), 1]``.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
x (Tensor): A block-sparse tensor in compressed form.
|
|
23
|
+
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
24
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
25
|
+
flag_slice_only (bool, optional): If set the output will be of shape ``[x.size(0), x.size(1), 1]``
|
|
26
|
+
(default ``False``).
|
|
27
|
+
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
tuple[Tensor, Tensor]: A tuple containing a block-sparse tensor in compressed form containing the row-wise sum
|
|
31
|
+
of the input and the sparsity layout of the output tensor.
|
|
18
32
|
|
|
19
33
|
"""
|
|
20
34
|
validate_dimensions(x)
|
|
21
35
|
validate_contiguous(x)
|
|
22
|
-
validate_dtype_float(x)
|
|
23
36
|
validate_device(x)
|
|
37
|
+
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
38
|
+
validate_sparsity_block_size(sparsity_block_size, x)
|
|
39
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
24
40
|
|
|
25
41
|
sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
|
|
26
42
|
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
blksprs/ops/softmax.py
CHANGED
|
@@ -6,22 +6,37 @@ from triton import language as tl
|
|
|
6
6
|
from blksprs.ops.exp import exp
|
|
7
7
|
from blksprs.ops.row_wise_sum import row_wise_sum
|
|
8
8
|
from blksprs.utils.tools import get_triton_block_size
|
|
9
|
-
from blksprs.utils.validation import validate_contiguous, validate_dimensions,
|
|
9
|
+
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
10
|
+
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
def softmax(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
13
|
-
"""Computes the softmax of a
|
|
14
|
+
"""Computes the softmax of a block-sparse tensor in compressed form.
|
|
14
15
|
|
|
15
16
|
Note:
|
|
16
|
-
Sparse blocks are not considered for the calculation of the softmax, i.e., assumed to be ``-inf``.
|
|
17
|
+
Sparse blocks are not considered for the calculation of the softmax, i.e., all values are assumed to be ``-inf``.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
x (Tensor): A block-sparse tensor in compressed form.
|
|
21
|
+
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
22
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
23
|
+
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
Tensor: The result of the softmax operation as a block-sparse tensor in compressed form.
|
|
17
27
|
|
|
18
28
|
"""
|
|
19
29
|
validate_dimensions(x)
|
|
20
30
|
validate_contiguous(x)
|
|
21
|
-
validate_dtype_float(x)
|
|
22
31
|
validate_device(x)
|
|
23
|
-
|
|
24
|
-
|
|
32
|
+
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
33
|
+
validate_sparsity_block_size(sparsity_block_size, x)
|
|
34
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
35
|
+
|
|
36
|
+
if x.size(0) != 0:
|
|
37
|
+
max_val = torch.max(x).item()
|
|
38
|
+
else:
|
|
39
|
+
max_val = 0
|
|
25
40
|
x_scaled = x - max_val
|
|
26
41
|
|
|
27
42
|
sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
|
|
@@ -83,9 +98,7 @@ class _BlocksparseSoftmax(torch.autograd.Function):
|
|
|
83
98
|
triton_block_size))
|
|
84
99
|
|
|
85
100
|
# Save for backward pass
|
|
86
|
-
ctx.save_for_backward(output)
|
|
87
|
-
ctx.sparsity_layout = sparsity_layout
|
|
88
|
-
ctx.sparsity_lut = sparsity_lut
|
|
101
|
+
ctx.save_for_backward(output, sparsity_layout, sparsity_lut)
|
|
89
102
|
ctx.sparsity_block_size = sparsity_block_size
|
|
90
103
|
ctx.triton_block_size = triton_block_size
|
|
91
104
|
|
|
@@ -93,9 +106,7 @@ class _BlocksparseSoftmax(torch.autograd.Function):
|
|
|
93
106
|
|
|
94
107
|
@staticmethod
|
|
95
108
|
def backward(ctx, grad_output):
|
|
96
|
-
o = ctx.saved_tensors
|
|
97
|
-
sparsity_layout = ctx.sparsity_layout
|
|
98
|
-
sparsity_lut = ctx.sparsity_lut
|
|
109
|
+
o, sparsity_layout, sparsity_lut = ctx.saved_tensors
|
|
99
110
|
sparsity_block_size = ctx.sparsity_block_size
|
|
100
111
|
triton_block_size = ctx.triton_block_size
|
|
101
112
|
|
blksprs/ops/transpose.py
CHANGED
|
@@ -1,26 +1,37 @@
|
|
|
1
|
-
from typing import Any
|
|
2
|
-
|
|
3
1
|
import torch
|
|
4
2
|
import triton
|
|
5
|
-
from triton import language as tl
|
|
6
3
|
from torch import Tensor
|
|
4
|
+
from triton import language as tl
|
|
7
5
|
|
|
8
6
|
from blksprs.utils.tools import get_triton_block_size
|
|
9
|
-
from blksprs.utils.validation import validate_dimensions, validate_contiguous,
|
|
7
|
+
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
|
|
8
|
+
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
10
9
|
|
|
11
10
|
|
|
12
11
|
def transpose(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> (
|
|
13
12
|
Tensor, Tensor):
|
|
14
|
-
"""Transposes a
|
|
13
|
+
"""Transposes a block-sparse tensor in compressed form.
|
|
15
14
|
|
|
16
15
|
Note:
|
|
17
16
|
Returns the transposed tensor and the sparsity layout of the transposed tensor.
|
|
18
17
|
|
|
18
|
+
Args:
|
|
19
|
+
x (Tensor): A block-sparse tensor in compressed form.
|
|
20
|
+
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
21
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
22
|
+
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
Tensor: The transposed block-sparse tensor in compressed form.
|
|
26
|
+
Tensor: The sparsity layout of the transposed tensor.
|
|
27
|
+
|
|
19
28
|
"""
|
|
20
29
|
validate_dimensions(x)
|
|
21
30
|
validate_contiguous(x)
|
|
22
|
-
validate_dtype_float(x)
|
|
23
31
|
validate_device(x)
|
|
32
|
+
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
33
|
+
validate_sparsity_block_size(sparsity_block_size, x)
|
|
34
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
24
35
|
|
|
25
36
|
sparsity_layout_t = sparsity_layout.transpose(-1, -2).contiguous()
|
|
26
37
|
|
|
@@ -75,6 +86,7 @@ class _BlocksparseTranspose(torch.autograd.Function):
|
|
|
75
86
|
triton_block_size))
|
|
76
87
|
|
|
77
88
|
# Save for backward pass
|
|
89
|
+
ctx.save_for_backward(sparsity_layout)
|
|
78
90
|
ctx.sparsity_layout = sparsity_layout
|
|
79
91
|
ctx.sparsity_block_size = sparsity_block_size
|
|
80
92
|
ctx.triton_block_size = triton_block_size
|
|
@@ -83,7 +95,7 @@ class _BlocksparseTranspose(torch.autograd.Function):
|
|
|
83
95
|
|
|
84
96
|
@staticmethod
|
|
85
97
|
def backward(ctx, grad_output):
|
|
86
|
-
sparsity_layout = ctx.
|
|
98
|
+
sparsity_layout = ctx.saved_tensors[0]
|
|
87
99
|
sparsity_block_size = ctx.sparsity_block_size
|
|
88
100
|
triton_block_size = ctx.triton_block_size
|
|
89
101
|
|