blksprs 1.10.2__py3-none-any.whl → 2.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +2 -6
- blksprs/layouting/distribution_layout.py +32 -25
- blksprs/layouting/sparsity_layout.py +65 -52
- blksprs/ops/conversion.py +423 -374
- blksprs/ops/distribution.py +403 -335
- blksprs/ops/flow.py +135 -83
- blksprs/ops/matmul.py +221 -187
- blksprs/ops/misc/broadcast_ops.py +53 -35
- blksprs/ops/misc/row_wise.py +151 -91
- blksprs/ops/partitioning.py +136 -89
- blksprs/ops/repeat.py +115 -108
- blksprs/ops/softmax.py +244 -208
- blksprs/ops/transpose.py +69 -131
- blksprs/utils/benchmarking.py +3 -3
- blksprs/utils/tools.py +31 -4
- blksprs/utils/validation.py +0 -14
- {blksprs-1.10.2.dist-info → blksprs-2.0rc1.dist-info}/METADATA +42 -36
- blksprs-2.0rc1.dist-info/RECORD +22 -0
- {blksprs-1.10.2.dist-info → blksprs-2.0rc1.dist-info}/WHEEL +1 -1
- blksprs/ops/misc/exp.py +0 -104
- blksprs/utils/layout_utils.py +0 -17
- blksprs-1.10.2.dist-info/RECORD +0 -24
- {blksprs-1.10.2.dist-info → blksprs-2.0rc1.dist-info}/top_level.txt +0 -0
blksprs/ops/partitioning.py
CHANGED
|
@@ -1,15 +1,16 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
from torch import Tensor
|
|
3
|
+
from torch._library import triton_op
|
|
3
4
|
|
|
4
|
-
from blksprs.ops.flow import
|
|
5
|
+
from blksprs.ops.flow import flow_pull_forward
|
|
5
6
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
6
|
-
|
|
7
7
|
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
|
|
8
|
-
validate_sparsity, validate_sparsity_block_size
|
|
8
|
+
validate_sparsity, validate_sparsity_block_size
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
12
|
-
dim: int, sparsity_block_size: int,
|
|
12
|
+
dim: int, sparsity_block_size: int, lut: dict = None) -> (
|
|
13
|
+
BlksprsTensor, Tensor):
|
|
13
14
|
"""Splits a block-sparse tensor in compressed form along the last dimension into partitions.
|
|
14
15
|
|
|
15
16
|
Args:
|
|
@@ -18,7 +19,7 @@ def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
|
18
19
|
partitions (int): The number of partitions to split the block-sparse tensor into.
|
|
19
20
|
dim (int): The dimension along which to split the tensor. Currently only supports dim=2.
|
|
20
21
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
21
|
-
|
|
22
|
+
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
22
23
|
|
|
23
24
|
Returns:
|
|
24
25
|
BlksprsTensor: The block-sparse tensor split into partitions in compressed form.
|
|
@@ -32,63 +33,87 @@ def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
|
32
33
|
validate_device(x)
|
|
33
34
|
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
34
35
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
35
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
36
36
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
.permute(0, 2, 1, 3)
|
|
41
|
-
.reshape(sparsity_layout.size(0) * partitions, sparsity_layout.size(1),
|
|
42
|
-
sparsity_layout.size(2) // partitions).contiguous())
|
|
37
|
+
adjusted_dim = dim % 3
|
|
38
|
+
if adjusted_dim != 2:
|
|
39
|
+
raise NotImplementedError("Currently only supports dim=2")
|
|
43
40
|
|
|
44
|
-
|
|
41
|
+
lut = split_build_lut(lut, sparsity_layout, partitions)
|
|
45
42
|
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
(1 * (sparsity_layout_flat == 0)))
|
|
50
|
-
.reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
|
|
51
|
-
sparsity_layout.size(2) // partitions)
|
|
52
|
-
.permute(0, 2, 1, 3).reshape(-1).contiguous())
|
|
43
|
+
return BlksprsTensor(split_forward(
|
|
44
|
+
x, lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"],
|
|
45
|
+
partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_output"]
|
|
53
46
|
|
|
54
|
-
n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
|
|
55
47
|
|
|
56
|
-
|
|
48
|
+
@triton_op("blksprs::split", mutates_args={})
|
|
49
|
+
def split_forward(x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
50
|
+
_: int, __: int, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
51
|
+
return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
52
|
+
n_sparse_blocks)
|
|
57
53
|
|
|
58
|
-
adjusted_dim = dim % 3
|
|
59
|
-
if adjusted_dim != 2:
|
|
60
|
-
raise NotImplementedError("Currently only supports dim=2")
|
|
61
54
|
|
|
62
|
-
|
|
63
|
-
|
|
55
|
+
def split_backward(ctx, grad_output):
|
|
56
|
+
sparsity_layout = ctx.saved_tensors[0]
|
|
57
|
+
num_partitions = ctx.num_partitions
|
|
58
|
+
dim = ctx.dim
|
|
59
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
60
|
+
|
|
61
|
+
return merge(grad_output, sparsity_layout, num_partitions, dim,
|
|
62
|
+
sparsity_block_size)[0], None, None, None, None, None, None, None
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def split_build_lut(lut: dict, sparsity_layout: Tensor, partitions: int):
|
|
66
|
+
if lut is None:
|
|
67
|
+
lut = dict()
|
|
68
|
+
|
|
69
|
+
if "sparsity_layout_output" not in lut:
|
|
70
|
+
sparsity_layout_output = (sparsity_layout
|
|
71
|
+
.reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
|
|
72
|
+
sparsity_layout.size(2) // partitions)
|
|
73
|
+
.permute(0, 2, 1, 3)
|
|
74
|
+
.reshape(sparsity_layout.size(0) * partitions, sparsity_layout.size(1),
|
|
75
|
+
sparsity_layout.size(2) // partitions).contiguous())
|
|
76
|
+
lut["sparsity_layout_output"] = sparsity_layout_output
|
|
77
|
+
|
|
78
|
+
if "sparsity_lut" not in lut:
|
|
79
|
+
sparsity_lut = torch.nonzero(lut["sparsity_layout_output"]).contiguous()
|
|
80
|
+
lut["sparsity_lut"] = sparsity_lut
|
|
81
|
+
|
|
82
|
+
if "sparsity_reverse_lut" not in lut:
|
|
83
|
+
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
|
84
|
+
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
85
|
+
(sparsity_layout_flat == 1) -
|
|
86
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
87
|
+
.reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
|
|
88
|
+
sparsity_layout.size(2) // partitions)
|
|
89
|
+
.permute(0, 2, 1, 3).reshape(-1).contiguous())
|
|
90
|
+
lut["sparsity_reverse_lut"] = sparsity_reverse_lut
|
|
91
|
+
|
|
92
|
+
if "n_sparse_blocks" not in lut:
|
|
93
|
+
n_sparse_blocks = torch.sum(lut["sparsity_layout_output"].to(torch.int)).item()
|
|
94
|
+
lut["n_sparse_blocks"] = n_sparse_blocks
|
|
64
95
|
|
|
96
|
+
validate_contiguous(lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"])
|
|
65
97
|
|
|
66
|
-
|
|
98
|
+
return lut
|
|
67
99
|
|
|
68
|
-
@staticmethod
|
|
69
|
-
def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
70
|
-
num_partitions: int, dim: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
71
|
-
ctx.save_for_backward(sparsity_layout_o)
|
|
72
|
-
ctx.num_partitions = num_partitions
|
|
73
|
-
ctx.dim = dim
|
|
74
100
|
|
|
75
|
-
|
|
76
|
-
|
|
101
|
+
# noinspection PyUnusedLocal
|
|
102
|
+
def split_setup_context(ctx, inputs, output):
|
|
103
|
+
(_, sparsity_layout_o, _, _, num_partitions, dim, sparsity_block_size, _) = inputs
|
|
77
104
|
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
dim = ctx.dim
|
|
83
|
-
sparsity_block_size = ctx.sparsity_block_size
|
|
84
|
-
triton_block_size = ctx.triton_block_size
|
|
105
|
+
ctx.save_for_backward(sparsity_layout_o)
|
|
106
|
+
ctx.num_partitions = num_partitions
|
|
107
|
+
ctx.dim = dim
|
|
108
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
85
109
|
|
|
86
|
-
|
|
87
|
-
|
|
110
|
+
|
|
111
|
+
split_forward.register_autograd(split_backward, setup_context=split_setup_context)
|
|
88
112
|
|
|
89
113
|
|
|
90
114
|
def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
91
|
-
dim: int, sparsity_block_size: int,
|
|
115
|
+
dim: int, sparsity_block_size: int, lut: dict = None) -> (
|
|
116
|
+
BlksprsTensor, Tensor):
|
|
92
117
|
"""Merges the specified partitions of a block-sparse tensor in compressed form along the last dimension.
|
|
93
118
|
|
|
94
119
|
Args:
|
|
@@ -97,7 +122,7 @@ def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
|
97
122
|
partitions (int): The number of partitions to be merged.
|
|
98
123
|
dim (int): The dimension along which to merge the tensor. Currently only supports dim=2.
|
|
99
124
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
100
|
-
|
|
125
|
+
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
101
126
|
|
|
102
127
|
Returns:
|
|
103
128
|
BlksprsTensor: The merged block-sparse tensor in compressed form.
|
|
@@ -111,60 +136,82 @@ def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
|
111
136
|
validate_device(x)
|
|
112
137
|
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
113
138
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
114
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
115
139
|
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
.reshape(sparsity_layout.size(0) // partitions,
|
|
120
|
-
sparsity_layout.size(1), sparsity_layout.size(2) * partitions).contiguous())
|
|
140
|
+
adjusted_dim = dim % 3
|
|
141
|
+
if adjusted_dim != 2:
|
|
142
|
+
raise NotImplementedError("Currently only supports dim=2")
|
|
121
143
|
|
|
122
|
-
|
|
144
|
+
lut = merge_build_lut(lut, sparsity_layout, partitions)
|
|
123
145
|
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
(1 * (sparsity_layout_flat == 0)))
|
|
128
|
-
.reshape(sparsity_layout.size(0) // partitions, partitions,
|
|
129
|
-
sparsity_layout.size(1), sparsity_layout.size(2))
|
|
130
|
-
.permute(0, 2, 1, 3)
|
|
131
|
-
.reshape(sparsity_layout.size(0) // partitions,
|
|
132
|
-
sparsity_layout.size(1), sparsity_layout.size(2) * partitions)
|
|
133
|
-
.reshape(-1).contiguous())
|
|
146
|
+
return BlksprsTensor(merge_forward(
|
|
147
|
+
x, lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"],
|
|
148
|
+
partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_output"]
|
|
134
149
|
|
|
135
|
-
n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
|
|
136
150
|
|
|
137
|
-
|
|
151
|
+
@triton_op("blksprs::merge", mutates_args={})
|
|
152
|
+
def merge_forward(x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
153
|
+
_: int, __: int, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
154
|
+
return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
155
|
+
n_sparse_blocks)
|
|
138
156
|
|
|
139
|
-
adjusted_dim = dim % 3
|
|
140
|
-
if adjusted_dim != 2:
|
|
141
|
-
raise NotImplementedError("Currently only supports dim=2")
|
|
142
157
|
|
|
143
|
-
|
|
144
|
-
|
|
158
|
+
def merge_backward(ctx, grad_output):
|
|
159
|
+
sparsity_layout = ctx.saved_tensors[0]
|
|
160
|
+
num_partitions = ctx.num_partitions
|
|
161
|
+
dim = ctx.dim
|
|
162
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
163
|
+
|
|
164
|
+
return split(grad_output, sparsity_layout, num_partitions, dim,
|
|
165
|
+
sparsity_block_size)[0], None, None, None, None, None, None, None
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def merge_build_lut(lut: dict, sparsity_layout: Tensor, partitions: int):
|
|
169
|
+
if lut is None:
|
|
170
|
+
lut = dict()
|
|
171
|
+
|
|
172
|
+
if "sparsity_layout_output" not in lut:
|
|
173
|
+
sparsity_layout_output = (sparsity_layout.reshape(sparsity_layout.size(0) // partitions, partitions,
|
|
174
|
+
sparsity_layout.size(1), sparsity_layout.size(2))
|
|
175
|
+
.permute(0, 2, 1, 3)
|
|
176
|
+
.reshape(sparsity_layout.size(0) // partitions,
|
|
177
|
+
sparsity_layout.size(1),
|
|
178
|
+
sparsity_layout.size(2) * partitions).contiguous())
|
|
179
|
+
lut["sparsity_layout_output"] = sparsity_layout_output
|
|
180
|
+
|
|
181
|
+
if "sparsity_lut" not in lut:
|
|
182
|
+
sparsity_lut = torch.nonzero(lut["sparsity_layout_output"]).contiguous()
|
|
183
|
+
lut["sparsity_lut"] = sparsity_lut
|
|
184
|
+
|
|
185
|
+
if "sparsity_reverse_lut" not in lut:
|
|
186
|
+
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
|
187
|
+
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
188
|
+
(sparsity_layout_flat == 1) -
|
|
189
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
190
|
+
.reshape(sparsity_layout.size(0) // partitions, partitions,
|
|
191
|
+
sparsity_layout.size(1), sparsity_layout.size(2))
|
|
192
|
+
.permute(0, 2, 1, 3)
|
|
193
|
+
.reshape(sparsity_layout.size(0) // partitions,
|
|
194
|
+
sparsity_layout.size(1), sparsity_layout.size(2) * partitions)
|
|
195
|
+
.reshape(-1).contiguous())
|
|
196
|
+
lut["sparsity_reverse_lut"] = sparsity_reverse_lut
|
|
145
197
|
|
|
198
|
+
if "n_sparse_blocks" not in lut:
|
|
199
|
+
n_sparse_blocks = torch.sum(lut["sparsity_layout_output"].to(torch.int)).item()
|
|
200
|
+
lut["n_sparse_blocks"] = n_sparse_blocks
|
|
146
201
|
|
|
147
|
-
|
|
202
|
+
validate_contiguous(lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"])
|
|
148
203
|
|
|
149
|
-
|
|
150
|
-
def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
151
|
-
num_partitions: int, dim: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
152
|
-
ctx.save_for_backward(sparsity_layout_o)
|
|
153
|
-
ctx.num_partitions = num_partitions
|
|
154
|
-
ctx.dim = dim
|
|
204
|
+
return lut
|
|
155
205
|
|
|
156
|
-
return flow_forward(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
157
|
-
n_sparse_blocks, triton_block_size)
|
|
158
206
|
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
num_partitions = ctx.num_partitions
|
|
163
|
-
dim = ctx.dim
|
|
164
|
-
sparsity_block_size = ctx.sparsity_block_size
|
|
165
|
-
triton_block_size = ctx.triton_block_size
|
|
207
|
+
# noinspection PyUnusedLocal
|
|
208
|
+
def merge_setup_context(ctx, inputs, output):
|
|
209
|
+
(_, sparsity_layout_o, _, _, num_partitions, dim, sparsity_block_size, _) = inputs
|
|
166
210
|
|
|
167
|
-
|
|
168
|
-
|
|
211
|
+
ctx.save_for_backward(sparsity_layout_o)
|
|
212
|
+
ctx.num_partitions = num_partitions
|
|
213
|
+
ctx.dim = dim
|
|
214
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
169
215
|
|
|
170
216
|
|
|
217
|
+
merge_forward.register_autograd(merge_backward, setup_context=merge_setup_context)
|
blksprs/ops/repeat.py
CHANGED
|
@@ -1,16 +1,15 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
import triton
|
|
3
2
|
from torch import Tensor
|
|
3
|
+
from torch._library import triton_op
|
|
4
4
|
|
|
5
|
-
from blksprs.ops.flow import
|
|
5
|
+
from blksprs.ops.flow import flow_pull_forward, flow_push_forward
|
|
6
6
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
7
|
-
from blksprs.utils.tools import get_triton_block_size, stride
|
|
8
7
|
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
|
|
9
|
-
validate_sparsity, validate_sparsity_block_size
|
|
8
|
+
validate_sparsity, validate_sparsity_block_size
|
|
10
9
|
|
|
11
10
|
|
|
12
11
|
def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
|
|
13
|
-
sparsity_block_size: int, sparsity_layout_output: Tensor = None,
|
|
12
|
+
sparsity_block_size: int, sparsity_layout_output: Tensor = None, lut: dict = None) -> (
|
|
14
13
|
BlksprsTensor, Tensor):
|
|
15
14
|
"""Repeats a block-spare tensor in compressed form according to the given repeats.
|
|
16
15
|
|
|
@@ -29,7 +28,7 @@ def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int,
|
|
|
29
28
|
third dimension respectively.
|
|
30
29
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
31
30
|
sparsity_layout_output (Tensor): The desired sparsity layout of the output tensor (default ``None``).
|
|
32
|
-
|
|
31
|
+
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
33
32
|
|
|
34
33
|
Returns:
|
|
35
34
|
BlksprsTensor: A block-sparse tensor in compressed form containing the repeated values.
|
|
@@ -43,35 +42,16 @@ def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int,
|
|
|
43
42
|
validate_device(x)
|
|
44
43
|
validate_sparsity(sparsity_block_size, (x, sparsity_layout_x))
|
|
45
44
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
46
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
47
45
|
|
|
48
|
-
|
|
46
|
+
lut = repeat_build_lut(lut, sparsity_layout_x, repeats, sparsity_layout_output)
|
|
49
47
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
sparsity_lut = torch.nonzero(sparsity_layout_o).contiguous()
|
|
54
|
-
|
|
55
|
-
sparsity_layout_flat = sparsity_layout_x.reshape(-1)
|
|
56
|
-
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
57
|
-
(sparsity_layout_flat == 1) -
|
|
58
|
-
(1 * (sparsity_layout_flat == 0)))
|
|
59
|
-
.reshape(sparsity_layout_x.size())
|
|
60
|
-
.repeat(repeats[0], repeats[1], repeats[2])
|
|
61
|
-
.reshape(-1).contiguous())
|
|
62
|
-
|
|
63
|
-
n_sparse_blocks = torch.sum(sparsity_layout_o.to(torch.int)).item()
|
|
64
|
-
|
|
65
|
-
validate_contiguous(sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
|
|
66
|
-
|
|
67
|
-
return BlksprsTensor(
|
|
68
|
-
_BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
|
|
69
|
-
sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_o
|
|
48
|
+
return BlksprsTensor(repeat_forward(
|
|
49
|
+
x, sparsity_layout_x, lut["sparsity_layout_o"], lut["sparsity_lut"],
|
|
50
|
+
lut["sparsity_reverse_lut"], sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_o"]
|
|
70
51
|
|
|
71
52
|
|
|
72
53
|
def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
|
|
73
|
-
sparsity_block_size: int, sparsity_layout_output: Tensor = None,
|
|
74
|
-
triton_block_size: int = None) -> (
|
|
54
|
+
sparsity_block_size: int, sparsity_layout_output: Tensor = None, lut: dict = None) -> (
|
|
75
55
|
BlksprsTensor, Tensor):
|
|
76
56
|
"""Repeats and interleaves the block-sparse tensor in compressed form.
|
|
77
57
|
|
|
@@ -88,7 +68,7 @@ def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
|
|
|
88
68
|
repeats (int): The number of times to repeat the matrices.
|
|
89
69
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
90
70
|
sparsity_layout_output (Tensor): The desired sparsity layout of the output tensor (default ``None``).
|
|
91
|
-
|
|
71
|
+
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
92
72
|
|
|
93
73
|
Returns:
|
|
94
74
|
BlksprsTensor: A block-sparse tensor in compressed form containing the repeated and interleaved matrices.
|
|
@@ -102,83 +82,110 @@ def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
|
|
|
102
82
|
validate_device(x)
|
|
103
83
|
validate_sparsity(sparsity_block_size, (x, sparsity_layout_x))
|
|
104
84
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
105
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
106
85
|
|
|
107
|
-
|
|
86
|
+
lut = repeat_interleave_build_lut(lut, sparsity_layout_x, repeats, sparsity_layout_output)
|
|
87
|
+
|
|
88
|
+
return BlksprsTensor(repeat_forward(
|
|
89
|
+
x, sparsity_layout_x, lut["sparsity_layout_o"], lut["sparsity_lut"],
|
|
90
|
+
lut["sparsity_reverse_lut"], sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_o"]
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@triton_op("blksprs::repeat", mutates_args={})
|
|
94
|
+
def repeat_forward(x: Tensor, _: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor,
|
|
95
|
+
sparsity_reverse_lut: Tensor,
|
|
96
|
+
sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
97
|
+
return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
98
|
+
n_sparse_blocks)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def repeat_backward(ctx, grad_output):
|
|
102
|
+
sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut = ctx.saved_tensors
|
|
103
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
104
|
+
n_sparse_blocks = torch.sum(sparsity_layout_x.to(torch.int)).item()
|
|
105
|
+
|
|
106
|
+
return flow_push_forward(grad_output, sparsity_layout_o, sparsity_lut,
|
|
107
|
+
sparsity_reverse_lut, sparsity_block_size,
|
|
108
|
+
n_sparse_blocks), None, None, None, None, None, None
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def repeat_build_lut(lut: dict, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
|
|
112
|
+
sparsity_layout_output: Tensor):
|
|
113
|
+
if lut is None:
|
|
114
|
+
lut = dict()
|
|
115
|
+
|
|
116
|
+
if "sparsity_layout_o" not in lut:
|
|
117
|
+
sparsity_layout_o = sparsity_layout_x.repeat(repeats[0], repeats[1], repeats[2])
|
|
118
|
+
lut["sparsity_layout_o"] = sparsity_layout_o
|
|
108
119
|
|
|
109
120
|
if sparsity_layout_output is not None:
|
|
110
|
-
sparsity_layout_o = torch.logical_and(sparsity_layout_o, sparsity_layout_output)
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
o_b, o_b_s, o_r_s, o_c_s,
|
|
182
|
-
triton_block_size))
|
|
183
|
-
|
|
184
|
-
return output, None, None, None, None, None, None, None
|
|
121
|
+
sparsity_layout_o = torch.logical_and(lut["sparsity_layout_o"], sparsity_layout_output)
|
|
122
|
+
lut["sparsity_layout_o"] = sparsity_layout_o
|
|
123
|
+
|
|
124
|
+
if "sparsity_lut" not in lut:
|
|
125
|
+
sparsity_lut = torch.nonzero(lut["sparsity_layout_o"]).contiguous()
|
|
126
|
+
lut["sparsity_lut"] = sparsity_lut
|
|
127
|
+
|
|
128
|
+
if "sparsity_reverse_lut" not in lut:
|
|
129
|
+
sparsity_layout_flat = sparsity_layout_x.reshape(-1)
|
|
130
|
+
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
131
|
+
(sparsity_layout_flat == 1) -
|
|
132
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
133
|
+
.reshape(sparsity_layout_x.size())
|
|
134
|
+
.repeat(repeats[0], repeats[1], repeats[2])
|
|
135
|
+
.reshape(-1).contiguous())
|
|
136
|
+
lut["sparsity_reverse_lut"] = sparsity_reverse_lut
|
|
137
|
+
|
|
138
|
+
if "n_sparse_blocks" not in lut:
|
|
139
|
+
n_sparse_blocks = torch.sum(lut["sparsity_layout_o"].to(torch.int)).item()
|
|
140
|
+
lut["n_sparse_blocks"] = n_sparse_blocks
|
|
141
|
+
|
|
142
|
+
validate_contiguous(sparsity_layout_o, lut["sparsity_lut"], lut["sparsity_reverse_lut"])
|
|
143
|
+
|
|
144
|
+
return lut
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def repeat_interleave_build_lut(lut: dict, sparsity_layout_x: Tensor, repeats: int,
|
|
148
|
+
sparsity_layout_output: Tensor):
|
|
149
|
+
if lut is None:
|
|
150
|
+
lut = dict()
|
|
151
|
+
|
|
152
|
+
if "sparsity_layout_o" not in lut:
|
|
153
|
+
sparsity_layout_o = torch.repeat_interleave(sparsity_layout_x, repeats, dim=0).contiguous()
|
|
154
|
+
lut["sparsity_layout_o"] = sparsity_layout_o
|
|
155
|
+
|
|
156
|
+
if sparsity_layout_output is not None:
|
|
157
|
+
sparsity_layout_o = torch.logical_and(lut["sparsity_layout_o"], sparsity_layout_output)
|
|
158
|
+
lut["sparsity_layout_o"] = sparsity_layout_o
|
|
159
|
+
|
|
160
|
+
if "sparsity_lut" not in lut:
|
|
161
|
+
sparsity_lut = torch.nonzero(lut["sparsity_layout_o"]).contiguous()
|
|
162
|
+
lut["sparsity_lut"] = sparsity_lut
|
|
163
|
+
|
|
164
|
+
if "sparsity_reverse_lut" not in lut:
|
|
165
|
+
sparsity_layout_flat = sparsity_layout_x.reshape(-1)
|
|
166
|
+
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
167
|
+
(sparsity_layout_flat == 1) -
|
|
168
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
169
|
+
.reshape(sparsity_layout_x.size())
|
|
170
|
+
.repeat_interleave(repeats, dim=0)
|
|
171
|
+
.reshape(-1).contiguous())
|
|
172
|
+
lut["sparsity_reverse_lut"] = sparsity_reverse_lut
|
|
173
|
+
|
|
174
|
+
if "n_sparse_blocks" not in lut:
|
|
175
|
+
n_sparse_blocks = torch.sum(lut["sparsity_layout_o"].to(torch.int)).item()
|
|
176
|
+
lut["n_sparse_blocks"] = n_sparse_blocks
|
|
177
|
+
|
|
178
|
+
validate_contiguous(sparsity_layout_o, lut["sparsity_lut"], lut["sparsity_reverse_lut"])
|
|
179
|
+
|
|
180
|
+
return lut
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
# noinspection PyUnusedLocal
|
|
184
|
+
def repeat_setup_context(ctx, inputs, output):
|
|
185
|
+
(_, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size, _) = inputs
|
|
186
|
+
|
|
187
|
+
ctx.save_for_backward(sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
|
|
188
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
repeat_forward.register_autograd(repeat_backward, setup_context=repeat_setup_context)
|