blksprs 1.8.3__tar.gz → 1.9.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {blksprs-1.8.3 → blksprs-1.9.1}/PKG-INFO +7 -3
- {blksprs-1.8.3 → blksprs-1.9.1}/README.md +6 -2
- {blksprs-1.8.3 → blksprs-1.9.1}/blksprs/__init__.py +4 -1
- {blksprs-1.8.3 → blksprs-1.9.1}/blksprs/layouting/distribution_layout.py +23 -5
- {blksprs-1.8.3 → blksprs-1.9.1}/blksprs/ops/conversion.py +26 -34
- {blksprs-1.8.3 → blksprs-1.9.1}/blksprs/ops/distribution.py +94 -35
- {blksprs-1.8.3 → blksprs-1.9.1}/blksprs/ops/experimental/distribution_mdi.py +8 -0
- blksprs-1.9.1/blksprs/ops/flow.py +147 -0
- {blksprs-1.8.3 → blksprs-1.9.1}/blksprs/ops/misc/row_wise.py +8 -0
- {blksprs-1.8.3 → blksprs-1.9.1}/blksprs/ops/partitioning.py +3 -3
- {blksprs-1.8.3 → blksprs-1.9.1}/blksprs/ops/repeat.py +8 -147
- {blksprs-1.8.3 → blksprs-1.9.1}/blksprs/ops/softmax.py +4 -0
- blksprs-1.9.1/blksprs/utils/layout_utils.py +17 -0
- {blksprs-1.8.3 → blksprs-1.9.1}/blksprs/utils/processing.py +35 -2
- {blksprs-1.8.3 → blksprs-1.9.1}/blksprs/utils/validation.py +2 -1
- {blksprs-1.8.3 → blksprs-1.9.1}/blksprs.egg-info/PKG-INFO +7 -3
- {blksprs-1.8.3 → blksprs-1.9.1}/blksprs.egg-info/SOURCES.txt +2 -0
- {blksprs-1.8.3 → blksprs-1.9.1}/pyproject.toml +1 -1
- {blksprs-1.8.3 → blksprs-1.9.1}/blksprs/layouting/sparsity_layout.py +0 -0
- {blksprs-1.8.3 → blksprs-1.9.1}/blksprs/ops/matmul.py +0 -0
- {blksprs-1.8.3 → blksprs-1.9.1}/blksprs/ops/misc/broadcast_ops.py +0 -0
- {blksprs-1.8.3 → blksprs-1.9.1}/blksprs/ops/misc/exp.py +0 -0
- {blksprs-1.8.3 → blksprs-1.9.1}/blksprs/ops/transpose.py +0 -0
- {blksprs-1.8.3 → blksprs-1.9.1}/blksprs/utils/benchmarking.py +0 -0
- {blksprs-1.8.3 → blksprs-1.9.1}/blksprs/utils/blksprs_tensor.py +0 -0
- {blksprs-1.8.3 → blksprs-1.9.1}/blksprs/utils/tools.py +0 -0
- {blksprs-1.8.3 → blksprs-1.9.1}/blksprs.egg-info/dependency_links.txt +0 -0
- {blksprs-1.8.3 → blksprs-1.9.1}/blksprs.egg-info/requires.txt +0 -0
- {blksprs-1.8.3 → blksprs-1.9.1}/blksprs.egg-info/top_level.txt +0 -0
- {blksprs-1.8.3 → blksprs-1.9.1}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.9.1
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -64,8 +64,12 @@ Further helpful operations (included in the ``bs.ops.misc`` module) that do **no
|
|
|
64
64
|
- Row-wise sum, max, addition, and subtraction
|
|
65
65
|
- Broadcast addition and subtraction between slices
|
|
66
66
|
|
|
67
|
-
Furthermore, the library provides a set of utility functions
|
|
68
|
-
|
|
67
|
+
Furthermore, the library provides a set of utility functions
|
|
68
|
+
|
|
69
|
+
- for the creation of sparsity layouts based on existing
|
|
70
|
+
dense tensors and for the scatter operation (module ``bs.layouting``),
|
|
71
|
+
- for the application of ``nn.Linear``, ``nn.Dropout``, and ``nn.LayerNorm`` layers to block-sparse tensors,
|
|
72
|
+
- as well as utility functions to apply linear layers,
|
|
69
73
|
ensure correct input dimensionality, and validate input (module ``bs.utils``).
|
|
70
74
|
|
|
71
75
|
## Installation
|
|
@@ -45,8 +45,12 @@ Further helpful operations (included in the ``bs.ops.misc`` module) that do **no
|
|
|
45
45
|
- Row-wise sum, max, addition, and subtraction
|
|
46
46
|
- Broadcast addition and subtraction between slices
|
|
47
47
|
|
|
48
|
-
Furthermore, the library provides a set of utility functions
|
|
49
|
-
|
|
48
|
+
Furthermore, the library provides a set of utility functions
|
|
49
|
+
|
|
50
|
+
- for the creation of sparsity layouts based on existing
|
|
51
|
+
dense tensors and for the scatter operation (module ``bs.layouting``),
|
|
52
|
+
- for the application of ``nn.Linear``, ``nn.Dropout``, and ``nn.LayerNorm`` layers to block-sparse tensors,
|
|
53
|
+
- as well as utility functions to apply linear layers,
|
|
50
54
|
ensure correct input dimensionality, and validate input (module ``bs.utils``).
|
|
51
55
|
|
|
52
56
|
## Installation
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
2
2
|
|
|
3
|
+
|
|
3
4
|
class ops:
|
|
4
5
|
from blksprs.ops.conversion import to_dense, to_sparse, from_blksprs, to_blksprs, adapt_layout
|
|
5
6
|
from blksprs.ops.distribution import gather, scatter, scatter_reduce
|
|
@@ -22,13 +23,15 @@ class layouting:
|
|
|
22
23
|
from blksprs.layouting.distribution_layout import build_distribution_layout
|
|
23
24
|
from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption, \
|
|
24
25
|
build_sparsity_layout_matmul, build_sparsity_layout_matmul_fast
|
|
26
|
+
from blksprs.utils.layout_utils import build_full_sparsity_layout
|
|
25
27
|
|
|
26
28
|
class experimental:
|
|
27
29
|
from blksprs.ops.experimental.distribution_mdi import build_distribution_layout_mdi
|
|
28
30
|
|
|
29
31
|
|
|
30
32
|
class utils:
|
|
31
|
-
from blksprs.utils.processing import apply_torch_linear
|
|
33
|
+
from blksprs.utils.processing import apply_torch_linear, apply_torch_normalisation, apply_torch_dropout, \
|
|
34
|
+
apply_function_applicable_row_wise
|
|
32
35
|
from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse
|
|
33
36
|
from blksprs.utils.validation import disable_validation
|
|
34
37
|
|
|
@@ -10,13 +10,14 @@ from blksprs.utils.validation import validate_triton_block_size, validate_dimens
|
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: Tensor,
|
|
13
|
-
size_target: torch.Size,
|
|
13
|
+
dim: int, size_target: torch.Size,
|
|
14
14
|
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
15
15
|
"""Builds the sparsity layout of either the source of a gather or the target of a scatter operation.
|
|
16
16
|
|
|
17
17
|
Args:
|
|
18
18
|
indices (BlksprsTensor): The block-sparse indices tensor in compressed form used for the gather or scatter operation.
|
|
19
19
|
sparsity_layout_indices (Tensor): The sparsity layout of the indices block-sparse tensor.
|
|
20
|
+
dim (int): The dimension along which the operation is conducted.
|
|
20
21
|
size_target (torch.Size): The size of the block-sparse target tensor in regular form.
|
|
21
22
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
22
23
|
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
@@ -31,6 +32,8 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
|
|
|
31
32
|
|
|
32
33
|
sparsity_lut_i = torch.nonzero(sparsity_layout_indices).contiguous()
|
|
33
34
|
|
|
35
|
+
adjusted_dim = dim % 3
|
|
36
|
+
|
|
34
37
|
output = torch.zeros(size_target[0], size_target[1] // sparsity_block_size, size_target[2] // sparsity_block_size,
|
|
35
38
|
dtype=torch.bool, device=indices.device)
|
|
36
39
|
|
|
@@ -55,6 +58,7 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
|
|
|
55
58
|
i_b, i_b_s, i_r_s, i_c_s,
|
|
56
59
|
sparsity_lut_i,
|
|
57
60
|
s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
|
|
61
|
+
adjusted_dim,
|
|
58
62
|
output,
|
|
59
63
|
o_b, o_b_s, o_r_s, o_c_s,
|
|
60
64
|
sparsity_block_size,
|
|
@@ -68,6 +72,7 @@ def kernel_distribution_layout(i,
|
|
|
68
72
|
i_b, i_b_s, i_r_s, i_c_s,
|
|
69
73
|
s_lut_i,
|
|
70
74
|
s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
|
|
75
|
+
dim,
|
|
71
76
|
o,
|
|
72
77
|
o_b, o_b_s, o_r_s, o_c_s,
|
|
73
78
|
sparsity_block_size,
|
|
@@ -86,17 +91,30 @@ def kernel_distribution_layout(i,
|
|
|
86
91
|
spa_row_i_msk = (spa_row_i_idx < s_lut_i_r * s_lut_i_r_s)
|
|
87
92
|
spa_row_i = tl.load(s_lut_i + spa_row_i_idx, mask=spa_row_i_msk)
|
|
88
93
|
|
|
94
|
+
spa_col_i_idx = (pid_blk * s_lut_i_r_s + 2 * s_lut_i_c_s)
|
|
95
|
+
spa_col_i_msk = (spa_col_i_idx < s_lut_i_r * s_lut_i_r_s)
|
|
96
|
+
spa_col_i = tl.load(s_lut_i + spa_col_i_idx, mask=spa_col_i_msk)
|
|
97
|
+
|
|
89
98
|
blk_i_idx = (pid_blk * i_b_s +
|
|
90
99
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
91
100
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
92
101
|
blk_i_msk = (blk_i_idx < i_b * i_b_s)
|
|
93
102
|
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk)
|
|
94
103
|
|
|
95
|
-
|
|
104
|
+
dst_bat_idx = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_i, dtype=tl.int32)
|
|
105
|
+
dst_row_idx = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_i, dtype=tl.int32)
|
|
106
|
+
dst_col_idx = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_i, dtype=tl.int32)
|
|
107
|
+
if dim == 0:
|
|
108
|
+
dst_bat_idx = blk_i
|
|
109
|
+
elif dim == 1:
|
|
110
|
+
dst_row_idx = blk_i // sparsity_block_size
|
|
111
|
+
elif dim == 2:
|
|
112
|
+
dst_col_idx = blk_i // sparsity_block_size
|
|
113
|
+
|
|
96
114
|
blk_v = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), 1, dtype=tl.int32)
|
|
97
115
|
|
|
98
|
-
blk_o_idx = ((
|
|
99
|
-
(
|
|
100
|
-
(
|
|
116
|
+
blk_o_idx = ((dst_bat_idx * o_b_s) +
|
|
117
|
+
(dst_row_idx * o_r_s) +
|
|
118
|
+
(dst_col_idx * o_c_s))
|
|
101
119
|
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
102
120
|
tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
|
|
@@ -289,8 +289,8 @@ class _BlocksparseToSparse(torch.autograd.Function):
|
|
|
289
289
|
|
|
290
290
|
|
|
291
291
|
def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_size_from: int,
|
|
292
|
-
sparsity_block_size_to: int,
|
|
293
|
-
|
|
292
|
+
sparsity_block_size_to: int, sparsity_layout_to: Tensor = None,
|
|
293
|
+
triton_block_size: int = None) -> (BlksprsTensor, Tensor):
|
|
294
294
|
"""Adapts the sparsity layout of a block-sparse tensor, resulting in a new block-sparse tensor in compressed form
|
|
295
295
|
conforming to the new sparsity layout (and sparsity block size) definition.
|
|
296
296
|
|
|
@@ -299,11 +299,12 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
|
|
|
299
299
|
sparsity_layout_from (Tensor): The sparsity layout of the input block-sparse tensor.
|
|
300
300
|
sparsity_block_size_from (int): The size of the sparsity blocks of the input sparsity layout.
|
|
301
301
|
sparsity_block_size_to (int): The size of the sparsity blocks of the output sparsity layout.
|
|
302
|
-
|
|
302
|
+
sparsity_layout_to (Tensor): The sparsity layout of the output block-sparse tensor (default ``None``).
|
|
303
303
|
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
304
304
|
|
|
305
305
|
Returns:
|
|
306
306
|
BlksprsTensor: The block-sparse tensor in compressed form with the adapted sparsity layout and sparsity block size.
|
|
307
|
+
Tensor: The sparsity layout of the resulting output tensor.
|
|
307
308
|
|
|
308
309
|
"""
|
|
309
310
|
x = x.contiguous()
|
|
@@ -317,52 +318,42 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
|
|
|
317
318
|
min_sparsity_block_size = min(sparsity_block_size_from, sparsity_block_size_to)
|
|
318
319
|
validate_triton_block_size(triton_block_size, min_sparsity_block_size)
|
|
319
320
|
|
|
320
|
-
|
|
321
|
-
|
|
321
|
+
sparsity_layout_from_flat = sparsity_layout_from.reshape(-1)
|
|
322
|
+
sparsity_reverse_lut_from = ((torch.cumsum(sparsity_layout_from_flat, dim=-1) - 1) *
|
|
323
|
+
(sparsity_layout_from_flat == 1) -
|
|
324
|
+
(1 * (sparsity_layout_from_flat == 0)))
|
|
322
325
|
|
|
323
|
-
if
|
|
324
|
-
sparsity_layout_from_flat = sparsity_layout_from.reshape(-1)
|
|
325
|
-
sparsity_reverse_lut_from = ((torch.cumsum(sparsity_layout_from_flat, dim=-1) - 1) *
|
|
326
|
-
(sparsity_layout_from_flat == 1) -
|
|
327
|
-
(1 * (sparsity_layout_from_flat == 0)))
|
|
328
|
-
else:
|
|
329
|
-
sparsity_reverse_lut_from = preprocess_data["sparsity_reverse_lut_from"]
|
|
330
|
-
|
|
331
|
-
if "sparsity_layout_to" not in preprocess_data:
|
|
326
|
+
if sparsity_layout_to is None:
|
|
332
327
|
sparsity_layout_to = build_sparsity_layout_adaption(x, sparsity_layout_from,
|
|
333
328
|
sparsity_block_size_from, sparsity_block_size_to,
|
|
334
329
|
triton_block_size)
|
|
335
|
-
else:
|
|
336
|
-
sparsity_layout_to = preprocess_data["sparsity_layout_to"]
|
|
337
330
|
|
|
338
|
-
|
|
339
|
-
sparsity_lut_to = torch.nonzero(sparsity_layout_to).contiguous()
|
|
340
|
-
else:
|
|
341
|
-
sparsity_lut_to = preprocess_data["sparsity_lut_to"]
|
|
331
|
+
sparsity_lut_to = torch.nonzero(sparsity_layout_to).contiguous()
|
|
342
332
|
|
|
343
|
-
|
|
344
|
-
n_sparse_blocks_to = torch.sum(sparsity_layout_to.to(torch.int)).item()
|
|
345
|
-
else:
|
|
346
|
-
n_sparse_blocks_to = preprocess_data["n_sparse_blocks_to"]
|
|
333
|
+
n_sparse_blocks_to = torch.sum(sparsity_layout_to.to(torch.int)).item()
|
|
347
334
|
|
|
348
|
-
validate_contiguous(
|
|
335
|
+
validate_contiguous(sparsity_reverse_lut_from, sparsity_layout_to, sparsity_lut_to)
|
|
349
336
|
|
|
350
337
|
if (sparsity_block_size_from == sparsity_block_size_to) and torch.equal(sparsity_layout_from, sparsity_layout_to):
|
|
351
|
-
return BlksprsTensor(x)
|
|
338
|
+
return BlksprsTensor(x), sparsity_layout_to
|
|
352
339
|
|
|
353
340
|
return BlksprsTensor(_BlocksparseAdaptLayout.apply(x,
|
|
354
341
|
sparsity_layout_from, sparsity_reverse_lut_from,
|
|
355
342
|
sparsity_block_size_from,
|
|
356
|
-
sparsity_layout_to, sparsity_lut_to,
|
|
357
|
-
|
|
343
|
+
sparsity_layout_to, sparsity_lut_to,
|
|
344
|
+
sparsity_block_size_to,
|
|
345
|
+
n_sparse_blocks_to, min_sparsity_block_size,
|
|
346
|
+
triton_block_size)), sparsity_layout_to
|
|
358
347
|
|
|
359
348
|
|
|
360
349
|
class _BlocksparseAdaptLayout(torch.autograd.Function):
|
|
361
350
|
|
|
362
351
|
@staticmethod
|
|
363
352
|
def forward(ctx, x: Tensor,
|
|
364
|
-
sparsity_layout_from: Tensor, sparsity_reverse_lut_from: Tensor,
|
|
365
|
-
|
|
353
|
+
sparsity_layout_from: Tensor, sparsity_reverse_lut_from: Tensor,
|
|
354
|
+
sparsity_block_size_from: int,
|
|
355
|
+
sparsity_layout_to: Tensor, sparsity_lut_to: Tensor,
|
|
356
|
+
sparsity_block_size_to: int,
|
|
366
357
|
n_sparse_blocks_to: int, min_sparsity_block_size: int, triton_block_size: int) -> Tensor:
|
|
367
358
|
output = torch.zeros(size=(n_sparse_blocks_to, sparsity_block_size_to, sparsity_block_size_to),
|
|
368
359
|
dtype=x.dtype, device=x.device)
|
|
@@ -409,9 +400,10 @@ class _BlocksparseAdaptLayout(torch.autograd.Function):
|
|
|
409
400
|
sparsity_block_size_to = ctx.sparsity_block_size_to
|
|
410
401
|
triton_block_size = ctx.triton_block_size
|
|
411
402
|
|
|
412
|
-
return adapt_layout(
|
|
413
|
-
|
|
414
|
-
|
|
403
|
+
return adapt_layout(
|
|
404
|
+
grad_output, sparsity_layout_to, sparsity_block_size_to, sparsity_block_size_from,
|
|
405
|
+
sparsity_layout_to=sparsity_layout_from,
|
|
406
|
+
triton_block_size=triton_block_size)[0], None, None, None, None, None, None, None, None, None
|
|
415
407
|
|
|
416
408
|
@staticmethod
|
|
417
409
|
@triton.jit
|
|
@@ -448,7 +440,7 @@ class _BlocksparseAdaptLayout(torch.autograd.Function):
|
|
|
448
440
|
spa_row_x = (spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size_from
|
|
449
441
|
spa_col_x = (spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size_from
|
|
450
442
|
|
|
451
|
-
#
|
|
443
|
+
# Get reverse sparsity indices for x
|
|
452
444
|
rev_idx_spa_x_idx = (spa_bat_x * s_l_x_b_s +
|
|
453
445
|
spa_row_x * s_l_x_r_s +
|
|
454
446
|
spa_col_x * s_l_x_c_s)
|
|
@@ -3,19 +3,23 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
+
from blksprs.ops.conversion import to_dense
|
|
6
7
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
7
8
|
from blksprs.utils.tools import get_triton_block_size, stride
|
|
8
9
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
9
10
|
validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
|
|
10
11
|
|
|
11
12
|
|
|
12
|
-
def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
13
|
+
def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
14
|
+
dim: int,
|
|
15
|
+
idx: BlksprsTensor, sparsity_layout_idx: Tensor,
|
|
13
16
|
sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
|
|
14
17
|
"""Applies a gather operation on a block-sparse tensor in compressed form.
|
|
15
18
|
|
|
16
19
|
Args:
|
|
17
20
|
src (BlksprsTensor): The source block-sparse tensor in compressed form to gather from.
|
|
18
21
|
sparsity_layout_src (Tensor): The sparsity layout of the source block-sparse tensor.
|
|
22
|
+
dim (int): The dimension along which to gather.
|
|
19
23
|
idx (BlksprsTensor): The block-sparse indices tensor in compressed form specifying how to gather from the source tensor.
|
|
20
24
|
sparsity_layout_idx (Tensor): The sparsity layout of the indices block-sparse tensor.
|
|
21
25
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
@@ -46,16 +50,18 @@ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor, idx: BlksprsTensor,
|
|
|
46
50
|
validate_contiguous(sparsity_layout_src, sparsity_reverse_lut_x,
|
|
47
51
|
sparsity_layout_idx, sparsity_lut_i)
|
|
48
52
|
|
|
53
|
+
adjusted_dim = dim % 3
|
|
54
|
+
|
|
49
55
|
return BlksprsTensor(_BlocksparseGather.apply(src, sparsity_layout_src, sparsity_reverse_lut_x,
|
|
50
|
-
|
|
51
|
-
|
|
56
|
+
adjusted_dim, idx, sparsity_layout_idx, sparsity_lut_i,
|
|
57
|
+
sparsity_block_size, triton_block_size))
|
|
52
58
|
|
|
53
59
|
|
|
54
60
|
class _BlocksparseGather(torch.autograd.Function):
|
|
55
61
|
|
|
56
62
|
@staticmethod
|
|
57
63
|
def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
|
|
58
|
-
i: Tensor, sparsity_layout_i: Tensor, sparsity_lut_i: Tensor,
|
|
64
|
+
dim: int, i: Tensor, sparsity_layout_i: Tensor, sparsity_lut_i: Tensor,
|
|
59
65
|
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
60
66
|
output = torch.empty_like(i, dtype=x.dtype)
|
|
61
67
|
|
|
@@ -82,6 +88,7 @@ class _BlocksparseGather(torch.autograd.Function):
|
|
|
82
88
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
83
89
|
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
84
90
|
sparsity_reverse_lut_x,
|
|
91
|
+
dim,
|
|
85
92
|
i,
|
|
86
93
|
i_b, i_b_s, i_r_s, i_c_s,
|
|
87
94
|
output,
|
|
@@ -91,6 +98,7 @@ class _BlocksparseGather(torch.autograd.Function):
|
|
|
91
98
|
triton_block_size))
|
|
92
99
|
|
|
93
100
|
ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_i)
|
|
101
|
+
ctx.dim = dim
|
|
94
102
|
ctx.sparsity_block_size = sparsity_block_size
|
|
95
103
|
ctx.triton_block_size = triton_block_size
|
|
96
104
|
|
|
@@ -99,15 +107,15 @@ class _BlocksparseGather(torch.autograd.Function):
|
|
|
99
107
|
@staticmethod
|
|
100
108
|
def backward(ctx, grad_output):
|
|
101
109
|
sparsity_layout_x, i, sparsity_layout_i = ctx.saved_tensors
|
|
110
|
+
dim = ctx.dim
|
|
102
111
|
sparsity_block_size = ctx.sparsity_block_size
|
|
103
112
|
triton_block_size = ctx.triton_block_size
|
|
104
113
|
|
|
105
114
|
return scatter_reduce(grad_output, sparsity_layout_i,
|
|
106
|
-
i,
|
|
107
|
-
sparsity_layout_x,
|
|
108
|
-
sparsity_block_size,
|
|
115
|
+
dim, i,
|
|
116
|
+
sparsity_layout_x, sparsity_block_size,
|
|
109
117
|
reduce_op="sum",
|
|
110
|
-
triton_block_size=triton_block_size), None, None, None, None, None, None, None
|
|
118
|
+
triton_block_size=triton_block_size), None, None, None, None, None, None, None, None
|
|
111
119
|
|
|
112
120
|
@staticmethod
|
|
113
121
|
@triton.jit
|
|
@@ -115,6 +123,7 @@ class _BlocksparseGather(torch.autograd.Function):
|
|
|
115
123
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
116
124
|
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
117
125
|
r_lut_x,
|
|
126
|
+
dim,
|
|
118
127
|
i,
|
|
119
128
|
i_b, i_b_s, i_r_s, i_c_s,
|
|
120
129
|
o,
|
|
@@ -136,6 +145,10 @@ class _BlocksparseGather(torch.autograd.Function):
|
|
|
136
145
|
spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
137
146
|
spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
|
|
138
147
|
|
|
148
|
+
spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
|
|
149
|
+
spa_col_o_msk = (spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
150
|
+
spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
|
|
151
|
+
|
|
139
152
|
# Load index values
|
|
140
153
|
blk_i_idx = ((pid_blk * i_b_s) +
|
|
141
154
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
@@ -143,33 +156,50 @@ class _BlocksparseGather(torch.autograd.Function):
|
|
|
143
156
|
blk_i_msk = (blk_i_idx < i_b * i_b_s)
|
|
144
157
|
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
|
|
145
158
|
|
|
146
|
-
# Get
|
|
159
|
+
# Get indices of sparsity blocks and positions within the blocks
|
|
147
160
|
pos_spa_blk_x = blk_i // sparsity_block_size
|
|
148
|
-
|
|
161
|
+
pos_spa_int_x = blk_i % sparsity_block_size
|
|
162
|
+
|
|
163
|
+
rev_dst_bat_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_o, dtype=tl.int32)
|
|
164
|
+
rev_dst_row_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_o, dtype=tl.int32)
|
|
165
|
+
rev_dst_col_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_o, dtype=tl.int32)
|
|
166
|
+
dst_row_x = (((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
|
|
167
|
+
.broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
|
|
168
|
+
dst_col_x = (((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
|
|
169
|
+
.broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
|
|
170
|
+
if dim == 0:
|
|
171
|
+
rev_dst_bat_x = blk_i
|
|
172
|
+
elif dim == 1:
|
|
173
|
+
rev_dst_row_x = pos_spa_blk_x
|
|
174
|
+
dst_row_x = pos_spa_int_x * x_r_s
|
|
175
|
+
elif dim == 2:
|
|
176
|
+
rev_dst_col_x = pos_spa_blk_x
|
|
177
|
+
dst_col_x = pos_spa_int_x * x_c_s
|
|
149
178
|
|
|
150
179
|
# Load reverse sparsity indices for x
|
|
151
|
-
rev_idx_spa_x_idx = ((
|
|
152
|
-
(
|
|
153
|
-
(
|
|
180
|
+
rev_idx_spa_x_idx = ((rev_dst_bat_x * s_l_x_b_s) +
|
|
181
|
+
(rev_dst_row_x * s_l_x_r_s) +
|
|
182
|
+
(rev_dst_col_x * s_l_x_c_s))
|
|
154
183
|
rev_idx_spa_x_msk = (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
|
|
155
184
|
rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
|
|
156
185
|
|
|
157
186
|
# Load x values
|
|
158
187
|
blk_x_idx = ((rev_idx_spa_x * x_b_s) +
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
188
|
+
dst_row_x +
|
|
189
|
+
dst_col_x)
|
|
190
|
+
blk_x_msk = ((blk_x_idx < x_b * x_b_s) & rev_idx_spa_x_msk != -1)
|
|
162
191
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
163
192
|
|
|
164
193
|
# Store output
|
|
165
194
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
166
195
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
167
196
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
168
|
-
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
197
|
+
blk_o_msk = ((blk_o_idx < o_b * o_b_s) & rev_idx_spa_x_msk != -1)
|
|
169
198
|
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
170
199
|
|
|
171
200
|
|
|
172
201
|
def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
202
|
+
dim: int,
|
|
173
203
|
idx: BlksprsTensor,
|
|
174
204
|
sparsity_layout_tgt: Tensor,
|
|
175
205
|
sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
|
|
@@ -177,6 +207,7 @@ def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
177
207
|
|
|
178
208
|
"""
|
|
179
209
|
return scatter_reduce(src, sparsity_layout_src,
|
|
210
|
+
dim,
|
|
180
211
|
idx,
|
|
181
212
|
sparsity_layout_tgt,
|
|
182
213
|
sparsity_block_size,
|
|
@@ -184,6 +215,7 @@ def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
184
215
|
|
|
185
216
|
|
|
186
217
|
def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
218
|
+
dim: int,
|
|
187
219
|
idx: BlksprsTensor,
|
|
188
220
|
sparsity_layout_tgt: Tensor,
|
|
189
221
|
sparsity_block_size: int,
|
|
@@ -193,6 +225,7 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
193
225
|
Args:
|
|
194
226
|
src (BlksprsTensor): The source block-sparse tensor in compressed form to scatter from.
|
|
195
227
|
sparsity_layout_src (Tensor): The sparsity layout of the source block-sparse tensor.
|
|
228
|
+
dim (int): The dimension along which to scatter.
|
|
196
229
|
idx (BlksprsTensor): The block-sparse indices tensor in compressed form specifying how to scatter to the target tensor.
|
|
197
230
|
sparsity_layout_tgt (Tensor): The sparsity layout of the target block-sparse tensor.
|
|
198
231
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
@@ -230,18 +263,20 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
230
263
|
validate_contiguous(sparsity_layout_src, sparsity_lut_x,
|
|
231
264
|
sparsity_layout_tgt, sparsity_reverse_lut_o)
|
|
232
265
|
|
|
266
|
+
adjusted_dim = dim % 3
|
|
267
|
+
|
|
233
268
|
return BlksprsTensor(_BlocksparseScatterReduce.apply(src, sparsity_layout_src, sparsity_lut_x,
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
269
|
+
adjusted_dim, idx,
|
|
270
|
+
sparsity_layout_tgt, sparsity_reverse_lut_o,
|
|
271
|
+
sparsity_block_size, n_sparse_blocks,
|
|
272
|
+
reduce_op, triton_block_size))
|
|
238
273
|
|
|
239
274
|
|
|
240
275
|
class _BlocksparseScatterReduce(torch.autograd.Function):
|
|
241
276
|
|
|
242
277
|
@staticmethod
|
|
243
278
|
def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_lut_x: Tensor,
|
|
244
|
-
i: Tensor,
|
|
279
|
+
dim: int, i: Tensor,
|
|
245
280
|
sparsity_layout_o: Tensor, sparsity_reverse_lut_o: Tensor,
|
|
246
281
|
sparsity_block_size: int, n_sparse_blocks: int,
|
|
247
282
|
reduce_op: str, triton_block_size: int) -> Tensor:
|
|
@@ -274,10 +309,11 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
|
|
|
274
309
|
(x,
|
|
275
310
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
276
311
|
sparsity_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
312
|
+
dim,
|
|
277
313
|
i,
|
|
278
314
|
i_b, i_b_s, i_r_s, i_c_s,
|
|
279
315
|
output,
|
|
280
|
-
o_b, o_b_s,
|
|
316
|
+
o_b, o_b_s,
|
|
281
317
|
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
282
318
|
sparsity_reverse_lut_o,
|
|
283
319
|
reduce_op_ind,
|
|
@@ -285,6 +321,7 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
|
|
|
285
321
|
triton_block_size))
|
|
286
322
|
|
|
287
323
|
ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_o)
|
|
324
|
+
ctx.dim = dim
|
|
288
325
|
ctx.sparsity_block_size = sparsity_block_size
|
|
289
326
|
ctx.reduce_op = reduce_op
|
|
290
327
|
ctx.triton_block_size = triton_block_size
|
|
@@ -294,13 +331,14 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
|
|
|
294
331
|
@staticmethod
|
|
295
332
|
def backward(ctx, grad_output):
|
|
296
333
|
sparsity_layout_x, i, sparsity_layout_o = ctx.saved_tensors
|
|
334
|
+
dim = ctx.dim
|
|
297
335
|
sparsity_block_size = ctx.sparsity_block_size
|
|
298
336
|
reduce_op = ctx.reduce_op
|
|
299
337
|
triton_block_size = ctx.triton_block_size
|
|
300
338
|
|
|
301
339
|
if reduce_op == "sum":
|
|
302
|
-
return gather(grad_output, sparsity_layout_o, i, sparsity_layout_x, sparsity_block_size,
|
|
303
|
-
triton_block_size=triton_block_size), None, None, None, None, None, None, None, None, None
|
|
340
|
+
return gather(grad_output, sparsity_layout_o, dim, i, sparsity_layout_x, sparsity_block_size,
|
|
341
|
+
triton_block_size=triton_block_size), None, None, None, None, None, None, None, None, None, None
|
|
304
342
|
else:
|
|
305
343
|
raise ValueError(f"Reduction operation '{reduce_op}' does not support backward pass")
|
|
306
344
|
|
|
@@ -309,10 +347,11 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
|
|
|
309
347
|
def kernel_blocksparse_scatter(x,
|
|
310
348
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
311
349
|
s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
350
|
+
dim,
|
|
312
351
|
i,
|
|
313
352
|
i_b, i_b_s, i_r_s, i_c_s,
|
|
314
353
|
o,
|
|
315
|
-
o_b, o_b_s,
|
|
354
|
+
o_b, o_b_s,
|
|
316
355
|
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
317
356
|
r_lut_o,
|
|
318
357
|
reduce_op_ind,
|
|
@@ -332,6 +371,10 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
|
|
|
332
371
|
spa_row_x_msk = (spa_row_x_idx < s_lut_x_r * s_lut_x_r_s)
|
|
333
372
|
spa_row_x = tl.load(s_lut_x + spa_row_x_idx, mask=spa_row_x_msk)
|
|
334
373
|
|
|
374
|
+
spa_col_x_idx = (pid_blk * s_lut_x_r_s + 2 * s_lut_x_c_s)
|
|
375
|
+
spa_col_x_msk = (spa_col_x_idx < s_lut_x_r * s_lut_x_r_s)
|
|
376
|
+
spa_col_x = tl.load(s_lut_x + spa_col_x_idx, mask=spa_col_x_msk)
|
|
377
|
+
|
|
335
378
|
# Load x values
|
|
336
379
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
337
380
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
@@ -346,22 +389,38 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
|
|
|
346
389
|
blk_i_msk = (blk_i_idx < i_b * i_b_s)
|
|
347
390
|
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
|
|
348
391
|
|
|
349
|
-
# Get
|
|
350
|
-
|
|
351
|
-
|
|
392
|
+
# Get indices of sparsity blocks and positions within the blocks
|
|
393
|
+
pos_spa_blk_x = blk_i // sparsity_block_size
|
|
394
|
+
pos_spa_int_x = blk_i % sparsity_block_size
|
|
395
|
+
|
|
396
|
+
rev_dst_bat_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_x, dtype=tl.int32)
|
|
397
|
+
rev_dst_row_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_x, dtype=tl.int32)
|
|
398
|
+
rev_dst_col_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_x, dtype=tl.int32)
|
|
399
|
+
dst_row_o = (((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
|
|
400
|
+
.broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
|
|
401
|
+
dst_col_o = (((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
|
|
402
|
+
.broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
|
|
403
|
+
if dim == 0:
|
|
404
|
+
rev_dst_bat_o = blk_i
|
|
405
|
+
elif dim == 1:
|
|
406
|
+
rev_dst_row_o = pos_spa_blk_x
|
|
407
|
+
dst_row_o = pos_spa_int_x * x_r_s
|
|
408
|
+
elif dim == 2:
|
|
409
|
+
rev_dst_col_o = pos_spa_blk_x
|
|
410
|
+
dst_col_o = pos_spa_int_x * x_c_s
|
|
352
411
|
|
|
353
412
|
# Load reverse sparsity indices for o
|
|
354
|
-
rev_idx_spa_o_idx = ((
|
|
355
|
-
(
|
|
356
|
-
(
|
|
413
|
+
rev_idx_spa_o_idx = ((rev_dst_bat_o * s_l_o_b_s) +
|
|
414
|
+
(rev_dst_row_o * s_l_o_r_s) +
|
|
415
|
+
(rev_dst_col_o * s_l_o_c_s))
|
|
357
416
|
rev_idx_spa_o_msk = (rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
|
|
358
417
|
rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
|
|
359
418
|
|
|
360
419
|
# Store output
|
|
361
420
|
blk_o_idx = ((rev_idx_spa_o * o_b_s) +
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
421
|
+
dst_row_o +
|
|
422
|
+
dst_col_o)
|
|
423
|
+
blk_o_msk = ((blk_o_idx < o_b * o_b_s) & rev_idx_spa_o_msk != -1)
|
|
365
424
|
|
|
366
425
|
if reduce_op_ind == 0:
|
|
367
426
|
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
@@ -153,6 +153,10 @@ class _BlocksparseGatherMDI(torch.autograd.Function):
|
|
|
153
153
|
rev_idx_spa_x_msk = (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
|
|
154
154
|
rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
|
|
155
155
|
|
|
156
|
+
if rev_idx_spa_x == -1:
|
|
157
|
+
tl.device_assert(False)
|
|
158
|
+
return
|
|
159
|
+
|
|
156
160
|
# Load x values
|
|
157
161
|
blk_x_idx = ((rev_idx_spa_x * x_b_s) +
|
|
158
162
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
@@ -342,6 +346,10 @@ class _BlocksparseScatterReduceMDI(torch.autograd.Function):
|
|
|
342
346
|
rev_idx_spa_o_msk = (rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
|
|
343
347
|
rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
|
|
344
348
|
|
|
349
|
+
if rev_idx_spa_o == -1:
|
|
350
|
+
tl.device_assert(False)
|
|
351
|
+
return
|
|
352
|
+
|
|
345
353
|
# Store output
|
|
346
354
|
blk_o_idx = ((rev_idx_spa_o * o_b_s) +
|
|
347
355
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
from triton import language as tl
|
|
5
|
+
|
|
6
|
+
from blksprs.utils.tools import stride, get_triton_block_size
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@triton.jit
|
|
10
|
+
def kernel_blocksparse_flow_pull(x,
|
|
11
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
12
|
+
o,
|
|
13
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
14
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
15
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
16
|
+
r_lut,
|
|
17
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
18
|
+
# Get triton block indices
|
|
19
|
+
pid_blk = tl.program_id(axis=0)
|
|
20
|
+
pid_row = tl.program_id(axis=1)
|
|
21
|
+
pid_col = tl.program_id(axis=2)
|
|
22
|
+
|
|
23
|
+
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
24
|
+
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
25
|
+
spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
26
|
+
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
27
|
+
|
|
28
|
+
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
29
|
+
spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
|
|
30
|
+
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
31
|
+
|
|
32
|
+
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
33
|
+
spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
|
|
34
|
+
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
35
|
+
|
|
36
|
+
# Get reverse sparsity index
|
|
37
|
+
rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
|
|
38
|
+
spa_row * s_l_o_r_s +
|
|
39
|
+
spa_col * s_l_o_c_s)
|
|
40
|
+
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
|
|
41
|
+
rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
42
|
+
|
|
43
|
+
if rev_idx_spa == -1:
|
|
44
|
+
tl.device_assert(False)
|
|
45
|
+
return
|
|
46
|
+
|
|
47
|
+
blk_x_idx = (rev_idx_spa * x_b_s +
|
|
48
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
49
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
50
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
51
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
52
|
+
|
|
53
|
+
blk_o_idx = (pid_blk * o_b_s +
|
|
54
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
55
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
56
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
57
|
+
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@triton.jit
|
|
61
|
+
def kernel_blocksparse_flow_push(x,
|
|
62
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
63
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
64
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
65
|
+
r_lut,
|
|
66
|
+
o,
|
|
67
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
68
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
69
|
+
# Get triton block indices
|
|
70
|
+
pid_blk = tl.program_id(axis=0)
|
|
71
|
+
pid_row = tl.program_id(axis=1)
|
|
72
|
+
pid_col = tl.program_id(axis=2)
|
|
73
|
+
|
|
74
|
+
# Get sparsity index of current input block consisting of its batch, row, and column index
|
|
75
|
+
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
76
|
+
spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
77
|
+
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
78
|
+
|
|
79
|
+
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
80
|
+
spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
|
|
81
|
+
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
82
|
+
|
|
83
|
+
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
84
|
+
spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
|
|
85
|
+
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
86
|
+
|
|
87
|
+
# Get reverse sparsity index
|
|
88
|
+
rev_idx_spa_idx = (spa_bat * s_l_x_b_s +
|
|
89
|
+
spa_row * s_l_x_r_s +
|
|
90
|
+
spa_col * s_l_x_c_s)
|
|
91
|
+
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_x_b * s_l_x_b_s)
|
|
92
|
+
rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
93
|
+
|
|
94
|
+
if rev_idx_spa == -1:
|
|
95
|
+
tl.device_assert(False)
|
|
96
|
+
return
|
|
97
|
+
|
|
98
|
+
blk_x_idx = (pid_blk * x_b_s +
|
|
99
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
100
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
101
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
102
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
103
|
+
|
|
104
|
+
blk_o_idx = (rev_idx_spa * o_b_s +
|
|
105
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
106
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
107
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
108
|
+
tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def flow_forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
112
|
+
sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
113
|
+
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
114
|
+
dtype=x.dtype, device=x.device)
|
|
115
|
+
output = torch.zeros_like(output)
|
|
116
|
+
|
|
117
|
+
x_b, x_r, x_c = x.size()
|
|
118
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
119
|
+
o_b, o_r, o_c = output.size()
|
|
120
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
121
|
+
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
|
|
122
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
|
|
123
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
124
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
125
|
+
|
|
126
|
+
if triton_block_size is None:
|
|
127
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
128
|
+
|
|
129
|
+
triton_grid = lambda meta: [o_b,
|
|
130
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
131
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
132
|
+
|
|
133
|
+
(kernel_blocksparse_flow_pull[triton_grid]
|
|
134
|
+
(x,
|
|
135
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
136
|
+
output,
|
|
137
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
138
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
139
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
140
|
+
sparsity_reverse_lut,
|
|
141
|
+
triton_block_size))
|
|
142
|
+
|
|
143
|
+
# Save for backward pass
|
|
144
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
145
|
+
ctx.triton_block_size = triton_block_size
|
|
146
|
+
|
|
147
|
+
return output
|
|
@@ -117,6 +117,10 @@ def kernel_blocksparse_row_wise_sum(x,
|
|
|
117
117
|
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
|
|
118
118
|
rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
119
119
|
|
|
120
|
+
if rev_idx_spa == -1:
|
|
121
|
+
tl.device_assert(False)
|
|
122
|
+
return
|
|
123
|
+
|
|
120
124
|
blk_idx = ((pid_blk * x_b_s) +
|
|
121
125
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
122
126
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
@@ -240,6 +244,10 @@ def kernel_blocksparse_row_wise_max(x,
|
|
|
240
244
|
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
|
|
241
245
|
rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
242
246
|
|
|
247
|
+
if rev_idx_spa == -1:
|
|
248
|
+
tl.device_assert(False)
|
|
249
|
+
return
|
|
250
|
+
|
|
243
251
|
blk_idx = ((pid_blk * x_b_s) +
|
|
244
252
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
245
253
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
from torch import Tensor
|
|
3
3
|
|
|
4
|
-
from blksprs.ops.
|
|
4
|
+
from blksprs.ops.flow import flow_forward
|
|
5
5
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
6
6
|
|
|
7
7
|
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
|
|
@@ -66,7 +66,7 @@ class _BlocksparseSplit(torch.autograd.Function):
|
|
|
66
66
|
ctx.save_for_backward(sparsity_layout_o)
|
|
67
67
|
ctx.num_partitions = num_partitions
|
|
68
68
|
|
|
69
|
-
return
|
|
69
|
+
return flow_forward(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
70
70
|
n_sparse_blocks, triton_block_size)
|
|
71
71
|
|
|
72
72
|
@staticmethod
|
|
@@ -140,7 +140,7 @@ class _BlocksparseMerge(torch.autograd.Function):
|
|
|
140
140
|
ctx.save_for_backward(sparsity_layout_o)
|
|
141
141
|
ctx.num_partitions = num_partitions
|
|
142
142
|
|
|
143
|
-
return
|
|
143
|
+
return flow_forward(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
144
144
|
n_sparse_blocks, triton_block_size)
|
|
145
145
|
|
|
146
146
|
@staticmethod
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import triton
|
|
3
|
-
from triton import language as tl
|
|
4
3
|
from torch import Tensor
|
|
5
4
|
|
|
5
|
+
from blksprs.ops.flow import kernel_blocksparse_flow_push, flow_forward
|
|
6
6
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
7
7
|
from blksprs.utils.tools import get_triton_block_size, stride
|
|
8
8
|
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
|
|
@@ -64,8 +64,9 @@ def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int,
|
|
|
64
64
|
|
|
65
65
|
validate_contiguous(sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
|
|
66
66
|
|
|
67
|
-
return BlksprsTensor(
|
|
68
|
-
|
|
67
|
+
return BlksprsTensor(
|
|
68
|
+
_BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
|
|
69
|
+
sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_o
|
|
69
70
|
|
|
70
71
|
|
|
71
72
|
def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
|
|
@@ -122,8 +123,9 @@ def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
|
|
|
122
123
|
|
|
123
124
|
validate_contiguous(sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
|
|
124
125
|
|
|
125
|
-
return BlksprsTensor(
|
|
126
|
-
|
|
126
|
+
return BlksprsTensor(
|
|
127
|
+
_BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
|
|
128
|
+
sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_o
|
|
127
129
|
|
|
128
130
|
|
|
129
131
|
class _BlocksparseRepeat(torch.autograd.Function):
|
|
@@ -137,7 +139,7 @@ class _BlocksparseRepeat(torch.autograd.Function):
|
|
|
137
139
|
ctx.x_size = x.size()
|
|
138
140
|
ctx.x_stride = stride(x)
|
|
139
141
|
|
|
140
|
-
return
|
|
142
|
+
return flow_forward(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
141
143
|
n_sparse_blocks, triton_block_size)
|
|
142
144
|
|
|
143
145
|
@staticmethod
|
|
@@ -180,144 +182,3 @@ class _BlocksparseRepeat(torch.autograd.Function):
|
|
|
180
182
|
triton_block_size))
|
|
181
183
|
|
|
182
184
|
return output, None, None, None, None, None, None, None
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
@triton.jit
|
|
186
|
-
def kernel_blocksparse_flow_pull(x,
|
|
187
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
188
|
-
o,
|
|
189
|
-
o_b, o_b_s, o_r_s, o_c_s,
|
|
190
|
-
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
191
|
-
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
192
|
-
r_lut,
|
|
193
|
-
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
194
|
-
# Get triton block indices
|
|
195
|
-
pid_blk = tl.program_id(axis=0)
|
|
196
|
-
pid_row = tl.program_id(axis=1)
|
|
197
|
-
pid_col = tl.program_id(axis=2)
|
|
198
|
-
|
|
199
|
-
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
200
|
-
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
201
|
-
spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
202
|
-
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
203
|
-
|
|
204
|
-
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
205
|
-
spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
|
|
206
|
-
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
207
|
-
|
|
208
|
-
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
209
|
-
spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
|
|
210
|
-
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
211
|
-
|
|
212
|
-
# Get reverse sparsity index
|
|
213
|
-
rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
|
|
214
|
-
spa_row * s_l_o_r_s +
|
|
215
|
-
spa_col * s_l_o_c_s)
|
|
216
|
-
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
|
|
217
|
-
rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
218
|
-
|
|
219
|
-
if rev_idx_spa == -1:
|
|
220
|
-
tl.device_assert(False)
|
|
221
|
-
return
|
|
222
|
-
|
|
223
|
-
blk_x_idx = (rev_idx_spa * x_b_s +
|
|
224
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
225
|
-
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
226
|
-
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
227
|
-
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
228
|
-
|
|
229
|
-
blk_o_idx = (pid_blk * o_b_s +
|
|
230
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
231
|
-
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
232
|
-
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
233
|
-
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
@triton.jit
|
|
237
|
-
def kernel_blocksparse_flow_push(x,
|
|
238
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
239
|
-
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
240
|
-
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
241
|
-
r_lut,
|
|
242
|
-
o,
|
|
243
|
-
o_b, o_b_s, o_r_s, o_c_s,
|
|
244
|
-
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
245
|
-
# Get triton block indices
|
|
246
|
-
pid_blk = tl.program_id(axis=0)
|
|
247
|
-
pid_row = tl.program_id(axis=1)
|
|
248
|
-
pid_col = tl.program_id(axis=2)
|
|
249
|
-
|
|
250
|
-
# Get sparsity index of current input block consisting of its batch, row, and column index
|
|
251
|
-
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
252
|
-
spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
253
|
-
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
254
|
-
|
|
255
|
-
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
256
|
-
spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
|
|
257
|
-
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
258
|
-
|
|
259
|
-
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
260
|
-
spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
|
|
261
|
-
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
262
|
-
|
|
263
|
-
# Get reverse sparsity index
|
|
264
|
-
rev_idx_spa_idx = (spa_bat * s_l_x_b_s +
|
|
265
|
-
spa_row * s_l_x_r_s +
|
|
266
|
-
spa_col * s_l_x_c_s)
|
|
267
|
-
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_x_b * s_l_x_b_s)
|
|
268
|
-
rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
269
|
-
|
|
270
|
-
if rev_idx_spa == -1:
|
|
271
|
-
tl.device_assert(False)
|
|
272
|
-
return
|
|
273
|
-
|
|
274
|
-
blk_x_idx = (pid_blk * x_b_s +
|
|
275
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
276
|
-
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
277
|
-
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
278
|
-
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
279
|
-
|
|
280
|
-
blk_o_idx = (rev_idx_spa * o_b_s +
|
|
281
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
282
|
-
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
283
|
-
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
284
|
-
tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
def forward_flow(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
288
|
-
sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
289
|
-
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
290
|
-
dtype=x.dtype, device=x.device)
|
|
291
|
-
output = torch.zeros_like(output)
|
|
292
|
-
|
|
293
|
-
x_b, x_r, x_c = x.size()
|
|
294
|
-
x_b_s, x_r_s, x_c_s = stride(x)
|
|
295
|
-
o_b, o_r, o_c = output.size()
|
|
296
|
-
o_b_s, o_r_s, o_c_s = stride(output)
|
|
297
|
-
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
|
|
298
|
-
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
|
|
299
|
-
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
300
|
-
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
301
|
-
|
|
302
|
-
if triton_block_size is None:
|
|
303
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
304
|
-
|
|
305
|
-
triton_grid = lambda meta: [o_b,
|
|
306
|
-
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
307
|
-
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
308
|
-
|
|
309
|
-
(kernel_blocksparse_flow_pull[triton_grid]
|
|
310
|
-
(x,
|
|
311
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
312
|
-
output,
|
|
313
|
-
o_b, o_b_s, o_r_s, o_c_s,
|
|
314
|
-
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
315
|
-
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
316
|
-
sparsity_reverse_lut,
|
|
317
|
-
triton_block_size))
|
|
318
|
-
|
|
319
|
-
# Save for backward pass
|
|
320
|
-
ctx.sparsity_block_size = sparsity_block_size
|
|
321
|
-
ctx.triton_block_size = triton_block_size
|
|
322
|
-
|
|
323
|
-
return output
|
|
@@ -238,6 +238,10 @@ class _BlocksparseSoftmax(torch.autograd.Function):
|
|
|
238
238
|
rev_idx_spa_s_msk = (rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
|
|
239
239
|
rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
|
|
240
240
|
|
|
241
|
+
if rev_idx_spa_s == -1:
|
|
242
|
+
tl.device_assert(False)
|
|
243
|
+
return
|
|
244
|
+
|
|
241
245
|
blk_s_idx = (rev_idx_spa_s * s_b_s +
|
|
242
246
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
|
|
243
247
|
(tl.arange(0, 1) * s_c_s)[None, :])
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import triton
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
from torch.xpu import device
|
|
7
|
+
from triton import language as tl
|
|
8
|
+
|
|
9
|
+
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
10
|
+
from blksprs.utils.tools import get_triton_block_size, stride
|
|
11
|
+
from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
|
|
12
|
+
validate_contiguous, validate_sparsity, validate_sparsity_block_size
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def build_full_sparsity_layout(x: Tensor, sparsity_block_size: int) -> Tensor:
|
|
16
|
+
return torch.ones(size=(x.size(0), x.size(1) // sparsity_block_size, x.size(2) // sparsity_block_size),
|
|
17
|
+
dtype=torch.bool, device=x.device)
|
|
@@ -1,7 +1,9 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
|
|
1
3
|
import torch
|
|
2
4
|
from torch import Tensor, nn
|
|
3
|
-
from triton.language import dtype
|
|
4
5
|
|
|
6
|
+
import blksprs as bs
|
|
5
7
|
from blksprs.layouting.sparsity_layout import build_sparsity_layout_matmul_fast
|
|
6
8
|
from blksprs.ops.conversion import to_sparse
|
|
7
9
|
from blksprs.ops.matmul import matmul
|
|
@@ -10,7 +12,7 @@ from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
|
10
12
|
|
|
11
13
|
|
|
12
14
|
def apply_torch_linear(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
13
|
-
linear: nn.Linear) -> (BlksprsTensor, Tensor):
|
|
15
|
+
linear: nn.Linear, bias: nn.Parameter = None) -> (BlksprsTensor, Tensor):
|
|
14
16
|
# Extract weight and bias
|
|
15
17
|
w = linear.weight
|
|
16
18
|
b = linear.bias
|
|
@@ -27,6 +29,8 @@ def apply_torch_linear(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block
|
|
|
27
29
|
interim = xw
|
|
28
30
|
|
|
29
31
|
# Apply bias
|
|
32
|
+
if bias is not None:
|
|
33
|
+
b = bias
|
|
30
34
|
if b is not None:
|
|
31
35
|
b_slice = b.unsqueeze(0).unsqueeze(0).repeat(1, sparsity_block_size, 1)
|
|
32
36
|
sparsity_layout_b_slice = torch.ones(size=(1, b_slice.size(1) // sparsity_block_size,
|
|
@@ -39,3 +43,32 @@ def apply_torch_linear(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block
|
|
|
39
43
|
interim = interim + b_bs
|
|
40
44
|
|
|
41
45
|
return interim, sparsity_layout_xw
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def apply_torch_normalisation(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
49
|
+
normalisation: nn.Module) -> BlksprsTensor:
|
|
50
|
+
return apply_function_applicable_row_wise(x, sparsity_layout, sparsity_block_size, normalisation)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def apply_torch_dropout(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
54
|
+
dropout: nn.Dropout) -> BlksprsTensor:
|
|
55
|
+
return apply_function_applicable_row_wise(x, sparsity_layout, sparsity_block_size, dropout)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def apply_function_applicable_row_wise(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
59
|
+
function: Callable) -> BlksprsTensor:
|
|
60
|
+
sparsity_layout_packed = _pack_layout(sparsity_layout)
|
|
61
|
+
blksprs_pseudo_dense = bs.ops.to_dense(x, sparsity_layout_packed, sparsity_block_size)
|
|
62
|
+
normalisation_out = function(blksprs_pseudo_dense)
|
|
63
|
+
blksprs_sparse = bs.ops.to_sparse(normalisation_out, sparsity_layout_packed, sparsity_block_size)
|
|
64
|
+
|
|
65
|
+
return blksprs_sparse
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _pack_layout(sparsity_layout: Tensor) -> BlksprsTensor:
|
|
69
|
+
sparsity_layout_resized = sparsity_layout.resize(1, sparsity_layout.size(0) * sparsity_layout.size(1),
|
|
70
|
+
sparsity_layout.size(2))
|
|
71
|
+
non_zero_rows = torch.any(sparsity_layout_resized, dim=-1)
|
|
72
|
+
sparsity_layout_filtered = sparsity_layout_resized[non_zero_rows].unsqueeze(0)
|
|
73
|
+
|
|
74
|
+
return sparsity_layout_filtered
|
|
@@ -36,7 +36,8 @@ def validate_dtype_int(*tensors: Tensor) -> None:
|
|
|
36
36
|
return
|
|
37
37
|
|
|
38
38
|
for tensor in tensors:
|
|
39
|
-
if tensor.dtype !=
|
|
39
|
+
if (tensor.dtype !=
|
|
40
|
+
torch.int32 and tensor.dtype != torch.int64):
|
|
40
41
|
raise ValueError("Tensor must have int32 or int64 dtype")
|
|
41
42
|
|
|
42
43
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.9.1
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -64,8 +64,12 @@ Further helpful operations (included in the ``bs.ops.misc`` module) that do **no
|
|
|
64
64
|
- Row-wise sum, max, addition, and subtraction
|
|
65
65
|
- Broadcast addition and subtraction between slices
|
|
66
66
|
|
|
67
|
-
Furthermore, the library provides a set of utility functions
|
|
68
|
-
|
|
67
|
+
Furthermore, the library provides a set of utility functions
|
|
68
|
+
|
|
69
|
+
- for the creation of sparsity layouts based on existing
|
|
70
|
+
dense tensors and for the scatter operation (module ``bs.layouting``),
|
|
71
|
+
- for the application of ``nn.Linear``, ``nn.Dropout``, and ``nn.LayerNorm`` layers to block-sparse tensors,
|
|
72
|
+
- as well as utility functions to apply linear layers,
|
|
69
73
|
ensure correct input dimensionality, and validate input (module ``bs.utils``).
|
|
70
74
|
|
|
71
75
|
## Installation
|
|
@@ -10,6 +10,7 @@ blksprs/layouting/distribution_layout.py
|
|
|
10
10
|
blksprs/layouting/sparsity_layout.py
|
|
11
11
|
blksprs/ops/conversion.py
|
|
12
12
|
blksprs/ops/distribution.py
|
|
13
|
+
blksprs/ops/flow.py
|
|
13
14
|
blksprs/ops/matmul.py
|
|
14
15
|
blksprs/ops/partitioning.py
|
|
15
16
|
blksprs/ops/repeat.py
|
|
@@ -21,6 +22,7 @@ blksprs/ops/misc/exp.py
|
|
|
21
22
|
blksprs/ops/misc/row_wise.py
|
|
22
23
|
blksprs/utils/benchmarking.py
|
|
23
24
|
blksprs/utils/blksprs_tensor.py
|
|
25
|
+
blksprs/utils/layout_utils.py
|
|
24
26
|
blksprs/utils/processing.py
|
|
25
27
|
blksprs/utils/tools.py
|
|
26
28
|
blksprs/utils/validation.py
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|