blksprs 1.6.1__py3-none-any.whl → 1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +11 -6
- blksprs/experimental/distribution_mdi.py +14 -14
- blksprs/layouting/distribution_layout.py +4 -4
- blksprs/layouting/sparsity_layout.py +6 -6
- blksprs/misc/broadcast_ops.py +5 -5
- blksprs/{ops → misc}/exp.py +3 -3
- blksprs/{ops → misc}/partitioning.py +9 -98
- blksprs/misc/row_wise.py +16 -15
- blksprs/ops/conversion.py +23 -12
- blksprs/ops/distribution.py +11 -11
- blksprs/ops/matmul.py +7 -7
- blksprs/ops/repeat.py +322 -0
- blksprs/ops/softmax.py +12 -11
- blksprs/ops/transpose.py +7 -6
- blksprs/utils/tools.py +3 -0
- blksprs/utils/validation.py +20 -1
- {blksprs-1.6.1.dist-info → blksprs-1.8.dist-info}/METADATA +12 -5
- blksprs-1.8.dist-info/RECORD +21 -0
- blksprs/misc/repeat_interleave.py +0 -132
- blksprs-1.6.1.dist-info/RECORD +0 -21
- {blksprs-1.6.1.dist-info → blksprs-1.8.dist-info}/WHEEL +0 -0
- {blksprs-1.6.1.dist-info → blksprs-1.8.dist-info}/top_level.txt +0 -0
blksprs/ops/repeat.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
from triton import language as tl
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
7
|
+
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
|
|
8
|
+
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def repeat(x: Tensor, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
|
|
12
|
+
sparsity_block_size: int, sparsity_layout_output: Tensor = None, triton_block_size: int = None) -> (
|
|
13
|
+
Tensor, Tensor):
|
|
14
|
+
"""Repeats a block-spare tensor in compressed form according to the given repeats.
|
|
15
|
+
|
|
16
|
+
Repeats is a 3-tuple of integers, where each integer represents the number of times the tensor should be repeated in
|
|
17
|
+
the first, second and third dimension respectively.
|
|
18
|
+
|
|
19
|
+
Note:
|
|
20
|
+
An output sparsity layout can be provided, in which case only the indicated blocks are filled. This may result
|
|
21
|
+
in blocks not being present in the output that were present in the input if the output sparsity layout indicates
|
|
22
|
+
them to be sparse.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
x (Tensor): A block-sparse tensor in compressed form.
|
|
26
|
+
sparsity_layout_x (Tensor): The sparsity layout of the block-sparse tensor.
|
|
27
|
+
repeats (tuple[int, int, int]): The number of times the tensor should be repeated in the first, second and
|
|
28
|
+
third dimension respectively.
|
|
29
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
30
|
+
sparsity_layout_output (Tensor): The desired sparsity layout of the output tensor (default ``None``).
|
|
31
|
+
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
Tensor: A block-sparse tensor in compressed form containing the repeated values.
|
|
35
|
+
Tensor: The sparsity layout of the resulting output tensor.
|
|
36
|
+
|
|
37
|
+
"""
|
|
38
|
+
x = x.contiguous()
|
|
39
|
+
|
|
40
|
+
validate_dimensions(x)
|
|
41
|
+
validate_contiguous(x)
|
|
42
|
+
validate_device(x)
|
|
43
|
+
validate_sparsity(sparsity_block_size, (x, sparsity_layout_x))
|
|
44
|
+
validate_sparsity_block_size(sparsity_block_size, x)
|
|
45
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
46
|
+
|
|
47
|
+
sparsity_layout_o = sparsity_layout_x.repeat(repeats[0], repeats[1], repeats[2])
|
|
48
|
+
|
|
49
|
+
if sparsity_layout_output is not None:
|
|
50
|
+
sparsity_layout_o = torch.logical_and(sparsity_layout_o, sparsity_layout_output)
|
|
51
|
+
|
|
52
|
+
sparsity_lut = torch.nonzero(sparsity_layout_o).contiguous()
|
|
53
|
+
|
|
54
|
+
sparsity_layout_flat = sparsity_layout_x.reshape(-1)
|
|
55
|
+
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
56
|
+
(sparsity_layout_flat == 1) -
|
|
57
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
58
|
+
.reshape(sparsity_layout_x.size())
|
|
59
|
+
.repeat(repeats[0], repeats[1], repeats[2])
|
|
60
|
+
.reshape(-1).contiguous())
|
|
61
|
+
|
|
62
|
+
n_sparse_blocks = torch.sum(sparsity_layout_o.to(torch.int)).item()
|
|
63
|
+
|
|
64
|
+
validate_contiguous(sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
|
|
65
|
+
|
|
66
|
+
return _BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
|
|
67
|
+
sparsity_block_size, n_sparse_blocks, triton_block_size), sparsity_layout_o
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def repeat_interleave(x: Tensor, sparsity_layout_x: Tensor, repeats: int,
|
|
71
|
+
sparsity_block_size: int, sparsity_layout_output: Tensor = None,
|
|
72
|
+
triton_block_size: int = None) -> (
|
|
73
|
+
Tensor, Tensor):
|
|
74
|
+
"""Repeats and interleaves the block-sparse tensor in compressed form.
|
|
75
|
+
|
|
76
|
+
Repeats each matrix contained in the tensors by ``repeats`` amount and places them consecutively in the output
|
|
77
|
+
tensor.
|
|
78
|
+
|
|
79
|
+
Note:
|
|
80
|
+
In similar fashion to the regular ``repeat`` an output sparsity layout can be provided. In this case only
|
|
81
|
+
non-sparse blocks will be filled.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
x (Tensor): A block-sparse tensor in compressed form.
|
|
85
|
+
sparsity_layout_x (Tensor): The sparsity layout of the block-sparse tensor.
|
|
86
|
+
repeats (int): The number of times to repeat the matrices.
|
|
87
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
88
|
+
sparsity_layout_output (Tensor): The desired sparsity layout of the output tensor (default ``None``).
|
|
89
|
+
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
Tensor: A block-sparse tensor in compressed form containing the repeated and interleaved matrices.
|
|
93
|
+
Tensor: The sparsity layout of the resulting output tensor.
|
|
94
|
+
|
|
95
|
+
"""
|
|
96
|
+
x = x.contiguous()
|
|
97
|
+
|
|
98
|
+
validate_dimensions(x)
|
|
99
|
+
validate_contiguous(x)
|
|
100
|
+
validate_device(x)
|
|
101
|
+
validate_sparsity(sparsity_block_size, (x, sparsity_layout_x))
|
|
102
|
+
validate_sparsity_block_size(sparsity_block_size, x)
|
|
103
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
104
|
+
|
|
105
|
+
sparsity_layout_o = torch.repeat_interleave(sparsity_layout_x, repeats, dim=0).contiguous()
|
|
106
|
+
|
|
107
|
+
if sparsity_layout_output is not None:
|
|
108
|
+
sparsity_layout_o = torch.logical_and(sparsity_layout_o, sparsity_layout_output)
|
|
109
|
+
|
|
110
|
+
sparsity_lut = torch.nonzero(sparsity_layout_o).contiguous()
|
|
111
|
+
|
|
112
|
+
sparsity_layout_flat = sparsity_layout_x.reshape(-1)
|
|
113
|
+
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
114
|
+
(sparsity_layout_flat == 1) -
|
|
115
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
116
|
+
.reshape(sparsity_layout_x.size())
|
|
117
|
+
.repeat_interleave(repeats, dim=0)
|
|
118
|
+
.reshape(-1).contiguous())
|
|
119
|
+
|
|
120
|
+
n_sparse_blocks = torch.sum(sparsity_layout_o.to(torch.int)).item()
|
|
121
|
+
|
|
122
|
+
validate_contiguous(sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
|
|
123
|
+
|
|
124
|
+
return _BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
|
|
125
|
+
sparsity_block_size, n_sparse_blocks, triton_block_size), sparsity_layout_o
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class _BlocksparseRepeat(torch.autograd.Function):
|
|
129
|
+
|
|
130
|
+
@staticmethod
|
|
131
|
+
def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor,
|
|
132
|
+
sparsity_reverse_lut: Tensor,
|
|
133
|
+
sparsity_block_size: int, n_sparse_blocks: int,
|
|
134
|
+
triton_block_size: int) -> Tensor:
|
|
135
|
+
ctx.save_for_backward(sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
|
|
136
|
+
ctx.x_size = x.size()
|
|
137
|
+
ctx.x_stride = stride(x)
|
|
138
|
+
|
|
139
|
+
return forward_flow(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
140
|
+
n_sparse_blocks, triton_block_size)
|
|
141
|
+
|
|
142
|
+
@staticmethod
|
|
143
|
+
def backward(ctx, grad_output):
|
|
144
|
+
sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut = ctx.saved_tensors
|
|
145
|
+
x_size = ctx.x_size
|
|
146
|
+
x_stride = ctx.x_stride
|
|
147
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
148
|
+
triton_block_size = ctx.triton_block_size
|
|
149
|
+
|
|
150
|
+
n_sparse_blocks = torch.sum(sparsity_layout_x.to(torch.int)).item()
|
|
151
|
+
|
|
152
|
+
output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
153
|
+
dtype=grad_output.dtype, device=grad_output.device)
|
|
154
|
+
|
|
155
|
+
x_b, x_r, x_c = grad_output.size()
|
|
156
|
+
x_b_s, x_r_s, x_c_s = stride(grad_output)
|
|
157
|
+
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_o.size()
|
|
158
|
+
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_o)
|
|
159
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
160
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
161
|
+
o_b, o_r, o_c = x_size
|
|
162
|
+
o_b_s, o_r_s, o_c_s = x_stride
|
|
163
|
+
|
|
164
|
+
if triton_block_size is None:
|
|
165
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
166
|
+
|
|
167
|
+
triton_grid = lambda meta: [x_b,
|
|
168
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
169
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
170
|
+
|
|
171
|
+
(kernel_blocksparse_flow_push[triton_grid]
|
|
172
|
+
(grad_output,
|
|
173
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
174
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
175
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
176
|
+
sparsity_reverse_lut,
|
|
177
|
+
output,
|
|
178
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
179
|
+
triton_block_size))
|
|
180
|
+
|
|
181
|
+
return output, None, None, None, None, None, None, None
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
@triton.jit
|
|
185
|
+
def kernel_blocksparse_flow_pull(x,
|
|
186
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
187
|
+
o,
|
|
188
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
189
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
190
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
191
|
+
r_lut,
|
|
192
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
193
|
+
# Get triton block indices
|
|
194
|
+
pid_blk = tl.program_id(axis=0)
|
|
195
|
+
pid_row = tl.program_id(axis=1)
|
|
196
|
+
pid_col = tl.program_id(axis=2)
|
|
197
|
+
|
|
198
|
+
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
199
|
+
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
200
|
+
spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
201
|
+
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
202
|
+
|
|
203
|
+
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
204
|
+
spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
|
|
205
|
+
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
206
|
+
|
|
207
|
+
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
208
|
+
spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
|
|
209
|
+
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
210
|
+
|
|
211
|
+
# Get reverse sparsity index
|
|
212
|
+
rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
|
|
213
|
+
spa_row * s_l_o_r_s +
|
|
214
|
+
spa_col * s_l_o_c_s)
|
|
215
|
+
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
|
|
216
|
+
rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
217
|
+
|
|
218
|
+
if rev_idx_spa == -1:
|
|
219
|
+
tl.device_assert(False)
|
|
220
|
+
return
|
|
221
|
+
|
|
222
|
+
blk_x_idx = (rev_idx_spa * x_b_s +
|
|
223
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
224
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
225
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
226
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
227
|
+
|
|
228
|
+
blk_o_idx = (pid_blk * o_b_s +
|
|
229
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
230
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
231
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
232
|
+
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
@triton.jit
|
|
236
|
+
def kernel_blocksparse_flow_push(x,
|
|
237
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
238
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
239
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
240
|
+
r_lut,
|
|
241
|
+
o,
|
|
242
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
243
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
244
|
+
# Get triton block indices
|
|
245
|
+
pid_blk = tl.program_id(axis=0)
|
|
246
|
+
pid_row = tl.program_id(axis=1)
|
|
247
|
+
pid_col = tl.program_id(axis=2)
|
|
248
|
+
|
|
249
|
+
# Get sparsity index of current input block consisting of its batch, row, and column index
|
|
250
|
+
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
251
|
+
spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
252
|
+
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
253
|
+
|
|
254
|
+
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
255
|
+
spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
|
|
256
|
+
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
257
|
+
|
|
258
|
+
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
259
|
+
spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
|
|
260
|
+
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
261
|
+
|
|
262
|
+
# Get reverse sparsity index
|
|
263
|
+
rev_idx_spa_idx = (spa_bat * s_l_x_b_s +
|
|
264
|
+
spa_row * s_l_x_r_s +
|
|
265
|
+
spa_col * s_l_x_c_s)
|
|
266
|
+
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_x_b * s_l_x_b_s)
|
|
267
|
+
rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
268
|
+
|
|
269
|
+
if rev_idx_spa == -1:
|
|
270
|
+
tl.device_assert(False)
|
|
271
|
+
return
|
|
272
|
+
|
|
273
|
+
blk_x_idx = (pid_blk * x_b_s +
|
|
274
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
275
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
276
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
277
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
278
|
+
|
|
279
|
+
blk_o_idx = (rev_idx_spa * o_b_s +
|
|
280
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
281
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
282
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
283
|
+
tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def forward_flow(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
287
|
+
sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
288
|
+
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
289
|
+
dtype=x.dtype, device=x.device)
|
|
290
|
+
output = torch.zeros_like(output)
|
|
291
|
+
|
|
292
|
+
x_b, x_r, x_c = x.size()
|
|
293
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
294
|
+
o_b, o_r, o_c = output.size()
|
|
295
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
296
|
+
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
|
|
297
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
|
|
298
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
299
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
300
|
+
|
|
301
|
+
if triton_block_size is None:
|
|
302
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
303
|
+
|
|
304
|
+
triton_grid = lambda meta: [o_b,
|
|
305
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
306
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
307
|
+
|
|
308
|
+
(kernel_blocksparse_flow_pull[triton_grid]
|
|
309
|
+
(x,
|
|
310
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
311
|
+
output,
|
|
312
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
313
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
314
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
315
|
+
sparsity_reverse_lut,
|
|
316
|
+
triton_block_size))
|
|
317
|
+
|
|
318
|
+
# Save for backward pass
|
|
319
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
320
|
+
ctx.triton_block_size = triton_block_size
|
|
321
|
+
|
|
322
|
+
return output
|
blksprs/ops/softmax.py
CHANGED
|
@@ -3,9 +3,9 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
-
from blksprs.
|
|
6
|
+
from blksprs.misc.exp import exp
|
|
7
7
|
from blksprs.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
|
|
8
|
-
from blksprs.utils.tools import get_triton_block_size
|
|
8
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
9
9
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
10
10
|
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
11
11
|
|
|
@@ -61,9 +61,9 @@ class _BlocksparseSoftmax(torch.autograd.Function):
|
|
|
61
61
|
output = torch.empty_like(x)
|
|
62
62
|
|
|
63
63
|
x_b, x_r, x_c = x.size()
|
|
64
|
-
x_b_s, x_r_s, x_c_s =
|
|
64
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
65
65
|
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
66
|
-
s_lut_r_s, s_lut_c_s =
|
|
66
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
67
67
|
o_b, o_r, o_c = output.size()
|
|
68
68
|
|
|
69
69
|
x_row_wise_max, sparsity_layout_rwm = row_wise_max(x, sparsity_layout, sparsity_block_size,
|
|
@@ -76,9 +76,9 @@ class _BlocksparseSoftmax(torch.autograd.Function):
|
|
|
76
76
|
triton_block_size=triton_block_size)
|
|
77
77
|
|
|
78
78
|
s_b, s_r, s_c = x_exp_row_wise_sum.shape
|
|
79
|
-
s_b_s, s_r_s, s_c_s =
|
|
79
|
+
s_b_s, s_r_s, s_c_s = stride(x_exp_row_wise_sum)
|
|
80
80
|
s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_rws.shape
|
|
81
|
-
s_l_s_b_s, s_l_s_r_s, s_l_s_c_s =
|
|
81
|
+
s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_rws)
|
|
82
82
|
|
|
83
83
|
if triton_block_size is None:
|
|
84
84
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -119,13 +119,13 @@ class _BlocksparseSoftmax(torch.autograd.Function):
|
|
|
119
119
|
(1 * (sparsity_layout_s_flat == 0)))
|
|
120
120
|
|
|
121
121
|
o_b, o_r, o_c = o.size()
|
|
122
|
-
o_b_s, o_r_s, o_c_s =
|
|
122
|
+
o_b_s, o_r_s, o_c_s = stride(o)
|
|
123
123
|
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
124
|
-
s_lut_r_s, s_lut_c_s =
|
|
124
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
125
125
|
s_b, s_r, s_c = s.size()
|
|
126
|
-
s_b_s, s_r_s, s_c_s =
|
|
126
|
+
s_b_s, s_r_s, s_c_s = stride(s)
|
|
127
127
|
s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
|
|
128
|
-
s_l_s_b_s, s_l_s_r_s, s_l_s_c_s =
|
|
128
|
+
s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_s)
|
|
129
129
|
|
|
130
130
|
grad_x = torch.empty_like(o, dtype=torch.float)
|
|
131
131
|
|
|
@@ -181,7 +181,8 @@ class _BlocksparseSoftmax(torch.autograd.Function):
|
|
|
181
181
|
rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
|
|
182
182
|
|
|
183
183
|
if rev_idx_spa_s == -1:
|
|
184
|
-
|
|
184
|
+
tl.device_assert(False)
|
|
185
|
+
return
|
|
185
186
|
|
|
186
187
|
# Load x block
|
|
187
188
|
blk_x_idx = ((pid_blk * x_b_s) +
|
blksprs/ops/transpose.py
CHANGED
|
@@ -3,7 +3,7 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
-
from blksprs.utils.tools import get_triton_block_size
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
7
7
|
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
|
|
8
8
|
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
9
9
|
|
|
@@ -63,13 +63,13 @@ class _BlocksparseTranspose(torch.autograd.Function):
|
|
|
63
63
|
dtype=x.dtype, device=x.device)
|
|
64
64
|
|
|
65
65
|
x_b, x_r, x_c = x.size()
|
|
66
|
-
x_b_s, x_r_s, x_c_s =
|
|
66
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
67
67
|
s_l_b, s_l_r, s_l_c = sparsity_layout_o.size()
|
|
68
|
-
s_l_b_s, s_l_r_s, s_l_c_s =
|
|
68
|
+
s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout_o)
|
|
69
69
|
s_lut_r, s_lut_c = sparsity_lut.shape
|
|
70
|
-
s_lut_r_s, s_lut_c_s =
|
|
70
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
71
71
|
o_b, o_r, o_c = output.size()
|
|
72
|
-
o_b_s, o_r_s, o_c_s =
|
|
72
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
73
73
|
|
|
74
74
|
if triton_block_size is None:
|
|
75
75
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -140,7 +140,8 @@ class _BlocksparseTranspose(torch.autograd.Function):
|
|
|
140
140
|
rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
141
141
|
|
|
142
142
|
if rev_idx_spa == -1:
|
|
143
|
-
|
|
143
|
+
tl.device_assert(False)
|
|
144
|
+
return
|
|
144
145
|
|
|
145
146
|
blk_x_idx = (rev_idx_spa * x_b_s +
|
|
146
147
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
blksprs/utils/tools.py
CHANGED
blksprs/utils/validation.py
CHANGED
|
@@ -3,6 +3,7 @@ from torch import Tensor
|
|
|
3
3
|
|
|
4
4
|
VALIDATION = True
|
|
5
5
|
|
|
6
|
+
|
|
6
7
|
def validate_dimensions(*tensors: Tensor, dims=3) -> None:
|
|
7
8
|
if _check_skip_validation():
|
|
8
9
|
return
|
|
@@ -71,10 +72,25 @@ def validate_sparsity(sparsity_block_size: int, *tensor_sparsity_layout_tuples:
|
|
|
71
72
|
raise ValueError("Mismatch between sparsity layout and blocks")
|
|
72
73
|
|
|
73
74
|
|
|
75
|
+
def validate_sparsity_dense(sparsity_block_size: int, *tensor_sparsity_layout_tuples: tuple[Tensor, Tensor]) -> None:
|
|
76
|
+
if _check_skip_validation():
|
|
77
|
+
return
|
|
78
|
+
|
|
79
|
+
for (tensor, sparsity_layout) in tensor_sparsity_layout_tuples:
|
|
80
|
+
_validate_sparsity_layout_values(sparsity_layout)
|
|
81
|
+
|
|
82
|
+
if not sparsity_layout.dim() == 3:
|
|
83
|
+
raise ValueError("Sparsity layout must have exactly 3 dimensions")
|
|
84
|
+
if not (tensor.size(-1) // sparsity_block_size == sparsity_layout.size(-1) and
|
|
85
|
+
tensor.size(-2) // sparsity_block_size == sparsity_layout.size(-2)):
|
|
86
|
+
raise ValueError("Tensor not conforming to sparsity layout")
|
|
87
|
+
|
|
88
|
+
|
|
74
89
|
def _validate_sparsity_layout_values(sparsity_layout: Tensor):
|
|
75
90
|
if not torch.all(torch.logical_or(sparsity_layout == 0, sparsity_layout == 1)):
|
|
76
91
|
raise ValueError("Sparsity layout values must be either 0 or 1")
|
|
77
92
|
|
|
93
|
+
|
|
78
94
|
def validate_sparsity_block_size(sparsity_block_size: int, *tensors):
|
|
79
95
|
if _check_skip_validation():
|
|
80
96
|
return
|
|
@@ -86,6 +102,7 @@ def validate_sparsity_block_size(sparsity_block_size: int, *tensors):
|
|
|
86
102
|
if not (tensor.size(-1) % sparsity_block_size == 0 and tensor.size(-2) % sparsity_block_size == 0):
|
|
87
103
|
raise ValueError("Tensor sizes must be divisible by sparsity block size")
|
|
88
104
|
|
|
105
|
+
|
|
89
106
|
def validate_triton_block_size(triton_block_size: int, sparsity_block_size: int):
|
|
90
107
|
if _check_skip_validation():
|
|
91
108
|
return
|
|
@@ -99,9 +116,11 @@ def validate_triton_block_size(triton_block_size: int, sparsity_block_size: int)
|
|
|
99
116
|
if triton_block_size > sparsity_block_size:
|
|
100
117
|
raise ValueError("Triton block size cannot be larger than sparsity block size")
|
|
101
118
|
|
|
119
|
+
|
|
102
120
|
def _check_skip_validation():
|
|
103
121
|
return not VALIDATION
|
|
104
122
|
|
|
123
|
+
|
|
105
124
|
def _set_skip_validation(skip_validation: bool):
|
|
106
125
|
global VALIDATION
|
|
107
|
-
VALIDATION = not skip_validation
|
|
126
|
+
VALIDATION = not skip_validation
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.8
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -28,12 +28,13 @@ A lightweight and efficient library for operations on block-sparse matrices in P
|
|
|
28
28
|
|
|
29
29
|
Currently supported operations (includes gradient calculation):
|
|
30
30
|
|
|
31
|
-
-
|
|
32
|
-
for `sparse = sparse @ sparse` matmul_)
|
|
31
|
+
- Matrix multiplication
|
|
33
32
|
- Softmax
|
|
34
33
|
- Transpose
|
|
35
34
|
- Gather
|
|
36
35
|
- Scatter (_supports either no reduction or summation, gradients are only available for summation_)
|
|
36
|
+
- Repeat (_supports target sparsity layout_)
|
|
37
|
+
- Repeat Interleave (_supports target sparsity layout_)
|
|
37
38
|
- Splitting and merging of matrices along the last dimension
|
|
38
39
|
- Conversion to and from sparse form
|
|
39
40
|
- Conversion to different sparsity layouts and different sparsity block sizes
|
|
@@ -50,8 +51,14 @@ These include, e.g.,
|
|
|
50
51
|
Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to
|
|
51
52
|
match.
|
|
52
53
|
|
|
54
|
+
Further helpful operations (included in the ``bs.misc`` module) that do **not** support gradient calculation include:
|
|
55
|
+
|
|
56
|
+
- Row-wise sum, max, addition, and subtraction
|
|
57
|
+
- Broadcast addition and subtraction between slices
|
|
58
|
+
|
|
53
59
|
Furthermore, the library provides a set of utility functions for the creation of sparsity layouts based on existing
|
|
54
|
-
dense tensors.
|
|
60
|
+
dense tensors and for the scatter operation (module ``bs.layout``), as well as utility functions to ensure correct input
|
|
61
|
+
dimensionality (module ``bs.util``).
|
|
55
62
|
|
|
56
63
|
## Installation
|
|
57
64
|
|
|
@@ -64,7 +71,7 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
|
|
|
64
71
|
|
|
65
72
|
### Dependencies
|
|
66
73
|
|
|
67
|
-
- [PyTorch](https://pytorch.org/) (built with v2.
|
|
74
|
+
- [PyTorch](https://pytorch.org/) (built with v2.5.0)
|
|
68
75
|
- _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
|
|
69
76
|
|
|
70
77
|
## Changelog
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
blksprs/__init__.py,sha256=qDqoB-X5vo5_3PlrN54sp59XR5hg6EanIsADS67QnH0,1058
|
|
2
|
+
blksprs/experimental/distribution_mdi.py,sha256=jE_SbB0SMGPcCoeM0699cceWAikBkBpGM_8Oo5A7Ets,20083
|
|
3
|
+
blksprs/layouting/distribution_layout.py,sha256=nCR3fCF6eNpi97DI6MMLF2hq_9Lwyo6_aUCIZiJfiX4,4170
|
|
4
|
+
blksprs/layouting/sparsity_layout.py,sha256=JNwbJ4L-418zCiCbt-vLfZ0xU7ReP0zr_tLHs_dytRA,9686
|
|
5
|
+
blksprs/misc/broadcast_ops.py,sha256=0RLnLMYV7GAPI2YL8RotcxjIUSBZKGxdVcsGaJFeL_I,5327
|
|
6
|
+
blksprs/misc/exp.py,sha256=cdF0s93Q9iucIXuEE3howsB0N6D60xgvem7C-a-yiGI,3704
|
|
7
|
+
blksprs/misc/partitioning.py,sha256=nBRZzfi3XYAhDLEBzYflQkvGa3MIZ-qNeIlrZ16k44g,7533
|
|
8
|
+
blksprs/misc/row_wise.py,sha256=0vDJA8uCocmebSIPIbFeND5_PQIE10pUj3DBOQXlTvE,16888
|
|
9
|
+
blksprs/ops/conversion.py,sha256=9xVdCrj38m1cMh43LQs-GrXZ5pNRjhQyKx6paaw3C6A,21898
|
|
10
|
+
blksprs/ops/distribution.py,sha256=V3TK5SlNT_JdGHNaDNl-U4U5vwAYsgkAOg4eTmYxbuA,16877
|
|
11
|
+
blksprs/ops/matmul.py,sha256=uqVe6Dz2aaCbCglM1uS2eRHVKh7PQcuecaIBWFubPEw,11256
|
|
12
|
+
blksprs/ops/repeat.py,sha256=OSsa2rj6BHL3Kedfu3wr0D82mn4HmbJ1l7XEmT-6ehg,14423
|
|
13
|
+
blksprs/ops/softmax.py,sha256=5nAgeT68nucgOugjtCy1aBIMa7Kyk1KNN-j8fgmeVuk,11996
|
|
14
|
+
blksprs/ops/transpose.py,sha256=67pDdCEb7r-Xifupl82fBKAYsxKcCUDy--cPPfduRvU,6761
|
|
15
|
+
blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
|
|
16
|
+
blksprs/utils/tools.py,sha256=S3836Zuc-BMigv-5mLTjRznCzuaF6oYW-Ir9zzUnr3o,655
|
|
17
|
+
blksprs/utils/validation.py,sha256=WzihRPibXYzss3PMkhDt5_d3Q3NHA_d1TzTz3CoGPGg,4136
|
|
18
|
+
blksprs-1.8.dist-info/METADATA,sha256=koey4w8ynY84Z0dM5u9y_P831rtR0w-Z-dBcje4O6ko,8007
|
|
19
|
+
blksprs-1.8.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
|
|
20
|
+
blksprs-1.8.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
|
|
21
|
+
blksprs-1.8.dist-info/RECORD,,
|