blksprs 1.11__py3-none-any.whl → 2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +4 -5
- blksprs/layouting/distribution_layout.py +64 -48
- blksprs/layouting/sparsity_layout.py +96 -72
- blksprs/ops/conversion.py +349 -338
- blksprs/ops/distribution.py +318 -294
- blksprs/ops/flow.py +113 -100
- blksprs/ops/matmul.py +187 -172
- blksprs/ops/misc/broadcast_ops.py +68 -53
- blksprs/ops/misc/row_wise.py +223 -176
- blksprs/ops/partitioning.py +140 -132
- blksprs/ops/repeat.py +118 -120
- blksprs/ops/softmax.py +240 -214
- blksprs/ops/transpose.py +55 -52
- blksprs/utils/autotuning.py +78 -0
- blksprs/utils/benchmarking.py +3 -3
- blksprs/utils/processing.py +2 -1
- blksprs/utils/tools.py +5 -6
- blksprs/utils/validation.py +22 -16
- {blksprs-1.11.dist-info → blksprs-2.0.dist-info}/METADATA +55 -36
- blksprs-2.0.dist-info/RECORD +23 -0
- {blksprs-1.11.dist-info → blksprs-2.0.dist-info}/WHEEL +1 -1
- blksprs/utils/layout_utils.py +0 -17
- blksprs-1.11.dist-info/RECORD +0 -23
- {blksprs-1.11.dist-info → blksprs-2.0.dist-info}/top_level.txt +0 -0
blksprs/ops/partitioning.py
CHANGED
|
@@ -1,15 +1,16 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
from torch import Tensor
|
|
3
|
+
from torch._library import triton_op
|
|
3
4
|
|
|
4
|
-
from blksprs.ops.flow import
|
|
5
|
+
from blksprs.ops.flow import flow_pull_forward
|
|
5
6
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
6
|
-
|
|
7
7
|
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
|
|
8
|
-
validate_sparsity, validate_sparsity_block_size
|
|
8
|
+
validate_sparsity, validate_sparsity_block_size
|
|
9
9
|
|
|
10
10
|
|
|
11
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
11
12
|
def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
12
|
-
dim: int, sparsity_block_size: int,
|
|
13
|
+
dim: int, sparsity_block_size: int, lut: dict = None) -> (
|
|
13
14
|
BlksprsTensor, Tensor):
|
|
14
15
|
"""Splits a block-sparse tensor in compressed form along the last dimension into partitions.
|
|
15
16
|
|
|
@@ -19,7 +20,6 @@ def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
|
19
20
|
partitions (int): The number of partitions to split the block-sparse tensor into.
|
|
20
21
|
dim (int): The dimension along which to split the tensor. Currently only supports dim=2.
|
|
21
22
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
22
|
-
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
23
23
|
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
24
24
|
|
|
25
25
|
Returns:
|
|
@@ -34,83 +34,88 @@ def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
|
34
34
|
validate_device(x)
|
|
35
35
|
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
36
36
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
37
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
38
37
|
|
|
39
38
|
adjusted_dim = dim % 3
|
|
40
39
|
if adjusted_dim != 2:
|
|
41
40
|
raise NotImplementedError("Currently only supports dim=2")
|
|
42
41
|
|
|
43
|
-
lut =
|
|
42
|
+
lut = split_build_lut(lut, sparsity_layout, partitions)
|
|
43
|
+
|
|
44
|
+
return BlksprsTensor(split_forward(
|
|
45
|
+
x, lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"],
|
|
46
|
+
partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_output"]
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@triton_op("blksprs::split_forward", mutates_args={})
|
|
50
|
+
def split_forward(x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
51
|
+
_: int, __: int, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
52
|
+
with torch.no_grad():
|
|
53
|
+
return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
54
|
+
n_sparse_blocks)
|
|
44
55
|
|
|
45
|
-
return BlksprsTensor(
|
|
46
|
-
_BlocksparseSplit.apply(x, lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"],
|
|
47
|
-
partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"],
|
|
48
|
-
triton_block_size)), lut["sparsity_layout_output"]
|
|
49
56
|
|
|
57
|
+
def split_wrapper_backward(ctx, grad_output):
|
|
58
|
+
sparsity_layout = ctx.saved_tensors[0]
|
|
59
|
+
num_partitions = ctx.num_partitions
|
|
60
|
+
dim = ctx.dim
|
|
61
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
50
62
|
|
|
51
|
-
|
|
63
|
+
return merge(grad_output, sparsity_layout, num_partitions, dim,
|
|
64
|
+
sparsity_block_size)[0], None, None, None, None, None, None, None
|
|
52
65
|
|
|
53
|
-
@staticmethod
|
|
54
|
-
def build_lut(lut: dict, sparsity_layout: Tensor, partitions: int):
|
|
55
|
-
if lut is None:
|
|
56
|
-
lut = dict()
|
|
57
66
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
sparsity_layout.size(2) // partitions)
|
|
62
|
-
.permute(0, 2, 1, 3)
|
|
63
|
-
.reshape(sparsity_layout.size(0) * partitions, sparsity_layout.size(1),
|
|
64
|
-
sparsity_layout.size(2) // partitions).contiguous())
|
|
65
|
-
lut["sparsity_layout_output"] = sparsity_layout_output
|
|
67
|
+
def split_build_lut(lut: dict, sparsity_layout: Tensor, partitions: int):
|
|
68
|
+
if lut is None:
|
|
69
|
+
lut = dict()
|
|
66
70
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
71
|
+
if "sparsity_layout_output" not in lut:
|
|
72
|
+
sparsity_layout_output = (sparsity_layout
|
|
73
|
+
.reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
|
|
74
|
+
sparsity_layout.size(2) // partitions)
|
|
75
|
+
.permute(0, 2, 1, 3)
|
|
76
|
+
.reshape(sparsity_layout.size(0) * partitions, sparsity_layout.size(1),
|
|
77
|
+
sparsity_layout.size(2) // partitions).contiguous())
|
|
78
|
+
lut["sparsity_layout_output"] = sparsity_layout_output
|
|
70
79
|
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
(sparsity_layout_flat == 1) -
|
|
75
|
-
(1 * (sparsity_layout_flat == 0)))
|
|
76
|
-
.reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
|
|
77
|
-
sparsity_layout.size(2) // partitions)
|
|
78
|
-
.permute(0, 2, 1, 3).reshape(-1).contiguous())
|
|
79
|
-
lut["sparsity_reverse_lut"] = sparsity_reverse_lut
|
|
80
|
+
if "sparsity_lut" not in lut:
|
|
81
|
+
sparsity_lut = torch.nonzero(lut["sparsity_layout_output"]).contiguous()
|
|
82
|
+
lut["sparsity_lut"] = sparsity_lut
|
|
80
83
|
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
+
if "sparsity_reverse_lut" not in lut:
|
|
85
|
+
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
|
86
|
+
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
87
|
+
(sparsity_layout_flat == 1) -
|
|
88
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
89
|
+
.reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
|
|
90
|
+
sparsity_layout.size(2) // partitions)
|
|
91
|
+
.permute(0, 2, 1, 3).reshape(-1).contiguous())
|
|
92
|
+
lut["sparsity_reverse_lut"] = sparsity_reverse_lut
|
|
84
93
|
|
|
85
|
-
|
|
94
|
+
if "n_sparse_blocks" not in lut:
|
|
95
|
+
n_sparse_blocks = torch.sum(lut["sparsity_layout_output"].to(torch.int)).item()
|
|
96
|
+
lut["n_sparse_blocks"] = n_sparse_blocks
|
|
86
97
|
|
|
87
|
-
|
|
98
|
+
validate_contiguous(lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"])
|
|
88
99
|
|
|
89
|
-
|
|
90
|
-
def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
91
|
-
num_partitions: int, dim: int, sparsity_block_size: int, n_sparse_blocks: int,
|
|
92
|
-
triton_block_size: int) -> Tensor:
|
|
93
|
-
ctx.save_for_backward(sparsity_layout_o)
|
|
94
|
-
ctx.num_partitions = num_partitions
|
|
95
|
-
ctx.dim = dim
|
|
100
|
+
return lut
|
|
96
101
|
|
|
97
|
-
return flow_forward_pull(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
98
|
-
n_sparse_blocks, triton_block_size)
|
|
99
102
|
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
num_partitions = ctx.num_partitions
|
|
104
|
-
dim = ctx.dim
|
|
105
|
-
sparsity_block_size = ctx.sparsity_block_size
|
|
106
|
-
triton_block_size = ctx.triton_block_size
|
|
103
|
+
# noinspection PyUnusedLocal
|
|
104
|
+
def split_setup_context(ctx, inputs, output):
|
|
105
|
+
(_, sparsity_layout_o, _, _, num_partitions, dim, sparsity_block_size, _) = inputs
|
|
107
106
|
|
|
108
|
-
|
|
109
|
-
|
|
107
|
+
ctx.save_for_backward(sparsity_layout_o)
|
|
108
|
+
ctx.num_partitions = num_partitions
|
|
109
|
+
ctx.dim = dim
|
|
110
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
110
111
|
|
|
111
112
|
|
|
113
|
+
split_forward.register_autograd(split_wrapper_backward, setup_context=split_setup_context)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
112
117
|
def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
113
|
-
dim: int, sparsity_block_size: int,
|
|
118
|
+
dim: int, sparsity_block_size: int, lut: dict = None) -> (
|
|
114
119
|
BlksprsTensor, Tensor):
|
|
115
120
|
"""Merges the specified partitions of a block-sparse tensor in compressed form along the last dimension.
|
|
116
121
|
|
|
@@ -120,7 +125,6 @@ def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
|
120
125
|
partitions (int): The number of partitions to be merged.
|
|
121
126
|
dim (int): The dimension along which to merge the tensor. Currently only supports dim=2.
|
|
122
127
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
123
|
-
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
124
128
|
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
125
129
|
|
|
126
130
|
Returns:
|
|
@@ -135,79 +139,83 @@ def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
|
135
139
|
validate_device(x)
|
|
136
140
|
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
137
141
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
138
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
139
142
|
|
|
140
143
|
adjusted_dim = dim % 3
|
|
141
144
|
if adjusted_dim != 2:
|
|
142
145
|
raise NotImplementedError("Currently only supports dim=2")
|
|
143
146
|
|
|
144
|
-
lut =
|
|
145
|
-
|
|
146
|
-
return BlksprsTensor(
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
147
|
+
lut = merge_build_lut(lut, sparsity_layout, partitions)
|
|
148
|
+
|
|
149
|
+
return BlksprsTensor(merge_forward(
|
|
150
|
+
x, lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"],
|
|
151
|
+
partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_output"]
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
@triton_op("blksprs::merge_forward", mutates_args={})
|
|
155
|
+
def merge_forward(x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
156
|
+
_: int, __: int, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
157
|
+
with torch.no_grad():
|
|
158
|
+
return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
159
|
+
n_sparse_blocks)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def merge_wrapper_backward(ctx, grad_output):
|
|
163
|
+
sparsity_layout = ctx.saved_tensors[0]
|
|
164
|
+
num_partitions = ctx.num_partitions
|
|
165
|
+
dim = ctx.dim
|
|
166
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
167
|
+
|
|
168
|
+
return split(grad_output, sparsity_layout, num_partitions, dim,
|
|
169
|
+
sparsity_block_size)[0], None, None, None, None, None, None, None
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def merge_build_lut(lut: dict, sparsity_layout: Tensor, partitions: int):
|
|
173
|
+
if lut is None:
|
|
174
|
+
lut = dict()
|
|
175
|
+
|
|
176
|
+
if "sparsity_layout_output" not in lut:
|
|
177
|
+
sparsity_layout_output = (sparsity_layout.reshape(sparsity_layout.size(0) // partitions, partitions,
|
|
178
|
+
sparsity_layout.size(1), sparsity_layout.size(2))
|
|
179
|
+
.permute(0, 2, 1, 3)
|
|
180
|
+
.reshape(sparsity_layout.size(0) // partitions,
|
|
181
|
+
sparsity_layout.size(1),
|
|
182
|
+
sparsity_layout.size(2) * partitions).contiguous())
|
|
183
|
+
lut["sparsity_layout_output"] = sparsity_layout_output
|
|
184
|
+
|
|
185
|
+
if "sparsity_lut" not in lut:
|
|
186
|
+
sparsity_lut = torch.nonzero(lut["sparsity_layout_output"]).contiguous()
|
|
187
|
+
lut["sparsity_lut"] = sparsity_lut
|
|
188
|
+
|
|
189
|
+
if "sparsity_reverse_lut" not in lut:
|
|
190
|
+
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
|
191
|
+
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
192
|
+
(sparsity_layout_flat == 1) -
|
|
193
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
194
|
+
.reshape(sparsity_layout.size(0) // partitions, partitions,
|
|
195
|
+
sparsity_layout.size(1), sparsity_layout.size(2))
|
|
196
|
+
.permute(0, 2, 1, 3)
|
|
197
|
+
.reshape(sparsity_layout.size(0) // partitions,
|
|
198
|
+
sparsity_layout.size(1), sparsity_layout.size(2) * partitions)
|
|
199
|
+
.reshape(-1).contiguous())
|
|
200
|
+
lut["sparsity_reverse_lut"] = sparsity_reverse_lut
|
|
201
|
+
|
|
202
|
+
if "n_sparse_blocks" not in lut:
|
|
203
|
+
n_sparse_blocks = torch.sum(lut["sparsity_layout_output"].to(torch.int)).item()
|
|
204
|
+
lut["n_sparse_blocks"] = n_sparse_blocks
|
|
205
|
+
|
|
206
|
+
validate_contiguous(lut["sparsity_layout_output"], lut["sparsity_lut"], lut["sparsity_reverse_lut"])
|
|
207
|
+
|
|
208
|
+
return lut
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
# noinspection PyUnusedLocal
|
|
212
|
+
def merge_setup_context(ctx, inputs, output):
|
|
213
|
+
(_, sparsity_layout_o, _, _, num_partitions, dim, sparsity_block_size, _) = inputs
|
|
214
|
+
|
|
215
|
+
ctx.save_for_backward(sparsity_layout_o)
|
|
216
|
+
ctx.num_partitions = num_partitions
|
|
217
|
+
ctx.dim = dim
|
|
218
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
merge_forward.register_autograd(merge_wrapper_backward, setup_context=merge_setup_context)
|
blksprs/ops/repeat.py
CHANGED
|
@@ -1,17 +1,16 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
import triton
|
|
3
2
|
from torch import Tensor
|
|
3
|
+
from torch._library import triton_op
|
|
4
4
|
|
|
5
|
-
from blksprs.ops.flow import
|
|
5
|
+
from blksprs.ops.flow import flow_pull_forward, flow_push_forward
|
|
6
6
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
7
|
-
from blksprs.utils.tools import get_triton_block_size, stride
|
|
8
7
|
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
|
|
9
|
-
validate_sparsity, validate_sparsity_block_size
|
|
8
|
+
validate_sparsity, validate_sparsity_block_size
|
|
10
9
|
|
|
11
10
|
|
|
11
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
12
12
|
def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
|
|
13
|
-
sparsity_block_size: int, sparsity_layout_output: Tensor = None,
|
|
14
|
-
lut: dict = None) -> (
|
|
13
|
+
sparsity_block_size: int, sparsity_layout_output: Tensor = None, lut: dict = None) -> (
|
|
15
14
|
BlksprsTensor, Tensor):
|
|
16
15
|
"""Repeats a block-spare tensor in compressed form according to the given repeats.
|
|
17
16
|
|
|
@@ -30,7 +29,6 @@ def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int,
|
|
|
30
29
|
third dimension respectively.
|
|
31
30
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
32
31
|
sparsity_layout_output (Tensor): The desired sparsity layout of the output tensor (default ``None``).
|
|
33
|
-
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
34
32
|
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
35
33
|
|
|
36
34
|
Returns:
|
|
@@ -45,19 +43,17 @@ def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int,
|
|
|
45
43
|
validate_device(x)
|
|
46
44
|
validate_sparsity(sparsity_block_size, (x, sparsity_layout_x))
|
|
47
45
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
48
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
49
46
|
|
|
50
|
-
lut =
|
|
47
|
+
lut = repeat_build_lut(lut, sparsity_layout_x, repeats, sparsity_layout_output)
|
|
51
48
|
|
|
52
|
-
return BlksprsTensor(
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
triton_block_size)), lut["sparsity_layout_o"]
|
|
49
|
+
return BlksprsTensor(repeat_forward(
|
|
50
|
+
x, sparsity_layout_x, lut["sparsity_layout_o"], lut["sparsity_lut"],
|
|
51
|
+
lut["sparsity_reverse_lut"], sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_o"]
|
|
56
52
|
|
|
57
53
|
|
|
54
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
58
55
|
def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
|
|
59
|
-
sparsity_block_size: int, sparsity_layout_output: Tensor = None,
|
|
60
|
-
triton_block_size: int = None, lut: dict = None) -> (
|
|
56
|
+
sparsity_block_size: int, sparsity_layout_output: Tensor = None, lut: dict = None) -> (
|
|
61
57
|
BlksprsTensor, Tensor):
|
|
62
58
|
"""Repeats and interleaves the block-sparse tensor in compressed form.
|
|
63
59
|
|
|
@@ -74,7 +70,6 @@ def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
|
|
|
74
70
|
repeats (int): The number of times to repeat the matrices.
|
|
75
71
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
76
72
|
sparsity_layout_output (Tensor): The desired sparsity layout of the output tensor (default ``None``).
|
|
77
|
-
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
78
73
|
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
79
74
|
|
|
80
75
|
Returns:
|
|
@@ -89,108 +84,111 @@ def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
|
|
|
89
84
|
validate_device(x)
|
|
90
85
|
validate_sparsity(sparsity_block_size, (x, sparsity_layout_x))
|
|
91
86
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
92
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
93
87
|
|
|
94
|
-
lut =
|
|
95
|
-
|
|
96
|
-
return BlksprsTensor(
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
88
|
+
lut = repeat_interleave_build_lut(lut, sparsity_layout_x, repeats, sparsity_layout_output)
|
|
89
|
+
|
|
90
|
+
return BlksprsTensor(repeat_forward(
|
|
91
|
+
x, sparsity_layout_x, lut["sparsity_layout_o"], lut["sparsity_lut"],
|
|
92
|
+
lut["sparsity_reverse_lut"], sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_o"]
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@triton_op("blksprs::repeat_forward", mutates_args={})
|
|
96
|
+
def repeat_forward(x: Tensor, _: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor,
|
|
97
|
+
sparsity_reverse_lut: Tensor,
|
|
98
|
+
sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
99
|
+
with torch.no_grad():
|
|
100
|
+
return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
101
|
+
n_sparse_blocks)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def repeat_wrapper_backward(ctx, grad_output):
|
|
105
|
+
sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut = ctx.saved_tensors
|
|
106
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
107
|
+
n_sparse_blocks = torch.sum(sparsity_layout_x.to(torch.int)).item()
|
|
108
|
+
|
|
109
|
+
return flow_push_forward(grad_output, sparsity_layout_o, sparsity_lut,
|
|
110
|
+
sparsity_reverse_lut, sparsity_block_size,
|
|
111
|
+
n_sparse_blocks), None, None, None, None, None, None
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def repeat_build_lut(lut: dict, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
|
|
115
|
+
sparsity_layout_output: Tensor):
|
|
116
|
+
if lut is None:
|
|
117
|
+
lut = dict()
|
|
118
|
+
|
|
119
|
+
if "sparsity_layout_o" not in lut:
|
|
120
|
+
sparsity_layout_o = sparsity_layout_x.repeat(repeats[0], repeats[1], repeats[2])
|
|
121
|
+
lut["sparsity_layout_o"] = sparsity_layout_o
|
|
122
|
+
|
|
123
|
+
if sparsity_layout_output is not None:
|
|
124
|
+
sparsity_layout_o = torch.logical_and(lut["sparsity_layout_o"], sparsity_layout_output)
|
|
125
|
+
lut["sparsity_layout_o"] = sparsity_layout_o
|
|
126
|
+
|
|
127
|
+
if "sparsity_lut" not in lut:
|
|
128
|
+
sparsity_lut = torch.nonzero(lut["sparsity_layout_o"]).contiguous()
|
|
129
|
+
lut["sparsity_lut"] = sparsity_lut
|
|
130
|
+
|
|
131
|
+
if "sparsity_reverse_lut" not in lut:
|
|
132
|
+
sparsity_layout_flat = sparsity_layout_x.reshape(-1)
|
|
133
|
+
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
134
|
+
(sparsity_layout_flat == 1) -
|
|
135
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
136
|
+
.reshape(sparsity_layout_x.size())
|
|
137
|
+
.repeat(repeats[0], repeats[1], repeats[2])
|
|
138
|
+
.reshape(-1).contiguous())
|
|
139
|
+
lut["sparsity_reverse_lut"] = sparsity_reverse_lut
|
|
140
|
+
|
|
141
|
+
if "n_sparse_blocks" not in lut:
|
|
142
|
+
n_sparse_blocks = torch.sum(lut["sparsity_layout_o"].to(torch.int)).item()
|
|
143
|
+
lut["n_sparse_blocks"] = n_sparse_blocks
|
|
144
|
+
|
|
145
|
+
validate_contiguous(sparsity_layout_o, lut["sparsity_lut"], lut["sparsity_reverse_lut"])
|
|
146
|
+
|
|
147
|
+
return lut
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def repeat_interleave_build_lut(lut: dict, sparsity_layout_x: Tensor, repeats: int,
|
|
151
|
+
sparsity_layout_output: Tensor):
|
|
152
|
+
if lut is None:
|
|
153
|
+
lut = dict()
|
|
154
|
+
|
|
155
|
+
if "sparsity_layout_o" not in lut:
|
|
156
|
+
sparsity_layout_o = torch.repeat_interleave(sparsity_layout_x, repeats, dim=0).contiguous()
|
|
157
|
+
lut["sparsity_layout_o"] = sparsity_layout_o
|
|
158
|
+
|
|
159
|
+
if sparsity_layout_output is not None:
|
|
160
|
+
sparsity_layout_o = torch.logical_and(lut["sparsity_layout_o"], sparsity_layout_output)
|
|
161
|
+
lut["sparsity_layout_o"] = sparsity_layout_o
|
|
162
|
+
|
|
163
|
+
if "sparsity_lut" not in lut:
|
|
164
|
+
sparsity_lut = torch.nonzero(lut["sparsity_layout_o"]).contiguous()
|
|
165
|
+
lut["sparsity_lut"] = sparsity_lut
|
|
166
|
+
|
|
167
|
+
if "sparsity_reverse_lut" not in lut:
|
|
168
|
+
sparsity_layout_flat = sparsity_layout_x.reshape(-1)
|
|
169
|
+
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
170
|
+
(sparsity_layout_flat == 1) -
|
|
171
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
172
|
+
.reshape(sparsity_layout_x.size())
|
|
173
|
+
.repeat_interleave(repeats, dim=0)
|
|
174
|
+
.reshape(-1).contiguous())
|
|
175
|
+
lut["sparsity_reverse_lut"] = sparsity_reverse_lut
|
|
176
|
+
|
|
177
|
+
if "n_sparse_blocks" not in lut:
|
|
178
|
+
n_sparse_blocks = torch.sum(lut["sparsity_layout_o"].to(torch.int)).item()
|
|
179
|
+
lut["n_sparse_blocks"] = n_sparse_blocks
|
|
180
|
+
|
|
181
|
+
validate_contiguous(sparsity_layout_o, lut["sparsity_lut"], lut["sparsity_reverse_lut"])
|
|
182
|
+
|
|
183
|
+
return lut
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
# noinspection PyUnusedLocal
|
|
187
|
+
def repeat_setup_context(ctx, inputs, output):
|
|
188
|
+
(_, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size, _) = inputs
|
|
189
|
+
|
|
190
|
+
ctx.save_for_backward(sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
|
|
191
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
repeat_forward.register_autograd(repeat_wrapper_backward, setup_context=repeat_setup_context)
|