blksprs 1.10.2__py3-none-any.whl → 2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +4 -6
- blksprs/layouting/distribution_layout.py +64 -48
- blksprs/layouting/sparsity_layout.py +96 -72
- blksprs/ops/conversion.py +350 -312
- blksprs/ops/distribution.py +320 -266
- blksprs/ops/flow.py +135 -89
- blksprs/ops/matmul.py +184 -151
- blksprs/ops/misc/broadcast_ops.py +68 -53
- blksprs/ops/misc/row_wise.py +223 -176
- blksprs/ops/partitioning.py +140 -89
- blksprs/ops/repeat.py +118 -108
- blksprs/ops/softmax.py +201 -167
- blksprs/ops/transpose.py +71 -131
- blksprs/utils/autotuning.py +78 -0
- blksprs/utils/benchmarking.py +3 -3
- blksprs/utils/processing.py +2 -1
- blksprs/utils/tools.py +5 -6
- blksprs/utils/validation.py +22 -16
- {blksprs-1.10.2.dist-info → blksprs-2.0.dist-info}/METADATA +55 -36
- blksprs-2.0.dist-info/RECORD +23 -0
- {blksprs-1.10.2.dist-info → blksprs-2.0.dist-info}/WHEEL +1 -1
- blksprs/ops/misc/exp.py +0 -104
- blksprs/utils/layout_utils.py +0 -17
- blksprs-1.10.2.dist-info/RECORD +0 -24
- {blksprs-1.10.2.dist-info → blksprs-2.0.dist-info}/top_level.txt +0 -0
blksprs/ops/distribution.py
CHANGED
|
@@ -1,19 +1,22 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import triton
|
|
3
3
|
from torch import Tensor
|
|
4
|
+
from torch._library import triton_op
|
|
5
|
+
from torch._library.triton import wrap_triton
|
|
4
6
|
from triton import language as tl
|
|
5
7
|
|
|
6
|
-
from blksprs.ops.conversion import to_dense
|
|
7
8
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
8
|
-
from blksprs.utils.tools import
|
|
9
|
+
from blksprs.utils.tools import stride
|
|
10
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
9
11
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
10
|
-
validate_sparsity, validate_dtype_int, validate_sparsity_block_size
|
|
12
|
+
validate_sparsity, validate_dtype_int, validate_sparsity_block_size
|
|
11
13
|
|
|
12
14
|
|
|
15
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
13
16
|
def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
14
17
|
dim: int,
|
|
15
18
|
idx: BlksprsTensor, sparsity_layout_idx: Tensor,
|
|
16
|
-
sparsity_block_size: int,
|
|
19
|
+
sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
|
|
17
20
|
"""Applies a gather operation on a block-sparse tensor in compressed form.
|
|
18
21
|
|
|
19
22
|
Args:
|
|
@@ -23,7 +26,7 @@ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
23
26
|
idx (BlksprsTensor): The block-sparse indices tensor in compressed form specifying how to gather from the source tensor.
|
|
24
27
|
sparsity_layout_idx (Tensor): The sparsity layout of the indices block-sparse tensor.
|
|
25
28
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
26
|
-
|
|
29
|
+
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
27
30
|
|
|
28
31
|
Returns:
|
|
29
32
|
BlksprsTensor: The result of the gather operation as a block-sparse tensor in compressed form.
|
|
@@ -38,32 +41,22 @@ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
38
41
|
validate_device(src, idx)
|
|
39
42
|
validate_sparsity(sparsity_block_size, (src, sparsity_layout_src), (idx, sparsity_layout_idx))
|
|
40
43
|
validate_sparsity_block_size(sparsity_block_size, src, idx)
|
|
41
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
42
|
-
|
|
43
|
-
sparsity_layout_x_flat = sparsity_layout_src.reshape(-1)
|
|
44
|
-
sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
|
|
45
|
-
(sparsity_layout_x_flat == 1) -
|
|
46
|
-
(1 * (sparsity_layout_x_flat == 0)))
|
|
47
|
-
|
|
48
|
-
sparsity_lut_i = torch.nonzero(sparsity_layout_idx).contiguous()
|
|
49
|
-
|
|
50
|
-
validate_contiguous(sparsity_layout_src, sparsity_reverse_lut_x,
|
|
51
|
-
sparsity_layout_idx, sparsity_lut_i)
|
|
52
44
|
|
|
53
45
|
adjusted_dim = dim % 3
|
|
54
46
|
|
|
55
|
-
|
|
56
|
-
adjusted_dim, idx, sparsity_layout_idx, sparsity_lut_i,
|
|
57
|
-
sparsity_block_size, triton_block_size))
|
|
47
|
+
lut = gather_build_lut(lut, sparsity_layout_src, sparsity_layout_idx)
|
|
58
48
|
|
|
49
|
+
return BlksprsTensor(gather_forward(src, sparsity_layout_src, lut["sparsity_reverse_lut_x"],
|
|
50
|
+
adjusted_dim, idx, sparsity_layout_idx, lut["sparsity_lut_i"],
|
|
51
|
+
sparsity_block_size))
|
|
59
52
|
|
|
60
|
-
class _BlocksparseGather(torch.autograd.Function):
|
|
61
53
|
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
54
|
+
@triton_op("blksprs::gather_forward", mutates_args={})
|
|
55
|
+
def gather_forward(x: Tensor, sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
|
|
56
|
+
dim: int, i: Tensor, _: Tensor, sparsity_lut_i: Tensor,
|
|
57
|
+
sparsity_block_size: int) -> Tensor:
|
|
58
|
+
with torch.no_grad():
|
|
59
|
+
output = torch.zeros_like(i, dtype=x.dtype)
|
|
67
60
|
|
|
68
61
|
x_b, x_r, x_c = x.size()
|
|
69
62
|
x_b_s, x_r_s, x_c_s = stride(x)
|
|
@@ -76,14 +69,11 @@ class _BlocksparseGather(torch.autograd.Function):
|
|
|
76
69
|
o_b, o_r, o_c = output.size()
|
|
77
70
|
o_b_s, o_r_s, o_c_s = stride(output)
|
|
78
71
|
|
|
79
|
-
if triton_block_size is None:
|
|
80
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
81
|
-
|
|
82
72
|
triton_grid = lambda meta: [o_b,
|
|
83
73
|
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
84
74
|
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
85
75
|
|
|
86
|
-
(
|
|
76
|
+
(wrap_triton(gather_kernel)[triton_grid]
|
|
87
77
|
(x,
|
|
88
78
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
89
79
|
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
@@ -94,115 +84,152 @@ class _BlocksparseGather(torch.autograd.Function):
|
|
|
94
84
|
output,
|
|
95
85
|
o_b, o_b_s, o_r_s, o_c_s,
|
|
96
86
|
sparsity_lut_i, s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
|
|
97
|
-
sparsity_block_size
|
|
98
|
-
triton_block_size))
|
|
99
|
-
|
|
100
|
-
ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_i)
|
|
101
|
-
ctx.dim = dim
|
|
102
|
-
ctx.sparsity_block_size = sparsity_block_size
|
|
103
|
-
ctx.triton_block_size = triton_block_size
|
|
87
|
+
sparsity_block_size))
|
|
104
88
|
|
|
105
89
|
return output
|
|
106
90
|
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
91
|
+
|
|
92
|
+
def gather_wrapper_backward(ctx, grad_output):
|
|
93
|
+
sparsity_layout_x, i, sparsity_layout_i = ctx.saved_tensors
|
|
94
|
+
dim = ctx.dim
|
|
95
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
96
|
+
|
|
97
|
+
return scatter_reduce(grad_output, sparsity_layout_i,
|
|
98
|
+
dim, i,
|
|
99
|
+
sparsity_layout_x, sparsity_block_size,
|
|
100
|
+
reduce_op="sum"), None, None, None, None, None, None, None
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@triton.autotune(
|
|
104
|
+
configs=get_autotune_configs(),
|
|
105
|
+
key=["sparsity_block_size"],
|
|
106
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
107
|
+
reset_to_zero=["o"]
|
|
108
|
+
)
|
|
109
|
+
@triton.jit
|
|
110
|
+
def gather_kernel(x,
|
|
111
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
112
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
113
|
+
r_lut_x,
|
|
114
|
+
dim,
|
|
115
|
+
i,
|
|
116
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
117
|
+
o,
|
|
118
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
119
|
+
s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
120
|
+
sparsity_block_size,
|
|
121
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
122
|
+
# Get triton block indices
|
|
123
|
+
pid_blk = tl.program_id(axis=0)
|
|
124
|
+
pid_row = tl.program_id(axis=1)
|
|
125
|
+
pid_col = tl.program_id(axis=2)
|
|
126
|
+
|
|
127
|
+
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
128
|
+
spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
|
|
129
|
+
spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
130
|
+
spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
|
|
131
|
+
|
|
132
|
+
spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
|
|
133
|
+
spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
134
|
+
spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
|
|
135
|
+
|
|
136
|
+
spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
|
|
137
|
+
spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
138
|
+
spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
|
|
139
|
+
|
|
140
|
+
# Load index values
|
|
141
|
+
blk_i_idx = ((pid_blk * i_b_s) +
|
|
142
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
143
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
144
|
+
blk_i_msk = (blk_i_idx >= 0 and
|
|
145
|
+
blk_i_idx < i_b * i_b_s)
|
|
146
|
+
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
|
|
147
|
+
|
|
148
|
+
# Get indices of sparsity blocks and positions within the blocks
|
|
149
|
+
pos_spa_blk_x = blk_i // sparsity_block_size
|
|
150
|
+
pos_spa_int_x = blk_i % sparsity_block_size
|
|
151
|
+
|
|
152
|
+
rev_dst_bat_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_o, dtype=tl.int32)
|
|
153
|
+
rev_dst_row_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_o, dtype=tl.int32)
|
|
154
|
+
rev_dst_col_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_o, dtype=tl.int32)
|
|
155
|
+
dst_row_x = (((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
|
|
156
|
+
.broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
|
|
157
|
+
dst_col_x = (((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
|
|
158
|
+
.broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
|
|
159
|
+
if dim == 0:
|
|
160
|
+
rev_dst_bat_x = blk_i
|
|
161
|
+
elif dim == 1:
|
|
162
|
+
rev_dst_row_x = pos_spa_blk_x
|
|
163
|
+
dst_row_x = pos_spa_int_x * x_r_s
|
|
164
|
+
elif dim == 2:
|
|
165
|
+
rev_dst_col_x = pos_spa_blk_x
|
|
166
|
+
dst_col_x = pos_spa_int_x * x_c_s
|
|
167
|
+
|
|
168
|
+
# Load reverse sparsity indices for x
|
|
169
|
+
rev_idx_spa_x_idx = ((rev_dst_bat_x * s_l_x_b_s) +
|
|
170
|
+
(rev_dst_row_x * s_l_x_r_s) +
|
|
171
|
+
(rev_dst_col_x * s_l_x_c_s))
|
|
172
|
+
rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and
|
|
173
|
+
rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
|
|
174
|
+
rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
|
|
175
|
+
|
|
176
|
+
# Load x values
|
|
177
|
+
blk_x_idx = ((rev_idx_spa_x * x_b_s) +
|
|
178
|
+
dst_row_x +
|
|
179
|
+
dst_col_x)
|
|
180
|
+
blk_x_msk = ((blk_x_idx >= 0 and
|
|
181
|
+
blk_x_idx < x_b * x_b_s) and
|
|
182
|
+
rev_idx_spa_x_msk != -1)
|
|
183
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
184
|
+
|
|
185
|
+
# Store output
|
|
186
|
+
blk_o_idx = ((pid_blk * o_b_s) +
|
|
187
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
188
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
189
|
+
blk_o_msk = ((blk_o_idx >= 0 and
|
|
190
|
+
blk_o_idx < o_b * o_b_s) and
|
|
191
|
+
rev_idx_spa_x_msk != -1)
|
|
192
|
+
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def gather_build_lut(lut: dict, sparsity_layout_src: Tensor, sparsity_layout_idx: Tensor):
|
|
196
|
+
if lut is None:
|
|
197
|
+
lut = dict()
|
|
198
|
+
|
|
199
|
+
if "sparsity_reverse_lut_x" not in lut:
|
|
200
|
+
sparsity_layout_x_flat = sparsity_layout_src.reshape(-1)
|
|
201
|
+
sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
|
|
202
|
+
(sparsity_layout_x_flat == 1) -
|
|
203
|
+
(1 * (sparsity_layout_x_flat == 0)))
|
|
204
|
+
lut["sparsity_reverse_lut_x"] = sparsity_reverse_lut_x
|
|
205
|
+
|
|
206
|
+
if "sparsity_lut_i" not in lut:
|
|
207
|
+
sparsity_lut_i = torch.nonzero(sparsity_layout_idx).contiguous()
|
|
208
|
+
lut["sparsity_lut_i"] = sparsity_lut_i
|
|
209
|
+
|
|
210
|
+
validate_contiguous(sparsity_layout_src, lut["sparsity_reverse_lut_x"],
|
|
211
|
+
sparsity_layout_idx, lut["sparsity_lut_i"])
|
|
212
|
+
|
|
213
|
+
return lut
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
# noinspection PyUnusedLocal
|
|
217
|
+
def gather_setup_context(ctx, inputs, output):
|
|
218
|
+
(_, sparsity_layout_x, _, dim, i, sparsity_layout_i, _, sparsity_block_size) = inputs
|
|
219
|
+
|
|
220
|
+
ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_i)
|
|
221
|
+
ctx.dim = dim
|
|
222
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
gather_forward.register_autograd(gather_wrapper_backward, setup_context=gather_setup_context)
|
|
199
226
|
|
|
200
227
|
|
|
201
228
|
def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
202
229
|
dim: int,
|
|
203
230
|
idx: BlksprsTensor,
|
|
204
231
|
sparsity_layout_tgt: Tensor,
|
|
205
|
-
sparsity_block_size: int,
|
|
232
|
+
sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
|
|
206
233
|
"""Wrapper for ``scatter_reduce`` with ``reduce_op="none"``.
|
|
207
234
|
|
|
208
235
|
"""
|
|
@@ -211,15 +238,16 @@ def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
211
238
|
idx,
|
|
212
239
|
sparsity_layout_tgt,
|
|
213
240
|
sparsity_block_size,
|
|
214
|
-
reduce_op="none",
|
|
241
|
+
reduce_op="none", lut=lut)
|
|
215
242
|
|
|
216
243
|
|
|
244
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
|
|
217
245
|
def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
218
246
|
dim: int,
|
|
219
247
|
idx: BlksprsTensor,
|
|
220
248
|
sparsity_layout_tgt: Tensor,
|
|
221
249
|
sparsity_block_size: int,
|
|
222
|
-
reduce_op: str = "sum",
|
|
250
|
+
reduce_op: str = "sum", lut: dict = None) -> BlksprsTensor:
|
|
223
251
|
"""Applies a scatter operation on a block-sparse tensor in compressed form.
|
|
224
252
|
|
|
225
253
|
Args:
|
|
@@ -231,7 +259,7 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
231
259
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
232
260
|
reduce_op (str, optional): The reduction operation to apply during the scatter operation (default ``"sum"``).
|
|
233
261
|
Supported operations are ``"none"`` and ``"sum"``.
|
|
234
|
-
|
|
262
|
+
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
235
263
|
|
|
236
264
|
Returns:
|
|
237
265
|
BlksprsTensor: The result of the scatter operation as a block-sparse tensor in compressed form.
|
|
@@ -246,40 +274,28 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
246
274
|
validate_device(src, idx)
|
|
247
275
|
validate_sparsity(sparsity_block_size, (src, sparsity_layout_src), (idx, sparsity_layout_src))
|
|
248
276
|
validate_sparsity_block_size(sparsity_block_size, src, idx)
|
|
249
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
250
277
|
|
|
251
278
|
if reduce_op not in ["none", "sum"]:
|
|
252
279
|
raise ValueError(f"Reduction operation '{reduce_op}' is not supported")
|
|
253
280
|
|
|
254
|
-
sparsity_lut_x = torch.nonzero(sparsity_layout_src).contiguous()
|
|
255
|
-
|
|
256
|
-
sparsity_layout_o_flat = sparsity_layout_tgt.reshape(-1)
|
|
257
|
-
sparsity_reverse_lut_o = ((torch.cumsum(sparsity_layout_o_flat, dim=-1) - 1) *
|
|
258
|
-
(sparsity_layout_o_flat == 1) -
|
|
259
|
-
(1 * (sparsity_layout_o_flat == 0)))
|
|
260
|
-
|
|
261
|
-
n_sparse_blocks = torch.sum(sparsity_layout_tgt.to(torch.int)).item()
|
|
262
|
-
|
|
263
|
-
validate_contiguous(sparsity_layout_src, sparsity_lut_x,
|
|
264
|
-
sparsity_layout_tgt, sparsity_reverse_lut_o)
|
|
265
|
-
|
|
266
281
|
adjusted_dim = dim % 3
|
|
267
282
|
|
|
268
|
-
|
|
269
|
-
adjusted_dim, idx,
|
|
270
|
-
sparsity_layout_tgt, sparsity_reverse_lut_o,
|
|
271
|
-
sparsity_block_size, n_sparse_blocks,
|
|
272
|
-
reduce_op, triton_block_size))
|
|
283
|
+
lut = scatter_reduce_build_lut(lut, sparsity_layout_src, sparsity_layout_tgt)
|
|
273
284
|
|
|
285
|
+
return BlksprsTensor(scatter_reduce_forward(src, sparsity_layout_src, lut["sparsity_lut_x"],
|
|
286
|
+
adjusted_dim, idx,
|
|
287
|
+
sparsity_layout_tgt, lut["sparsity_reverse_lut_o"],
|
|
288
|
+
sparsity_block_size, lut["n_sparse_blocks"],
|
|
289
|
+
reduce_op))
|
|
274
290
|
|
|
275
|
-
class _BlocksparseScatterReduce(torch.autograd.Function):
|
|
276
291
|
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
292
|
+
@triton_op("blksprs::scatter_reduce_forward", mutates_args={})
|
|
293
|
+
def scatter_reduce_forward(x: Tensor, _: Tensor, sparsity_lut_x: Tensor,
|
|
294
|
+
dim: int, i: Tensor,
|
|
295
|
+
sparsity_layout_o: Tensor, sparsity_reverse_lut_o: Tensor,
|
|
296
|
+
sparsity_block_size: int, n_sparse_blocks: int,
|
|
297
|
+
reduce_op: str) -> Tensor:
|
|
298
|
+
with torch.no_grad():
|
|
283
299
|
output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
284
300
|
dtype=x.dtype, device=x.device)
|
|
285
301
|
|
|
@@ -294,9 +310,6 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
|
|
|
294
310
|
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
|
|
295
311
|
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
|
|
296
312
|
|
|
297
|
-
if triton_block_size is None:
|
|
298
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
299
|
-
|
|
300
313
|
triton_grid = lambda meta: [x_b,
|
|
301
314
|
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
302
315
|
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
@@ -305,7 +318,7 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
|
|
|
305
318
|
if reduce_op == "sum":
|
|
306
319
|
reduce_op_ind = 1
|
|
307
320
|
|
|
308
|
-
(
|
|
321
|
+
(wrap_triton(scatter_reduce_kernel)[triton_grid]
|
|
309
322
|
(x,
|
|
310
323
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
311
324
|
sparsity_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
@@ -317,112 +330,153 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
|
|
|
317
330
|
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
318
331
|
sparsity_reverse_lut_o,
|
|
319
332
|
reduce_op_ind,
|
|
320
|
-
sparsity_block_size
|
|
321
|
-
triton_block_size))
|
|
322
|
-
|
|
323
|
-
ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_o)
|
|
324
|
-
ctx.dim = dim
|
|
325
|
-
ctx.sparsity_block_size = sparsity_block_size
|
|
326
|
-
ctx.reduce_op = reduce_op
|
|
327
|
-
ctx.triton_block_size = triton_block_size
|
|
333
|
+
sparsity_block_size))
|
|
328
334
|
|
|
329
335
|
return output
|
|
330
336
|
|
|
331
|
-
@staticmethod
|
|
332
|
-
def backward(ctx, grad_output):
|
|
333
|
-
sparsity_layout_x, i, sparsity_layout_o = ctx.saved_tensors
|
|
334
|
-
dim = ctx.dim
|
|
335
|
-
sparsity_block_size = ctx.sparsity_block_size
|
|
336
|
-
reduce_op = ctx.reduce_op
|
|
337
|
-
triton_block_size = ctx.triton_block_size
|
|
338
337
|
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
338
|
+
def scatter_reduce_wrapper_backward(ctx, grad_output):
|
|
339
|
+
sparsity_layout_x, i, sparsity_layout_o = ctx.saved_tensors
|
|
340
|
+
dim = ctx.dim
|
|
341
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
342
|
+
reduce_op = ctx.reduce_op
|
|
343
|
+
|
|
344
|
+
if reduce_op == "sum":
|
|
345
|
+
return gather(grad_output, sparsity_layout_o, dim, i, sparsity_layout_x,
|
|
346
|
+
sparsity_block_size), None, None, None, None, None, None, None, None, None
|
|
347
|
+
else:
|
|
348
|
+
raise ValueError(f"Reduction operation '{reduce_op}' does not support backward pass")
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
@triton.autotune(
|
|
352
|
+
configs=get_autotune_configs(),
|
|
353
|
+
key=["sparsity_block_size"],
|
|
354
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
355
|
+
reset_to_zero=["o"]
|
|
356
|
+
)
|
|
357
|
+
@triton.jit
|
|
358
|
+
def scatter_reduce_kernel(x,
|
|
359
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
360
|
+
s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
361
|
+
dim,
|
|
362
|
+
i,
|
|
363
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
364
|
+
o,
|
|
365
|
+
o_b, o_b_s,
|
|
366
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
367
|
+
r_lut_o,
|
|
368
|
+
reduce_op_ind,
|
|
369
|
+
sparsity_block_size,
|
|
370
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
371
|
+
# Get triton block indices
|
|
372
|
+
pid_blk = tl.program_id(axis=0)
|
|
373
|
+
pid_row = tl.program_id(axis=1)
|
|
374
|
+
pid_col = tl.program_id(axis=2)
|
|
375
|
+
|
|
376
|
+
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
377
|
+
spa_bat_x_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
|
|
378
|
+
spa_bat_x_msk = (spa_bat_x_idx >= 0 and spa_bat_x_idx < s_lut_x_r * s_lut_x_r_s)
|
|
379
|
+
spa_bat_x = tl.load(s_lut_x + spa_bat_x_idx, mask=spa_bat_x_msk)
|
|
380
|
+
|
|
381
|
+
spa_row_x_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
|
|
382
|
+
spa_row_x_msk = (spa_row_x_idx >= 0 and spa_row_x_idx < s_lut_x_r * s_lut_x_r_s)
|
|
383
|
+
spa_row_x = tl.load(s_lut_x + spa_row_x_idx, mask=spa_row_x_msk)
|
|
384
|
+
|
|
385
|
+
spa_col_x_idx = (pid_blk * s_lut_x_r_s + 2 * s_lut_x_c_s)
|
|
386
|
+
spa_col_x_msk = (spa_col_x_idx >= 0 and spa_col_x_idx < s_lut_x_r * s_lut_x_r_s)
|
|
387
|
+
spa_col_x = tl.load(s_lut_x + spa_col_x_idx, mask=spa_col_x_msk)
|
|
388
|
+
|
|
389
|
+
# Load x values
|
|
390
|
+
blk_x_idx = ((pid_blk * x_b_s) +
|
|
391
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
392
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
393
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
394
|
+
blk_x_idx < x_b * x_b_s)
|
|
395
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
396
|
+
|
|
397
|
+
# Load index values
|
|
398
|
+
blk_i_idx = ((pid_blk * i_b_s) +
|
|
399
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
400
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
401
|
+
blk_i_msk = (blk_i_idx >= 0 and
|
|
402
|
+
blk_i_idx < i_b * i_b_s)
|
|
403
|
+
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
|
|
404
|
+
|
|
405
|
+
# Get indices of sparsity blocks and positions within the blocks
|
|
406
|
+
pos_spa_blk_x = blk_i // sparsity_block_size
|
|
407
|
+
pos_spa_int_x = blk_i % sparsity_block_size
|
|
408
|
+
|
|
409
|
+
rev_dst_bat_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_x, dtype=tl.int32)
|
|
410
|
+
rev_dst_row_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_x, dtype=tl.int32)
|
|
411
|
+
rev_dst_col_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_x, dtype=tl.int32)
|
|
412
|
+
dst_row_o = (((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
|
|
413
|
+
.broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
|
|
414
|
+
dst_col_o = (((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
|
|
415
|
+
.broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
|
|
416
|
+
if dim == 0:
|
|
417
|
+
rev_dst_bat_o = blk_i
|
|
418
|
+
elif dim == 1:
|
|
419
|
+
rev_dst_row_o = pos_spa_blk_x
|
|
420
|
+
dst_row_o = pos_spa_int_x * x_r_s
|
|
421
|
+
elif dim == 2:
|
|
422
|
+
rev_dst_col_o = pos_spa_blk_x
|
|
423
|
+
dst_col_o = pos_spa_int_x * x_c_s
|
|
424
|
+
|
|
425
|
+
# Load reverse sparsity indices for o
|
|
426
|
+
rev_idx_spa_o_idx = ((rev_dst_bat_o * s_l_o_b_s) +
|
|
427
|
+
(rev_dst_row_o * s_l_o_r_s) +
|
|
428
|
+
(rev_dst_col_o * s_l_o_c_s))
|
|
429
|
+
rev_idx_spa_o_msk = (rev_idx_spa_o_idx >= 0 and
|
|
430
|
+
rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
|
|
431
|
+
rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
|
|
432
|
+
|
|
433
|
+
# Store output
|
|
434
|
+
blk_o_idx = ((rev_idx_spa_o * o_b_s) +
|
|
435
|
+
dst_row_o +
|
|
436
|
+
dst_col_o)
|
|
437
|
+
blk_o_msk = ((blk_o_idx >= 0 and
|
|
438
|
+
blk_o_idx < o_b * o_b_s) and
|
|
439
|
+
rev_idx_spa_o_msk != -1)
|
|
440
|
+
|
|
441
|
+
if reduce_op_ind == 0:
|
|
442
|
+
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
443
|
+
elif reduce_op_ind == 1:
|
|
444
|
+
tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
def scatter_reduce_build_lut(lut: dict, sparsity_layout_src: Tensor, sparsity_layout_tgt: Tensor):
|
|
448
|
+
if lut is None:
|
|
449
|
+
lut = dict()
|
|
450
|
+
|
|
451
|
+
if "sparsity_lut_x" not in lut:
|
|
452
|
+
sparsity_lut_x = torch.nonzero(sparsity_layout_src).contiguous()
|
|
453
|
+
lut["sparsity_lut_x"] = sparsity_lut_x
|
|
454
|
+
|
|
455
|
+
if "sparsity_reverse_lut_o" not in lut:
|
|
456
|
+
sparsity_layout_o_flat = sparsity_layout_tgt.reshape(-1)
|
|
457
|
+
sparsity_reverse_lut_o = ((torch.cumsum(sparsity_layout_o_flat, dim=-1) - 1) *
|
|
458
|
+
(sparsity_layout_o_flat == 1) -
|
|
459
|
+
(1 * (sparsity_layout_o_flat == 0)))
|
|
460
|
+
lut["sparsity_reverse_lut_o"] = sparsity_reverse_lut_o
|
|
461
|
+
|
|
462
|
+
if "n_sparse_blocks" not in lut:
|
|
463
|
+
n_sparse_blocks = torch.sum(sparsity_layout_tgt.to(torch.int)).item()
|
|
464
|
+
lut["n_sparse_blocks"] = n_sparse_blocks
|
|
465
|
+
|
|
466
|
+
validate_contiguous(sparsity_layout_src, lut["sparsity_lut_x"],
|
|
467
|
+
sparsity_layout_tgt, lut["sparsity_reverse_lut_o"])
|
|
468
|
+
|
|
469
|
+
return lut
|
|
470
|
+
|
|
471
|
+
|
|
472
|
+
# noinspection PyUnusedLocal
|
|
473
|
+
def scatter_reduce_setup_context(ctx, inputs, output):
|
|
474
|
+
(_, sparsity_layout_x, _, dim, i, sparsity_layout_o, _, sparsity_block_size, _, reduce_op) = inputs
|
|
475
|
+
|
|
476
|
+
ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_o)
|
|
477
|
+
ctx.dim = dim
|
|
478
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
479
|
+
ctx.reduce_op = reduce_op
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
scatter_reduce_forward.register_autograd(scatter_reduce_wrapper_backward, setup_context=scatter_reduce_setup_context)
|