blksprs 2.0rc7__py3-none-any.whl → 2.0rc8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +1 -0
- blksprs/layouting/distribution_layout.py +39 -26
- blksprs/layouting/sparsity_layout.py +58 -45
- blksprs/ops/conversion.py +86 -84
- blksprs/ops/distribution.py +80 -78
- blksprs/ops/flow.py +64 -60
- blksprs/ops/matmul.py +50 -55
- blksprs/ops/misc/broadcast_ops.py +28 -27
- blksprs/ops/misc/row_wise.py +123 -125
- blksprs/ops/partitioning.py +12 -10
- blksprs/ops/repeat.py +6 -5
- blksprs/ops/softmax.py +55 -47
- blksprs/ops/transpose.py +8 -7
- blksprs/utils/autotuning.py +10 -10
- blksprs/utils/processing.py +0 -1
- blksprs/utils/tools.py +8 -0
- {blksprs-2.0rc7.dist-info → blksprs-2.0rc8.dist-info}/METADATA +1 -1
- blksprs-2.0rc8.dist-info/RECORD +23 -0
- {blksprs-2.0rc7.dist-info → blksprs-2.0rc8.dist-info}/WHEEL +1 -1
- blksprs-2.0rc7.dist-info/RECORD +0 -23
- {blksprs-2.0rc7.dist-info → blksprs-2.0rc8.dist-info}/top_level.txt +0 -0
blksprs/__init__.py
CHANGED
|
@@ -1,6 +1,10 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
|
|
1
3
|
import torch
|
|
2
4
|
import triton
|
|
3
5
|
from torch import Tensor
|
|
6
|
+
from torch._library import triton_op
|
|
7
|
+
from torch._library.triton import wrap_triton
|
|
4
8
|
from triton import language as tl
|
|
5
9
|
|
|
6
10
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
@@ -10,6 +14,7 @@ from blksprs.utils.validation import validate_dimensions, validate_device, \
|
|
|
10
14
|
validate_contiguous
|
|
11
15
|
|
|
12
16
|
|
|
17
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
13
18
|
def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: Tensor,
|
|
14
19
|
dim: int, size_target: torch.Size,
|
|
15
20
|
sparsity_block_size: int) -> Tensor:
|
|
@@ -34,32 +39,40 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
|
|
|
34
39
|
|
|
35
40
|
adjusted_dim = dim % 3
|
|
36
41
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
42
|
+
return build_distribution_layout_operation(indices, sparsity_lut_i, adjusted_dim, size_target, sparsity_block_size)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@triton_op("blksprs::build_distribution_layout", mutates_args={})
|
|
46
|
+
def build_distribution_layout_operation(indices: Tensor, sparsity_lut_i: Tensor,
|
|
47
|
+
adjusted_dim: int, size_target: typing.List[int],
|
|
48
|
+
sparsity_block_size: int) -> Tensor:
|
|
49
|
+
with torch.no_grad():
|
|
50
|
+
output = torch.zeros(size_target[0], size_target[1] // sparsity_block_size,
|
|
51
|
+
size_target[2] // sparsity_block_size,
|
|
52
|
+
dtype=torch.bool, device=indices.device)
|
|
53
|
+
|
|
54
|
+
i_b, i_r, i_c = indices.size()
|
|
55
|
+
i_b_s, i_r_s, i_c_s = stride(indices)
|
|
56
|
+
s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
|
|
57
|
+
s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
|
|
58
|
+
o_b, o_r, o_c = output.size()
|
|
59
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
60
|
+
|
|
61
|
+
triton_grid = lambda meta: [i_b,
|
|
62
|
+
triton.cdiv(i_r, meta["TRITON_BLOCK_SIZE"]),
|
|
63
|
+
triton.cdiv(i_c, meta["TRITON_BLOCK_SIZE"])]
|
|
64
|
+
|
|
65
|
+
(wrap_triton(build_distribution_layout_kernel)[triton_grid]
|
|
66
|
+
(indices,
|
|
67
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
68
|
+
sparsity_lut_i,
|
|
69
|
+
s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
|
|
70
|
+
adjusted_dim,
|
|
71
|
+
output,
|
|
72
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
73
|
+
sparsity_block_size))
|
|
74
|
+
|
|
75
|
+
return output
|
|
63
76
|
|
|
64
77
|
|
|
65
78
|
@triton.autotune(
|
|
@@ -3,7 +3,7 @@ import math
|
|
|
3
3
|
import torch
|
|
4
4
|
import triton
|
|
5
5
|
from torch import Tensor
|
|
6
|
-
from torch._library.triton import wrap_triton
|
|
6
|
+
from torch._library.triton import wrap_triton, triton_op
|
|
7
7
|
from triton import language as tl
|
|
8
8
|
|
|
9
9
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
@@ -29,27 +29,32 @@ def build_sparsity_layout(x: Tensor, sparsity_block_size: int) -> Tensor:
|
|
|
29
29
|
validate_contiguous(x)
|
|
30
30
|
validate_device(x)
|
|
31
31
|
|
|
32
|
-
|
|
33
|
-
dtype=torch.bool, device=x.device)
|
|
32
|
+
return build_sparsity_layout_operation(x, sparsity_block_size)
|
|
34
33
|
|
|
35
|
-
x_b, x_r, x_c = x.size()
|
|
36
|
-
x_b_s, x_r_s, x_c_s = stride(x)
|
|
37
|
-
o_b, o_r, o_c = output.size()
|
|
38
|
-
o_b_s, o_r_s, o_c_s = stride(output)
|
|
39
34
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
35
|
+
@triton_op("blksprs::build_sparsity_layout", mutates_args={})
|
|
36
|
+
def build_sparsity_layout_operation(x: Tensor, sparsity_block_size: int) -> Tensor:
|
|
37
|
+
with torch.no_grad():
|
|
38
|
+
output = torch.zeros(x.size(0), x.size(1) // sparsity_block_size, x.size(2) // sparsity_block_size,
|
|
39
|
+
dtype=torch.bool, device=x.device)
|
|
43
40
|
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
output,
|
|
49
|
-
o_b, o_b_s, o_r_s, o_c_s,
|
|
50
|
-
sparsity_block_size))
|
|
41
|
+
x_b, x_r, x_c = x.size()
|
|
42
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
43
|
+
o_b, o_r, o_c = output.size()
|
|
44
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
51
45
|
|
|
52
|
-
|
|
46
|
+
triton_grid = lambda meta: [x_b,
|
|
47
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
48
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
49
|
+
|
|
50
|
+
(wrap_triton(build_sparsity_layout_kernel)[triton_grid]
|
|
51
|
+
(x,
|
|
52
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
53
|
+
output,
|
|
54
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
55
|
+
sparsity_block_size))
|
|
56
|
+
|
|
57
|
+
return output
|
|
53
58
|
|
|
54
59
|
|
|
55
60
|
@triton.autotune(
|
|
@@ -87,6 +92,7 @@ def build_sparsity_layout_kernel(x,
|
|
|
87
92
|
tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
|
|
88
93
|
|
|
89
94
|
|
|
95
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
90
96
|
def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tensor,
|
|
91
97
|
sparsity_block_size_from: int, sparsity_block_size_to: int) -> Tensor:
|
|
92
98
|
"""Builds the sparsity layout of a block-sparse tensor in compressed form if a different sparsity block size were
|
|
@@ -114,33 +120,40 @@ def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tenso
|
|
|
114
120
|
|
|
115
121
|
validate_contiguous(sparsity_layout_from, sparsity_lut)
|
|
116
122
|
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
123
|
+
return build_sparsity_layout_adaption_operation(x, sparsity_layout_from, sparsity_lut,
|
|
124
|
+
sparsity_block_size_from, sparsity_block_size_to)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
@triton_op("blksprs::build_sparsity_layout_adaption", mutates_args={})
|
|
128
|
+
def build_sparsity_layout_adaption_operation(x: Tensor, sparsity_layout_from: Tensor, sparsity_lut: Tensor,
|
|
129
|
+
sparsity_block_size_from: int, sparsity_block_size_to: int) -> Tensor:
|
|
130
|
+
with torch.no_grad():
|
|
131
|
+
o_b = sparsity_layout_from.size(0)
|
|
132
|
+
o_r = math.ceil(sparsity_layout_from.size(1) * sparsity_block_size_from // sparsity_block_size_to)
|
|
133
|
+
o_c = math.ceil(sparsity_layout_from.size(2) * sparsity_block_size_from // sparsity_block_size_to)
|
|
134
|
+
|
|
135
|
+
output = torch.zeros(o_b, o_r, o_c, dtype=torch.bool, device=x.device)
|
|
136
|
+
|
|
137
|
+
x_b, x_r, x_c = x.size()
|
|
138
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
139
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
140
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
141
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
142
|
+
|
|
143
|
+
triton_grid = lambda meta: [x_b,
|
|
144
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
145
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
146
|
+
|
|
147
|
+
(wrap_triton(build_sparsity_layout_adaption_kernel)[triton_grid]
|
|
148
|
+
(x,
|
|
149
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
150
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
151
|
+
output,
|
|
152
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
153
|
+
sparsity_block_size_from,
|
|
154
|
+
sparsity_block_size_to))
|
|
155
|
+
|
|
156
|
+
return output
|
|
144
157
|
|
|
145
158
|
|
|
146
159
|
@triton.autotune(
|
blksprs/ops/conversion.py
CHANGED
|
@@ -52,33 +52,34 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor,
|
|
|
52
52
|
lut["sparsity_lut"], sparsity_block_size, lut["n_sparse_blocks"]))
|
|
53
53
|
|
|
54
54
|
|
|
55
|
-
@triton_op("blksprs::
|
|
55
|
+
@triton_op("blksprs::to_sparse_forward", mutates_args={})
|
|
56
56
|
def to_sparse_forward(x: Tensor, _: Tensor,
|
|
57
57
|
sparsity_lut: Tensor, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
58
|
-
|
|
59
|
-
|
|
58
|
+
with torch.no_grad():
|
|
59
|
+
output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
60
|
+
dtype=x.dtype, device=x.device)
|
|
60
61
|
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
62
|
+
x_b, x_r, x_c = x.size()
|
|
63
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
64
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
65
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
66
|
+
o_b, o_r, o_c = output.size()
|
|
67
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
67
68
|
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
69
|
+
triton_grid = lambda meta: [o_b,
|
|
70
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
71
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
71
72
|
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
73
|
+
(wrap_triton(to_sparse_kernel)[triton_grid]
|
|
74
|
+
(x, x_b, x_b_s, x_r_s, x_c_s,
|
|
75
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
76
|
+
output, o_b_s, o_r_s, o_c_s,
|
|
77
|
+
sparsity_block_size))
|
|
77
78
|
|
|
78
|
-
|
|
79
|
+
return output
|
|
79
80
|
|
|
80
81
|
|
|
81
|
-
def
|
|
82
|
+
def to_sparse_wrapper_backward(ctx, grad_output):
|
|
82
83
|
sparsity_layout = ctx.saved_tensors[0]
|
|
83
84
|
sparsity_block_size = ctx.sparsity_block_size
|
|
84
85
|
|
|
@@ -161,7 +162,7 @@ def to_sparse_setup_context(ctx, inputs, output):
|
|
|
161
162
|
ctx.sparsity_block_size = sparsity_block_size
|
|
162
163
|
|
|
163
164
|
|
|
164
|
-
to_sparse_forward.register_autograd(
|
|
165
|
+
to_sparse_forward.register_autograd(to_sparse_wrapper_backward, setup_context=to_sparse_setup_context)
|
|
165
166
|
|
|
166
167
|
|
|
167
168
|
def from_blksprs(x: BlksprsTensor, sparsity_layout: Tensor,
|
|
@@ -207,38 +208,39 @@ def to_dense(x: BlksprsTensor, sparsity_layout: Tensor,
|
|
|
207
208
|
lut["sparsity_reverse_lut"], sparsity_block_size, fill_value)
|
|
208
209
|
|
|
209
210
|
|
|
210
|
-
@triton_op("blksprs::
|
|
211
|
+
@triton_op("blksprs::to_dense_forward", mutates_args={})
|
|
211
212
|
def to_dense_forward(x: Tensor, sparsity_layout: Tensor,
|
|
212
213
|
sparsity_reverse_lut: Tensor,
|
|
213
214
|
sparsity_block_size: int, fill_value: float) -> Tensor:
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
215
|
+
with torch.no_grad():
|
|
216
|
+
output = torch.full(size=(sparsity_layout.size(0), sparsity_layout.size(1) * sparsity_block_size,
|
|
217
|
+
sparsity_layout.size(2) * sparsity_block_size), fill_value=fill_value,
|
|
218
|
+
dtype=x.dtype, device=x.device)
|
|
219
|
+
|
|
220
|
+
x_b, x_r, x_c = x.shape
|
|
221
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
222
|
+
s_l_b, s_l_r, s_l_c = sparsity_layout.size()
|
|
223
|
+
s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout)
|
|
224
|
+
o_b, o_r, o_c = output.size()
|
|
225
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
226
|
+
|
|
227
|
+
triton_grid = lambda meta: [o_b,
|
|
228
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
229
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
230
|
+
|
|
231
|
+
(wrap_triton(to_dense_kernel)[triton_grid]
|
|
232
|
+
(x,
|
|
233
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
234
|
+
s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
|
|
235
|
+
sparsity_reverse_lut,
|
|
236
|
+
output,
|
|
237
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
238
|
+
sparsity_block_size))
|
|
239
|
+
|
|
240
|
+
return output
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def to_dense_wrapper_backward(ctx, grad_output):
|
|
242
244
|
sparsity_layout = ctx.saved_tensors[0]
|
|
243
245
|
sparsity_block_size = ctx.sparsity_block_size
|
|
244
246
|
|
|
@@ -316,7 +318,7 @@ def to_dense_setup_context(ctx, inputs, output):
|
|
|
316
318
|
ctx.sparsity_block_size = sparsity_block_size
|
|
317
319
|
|
|
318
320
|
|
|
319
|
-
to_dense_forward.register_autograd(
|
|
321
|
+
to_dense_forward.register_autograd(to_dense_wrapper_backward, setup_context=to_dense_setup_context)
|
|
320
322
|
|
|
321
323
|
|
|
322
324
|
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
@@ -372,45 +374,45 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
|
|
|
372
374
|
n_sparse_blocks_to)), sparsity_layout_to
|
|
373
375
|
|
|
374
376
|
|
|
375
|
-
@triton_op("blksprs::
|
|
377
|
+
@triton_op("blksprs::adapt_layout_forward", mutates_args={})
|
|
376
378
|
def adapt_layout_forward(x: Tensor,
|
|
377
379
|
sparsity_layout_from: Tensor, sparsity_reverse_lut_from: Tensor,
|
|
378
380
|
sparsity_block_size_from: int,
|
|
379
381
|
_: Tensor, sparsity_lut_to: Tensor,
|
|
380
382
|
sparsity_block_size_to: int,
|
|
381
383
|
n_sparse_blocks_to: int) -> Tensor:
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
def
|
|
384
|
+
with torch.no_grad():
|
|
385
|
+
output = torch.zeros(size=(n_sparse_blocks_to, sparsity_block_size_to, sparsity_block_size_to),
|
|
386
|
+
dtype=x.dtype, device=x.device)
|
|
387
|
+
|
|
388
|
+
x_b, x_r, x_c = x.size()
|
|
389
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
390
|
+
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_from.size()
|
|
391
|
+
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_from)
|
|
392
|
+
o_b, o_r, o_c = output.size()
|
|
393
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
394
|
+
s_lut_o_r, s_lut_o_c = sparsity_lut_to.size()
|
|
395
|
+
s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_to)
|
|
396
|
+
|
|
397
|
+
triton_grid = lambda meta: [o_b,
|
|
398
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
399
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
400
|
+
|
|
401
|
+
(wrap_triton(adapt_layout_kernel)[triton_grid]
|
|
402
|
+
(x,
|
|
403
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
404
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
405
|
+
sparsity_reverse_lut_from,
|
|
406
|
+
output,
|
|
407
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
408
|
+
sparsity_lut_to, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
409
|
+
sparsity_block_size_from,
|
|
410
|
+
sparsity_block_size_to))
|
|
411
|
+
|
|
412
|
+
return output
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
def adapt_layout_wrapper_backward(ctx, grad_output):
|
|
414
416
|
x, sparsity_layout_from, sparsity_layout_to = ctx.saved_tensors
|
|
415
417
|
sparsity_block_size_from = ctx.sparsity_block_size_from
|
|
416
418
|
sparsity_block_size_to = ctx.sparsity_block_size_to
|
|
@@ -501,4 +503,4 @@ def adapt_layout_setup_context(ctx, inputs, output):
|
|
|
501
503
|
ctx.sparsity_block_size_to = sparsity_block_size_to
|
|
502
504
|
|
|
503
505
|
|
|
504
|
-
adapt_layout_forward.register_autograd(
|
|
506
|
+
adapt_layout_forward.register_autograd(adapt_layout_wrapper_backward, setup_context=adapt_layout_setup_context)
|
blksprs/ops/distribution.py
CHANGED
|
@@ -51,44 +51,45 @@ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
51
51
|
sparsity_block_size))
|
|
52
52
|
|
|
53
53
|
|
|
54
|
-
@triton_op("blksprs::
|
|
54
|
+
@triton_op("blksprs::gather_forward", mutates_args={})
|
|
55
55
|
def gather_forward(x: Tensor, sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
|
|
56
56
|
dim: int, i: Tensor, _: Tensor, sparsity_lut_i: Tensor,
|
|
57
57
|
sparsity_block_size: int) -> Tensor:
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
58
|
+
with torch.no_grad():
|
|
59
|
+
output = torch.zeros_like(i, dtype=x.dtype)
|
|
60
|
+
|
|
61
|
+
x_b, x_r, x_c = x.size()
|
|
62
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
63
|
+
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
|
|
64
|
+
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
|
|
65
|
+
i_b, i_r, i_c = i.size()
|
|
66
|
+
i_b_s, i_r_s, i_c_s = stride(i)
|
|
67
|
+
s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
|
|
68
|
+
s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
|
|
69
|
+
o_b, o_r, o_c = output.size()
|
|
70
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
71
|
+
|
|
72
|
+
triton_grid = lambda meta: [o_b,
|
|
73
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
74
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
75
|
+
|
|
76
|
+
(wrap_triton(gather_kernel)[triton_grid]
|
|
77
|
+
(x,
|
|
78
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
79
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
80
|
+
sparsity_reverse_lut_x,
|
|
81
|
+
dim,
|
|
82
|
+
i,
|
|
83
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
84
|
+
output,
|
|
85
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
86
|
+
sparsity_lut_i, s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
|
|
87
|
+
sparsity_block_size))
|
|
88
|
+
|
|
89
|
+
return output
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def gather_wrapper_backward(ctx, grad_output):
|
|
92
93
|
sparsity_layout_x, i, sparsity_layout_i = ctx.saved_tensors
|
|
93
94
|
dim = ctx.dim
|
|
94
95
|
sparsity_block_size = ctx.sparsity_block_size
|
|
@@ -221,7 +222,7 @@ def gather_setup_context(ctx, inputs, output):
|
|
|
221
222
|
ctx.sparsity_block_size = sparsity_block_size
|
|
222
223
|
|
|
223
224
|
|
|
224
|
-
gather_forward.register_autograd(
|
|
225
|
+
gather_forward.register_autograd(gather_wrapper_backward, setup_context=gather_setup_context)
|
|
225
226
|
|
|
226
227
|
|
|
227
228
|
def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
@@ -288,52 +289,53 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
288
289
|
reduce_op))
|
|
289
290
|
|
|
290
291
|
|
|
291
|
-
@triton_op("blksprs::
|
|
292
|
+
@triton_op("blksprs::scatter_reduce_forward", mutates_args={})
|
|
292
293
|
def scatter_reduce_forward(x: Tensor, _: Tensor, sparsity_lut_x: Tensor,
|
|
293
294
|
dim: int, i: Tensor,
|
|
294
295
|
sparsity_layout_o: Tensor, sparsity_reverse_lut_o: Tensor,
|
|
295
296
|
sparsity_block_size: int, n_sparse_blocks: int,
|
|
296
297
|
reduce_op: str) -> Tensor:
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
298
|
+
with torch.no_grad():
|
|
299
|
+
output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
300
|
+
dtype=x.dtype, device=x.device)
|
|
301
|
+
|
|
302
|
+
x_b, x_r, x_c = x.size()
|
|
303
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
304
|
+
s_lut_x_r, s_lut_x_c = sparsity_lut_x.size()
|
|
305
|
+
s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut_x)
|
|
306
|
+
i_b, i_r, i_c = i.size()
|
|
307
|
+
i_b_s, i_r_s, i_c_s = stride(i)
|
|
308
|
+
o_b, o_r, o_c = output.size()
|
|
309
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
310
|
+
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
|
|
311
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
|
|
312
|
+
|
|
313
|
+
triton_grid = lambda meta: [x_b,
|
|
314
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
315
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
316
|
+
|
|
317
|
+
reduce_op_ind = 0
|
|
318
|
+
if reduce_op == "sum":
|
|
319
|
+
reduce_op_ind = 1
|
|
320
|
+
|
|
321
|
+
(wrap_triton(scatter_reduce_kernel)[triton_grid]
|
|
322
|
+
(x,
|
|
323
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
324
|
+
sparsity_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
325
|
+
dim,
|
|
326
|
+
i,
|
|
327
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
328
|
+
output,
|
|
329
|
+
o_b, o_b_s,
|
|
330
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
331
|
+
sparsity_reverse_lut_o,
|
|
332
|
+
reduce_op_ind,
|
|
333
|
+
sparsity_block_size))
|
|
334
|
+
|
|
335
|
+
return output
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def scatter_reduce_wrapper_backward(ctx, grad_output):
|
|
337
339
|
sparsity_layout_x, i, sparsity_layout_o = ctx.saved_tensors
|
|
338
340
|
dim = ctx.dim
|
|
339
341
|
sparsity_block_size = ctx.sparsity_block_size
|
|
@@ -477,4 +479,4 @@ def scatter_reduce_setup_context(ctx, inputs, output):
|
|
|
477
479
|
ctx.reduce_op = reduce_op
|
|
478
480
|
|
|
479
481
|
|
|
480
|
-
scatter_reduce_forward.register_autograd(
|
|
482
|
+
scatter_reduce_forward.register_autograd(scatter_reduce_wrapper_backward, setup_context=scatter_reduce_setup_context)
|