blksprs 2.0rc7__py3-none-any.whl → 2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +3 -1
- blksprs/layouting/distribution_layout.py +39 -26
- blksprs/layouting/sparsity_layout.py +58 -45
- blksprs/ops/conversion.py +88 -86
- blksprs/ops/distribution.py +80 -78
- blksprs/ops/flow.py +65 -61
- blksprs/ops/matmul.py +50 -55
- blksprs/ops/misc/broadcast_ops.py +28 -27
- blksprs/ops/misc/row_wise.py +123 -125
- blksprs/ops/partitioning.py +12 -10
- blksprs/ops/repeat.py +6 -5
- blksprs/ops/softmax.py +293 -47
- blksprs/ops/transpose.py +8 -7
- blksprs/utils/autotuning.py +10 -10
- blksprs/utils/processing.py +0 -1
- blksprs/utils/tools.py +2 -2
- {blksprs-2.0rc7.dist-info → blksprs-2.1.dist-info}/METADATA +1 -1
- blksprs-2.1.dist-info/RECORD +23 -0
- {blksprs-2.0rc7.dist-info → blksprs-2.1.dist-info}/WHEEL +1 -1
- blksprs-2.0rc7.dist-info/RECORD +0 -23
- {blksprs-2.0rc7.dist-info → blksprs-2.1.dist-info}/top_level.txt +0 -0
blksprs/__init__.py
CHANGED
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
2
2
|
|
|
3
|
+
__version__ = "2.1"
|
|
4
|
+
|
|
3
5
|
|
|
4
6
|
class ops:
|
|
5
7
|
from blksprs.ops.conversion import to_dense, to_sparse, from_blksprs, to_blksprs, adapt_layout
|
|
6
8
|
from blksprs.ops.distribution import gather, scatter, scatter_reduce
|
|
7
9
|
from blksprs.ops.matmul import matmul
|
|
8
|
-
from blksprs.ops.softmax import softmax
|
|
10
|
+
from blksprs.ops.softmax import softmax, softmax_fused
|
|
9
11
|
from blksprs.ops.transpose import transpose
|
|
10
12
|
from blksprs.ops.repeat import repeat, repeat_interleave
|
|
11
13
|
from blksprs.ops.partitioning import split, merge
|
|
@@ -1,6 +1,10 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
|
|
1
3
|
import torch
|
|
2
4
|
import triton
|
|
3
5
|
from torch import Tensor
|
|
6
|
+
from torch._library import triton_op
|
|
7
|
+
from torch._library.triton import wrap_triton
|
|
4
8
|
from triton import language as tl
|
|
5
9
|
|
|
6
10
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
@@ -10,6 +14,7 @@ from blksprs.utils.validation import validate_dimensions, validate_device, \
|
|
|
10
14
|
validate_contiguous
|
|
11
15
|
|
|
12
16
|
|
|
17
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
13
18
|
def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: Tensor,
|
|
14
19
|
dim: int, size_target: torch.Size,
|
|
15
20
|
sparsity_block_size: int) -> Tensor:
|
|
@@ -34,32 +39,40 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
|
|
|
34
39
|
|
|
35
40
|
adjusted_dim = dim % 3
|
|
36
41
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
42
|
+
return build_distribution_layout_operation(indices, sparsity_lut_i, adjusted_dim, size_target, sparsity_block_size)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@triton_op("blksprs::build_distribution_layout", mutates_args={})
|
|
46
|
+
def build_distribution_layout_operation(indices: Tensor, sparsity_lut_i: Tensor,
|
|
47
|
+
adjusted_dim: int, size_target: typing.List[int],
|
|
48
|
+
sparsity_block_size: int) -> Tensor:
|
|
49
|
+
with torch.no_grad():
|
|
50
|
+
output = torch.zeros(size_target[0], size_target[1] // sparsity_block_size,
|
|
51
|
+
size_target[2] // sparsity_block_size,
|
|
52
|
+
dtype=torch.bool, device=indices.device)
|
|
53
|
+
|
|
54
|
+
i_b, i_r, i_c = indices.size()
|
|
55
|
+
i_b_s, i_r_s, i_c_s = stride(indices)
|
|
56
|
+
s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
|
|
57
|
+
s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
|
|
58
|
+
o_b, o_r, o_c = output.size()
|
|
59
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
60
|
+
|
|
61
|
+
triton_grid = lambda meta: [i_b,
|
|
62
|
+
triton.cdiv(i_r, meta["TRITON_BLOCK_SIZE"]),
|
|
63
|
+
triton.cdiv(i_c, meta["TRITON_BLOCK_SIZE"])]
|
|
64
|
+
|
|
65
|
+
(wrap_triton(build_distribution_layout_kernel)[triton_grid]
|
|
66
|
+
(indices,
|
|
67
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
68
|
+
sparsity_lut_i,
|
|
69
|
+
s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
|
|
70
|
+
adjusted_dim,
|
|
71
|
+
output,
|
|
72
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
73
|
+
sparsity_block_size))
|
|
74
|
+
|
|
75
|
+
return output
|
|
63
76
|
|
|
64
77
|
|
|
65
78
|
@triton.autotune(
|
|
@@ -3,7 +3,7 @@ import math
|
|
|
3
3
|
import torch
|
|
4
4
|
import triton
|
|
5
5
|
from torch import Tensor
|
|
6
|
-
from torch._library.triton import wrap_triton
|
|
6
|
+
from torch._library.triton import wrap_triton, triton_op
|
|
7
7
|
from triton import language as tl
|
|
8
8
|
|
|
9
9
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
@@ -29,27 +29,32 @@ def build_sparsity_layout(x: Tensor, sparsity_block_size: int) -> Tensor:
|
|
|
29
29
|
validate_contiguous(x)
|
|
30
30
|
validate_device(x)
|
|
31
31
|
|
|
32
|
-
|
|
33
|
-
dtype=torch.bool, device=x.device)
|
|
32
|
+
return build_sparsity_layout_operation(x, sparsity_block_size)
|
|
34
33
|
|
|
35
|
-
x_b, x_r, x_c = x.size()
|
|
36
|
-
x_b_s, x_r_s, x_c_s = stride(x)
|
|
37
|
-
o_b, o_r, o_c = output.size()
|
|
38
|
-
o_b_s, o_r_s, o_c_s = stride(output)
|
|
39
34
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
35
|
+
@triton_op("blksprs::build_sparsity_layout", mutates_args={})
|
|
36
|
+
def build_sparsity_layout_operation(x: Tensor, sparsity_block_size: int) -> Tensor:
|
|
37
|
+
with torch.no_grad():
|
|
38
|
+
output = torch.zeros(x.size(0), x.size(1) // sparsity_block_size, x.size(2) // sparsity_block_size,
|
|
39
|
+
dtype=torch.bool, device=x.device)
|
|
43
40
|
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
output,
|
|
49
|
-
o_b, o_b_s, o_r_s, o_c_s,
|
|
50
|
-
sparsity_block_size))
|
|
41
|
+
x_b, x_r, x_c = x.size()
|
|
42
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
43
|
+
o_b, o_r, o_c = output.size()
|
|
44
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
51
45
|
|
|
52
|
-
|
|
46
|
+
triton_grid = lambda meta: [x_b,
|
|
47
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
48
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
49
|
+
|
|
50
|
+
(wrap_triton(build_sparsity_layout_kernel)[triton_grid]
|
|
51
|
+
(x,
|
|
52
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
53
|
+
output,
|
|
54
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
55
|
+
sparsity_block_size))
|
|
56
|
+
|
|
57
|
+
return output
|
|
53
58
|
|
|
54
59
|
|
|
55
60
|
@triton.autotune(
|
|
@@ -87,6 +92,7 @@ def build_sparsity_layout_kernel(x,
|
|
|
87
92
|
tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
|
|
88
93
|
|
|
89
94
|
|
|
95
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
90
96
|
def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tensor,
|
|
91
97
|
sparsity_block_size_from: int, sparsity_block_size_to: int) -> Tensor:
|
|
92
98
|
"""Builds the sparsity layout of a block-sparse tensor in compressed form if a different sparsity block size were
|
|
@@ -114,33 +120,40 @@ def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tenso
|
|
|
114
120
|
|
|
115
121
|
validate_contiguous(sparsity_layout_from, sparsity_lut)
|
|
116
122
|
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
123
|
+
return build_sparsity_layout_adaption_operation(x, sparsity_layout_from, sparsity_lut,
|
|
124
|
+
sparsity_block_size_from, sparsity_block_size_to)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
@triton_op("blksprs::build_sparsity_layout_adaption", mutates_args={})
|
|
128
|
+
def build_sparsity_layout_adaption_operation(x: Tensor, sparsity_layout_from: Tensor, sparsity_lut: Tensor,
|
|
129
|
+
sparsity_block_size_from: int, sparsity_block_size_to: int) -> Tensor:
|
|
130
|
+
with torch.no_grad():
|
|
131
|
+
o_b = sparsity_layout_from.size(0)
|
|
132
|
+
o_r = math.ceil(sparsity_layout_from.size(1) * sparsity_block_size_from // sparsity_block_size_to)
|
|
133
|
+
o_c = math.ceil(sparsity_layout_from.size(2) * sparsity_block_size_from // sparsity_block_size_to)
|
|
134
|
+
|
|
135
|
+
output = torch.zeros(o_b, o_r, o_c, dtype=torch.bool, device=x.device)
|
|
136
|
+
|
|
137
|
+
x_b, x_r, x_c = x.size()
|
|
138
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
139
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
140
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
141
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
142
|
+
|
|
143
|
+
triton_grid = lambda meta: [x_b,
|
|
144
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
145
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
146
|
+
|
|
147
|
+
(wrap_triton(build_sparsity_layout_adaption_kernel)[triton_grid]
|
|
148
|
+
(x,
|
|
149
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
150
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
151
|
+
output,
|
|
152
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
153
|
+
sparsity_block_size_from,
|
|
154
|
+
sparsity_block_size_to))
|
|
155
|
+
|
|
156
|
+
return output
|
|
144
157
|
|
|
145
158
|
|
|
146
159
|
@triton.autotune(
|
blksprs/ops/conversion.py
CHANGED
|
@@ -52,33 +52,34 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor,
|
|
|
52
52
|
lut["sparsity_lut"], sparsity_block_size, lut["n_sparse_blocks"]))
|
|
53
53
|
|
|
54
54
|
|
|
55
|
-
@triton_op("blksprs::
|
|
55
|
+
@triton_op("blksprs::to_sparse_forward", mutates_args={})
|
|
56
56
|
def to_sparse_forward(x: Tensor, _: Tensor,
|
|
57
57
|
sparsity_lut: Tensor, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
58
|
-
|
|
59
|
-
|
|
58
|
+
with torch.no_grad():
|
|
59
|
+
output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
60
|
+
dtype=x.dtype, device=x.device)
|
|
60
61
|
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
62
|
+
x_b, x_r, x_c = x.size()
|
|
63
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
64
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
65
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
66
|
+
o_b, o_r, o_c = output.size()
|
|
67
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
67
68
|
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
69
|
+
triton_grid = lambda meta: [o_b,
|
|
70
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
71
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
71
72
|
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
73
|
+
(wrap_triton(to_sparse_kernel)[triton_grid]
|
|
74
|
+
(x, x_b, x_b_s, x_r_s, x_c_s,
|
|
75
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
76
|
+
output, o_b_s, o_r_s, o_c_s,
|
|
77
|
+
sparsity_block_size))
|
|
77
78
|
|
|
78
|
-
|
|
79
|
+
return output
|
|
79
80
|
|
|
80
81
|
|
|
81
|
-
def
|
|
82
|
+
def to_sparse_wrapper_backward(ctx, grad_output):
|
|
82
83
|
sparsity_layout = ctx.saved_tensors[0]
|
|
83
84
|
sparsity_block_size = ctx.sparsity_block_size
|
|
84
85
|
|
|
@@ -161,7 +162,7 @@ def to_sparse_setup_context(ctx, inputs, output):
|
|
|
161
162
|
ctx.sparsity_block_size = sparsity_block_size
|
|
162
163
|
|
|
163
164
|
|
|
164
|
-
to_sparse_forward.register_autograd(
|
|
165
|
+
to_sparse_forward.register_autograd(to_sparse_wrapper_backward, setup_context=to_sparse_setup_context)
|
|
165
166
|
|
|
166
167
|
|
|
167
168
|
def from_blksprs(x: BlksprsTensor, sparsity_layout: Tensor,
|
|
@@ -203,42 +204,43 @@ def to_dense(x: BlksprsTensor, sparsity_layout: Tensor,
|
|
|
203
204
|
if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
|
|
204
205
|
return x
|
|
205
206
|
|
|
206
|
-
return to_dense_forward(x, sparsity_layout,
|
|
207
|
-
lut["sparsity_reverse_lut"], sparsity_block_size, fill_value)
|
|
207
|
+
return Tensor(to_dense_forward(x, sparsity_layout,
|
|
208
|
+
lut["sparsity_reverse_lut"], sparsity_block_size, fill_value))
|
|
208
209
|
|
|
209
210
|
|
|
210
|
-
@triton_op("blksprs::
|
|
211
|
+
@triton_op("blksprs::to_dense_forward", mutates_args={})
|
|
211
212
|
def to_dense_forward(x: Tensor, sparsity_layout: Tensor,
|
|
212
213
|
sparsity_reverse_lut: Tensor,
|
|
213
214
|
sparsity_block_size: int, fill_value: float) -> Tensor:
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
215
|
+
with torch.no_grad():
|
|
216
|
+
output = torch.full(size=(sparsity_layout.size(0), sparsity_layout.size(1) * sparsity_block_size,
|
|
217
|
+
sparsity_layout.size(2) * sparsity_block_size), fill_value=fill_value,
|
|
218
|
+
dtype=x.dtype, device=x.device)
|
|
219
|
+
|
|
220
|
+
x_b, x_r, x_c = x.shape
|
|
221
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
222
|
+
s_l_b, s_l_r, s_l_c = sparsity_layout.size()
|
|
223
|
+
s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout)
|
|
224
|
+
o_b, o_r, o_c = output.size()
|
|
225
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
226
|
+
|
|
227
|
+
triton_grid = lambda meta: [o_b,
|
|
228
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
229
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
230
|
+
|
|
231
|
+
(wrap_triton(to_dense_kernel)[triton_grid]
|
|
232
|
+
(x,
|
|
233
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
234
|
+
s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
|
|
235
|
+
sparsity_reverse_lut,
|
|
236
|
+
output,
|
|
237
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
238
|
+
sparsity_block_size))
|
|
239
|
+
|
|
240
|
+
return output
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def to_dense_wrapper_backward(ctx, grad_output):
|
|
242
244
|
sparsity_layout = ctx.saved_tensors[0]
|
|
243
245
|
sparsity_block_size = ctx.sparsity_block_size
|
|
244
246
|
|
|
@@ -316,7 +318,7 @@ def to_dense_setup_context(ctx, inputs, output):
|
|
|
316
318
|
ctx.sparsity_block_size = sparsity_block_size
|
|
317
319
|
|
|
318
320
|
|
|
319
|
-
to_dense_forward.register_autograd(
|
|
321
|
+
to_dense_forward.register_autograd(to_dense_wrapper_backward, setup_context=to_dense_setup_context)
|
|
320
322
|
|
|
321
323
|
|
|
322
324
|
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
@@ -372,45 +374,45 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
|
|
|
372
374
|
n_sparse_blocks_to)), sparsity_layout_to
|
|
373
375
|
|
|
374
376
|
|
|
375
|
-
@triton_op("blksprs::
|
|
377
|
+
@triton_op("blksprs::adapt_layout_forward", mutates_args={})
|
|
376
378
|
def adapt_layout_forward(x: Tensor,
|
|
377
379
|
sparsity_layout_from: Tensor, sparsity_reverse_lut_from: Tensor,
|
|
378
380
|
sparsity_block_size_from: int,
|
|
379
381
|
_: Tensor, sparsity_lut_to: Tensor,
|
|
380
382
|
sparsity_block_size_to: int,
|
|
381
383
|
n_sparse_blocks_to: int) -> Tensor:
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
def
|
|
384
|
+
with torch.no_grad():
|
|
385
|
+
output = torch.zeros(size=(n_sparse_blocks_to, sparsity_block_size_to, sparsity_block_size_to),
|
|
386
|
+
dtype=x.dtype, device=x.device)
|
|
387
|
+
|
|
388
|
+
x_b, x_r, x_c = x.size()
|
|
389
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
390
|
+
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_from.size()
|
|
391
|
+
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_from)
|
|
392
|
+
o_b, o_r, o_c = output.size()
|
|
393
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
394
|
+
s_lut_o_r, s_lut_o_c = sparsity_lut_to.size()
|
|
395
|
+
s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_to)
|
|
396
|
+
|
|
397
|
+
triton_grid = lambda meta: [o_b,
|
|
398
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
399
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
400
|
+
|
|
401
|
+
(wrap_triton(adapt_layout_kernel)[triton_grid]
|
|
402
|
+
(x,
|
|
403
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
404
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
405
|
+
sparsity_reverse_lut_from,
|
|
406
|
+
output,
|
|
407
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
408
|
+
sparsity_lut_to, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
409
|
+
sparsity_block_size_from,
|
|
410
|
+
sparsity_block_size_to))
|
|
411
|
+
|
|
412
|
+
return output
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
def adapt_layout_wrapper_backward(ctx, grad_output):
|
|
414
416
|
x, sparsity_layout_from, sparsity_layout_to = ctx.saved_tensors
|
|
415
417
|
sparsity_block_size_from = ctx.sparsity_block_size_from
|
|
416
418
|
sparsity_block_size_to = ctx.sparsity_block_size_to
|
|
@@ -501,4 +503,4 @@ def adapt_layout_setup_context(ctx, inputs, output):
|
|
|
501
503
|
ctx.sparsity_block_size_to = sparsity_block_size_to
|
|
502
504
|
|
|
503
505
|
|
|
504
|
-
adapt_layout_forward.register_autograd(
|
|
506
|
+
adapt_layout_forward.register_autograd(adapt_layout_wrapper_backward, setup_context=adapt_layout_setup_context)
|