blksprs 1.10.1__py3-none-any.whl → 1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +0 -1
- blksprs/ops/conversion.py +42 -15
- blksprs/ops/distribution.py +60 -30
- blksprs/ops/flow.py +63 -31
- blksprs/ops/matmul.py +40 -22
- blksprs/ops/partitioning.py +102 -59
- blksprs/ops/repeat.py +88 -76
- blksprs/ops/softmax.py +71 -63
- blksprs/ops/transpose.py +38 -101
- blksprs/utils/tools.py +7 -1
- {blksprs-1.10.1.dist-info → blksprs-1.11.dist-info}/METADATA +2 -2
- blksprs-1.11.dist-info/RECORD +23 -0
- {blksprs-1.10.1.dist-info → blksprs-1.11.dist-info}/WHEEL +1 -1
- blksprs/ops/misc/exp.py +0 -104
- blksprs-1.10.1.dist-info/RECORD +0 -24
- {blksprs-1.10.1.dist-info → blksprs-1.11.dist-info}/top_level.txt +0 -0
blksprs/ops/partitioning.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
from torch import Tensor
|
|
3
3
|
|
|
4
|
-
from blksprs.ops.flow import
|
|
4
|
+
from blksprs.ops.flow import flow_forward_pull
|
|
5
5
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
6
6
|
|
|
7
7
|
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
|
|
@@ -9,7 +9,8 @@ from blksprs.utils.validation import validate_dimensions, validate_contiguous, v
|
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
12
|
-
dim: int, sparsity_block_size: int, triton_block_size: int = None) -> (
|
|
12
|
+
dim: int, sparsity_block_size: int, triton_block_size: int = None, lut: dict = None) -> (
|
|
13
|
+
BlksprsTensor, Tensor):
|
|
13
14
|
"""Splits a block-sparse tensor in compressed form along the last dimension into partitions.
|
|
14
15
|
|
|
15
16
|
Args:
|
|
@@ -19,6 +20,7 @@ def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
|
19
20
|
dim (int): The dimension along which to split the tensor. Currently only supports dim=2.
|
|
20
21
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
21
22
|
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
23
|
+
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
22
24
|
|
|
23
25
|
Returns:
|
|
24
26
|
BlksprsTensor: The block-sparse tensor split into partitions in compressed form.
|
|
@@ -34,46 +36,66 @@ def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
|
34
36
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
35
37
|
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
36
38
|
|
|
37
|
-
sparsity_layout_output = (sparsity_layout
|
|
38
|
-
.reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
|
|
39
|
-
sparsity_layout.size(2) // partitions)
|
|
40
|
-
.permute(0, 2, 1, 3)
|
|
41
|
-
.reshape(sparsity_layout.size(0) * partitions, sparsity_layout.size(1),
|
|
42
|
-
sparsity_layout.size(2) // partitions).contiguous())
|
|
43
|
-
|
|
44
|
-
sparsity_lut = torch.nonzero(sparsity_layout_output).contiguous()
|
|
45
|
-
|
|
46
|
-
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
|
47
|
-
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
48
|
-
(sparsity_layout_flat == 1) -
|
|
49
|
-
(1 * (sparsity_layout_flat == 0)))
|
|
50
|
-
.reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
|
|
51
|
-
sparsity_layout.size(2) // partitions)
|
|
52
|
-
.permute(0, 2, 1, 3).reshape(-1).contiguous())
|
|
53
|
-
|
|
54
|
-
n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
|
|
55
|
-
|
|
56
|
-
validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
|
|
57
|
-
|
|
58
39
|
adjusted_dim = dim % 3
|
|
59
40
|
if adjusted_dim != 2:
|
|
60
41
|
raise NotImplementedError("Currently only supports dim=2")
|
|
61
42
|
|
|
62
|
-
|
|
63
|
-
|
|
43
|
+
lut = _BlocksparseSplit.build_lut(lut, sparsity_layout, partitions)
|
|
44
|
+
|
|
45
|
+
return BlksprsTensor(
|
|
46
|
+
_BlocksparseSplit.apply(x, lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"],
|
|
47
|
+
partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"],
|
|
48
|
+
triton_block_size)), lut["sparsity_layout_output"]
|
|
64
49
|
|
|
65
50
|
|
|
66
51
|
class _BlocksparseSplit(torch.autograd.Function):
|
|
67
52
|
|
|
53
|
+
@staticmethod
|
|
54
|
+
def build_lut(lut: dict, sparsity_layout: Tensor, partitions: int):
|
|
55
|
+
if lut is None:
|
|
56
|
+
lut = dict()
|
|
57
|
+
|
|
58
|
+
if "sparsity_layout_output" not in lut:
|
|
59
|
+
sparsity_layout_output = (sparsity_layout
|
|
60
|
+
.reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
|
|
61
|
+
sparsity_layout.size(2) // partitions)
|
|
62
|
+
.permute(0, 2, 1, 3)
|
|
63
|
+
.reshape(sparsity_layout.size(0) * partitions, sparsity_layout.size(1),
|
|
64
|
+
sparsity_layout.size(2) // partitions).contiguous())
|
|
65
|
+
lut["sparsity_layout_output"] = sparsity_layout_output
|
|
66
|
+
|
|
67
|
+
if "sparsity_lut" not in lut:
|
|
68
|
+
sparsity_lut = torch.nonzero(lut["sparsity_layout_output"]).contiguous()
|
|
69
|
+
lut["sparsity_lut"] = sparsity_lut
|
|
70
|
+
|
|
71
|
+
if "sparsity_reverse_lut" not in lut:
|
|
72
|
+
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
|
73
|
+
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
74
|
+
(sparsity_layout_flat == 1) -
|
|
75
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
76
|
+
.reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
|
|
77
|
+
sparsity_layout.size(2) // partitions)
|
|
78
|
+
.permute(0, 2, 1, 3).reshape(-1).contiguous())
|
|
79
|
+
lut["sparsity_reverse_lut"] = sparsity_reverse_lut
|
|
80
|
+
|
|
81
|
+
if "n_sparse_blocks" not in lut:
|
|
82
|
+
n_sparse_blocks = torch.sum(lut["sparsity_layout_output"].to(torch.int)).item()
|
|
83
|
+
lut["n_sparse_blocks"] = n_sparse_blocks
|
|
84
|
+
|
|
85
|
+
validate_contiguous(lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"])
|
|
86
|
+
|
|
87
|
+
return lut
|
|
88
|
+
|
|
68
89
|
@staticmethod
|
|
69
90
|
def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
70
|
-
num_partitions: int, dim: int, sparsity_block_size: int, n_sparse_blocks: int,
|
|
91
|
+
num_partitions: int, dim: int, sparsity_block_size: int, n_sparse_blocks: int,
|
|
92
|
+
triton_block_size: int) -> Tensor:
|
|
71
93
|
ctx.save_for_backward(sparsity_layout_o)
|
|
72
94
|
ctx.num_partitions = num_partitions
|
|
73
95
|
ctx.dim = dim
|
|
74
96
|
|
|
75
|
-
return
|
|
76
|
-
|
|
97
|
+
return flow_forward_pull(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
98
|
+
n_sparse_blocks, triton_block_size)
|
|
77
99
|
|
|
78
100
|
@staticmethod
|
|
79
101
|
def backward(ctx, grad_output):
|
|
@@ -88,7 +110,8 @@ class _BlocksparseSplit(torch.autograd.Function):
|
|
|
88
110
|
|
|
89
111
|
|
|
90
112
|
def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
91
|
-
dim: int, sparsity_block_size: int, triton_block_size: int = None) -> (
|
|
113
|
+
dim: int, sparsity_block_size: int, triton_block_size: int = None, lut: dict = None) -> (
|
|
114
|
+
BlksprsTensor, Tensor):
|
|
92
115
|
"""Merges the specified partitions of a block-sparse tensor in compressed form along the last dimension.
|
|
93
116
|
|
|
94
117
|
Args:
|
|
@@ -98,6 +121,7 @@ def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
|
98
121
|
dim (int): The dimension along which to merge the tensor. Currently only supports dim=2.
|
|
99
122
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
100
123
|
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
124
|
+
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
101
125
|
|
|
102
126
|
Returns:
|
|
103
127
|
BlksprsTensor: The merged block-sparse tensor in compressed form.
|
|
@@ -113,48 +137,69 @@ def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
|
113
137
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
114
138
|
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
115
139
|
|
|
116
|
-
sparsity_layout_output = (sparsity_layout.reshape(sparsity_layout.size(0) // partitions, partitions,
|
|
117
|
-
sparsity_layout.size(1), sparsity_layout.size(2))
|
|
118
|
-
.permute(0, 2, 1, 3)
|
|
119
|
-
.reshape(sparsity_layout.size(0) // partitions,
|
|
120
|
-
sparsity_layout.size(1), sparsity_layout.size(2) * partitions).contiguous())
|
|
121
|
-
|
|
122
|
-
sparsity_lut = torch.nonzero(sparsity_layout_output).contiguous()
|
|
123
|
-
|
|
124
|
-
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
|
125
|
-
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
126
|
-
(sparsity_layout_flat == 1) -
|
|
127
|
-
(1 * (sparsity_layout_flat == 0)))
|
|
128
|
-
.reshape(sparsity_layout.size(0) // partitions, partitions,
|
|
129
|
-
sparsity_layout.size(1), sparsity_layout.size(2))
|
|
130
|
-
.permute(0, 2, 1, 3)
|
|
131
|
-
.reshape(sparsity_layout.size(0) // partitions,
|
|
132
|
-
sparsity_layout.size(1), sparsity_layout.size(2) * partitions)
|
|
133
|
-
.reshape(-1).contiguous())
|
|
134
|
-
|
|
135
|
-
n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
|
|
136
|
-
|
|
137
|
-
validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
|
|
138
|
-
|
|
139
140
|
adjusted_dim = dim % 3
|
|
140
141
|
if adjusted_dim != 2:
|
|
141
142
|
raise NotImplementedError("Currently only supports dim=2")
|
|
142
143
|
|
|
143
|
-
|
|
144
|
-
|
|
144
|
+
lut = _BlocksparseMerge.build_lut(lut, sparsity_layout, partitions)
|
|
145
|
+
|
|
146
|
+
return BlksprsTensor(
|
|
147
|
+
_BlocksparseMerge.apply(x, lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"],
|
|
148
|
+
partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"],
|
|
149
|
+
triton_block_size)), lut["sparsity_layout_output"]
|
|
145
150
|
|
|
146
151
|
|
|
147
152
|
class _BlocksparseMerge(torch.autograd.Function):
|
|
148
153
|
|
|
154
|
+
@staticmethod
|
|
155
|
+
def build_lut(lut: dict, sparsity_layout: Tensor, partitions: int):
|
|
156
|
+
if lut is None:
|
|
157
|
+
lut = dict()
|
|
158
|
+
|
|
159
|
+
if "sparsity_layout_output" not in lut:
|
|
160
|
+
sparsity_layout_output = (sparsity_layout.reshape(sparsity_layout.size(0) // partitions, partitions,
|
|
161
|
+
sparsity_layout.size(1), sparsity_layout.size(2))
|
|
162
|
+
.permute(0, 2, 1, 3)
|
|
163
|
+
.reshape(sparsity_layout.size(0) // partitions,
|
|
164
|
+
sparsity_layout.size(1),
|
|
165
|
+
sparsity_layout.size(2) * partitions).contiguous())
|
|
166
|
+
lut["sparsity_layout_output"] = sparsity_layout_output
|
|
167
|
+
|
|
168
|
+
if "sparsity_lut" not in lut:
|
|
169
|
+
sparsity_lut = torch.nonzero(lut["sparsity_layout_output"]).contiguous()
|
|
170
|
+
lut["sparsity_lut"] = sparsity_lut
|
|
171
|
+
|
|
172
|
+
if "sparsity_reverse_lut" not in lut:
|
|
173
|
+
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
|
174
|
+
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
175
|
+
(sparsity_layout_flat == 1) -
|
|
176
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
177
|
+
.reshape(sparsity_layout.size(0) // partitions, partitions,
|
|
178
|
+
sparsity_layout.size(1), sparsity_layout.size(2))
|
|
179
|
+
.permute(0, 2, 1, 3)
|
|
180
|
+
.reshape(sparsity_layout.size(0) // partitions,
|
|
181
|
+
sparsity_layout.size(1), sparsity_layout.size(2) * partitions)
|
|
182
|
+
.reshape(-1).contiguous())
|
|
183
|
+
lut["sparsity_reverse_lut"] = sparsity_reverse_lut
|
|
184
|
+
|
|
185
|
+
if "n_sparse_blocks" not in lut:
|
|
186
|
+
n_sparse_blocks = torch.sum(lut["sparsity_layout_output"].to(torch.int)).item()
|
|
187
|
+
lut["n_sparse_blocks"] = n_sparse_blocks
|
|
188
|
+
|
|
189
|
+
validate_contiguous(lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"])
|
|
190
|
+
|
|
191
|
+
return lut
|
|
192
|
+
|
|
149
193
|
@staticmethod
|
|
150
194
|
def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
151
|
-
num_partitions: int, dim: int, sparsity_block_size: int, n_sparse_blocks: int,
|
|
195
|
+
num_partitions: int, dim: int, sparsity_block_size: int, n_sparse_blocks: int,
|
|
196
|
+
triton_block_size: int) -> Tensor:
|
|
152
197
|
ctx.save_for_backward(sparsity_layout_o)
|
|
153
198
|
ctx.num_partitions = num_partitions
|
|
154
199
|
ctx.dim = dim
|
|
155
200
|
|
|
156
|
-
return
|
|
157
|
-
|
|
201
|
+
return flow_forward_pull(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
202
|
+
n_sparse_blocks, triton_block_size)
|
|
158
203
|
|
|
159
204
|
@staticmethod
|
|
160
205
|
def backward(ctx, grad_output):
|
|
@@ -166,5 +211,3 @@ class _BlocksparseMerge(torch.autograd.Function):
|
|
|
166
211
|
|
|
167
212
|
return split(grad_output, sparsity_layout, num_partitions, dim,
|
|
168
213
|
sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None, None
|
|
169
|
-
|
|
170
|
-
|
blksprs/ops/repeat.py
CHANGED
|
@@ -2,7 +2,7 @@ import torch
|
|
|
2
2
|
import triton
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
|
|
5
|
-
from blksprs.ops.flow import kernel_blocksparse_flow_push,
|
|
5
|
+
from blksprs.ops.flow import kernel_blocksparse_flow_push, flow_forward_pull, flow_forward_push
|
|
6
6
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
7
7
|
from blksprs.utils.tools import get_triton_block_size, stride
|
|
8
8
|
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
|
|
@@ -10,7 +10,8 @@ from blksprs.utils.validation import validate_dimensions, validate_contiguous, v
|
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
|
|
13
|
-
sparsity_block_size: int, sparsity_layout_output: Tensor = None, triton_block_size: int = None
|
|
13
|
+
sparsity_block_size: int, sparsity_layout_output: Tensor = None, triton_block_size: int = None,
|
|
14
|
+
lut: dict = None) -> (
|
|
14
15
|
BlksprsTensor, Tensor):
|
|
15
16
|
"""Repeats a block-spare tensor in compressed form according to the given repeats.
|
|
16
17
|
|
|
@@ -30,6 +31,7 @@ def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int,
|
|
|
30
31
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
31
32
|
sparsity_layout_output (Tensor): The desired sparsity layout of the output tensor (default ``None``).
|
|
32
33
|
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
34
|
+
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
33
35
|
|
|
34
36
|
Returns:
|
|
35
37
|
BlksprsTensor: A block-sparse tensor in compressed form containing the repeated values.
|
|
@@ -45,33 +47,17 @@ def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int,
|
|
|
45
47
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
46
48
|
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
47
49
|
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
if sparsity_layout_output is not None:
|
|
51
|
-
sparsity_layout_o = torch.logical_and(sparsity_layout_o, sparsity_layout_output)
|
|
52
|
-
|
|
53
|
-
sparsity_lut = torch.nonzero(sparsity_layout_o).contiguous()
|
|
54
|
-
|
|
55
|
-
sparsity_layout_flat = sparsity_layout_x.reshape(-1)
|
|
56
|
-
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
57
|
-
(sparsity_layout_flat == 1) -
|
|
58
|
-
(1 * (sparsity_layout_flat == 0)))
|
|
59
|
-
.reshape(sparsity_layout_x.size())
|
|
60
|
-
.repeat(repeats[0], repeats[1], repeats[2])
|
|
61
|
-
.reshape(-1).contiguous())
|
|
62
|
-
|
|
63
|
-
n_sparse_blocks = torch.sum(sparsity_layout_o.to(torch.int)).item()
|
|
64
|
-
|
|
65
|
-
validate_contiguous(sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
|
|
50
|
+
lut = _BlocksparseRepeat.build_lut_repeat(lut, sparsity_layout_x, repeats, sparsity_layout_output)
|
|
66
51
|
|
|
67
52
|
return BlksprsTensor(
|
|
68
|
-
_BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut,
|
|
69
|
-
sparsity_block_size, n_sparse_blocks,
|
|
53
|
+
_BlocksparseRepeat.apply(x, sparsity_layout_x, lut["sparsity_layout_o"], lut["sparsity_lut"],
|
|
54
|
+
lut["sparsity_reverse_lut"], sparsity_block_size, lut["n_sparse_blocks"],
|
|
55
|
+
triton_block_size)), lut["sparsity_layout_o"]
|
|
70
56
|
|
|
71
57
|
|
|
72
58
|
def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
|
|
73
59
|
sparsity_block_size: int, sparsity_layout_output: Tensor = None,
|
|
74
|
-
triton_block_size: int = None) -> (
|
|
60
|
+
triton_block_size: int = None, lut: dict = None) -> (
|
|
75
61
|
BlksprsTensor, Tensor):
|
|
76
62
|
"""Repeats and interleaves the block-sparse tensor in compressed form.
|
|
77
63
|
|
|
@@ -89,6 +75,7 @@ def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
|
|
|
89
75
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
90
76
|
sparsity_layout_output (Tensor): The desired sparsity layout of the output tensor (default ``None``).
|
|
91
77
|
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
78
|
+
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
92
79
|
|
|
93
80
|
Returns:
|
|
94
81
|
BlksprsTensor: A block-sparse tensor in compressed form containing the repeated and interleaved matrices.
|
|
@@ -104,31 +91,87 @@ def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
|
|
|
104
91
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
105
92
|
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
106
93
|
|
|
107
|
-
|
|
94
|
+
lut = _BlocksparseRepeat.build_lut_repeat_interleave(lut, sparsity_layout_x, repeats, sparsity_layout_output)
|
|
108
95
|
|
|
109
|
-
|
|
110
|
-
|
|
96
|
+
return BlksprsTensor(
|
|
97
|
+
_BlocksparseRepeat.apply(x, sparsity_layout_x, lut["sparsity_layout_o"], lut["sparsity_lut"],
|
|
98
|
+
lut["sparsity_reverse_lut"], sparsity_block_size, lut["n_sparse_blocks"],
|
|
99
|
+
triton_block_size)), lut["sparsity_layout_o"]
|
|
111
100
|
|
|
112
|
-
sparsity_lut = torch.nonzero(sparsity_layout_o).contiguous()
|
|
113
101
|
|
|
114
|
-
|
|
115
|
-
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
116
|
-
(sparsity_layout_flat == 1) -
|
|
117
|
-
(1 * (sparsity_layout_flat == 0)))
|
|
118
|
-
.reshape(sparsity_layout_x.size())
|
|
119
|
-
.repeat_interleave(repeats, dim=0)
|
|
120
|
-
.reshape(-1).contiguous())
|
|
102
|
+
class _BlocksparseRepeat(torch.autograd.Function):
|
|
121
103
|
|
|
122
|
-
|
|
104
|
+
@staticmethod
|
|
105
|
+
def build_lut_repeat(lut: dict, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
|
|
106
|
+
sparsity_layout_output: Tensor):
|
|
107
|
+
if lut is None:
|
|
108
|
+
lut = dict()
|
|
123
109
|
|
|
124
|
-
|
|
110
|
+
if "sparsity_layout_o" not in lut:
|
|
111
|
+
sparsity_layout_o = sparsity_layout_x.repeat(repeats[0], repeats[1], repeats[2])
|
|
112
|
+
lut["sparsity_layout_o"] = sparsity_layout_o
|
|
125
113
|
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
114
|
+
if sparsity_layout_output is not None:
|
|
115
|
+
sparsity_layout_o = torch.logical_and(lut["sparsity_layout_o"], sparsity_layout_output)
|
|
116
|
+
lut["sparsity_layout_o"] = sparsity_layout_o
|
|
129
117
|
|
|
118
|
+
if "sparsity_lut" not in lut:
|
|
119
|
+
sparsity_lut = torch.nonzero(lut["sparsity_layout_o"]).contiguous()
|
|
120
|
+
lut["sparsity_lut"] = sparsity_lut
|
|
130
121
|
|
|
131
|
-
|
|
122
|
+
if "sparsity_reverse_lut" not in lut:
|
|
123
|
+
sparsity_layout_flat = sparsity_layout_x.reshape(-1)
|
|
124
|
+
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
125
|
+
(sparsity_layout_flat == 1) -
|
|
126
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
127
|
+
.reshape(sparsity_layout_x.size())
|
|
128
|
+
.repeat(repeats[0], repeats[1], repeats[2])
|
|
129
|
+
.reshape(-1).contiguous())
|
|
130
|
+
lut["sparsity_reverse_lut"] = sparsity_reverse_lut
|
|
131
|
+
|
|
132
|
+
if "n_sparse_blocks" not in lut:
|
|
133
|
+
n_sparse_blocks = torch.sum(lut["sparsity_layout_o"].to(torch.int)).item()
|
|
134
|
+
lut["n_sparse_blocks"] = n_sparse_blocks
|
|
135
|
+
|
|
136
|
+
validate_contiguous(sparsity_layout_o, lut["sparsity_lut"], lut["sparsity_reverse_lut"])
|
|
137
|
+
|
|
138
|
+
return lut
|
|
139
|
+
|
|
140
|
+
@staticmethod
|
|
141
|
+
def build_lut_repeat_interleave(lut: dict, sparsity_layout_x: Tensor, repeats: int,
|
|
142
|
+
sparsity_layout_output: Tensor):
|
|
143
|
+
if lut is None:
|
|
144
|
+
lut = dict()
|
|
145
|
+
|
|
146
|
+
if "sparsity_layout_o" not in lut:
|
|
147
|
+
sparsity_layout_o = torch.repeat_interleave(sparsity_layout_x, repeats, dim=0).contiguous()
|
|
148
|
+
lut["sparsity_layout_o"] = sparsity_layout_o
|
|
149
|
+
|
|
150
|
+
if sparsity_layout_output is not None:
|
|
151
|
+
sparsity_layout_o = torch.logical_and(lut["sparsity_layout_o"], sparsity_layout_output)
|
|
152
|
+
lut["sparsity_layout_o"] = sparsity_layout_o
|
|
153
|
+
|
|
154
|
+
if "sparsity_lut" not in lut:
|
|
155
|
+
sparsity_lut = torch.nonzero(lut["sparsity_layout_o"]).contiguous()
|
|
156
|
+
lut["sparsity_lut"] = sparsity_lut
|
|
157
|
+
|
|
158
|
+
if "sparsity_reverse_lut" not in lut:
|
|
159
|
+
sparsity_layout_flat = sparsity_layout_x.reshape(-1)
|
|
160
|
+
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
161
|
+
(sparsity_layout_flat == 1) -
|
|
162
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
163
|
+
.reshape(sparsity_layout_x.size())
|
|
164
|
+
.repeat_interleave(repeats, dim=0)
|
|
165
|
+
.reshape(-1).contiguous())
|
|
166
|
+
lut["sparsity_reverse_lut"] = sparsity_reverse_lut
|
|
167
|
+
|
|
168
|
+
if "n_sparse_blocks" not in lut:
|
|
169
|
+
n_sparse_blocks = torch.sum(lut["sparsity_layout_o"].to(torch.int)).item()
|
|
170
|
+
lut["n_sparse_blocks"] = n_sparse_blocks
|
|
171
|
+
|
|
172
|
+
validate_contiguous(sparsity_layout_o, lut["sparsity_lut"], lut["sparsity_reverse_lut"])
|
|
173
|
+
|
|
174
|
+
return lut
|
|
132
175
|
|
|
133
176
|
@staticmethod
|
|
134
177
|
def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor,
|
|
@@ -136,49 +179,18 @@ class _BlocksparseRepeat(torch.autograd.Function):
|
|
|
136
179
|
sparsity_block_size: int, n_sparse_blocks: int,
|
|
137
180
|
triton_block_size: int) -> Tensor:
|
|
138
181
|
ctx.save_for_backward(sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
|
|
139
|
-
ctx.x_size = x.size()
|
|
140
|
-
ctx.x_stride = stride(x)
|
|
141
182
|
|
|
142
|
-
return
|
|
143
|
-
|
|
183
|
+
return flow_forward_pull(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
184
|
+
n_sparse_blocks, triton_block_size)
|
|
144
185
|
|
|
145
186
|
@staticmethod
|
|
146
187
|
def backward(ctx, grad_output):
|
|
147
188
|
sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut = ctx.saved_tensors
|
|
148
|
-
x_size = ctx.x_size
|
|
149
|
-
x_stride = ctx.x_stride
|
|
150
189
|
sparsity_block_size = ctx.sparsity_block_size
|
|
151
190
|
triton_block_size = ctx.triton_block_size
|
|
152
191
|
|
|
153
192
|
n_sparse_blocks = torch.sum(sparsity_layout_x.to(torch.int)).item()
|
|
154
193
|
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
x_b, x_r, x_c = grad_output.size()
|
|
159
|
-
x_b_s, x_r_s, x_c_s = stride(grad_output)
|
|
160
|
-
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_o.size()
|
|
161
|
-
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_o)
|
|
162
|
-
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
163
|
-
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
164
|
-
o_b, o_r, o_c = x_size
|
|
165
|
-
o_b_s, o_r_s, o_c_s = x_stride
|
|
166
|
-
|
|
167
|
-
if triton_block_size is None:
|
|
168
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
169
|
-
|
|
170
|
-
triton_grid = lambda meta: [x_b,
|
|
171
|
-
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
172
|
-
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
173
|
-
|
|
174
|
-
(kernel_blocksparse_flow_push[triton_grid]
|
|
175
|
-
(grad_output,
|
|
176
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
177
|
-
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
178
|
-
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
179
|
-
sparsity_reverse_lut,
|
|
180
|
-
output,
|
|
181
|
-
o_b, o_b_s, o_r_s, o_c_s,
|
|
182
|
-
triton_block_size))
|
|
183
|
-
|
|
184
|
-
return output, None, None, None, None, None, None, None
|
|
194
|
+
return flow_forward_push(None, grad_output, sparsity_layout_o, sparsity_lut,
|
|
195
|
+
sparsity_reverse_lut, sparsity_block_size, n_sparse_blocks,
|
|
196
|
+
triton_block_size), None, None, None, None, None, None, None
|