blksprs 1.10.2__py3-none-any.whl → 2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +4 -6
- blksprs/layouting/distribution_layout.py +64 -48
- blksprs/layouting/sparsity_layout.py +96 -72
- blksprs/ops/conversion.py +350 -312
- blksprs/ops/distribution.py +320 -266
- blksprs/ops/flow.py +135 -89
- blksprs/ops/matmul.py +184 -151
- blksprs/ops/misc/broadcast_ops.py +68 -53
- blksprs/ops/misc/row_wise.py +223 -176
- blksprs/ops/partitioning.py +140 -89
- blksprs/ops/repeat.py +118 -108
- blksprs/ops/softmax.py +201 -167
- blksprs/ops/transpose.py +71 -131
- blksprs/utils/autotuning.py +78 -0
- blksprs/utils/benchmarking.py +3 -3
- blksprs/utils/processing.py +2 -1
- blksprs/utils/tools.py +5 -6
- blksprs/utils/validation.py +22 -16
- {blksprs-1.10.2.dist-info → blksprs-2.0.dist-info}/METADATA +55 -36
- blksprs-2.0.dist-info/RECORD +23 -0
- {blksprs-1.10.2.dist-info → blksprs-2.0.dist-info}/WHEEL +1 -1
- blksprs/ops/misc/exp.py +0 -104
- blksprs/utils/layout_utils.py +0 -17
- blksprs-1.10.2.dist-info/RECORD +0 -24
- {blksprs-1.10.2.dist-info → blksprs-2.0.dist-info}/top_level.txt +0 -0
blksprs/ops/partitioning.py
CHANGED
|
@@ -1,15 +1,17 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
from torch import Tensor
|
|
3
|
+
from torch._library import triton_op
|
|
3
4
|
|
|
4
|
-
from blksprs.ops.flow import
|
|
5
|
+
from blksprs.ops.flow import flow_pull_forward
|
|
5
6
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
6
|
-
|
|
7
7
|
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
|
|
8
|
-
validate_sparsity, validate_sparsity_block_size
|
|
8
|
+
validate_sparsity, validate_sparsity_block_size
|
|
9
9
|
|
|
10
10
|
|
|
11
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
11
12
|
def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
12
|
-
dim: int, sparsity_block_size: int,
|
|
13
|
+
dim: int, sparsity_block_size: int, lut: dict = None) -> (
|
|
14
|
+
BlksprsTensor, Tensor):
|
|
13
15
|
"""Splits a block-sparse tensor in compressed form along the last dimension into partitions.
|
|
14
16
|
|
|
15
17
|
Args:
|
|
@@ -18,7 +20,7 @@ def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
|
18
20
|
partitions (int): The number of partitions to split the block-sparse tensor into.
|
|
19
21
|
dim (int): The dimension along which to split the tensor. Currently only supports dim=2.
|
|
20
22
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
21
|
-
|
|
23
|
+
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
22
24
|
|
|
23
25
|
Returns:
|
|
24
26
|
BlksprsTensor: The block-sparse tensor split into partitions in compressed form.
|
|
@@ -32,63 +34,89 @@ def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
|
32
34
|
validate_device(x)
|
|
33
35
|
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
34
36
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
35
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
36
37
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
.permute(0, 2, 1, 3)
|
|
41
|
-
.reshape(sparsity_layout.size(0) * partitions, sparsity_layout.size(1),
|
|
42
|
-
sparsity_layout.size(2) // partitions).contiguous())
|
|
38
|
+
adjusted_dim = dim % 3
|
|
39
|
+
if adjusted_dim != 2:
|
|
40
|
+
raise NotImplementedError("Currently only supports dim=2")
|
|
43
41
|
|
|
44
|
-
|
|
42
|
+
lut = split_build_lut(lut, sparsity_layout, partitions)
|
|
45
43
|
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
(1 * (sparsity_layout_flat == 0)))
|
|
50
|
-
.reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
|
|
51
|
-
sparsity_layout.size(2) // partitions)
|
|
52
|
-
.permute(0, 2, 1, 3).reshape(-1).contiguous())
|
|
44
|
+
return BlksprsTensor(split_forward(
|
|
45
|
+
x, lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"],
|
|
46
|
+
partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_output"]
|
|
53
47
|
|
|
54
|
-
n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
|
|
55
48
|
|
|
56
|
-
|
|
49
|
+
@triton_op("blksprs::split_forward", mutates_args={})
|
|
50
|
+
def split_forward(x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
51
|
+
_: int, __: int, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
52
|
+
with torch.no_grad():
|
|
53
|
+
return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
54
|
+
n_sparse_blocks)
|
|
57
55
|
|
|
58
|
-
adjusted_dim = dim % 3
|
|
59
|
-
if adjusted_dim != 2:
|
|
60
|
-
raise NotImplementedError("Currently only supports dim=2")
|
|
61
56
|
|
|
62
|
-
|
|
63
|
-
|
|
57
|
+
def split_wrapper_backward(ctx, grad_output):
|
|
58
|
+
sparsity_layout = ctx.saved_tensors[0]
|
|
59
|
+
num_partitions = ctx.num_partitions
|
|
60
|
+
dim = ctx.dim
|
|
61
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
62
|
+
|
|
63
|
+
return merge(grad_output, sparsity_layout, num_partitions, dim,
|
|
64
|
+
sparsity_block_size)[0], None, None, None, None, None, None, None
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def split_build_lut(lut: dict, sparsity_layout: Tensor, partitions: int):
|
|
68
|
+
if lut is None:
|
|
69
|
+
lut = dict()
|
|
70
|
+
|
|
71
|
+
if "sparsity_layout_output" not in lut:
|
|
72
|
+
sparsity_layout_output = (sparsity_layout
|
|
73
|
+
.reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
|
|
74
|
+
sparsity_layout.size(2) // partitions)
|
|
75
|
+
.permute(0, 2, 1, 3)
|
|
76
|
+
.reshape(sparsity_layout.size(0) * partitions, sparsity_layout.size(1),
|
|
77
|
+
sparsity_layout.size(2) // partitions).contiguous())
|
|
78
|
+
lut["sparsity_layout_output"] = sparsity_layout_output
|
|
79
|
+
|
|
80
|
+
if "sparsity_lut" not in lut:
|
|
81
|
+
sparsity_lut = torch.nonzero(lut["sparsity_layout_output"]).contiguous()
|
|
82
|
+
lut["sparsity_lut"] = sparsity_lut
|
|
83
|
+
|
|
84
|
+
if "sparsity_reverse_lut" not in lut:
|
|
85
|
+
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
|
86
|
+
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
87
|
+
(sparsity_layout_flat == 1) -
|
|
88
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
89
|
+
.reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
|
|
90
|
+
sparsity_layout.size(2) // partitions)
|
|
91
|
+
.permute(0, 2, 1, 3).reshape(-1).contiguous())
|
|
92
|
+
lut["sparsity_reverse_lut"] = sparsity_reverse_lut
|
|
93
|
+
|
|
94
|
+
if "n_sparse_blocks" not in lut:
|
|
95
|
+
n_sparse_blocks = torch.sum(lut["sparsity_layout_output"].to(torch.int)).item()
|
|
96
|
+
lut["n_sparse_blocks"] = n_sparse_blocks
|
|
64
97
|
|
|
98
|
+
validate_contiguous(lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"])
|
|
65
99
|
|
|
66
|
-
|
|
100
|
+
return lut
|
|
67
101
|
|
|
68
|
-
@staticmethod
|
|
69
|
-
def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
70
|
-
num_partitions: int, dim: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
71
|
-
ctx.save_for_backward(sparsity_layout_o)
|
|
72
|
-
ctx.num_partitions = num_partitions
|
|
73
|
-
ctx.dim = dim
|
|
74
102
|
|
|
75
|
-
|
|
76
|
-
|
|
103
|
+
# noinspection PyUnusedLocal
|
|
104
|
+
def split_setup_context(ctx, inputs, output):
|
|
105
|
+
(_, sparsity_layout_o, _, _, num_partitions, dim, sparsity_block_size, _) = inputs
|
|
77
106
|
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
dim = ctx.dim
|
|
83
|
-
sparsity_block_size = ctx.sparsity_block_size
|
|
84
|
-
triton_block_size = ctx.triton_block_size
|
|
107
|
+
ctx.save_for_backward(sparsity_layout_o)
|
|
108
|
+
ctx.num_partitions = num_partitions
|
|
109
|
+
ctx.dim = dim
|
|
110
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
85
111
|
|
|
86
|
-
return merge(grad_output, sparsity_layout, num_partitions, dim,
|
|
87
|
-
sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None, None
|
|
88
112
|
|
|
113
|
+
split_forward.register_autograd(split_wrapper_backward, setup_context=split_setup_context)
|
|
89
114
|
|
|
115
|
+
|
|
116
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
90
117
|
def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
91
|
-
dim: int, sparsity_block_size: int,
|
|
118
|
+
dim: int, sparsity_block_size: int, lut: dict = None) -> (
|
|
119
|
+
BlksprsTensor, Tensor):
|
|
92
120
|
"""Merges the specified partitions of a block-sparse tensor in compressed form along the last dimension.
|
|
93
121
|
|
|
94
122
|
Args:
|
|
@@ -97,7 +125,7 @@ def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
|
97
125
|
partitions (int): The number of partitions to be merged.
|
|
98
126
|
dim (int): The dimension along which to merge the tensor. Currently only supports dim=2.
|
|
99
127
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
100
|
-
|
|
128
|
+
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
101
129
|
|
|
102
130
|
Returns:
|
|
103
131
|
BlksprsTensor: The merged block-sparse tensor in compressed form.
|
|
@@ -111,60 +139,83 @@ def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
|
111
139
|
validate_device(x)
|
|
112
140
|
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
113
141
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
114
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
115
142
|
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
.reshape(sparsity_layout.size(0) // partitions,
|
|
120
|
-
sparsity_layout.size(1), sparsity_layout.size(2) * partitions).contiguous())
|
|
143
|
+
adjusted_dim = dim % 3
|
|
144
|
+
if adjusted_dim != 2:
|
|
145
|
+
raise NotImplementedError("Currently only supports dim=2")
|
|
121
146
|
|
|
122
|
-
|
|
147
|
+
lut = merge_build_lut(lut, sparsity_layout, partitions)
|
|
123
148
|
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
(1 * (sparsity_layout_flat == 0)))
|
|
128
|
-
.reshape(sparsity_layout.size(0) // partitions, partitions,
|
|
129
|
-
sparsity_layout.size(1), sparsity_layout.size(2))
|
|
130
|
-
.permute(0, 2, 1, 3)
|
|
131
|
-
.reshape(sparsity_layout.size(0) // partitions,
|
|
132
|
-
sparsity_layout.size(1), sparsity_layout.size(2) * partitions)
|
|
133
|
-
.reshape(-1).contiguous())
|
|
149
|
+
return BlksprsTensor(merge_forward(
|
|
150
|
+
x, lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"],
|
|
151
|
+
partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_output"]
|
|
134
152
|
|
|
135
|
-
n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
|
|
136
153
|
|
|
137
|
-
|
|
154
|
+
@triton_op("blksprs::merge_forward", mutates_args={})
|
|
155
|
+
def merge_forward(x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
156
|
+
_: int, __: int, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
157
|
+
with torch.no_grad():
|
|
158
|
+
return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
159
|
+
n_sparse_blocks)
|
|
138
160
|
|
|
139
|
-
adjusted_dim = dim % 3
|
|
140
|
-
if adjusted_dim != 2:
|
|
141
|
-
raise NotImplementedError("Currently only supports dim=2")
|
|
142
161
|
|
|
143
|
-
|
|
144
|
-
|
|
162
|
+
def merge_wrapper_backward(ctx, grad_output):
|
|
163
|
+
sparsity_layout = ctx.saved_tensors[0]
|
|
164
|
+
num_partitions = ctx.num_partitions
|
|
165
|
+
dim = ctx.dim
|
|
166
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
167
|
+
|
|
168
|
+
return split(grad_output, sparsity_layout, num_partitions, dim,
|
|
169
|
+
sparsity_block_size)[0], None, None, None, None, None, None, None
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def merge_build_lut(lut: dict, sparsity_layout: Tensor, partitions: int):
|
|
173
|
+
if lut is None:
|
|
174
|
+
lut = dict()
|
|
175
|
+
|
|
176
|
+
if "sparsity_layout_output" not in lut:
|
|
177
|
+
sparsity_layout_output = (sparsity_layout.reshape(sparsity_layout.size(0) // partitions, partitions,
|
|
178
|
+
sparsity_layout.size(1), sparsity_layout.size(2))
|
|
179
|
+
.permute(0, 2, 1, 3)
|
|
180
|
+
.reshape(sparsity_layout.size(0) // partitions,
|
|
181
|
+
sparsity_layout.size(1),
|
|
182
|
+
sparsity_layout.size(2) * partitions).contiguous())
|
|
183
|
+
lut["sparsity_layout_output"] = sparsity_layout_output
|
|
184
|
+
|
|
185
|
+
if "sparsity_lut" not in lut:
|
|
186
|
+
sparsity_lut = torch.nonzero(lut["sparsity_layout_output"]).contiguous()
|
|
187
|
+
lut["sparsity_lut"] = sparsity_lut
|
|
188
|
+
|
|
189
|
+
if "sparsity_reverse_lut" not in lut:
|
|
190
|
+
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
|
191
|
+
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
192
|
+
(sparsity_layout_flat == 1) -
|
|
193
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
194
|
+
.reshape(sparsity_layout.size(0) // partitions, partitions,
|
|
195
|
+
sparsity_layout.size(1), sparsity_layout.size(2))
|
|
196
|
+
.permute(0, 2, 1, 3)
|
|
197
|
+
.reshape(sparsity_layout.size(0) // partitions,
|
|
198
|
+
sparsity_layout.size(1), sparsity_layout.size(2) * partitions)
|
|
199
|
+
.reshape(-1).contiguous())
|
|
200
|
+
lut["sparsity_reverse_lut"] = sparsity_reverse_lut
|
|
145
201
|
|
|
202
|
+
if "n_sparse_blocks" not in lut:
|
|
203
|
+
n_sparse_blocks = torch.sum(lut["sparsity_layout_output"].to(torch.int)).item()
|
|
204
|
+
lut["n_sparse_blocks"] = n_sparse_blocks
|
|
146
205
|
|
|
147
|
-
|
|
206
|
+
validate_contiguous(lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"])
|
|
148
207
|
|
|
149
|
-
|
|
150
|
-
def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
151
|
-
num_partitions: int, dim: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
152
|
-
ctx.save_for_backward(sparsity_layout_o)
|
|
153
|
-
ctx.num_partitions = num_partitions
|
|
154
|
-
ctx.dim = dim
|
|
208
|
+
return lut
|
|
155
209
|
|
|
156
|
-
return flow_forward(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
157
|
-
n_sparse_blocks, triton_block_size)
|
|
158
210
|
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
num_partitions = ctx.num_partitions
|
|
163
|
-
dim = ctx.dim
|
|
164
|
-
sparsity_block_size = ctx.sparsity_block_size
|
|
165
|
-
triton_block_size = ctx.triton_block_size
|
|
211
|
+
# noinspection PyUnusedLocal
|
|
212
|
+
def merge_setup_context(ctx, inputs, output):
|
|
213
|
+
(_, sparsity_layout_o, _, _, num_partitions, dim, sparsity_block_size, _) = inputs
|
|
166
214
|
|
|
167
|
-
|
|
168
|
-
|
|
215
|
+
ctx.save_for_backward(sparsity_layout_o)
|
|
216
|
+
ctx.num_partitions = num_partitions
|
|
217
|
+
ctx.dim = dim
|
|
218
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
169
219
|
|
|
170
220
|
|
|
221
|
+
merge_forward.register_autograd(merge_wrapper_backward, setup_context=merge_setup_context)
|
blksprs/ops/repeat.py
CHANGED
|
@@ -1,16 +1,16 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
import triton
|
|
3
2
|
from torch import Tensor
|
|
3
|
+
from torch._library import triton_op
|
|
4
4
|
|
|
5
|
-
from blksprs.ops.flow import
|
|
5
|
+
from blksprs.ops.flow import flow_pull_forward, flow_push_forward
|
|
6
6
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
7
|
-
from blksprs.utils.tools import get_triton_block_size, stride
|
|
8
7
|
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
|
|
9
|
-
validate_sparsity, validate_sparsity_block_size
|
|
8
|
+
validate_sparsity, validate_sparsity_block_size
|
|
10
9
|
|
|
11
10
|
|
|
11
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
12
12
|
def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
|
|
13
|
-
sparsity_block_size: int, sparsity_layout_output: Tensor = None,
|
|
13
|
+
sparsity_block_size: int, sparsity_layout_output: Tensor = None, lut: dict = None) -> (
|
|
14
14
|
BlksprsTensor, Tensor):
|
|
15
15
|
"""Repeats a block-spare tensor in compressed form according to the given repeats.
|
|
16
16
|
|
|
@@ -29,7 +29,7 @@ def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int,
|
|
|
29
29
|
third dimension respectively.
|
|
30
30
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
31
31
|
sparsity_layout_output (Tensor): The desired sparsity layout of the output tensor (default ``None``).
|
|
32
|
-
|
|
32
|
+
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
33
33
|
|
|
34
34
|
Returns:
|
|
35
35
|
BlksprsTensor: A block-sparse tensor in compressed form containing the repeated values.
|
|
@@ -43,35 +43,17 @@ def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int,
|
|
|
43
43
|
validate_device(x)
|
|
44
44
|
validate_sparsity(sparsity_block_size, (x, sparsity_layout_x))
|
|
45
45
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
46
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
47
46
|
|
|
48
|
-
|
|
47
|
+
lut = repeat_build_lut(lut, sparsity_layout_x, repeats, sparsity_layout_output)
|
|
49
48
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
sparsity_lut = torch.nonzero(sparsity_layout_o).contiguous()
|
|
54
|
-
|
|
55
|
-
sparsity_layout_flat = sparsity_layout_x.reshape(-1)
|
|
56
|
-
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
57
|
-
(sparsity_layout_flat == 1) -
|
|
58
|
-
(1 * (sparsity_layout_flat == 0)))
|
|
59
|
-
.reshape(sparsity_layout_x.size())
|
|
60
|
-
.repeat(repeats[0], repeats[1], repeats[2])
|
|
61
|
-
.reshape(-1).contiguous())
|
|
62
|
-
|
|
63
|
-
n_sparse_blocks = torch.sum(sparsity_layout_o.to(torch.int)).item()
|
|
64
|
-
|
|
65
|
-
validate_contiguous(sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
|
|
66
|
-
|
|
67
|
-
return BlksprsTensor(
|
|
68
|
-
_BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
|
|
69
|
-
sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_o
|
|
49
|
+
return BlksprsTensor(repeat_forward(
|
|
50
|
+
x, sparsity_layout_x, lut["sparsity_layout_o"], lut["sparsity_lut"],
|
|
51
|
+
lut["sparsity_reverse_lut"], sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_o"]
|
|
70
52
|
|
|
71
53
|
|
|
54
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
72
55
|
def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
|
|
73
|
-
sparsity_block_size: int, sparsity_layout_output: Tensor = None,
|
|
74
|
-
triton_block_size: int = None) -> (
|
|
56
|
+
sparsity_block_size: int, sparsity_layout_output: Tensor = None, lut: dict = None) -> (
|
|
75
57
|
BlksprsTensor, Tensor):
|
|
76
58
|
"""Repeats and interleaves the block-sparse tensor in compressed form.
|
|
77
59
|
|
|
@@ -88,7 +70,7 @@ def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
|
|
|
88
70
|
repeats (int): The number of times to repeat the matrices.
|
|
89
71
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
90
72
|
sparsity_layout_output (Tensor): The desired sparsity layout of the output tensor (default ``None``).
|
|
91
|
-
|
|
73
|
+
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
92
74
|
|
|
93
75
|
Returns:
|
|
94
76
|
BlksprsTensor: A block-sparse tensor in compressed form containing the repeated and interleaved matrices.
|
|
@@ -102,83 +84,111 @@ def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
|
|
|
102
84
|
validate_device(x)
|
|
103
85
|
validate_sparsity(sparsity_block_size, (x, sparsity_layout_x))
|
|
104
86
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
105
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
106
87
|
|
|
107
|
-
|
|
88
|
+
lut = repeat_interleave_build_lut(lut, sparsity_layout_x, repeats, sparsity_layout_output)
|
|
89
|
+
|
|
90
|
+
return BlksprsTensor(repeat_forward(
|
|
91
|
+
x, sparsity_layout_x, lut["sparsity_layout_o"], lut["sparsity_lut"],
|
|
92
|
+
lut["sparsity_reverse_lut"], sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_o"]
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@triton_op("blksprs::repeat_forward", mutates_args={})
|
|
96
|
+
def repeat_forward(x: Tensor, _: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor,
|
|
97
|
+
sparsity_reverse_lut: Tensor,
|
|
98
|
+
sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
99
|
+
with torch.no_grad():
|
|
100
|
+
return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
101
|
+
n_sparse_blocks)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def repeat_wrapper_backward(ctx, grad_output):
|
|
105
|
+
sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut = ctx.saved_tensors
|
|
106
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
107
|
+
n_sparse_blocks = torch.sum(sparsity_layout_x.to(torch.int)).item()
|
|
108
|
+
|
|
109
|
+
return flow_push_forward(grad_output, sparsity_layout_o, sparsity_lut,
|
|
110
|
+
sparsity_reverse_lut, sparsity_block_size,
|
|
111
|
+
n_sparse_blocks), None, None, None, None, None, None
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def repeat_build_lut(lut: dict, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
|
|
115
|
+
sparsity_layout_output: Tensor):
|
|
116
|
+
if lut is None:
|
|
117
|
+
lut = dict()
|
|
118
|
+
|
|
119
|
+
if "sparsity_layout_o" not in lut:
|
|
120
|
+
sparsity_layout_o = sparsity_layout_x.repeat(repeats[0], repeats[1], repeats[2])
|
|
121
|
+
lut["sparsity_layout_o"] = sparsity_layout_o
|
|
108
122
|
|
|
109
123
|
if sparsity_layout_output is not None:
|
|
110
|
-
sparsity_layout_o = torch.logical_and(sparsity_layout_o, sparsity_layout_output)
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
o_b, o_b_s, o_r_s, o_c_s,
|
|
182
|
-
triton_block_size))
|
|
183
|
-
|
|
184
|
-
return output, None, None, None, None, None, None, None
|
|
124
|
+
sparsity_layout_o = torch.logical_and(lut["sparsity_layout_o"], sparsity_layout_output)
|
|
125
|
+
lut["sparsity_layout_o"] = sparsity_layout_o
|
|
126
|
+
|
|
127
|
+
if "sparsity_lut" not in lut:
|
|
128
|
+
sparsity_lut = torch.nonzero(lut["sparsity_layout_o"]).contiguous()
|
|
129
|
+
lut["sparsity_lut"] = sparsity_lut
|
|
130
|
+
|
|
131
|
+
if "sparsity_reverse_lut" not in lut:
|
|
132
|
+
sparsity_layout_flat = sparsity_layout_x.reshape(-1)
|
|
133
|
+
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
134
|
+
(sparsity_layout_flat == 1) -
|
|
135
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
136
|
+
.reshape(sparsity_layout_x.size())
|
|
137
|
+
.repeat(repeats[0], repeats[1], repeats[2])
|
|
138
|
+
.reshape(-1).contiguous())
|
|
139
|
+
lut["sparsity_reverse_lut"] = sparsity_reverse_lut
|
|
140
|
+
|
|
141
|
+
if "n_sparse_blocks" not in lut:
|
|
142
|
+
n_sparse_blocks = torch.sum(lut["sparsity_layout_o"].to(torch.int)).item()
|
|
143
|
+
lut["n_sparse_blocks"] = n_sparse_blocks
|
|
144
|
+
|
|
145
|
+
validate_contiguous(sparsity_layout_o, lut["sparsity_lut"], lut["sparsity_reverse_lut"])
|
|
146
|
+
|
|
147
|
+
return lut
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def repeat_interleave_build_lut(lut: dict, sparsity_layout_x: Tensor, repeats: int,
|
|
151
|
+
sparsity_layout_output: Tensor):
|
|
152
|
+
if lut is None:
|
|
153
|
+
lut = dict()
|
|
154
|
+
|
|
155
|
+
if "sparsity_layout_o" not in lut:
|
|
156
|
+
sparsity_layout_o = torch.repeat_interleave(sparsity_layout_x, repeats, dim=0).contiguous()
|
|
157
|
+
lut["sparsity_layout_o"] = sparsity_layout_o
|
|
158
|
+
|
|
159
|
+
if sparsity_layout_output is not None:
|
|
160
|
+
sparsity_layout_o = torch.logical_and(lut["sparsity_layout_o"], sparsity_layout_output)
|
|
161
|
+
lut["sparsity_layout_o"] = sparsity_layout_o
|
|
162
|
+
|
|
163
|
+
if "sparsity_lut" not in lut:
|
|
164
|
+
sparsity_lut = torch.nonzero(lut["sparsity_layout_o"]).contiguous()
|
|
165
|
+
lut["sparsity_lut"] = sparsity_lut
|
|
166
|
+
|
|
167
|
+
if "sparsity_reverse_lut" not in lut:
|
|
168
|
+
sparsity_layout_flat = sparsity_layout_x.reshape(-1)
|
|
169
|
+
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
170
|
+
(sparsity_layout_flat == 1) -
|
|
171
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
172
|
+
.reshape(sparsity_layout_x.size())
|
|
173
|
+
.repeat_interleave(repeats, dim=0)
|
|
174
|
+
.reshape(-1).contiguous())
|
|
175
|
+
lut["sparsity_reverse_lut"] = sparsity_reverse_lut
|
|
176
|
+
|
|
177
|
+
if "n_sparse_blocks" not in lut:
|
|
178
|
+
n_sparse_blocks = torch.sum(lut["sparsity_layout_o"].to(torch.int)).item()
|
|
179
|
+
lut["n_sparse_blocks"] = n_sparse_blocks
|
|
180
|
+
|
|
181
|
+
validate_contiguous(sparsity_layout_o, lut["sparsity_lut"], lut["sparsity_reverse_lut"])
|
|
182
|
+
|
|
183
|
+
return lut
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
# noinspection PyUnusedLocal
|
|
187
|
+
def repeat_setup_context(ctx, inputs, output):
|
|
188
|
+
(_, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size, _) = inputs
|
|
189
|
+
|
|
190
|
+
ctx.save_for_backward(sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
|
|
191
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
repeat_forward.register_autograd(repeat_wrapper_backward, setup_context=repeat_setup_context)
|