blksprs 1.11__py3-none-any.whl → 2.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +2 -5
- blksprs/layouting/distribution_layout.py +32 -25
- blksprs/layouting/sparsity_layout.py +65 -52
- blksprs/ops/conversion.py +421 -399
- blksprs/ops/distribution.py +404 -366
- blksprs/ops/flow.py +125 -106
- blksprs/ops/matmul.py +220 -204
- blksprs/ops/misc/broadcast_ops.py +53 -35
- blksprs/ops/misc/row_wise.py +151 -91
- blksprs/ops/partitioning.py +136 -132
- blksprs/ops/repeat.py +115 -120
- blksprs/ops/softmax.py +274 -246
- blksprs/ops/transpose.py +52 -51
- blksprs/utils/benchmarking.py +3 -3
- blksprs/utils/tools.py +31 -4
- blksprs/utils/validation.py +0 -14
- {blksprs-1.11.dist-info → blksprs-2.0rc1.dist-info}/METADATA +42 -36
- blksprs-2.0rc1.dist-info/RECORD +22 -0
- {blksprs-1.11.dist-info → blksprs-2.0rc1.dist-info}/WHEEL +1 -1
- blksprs/utils/layout_utils.py +0 -17
- blksprs-1.11.dist-info/RECORD +0 -23
- {blksprs-1.11.dist-info → blksprs-2.0rc1.dist-info}/top_level.txt +0 -0
blksprs/ops/distribution.py
CHANGED
|
@@ -1,19 +1,20 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import triton
|
|
3
3
|
from torch import Tensor
|
|
4
|
+
from torch._library import triton_op
|
|
5
|
+
from torch._library.triton import wrap_triton
|
|
4
6
|
from triton import language as tl
|
|
5
7
|
|
|
6
|
-
from blksprs.ops.conversion import to_dense
|
|
7
8
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
8
|
-
from blksprs.utils.tools import
|
|
9
|
+
from blksprs.utils.tools import stride, get_autotune_configs
|
|
9
10
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
10
|
-
validate_sparsity, validate_dtype_int, validate_sparsity_block_size
|
|
11
|
+
validate_sparsity, validate_dtype_int, validate_sparsity_block_size
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
14
15
|
dim: int,
|
|
15
16
|
idx: BlksprsTensor, sparsity_layout_idx: Tensor,
|
|
16
|
-
sparsity_block_size: int,
|
|
17
|
+
sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
|
|
17
18
|
"""Applies a gather operation on a block-sparse tensor in compressed form.
|
|
18
19
|
|
|
19
20
|
Args:
|
|
@@ -23,7 +24,6 @@ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
23
24
|
idx (BlksprsTensor): The block-sparse indices tensor in compressed form specifying how to gather from the source tensor.
|
|
24
25
|
sparsity_layout_idx (Tensor): The sparsity layout of the indices block-sparse tensor.
|
|
25
26
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
26
|
-
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
27
27
|
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
28
28
|
|
|
29
29
|
Returns:
|
|
@@ -39,184 +39,203 @@ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
39
39
|
validate_device(src, idx)
|
|
40
40
|
validate_sparsity(sparsity_block_size, (src, sparsity_layout_src), (idx, sparsity_layout_idx))
|
|
41
41
|
validate_sparsity_block_size(sparsity_block_size, src, idx)
|
|
42
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
43
42
|
|
|
44
43
|
adjusted_dim = dim % 3
|
|
45
44
|
|
|
46
|
-
lut =
|
|
47
|
-
|
|
48
|
-
return BlksprsTensor(
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
45
|
+
lut = gather_build_lut(lut, sparsity_layout_src, sparsity_layout_idx)
|
|
46
|
+
|
|
47
|
+
return BlksprsTensor(gather_forward(src, sparsity_layout_src, lut["sparsity_reverse_lut_x"],
|
|
48
|
+
adjusted_dim, idx, sparsity_layout_idx, lut["sparsity_lut_i"],
|
|
49
|
+
sparsity_block_size))
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@triton_op("blksprs::gather", mutates_args={})
|
|
53
|
+
def gather_forward(x: Tensor, sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
|
|
54
|
+
dim: int, i: Tensor, _: Tensor, sparsity_lut_i: Tensor,
|
|
55
|
+
sparsity_block_size: int) -> Tensor:
|
|
56
|
+
output = torch.empty_like(i, dtype=x.dtype)
|
|
57
|
+
|
|
58
|
+
x_b, x_r, x_c = x.size()
|
|
59
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
60
|
+
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
|
|
61
|
+
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
|
|
62
|
+
i_b, i_r, i_c = i.size()
|
|
63
|
+
i_b_s, i_r_s, i_c_s = stride(i)
|
|
64
|
+
s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
|
|
65
|
+
s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
|
|
66
|
+
o_b, o_r, o_c = output.size()
|
|
67
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
68
|
+
|
|
69
|
+
triton_grid = lambda meta: [o_b,
|
|
70
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
71
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
72
|
+
|
|
73
|
+
(wrap_triton(gather_kernel)[triton_grid]
|
|
74
|
+
(x,
|
|
75
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
76
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
77
|
+
sparsity_reverse_lut_x,
|
|
78
|
+
dim,
|
|
79
|
+
i,
|
|
80
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
81
|
+
output,
|
|
82
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
83
|
+
sparsity_lut_i, s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
|
|
84
|
+
sparsity_block_size))
|
|
85
|
+
|
|
86
|
+
return output
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def gather_backward(ctx, grad_output):
|
|
90
|
+
sparsity_layout_x, i, sparsity_layout_i = ctx.saved_tensors
|
|
91
|
+
dim = ctx.dim
|
|
92
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
93
|
+
|
|
94
|
+
return scatter_reduce(grad_output, sparsity_layout_i,
|
|
95
|
+
dim, i,
|
|
96
|
+
sparsity_layout_x, sparsity_block_size,
|
|
97
|
+
reduce_op="sum"), None, None, None, None, None, None, None
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@triton.autotune(
|
|
101
|
+
configs=get_autotune_configs(),
|
|
102
|
+
key=[],
|
|
103
|
+
)
|
|
104
|
+
@triton.jit
|
|
105
|
+
def gather_kernel(x,
|
|
106
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
107
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
108
|
+
r_lut_x,
|
|
109
|
+
dim,
|
|
110
|
+
i,
|
|
111
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
112
|
+
o,
|
|
113
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
114
|
+
s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
115
|
+
sparsity_block_size,
|
|
116
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
117
|
+
# Get triton block indices
|
|
118
|
+
pid_blk = tl.program_id(axis=0)
|
|
119
|
+
pid_row = tl.program_id(axis=1)
|
|
120
|
+
pid_col = tl.program_id(axis=2)
|
|
121
|
+
|
|
122
|
+
# Get valid triton block size
|
|
123
|
+
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
124
|
+
|
|
125
|
+
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
126
|
+
spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
|
|
127
|
+
spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
128
|
+
spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
|
|
129
|
+
|
|
130
|
+
spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
|
|
131
|
+
spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
132
|
+
spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
|
|
133
|
+
|
|
134
|
+
spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
|
|
135
|
+
spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
136
|
+
spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
|
|
137
|
+
|
|
138
|
+
# Load index values
|
|
139
|
+
blk_i_idx = ((pid_blk * i_b_s) +
|
|
140
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
141
|
+
((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
142
|
+
blk_i_msk = ((blk_i_idx >= 0 and
|
|
143
|
+
blk_i_idx < i_b * i_b_s) and
|
|
144
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
145
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
146
|
+
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
|
|
147
|
+
|
|
148
|
+
# Get indices of sparsity blocks and positions within the blocks
|
|
149
|
+
pos_spa_blk_x = blk_i // sparsity_block_size
|
|
150
|
+
pos_spa_int_x = blk_i % sparsity_block_size
|
|
151
|
+
|
|
152
|
+
rev_dst_bat_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_o, dtype=tl.int32)
|
|
153
|
+
rev_dst_row_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_o, dtype=tl.int32)
|
|
154
|
+
rev_dst_col_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_o, dtype=tl.int32)
|
|
155
|
+
dst_row_x = (((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
|
|
156
|
+
.broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
|
|
157
|
+
dst_col_x = (((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
|
|
158
|
+
.broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
|
|
159
|
+
if dim == 0:
|
|
160
|
+
rev_dst_bat_x = blk_i
|
|
161
|
+
elif dim == 1:
|
|
162
|
+
rev_dst_row_x = pos_spa_blk_x
|
|
163
|
+
dst_row_x = pos_spa_int_x * x_r_s
|
|
164
|
+
elif dim == 2:
|
|
165
|
+
rev_dst_col_x = pos_spa_blk_x
|
|
166
|
+
dst_col_x = pos_spa_int_x * x_c_s
|
|
167
|
+
|
|
168
|
+
# Load reverse sparsity indices for x
|
|
169
|
+
rev_idx_spa_x_idx = ((rev_dst_bat_x * s_l_x_b_s) +
|
|
170
|
+
(rev_dst_row_x * s_l_x_r_s) +
|
|
171
|
+
(rev_dst_col_x * s_l_x_c_s))
|
|
172
|
+
rev_idx_spa_x_msk = ((rev_idx_spa_x_idx >= 0 and
|
|
173
|
+
rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s) and
|
|
174
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
175
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
176
|
+
rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
|
|
177
|
+
|
|
178
|
+
# Load x values
|
|
179
|
+
blk_x_idx = ((rev_idx_spa_x * x_b_s) +
|
|
180
|
+
dst_row_x +
|
|
181
|
+
dst_col_x)
|
|
182
|
+
blk_x_msk = (((blk_x_idx >= 0 and
|
|
183
|
+
blk_x_idx < x_b * x_b_s) and
|
|
184
|
+
rev_idx_spa_x_msk != -1) and
|
|
185
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
186
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
187
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
188
|
+
|
|
189
|
+
# Store output
|
|
190
|
+
blk_o_idx = ((pid_blk * o_b_s) +
|
|
191
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
192
|
+
((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
193
|
+
blk_o_msk = (((blk_o_idx >= 0 and
|
|
194
|
+
blk_o_idx < o_b * o_b_s) and
|
|
195
|
+
rev_idx_spa_x_msk != -1) and
|
|
196
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
197
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
198
|
+
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def gather_build_lut(lut: dict, sparsity_layout_src: Tensor, sparsity_layout_idx: Tensor):
|
|
202
|
+
if lut is None:
|
|
203
|
+
lut = dict()
|
|
204
|
+
|
|
205
|
+
if "sparsity_reverse_lut_x" not in lut:
|
|
206
|
+
sparsity_layout_x_flat = sparsity_layout_src.reshape(-1)
|
|
207
|
+
sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
|
|
208
|
+
(sparsity_layout_x_flat == 1) -
|
|
209
|
+
(1 * (sparsity_layout_x_flat == 0)))
|
|
210
|
+
lut["sparsity_reverse_lut_x"] = sparsity_reverse_lut_x
|
|
211
|
+
|
|
212
|
+
if "sparsity_lut_i" not in lut:
|
|
213
|
+
sparsity_lut_i = torch.nonzero(sparsity_layout_idx).contiguous()
|
|
214
|
+
lut["sparsity_lut_i"] = sparsity_lut_i
|
|
215
|
+
|
|
216
|
+
validate_contiguous(sparsity_layout_src, lut["sparsity_reverse_lut_x"],
|
|
217
|
+
sparsity_layout_idx, lut["sparsity_lut_i"])
|
|
218
|
+
|
|
219
|
+
return lut
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
# noinspection PyUnusedLocal
|
|
223
|
+
def gather_setup_context(ctx, inputs, output):
|
|
224
|
+
(_, sparsity_layout_x, _, dim, i, sparsity_layout_i, _, sparsity_block_size) = inputs
|
|
225
|
+
|
|
226
|
+
ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_i)
|
|
227
|
+
ctx.dim = dim
|
|
228
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
gather_forward.register_autograd(gather_backward, setup_context=gather_setup_context)
|
|
213
232
|
|
|
214
233
|
|
|
215
234
|
def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
216
235
|
dim: int,
|
|
217
236
|
idx: BlksprsTensor,
|
|
218
237
|
sparsity_layout_tgt: Tensor,
|
|
219
|
-
sparsity_block_size: int,
|
|
238
|
+
sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
|
|
220
239
|
"""Wrapper for ``scatter_reduce`` with ``reduce_op="none"``.
|
|
221
240
|
|
|
222
241
|
"""
|
|
@@ -225,7 +244,7 @@ def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
225
244
|
idx,
|
|
226
245
|
sparsity_layout_tgt,
|
|
227
246
|
sparsity_block_size,
|
|
228
|
-
reduce_op="none",
|
|
247
|
+
reduce_op="none", lut=lut)
|
|
229
248
|
|
|
230
249
|
|
|
231
250
|
def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
@@ -233,7 +252,7 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
233
252
|
idx: BlksprsTensor,
|
|
234
253
|
sparsity_layout_tgt: Tensor,
|
|
235
254
|
sparsity_block_size: int,
|
|
236
|
-
reduce_op: str = "sum",
|
|
255
|
+
reduce_op: str = "sum", lut: dict = None) -> BlksprsTensor:
|
|
237
256
|
"""Applies a scatter operation on a block-sparse tensor in compressed form.
|
|
238
257
|
|
|
239
258
|
Args:
|
|
@@ -245,7 +264,6 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
245
264
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
246
265
|
reduce_op (str, optional): The reduction operation to apply during the scatter operation (default ``"sum"``).
|
|
247
266
|
Supported operations are ``"none"`` and ``"sum"``.
|
|
248
|
-
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
249
267
|
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
250
268
|
|
|
251
269
|
Returns:
|
|
@@ -261,198 +279,218 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
261
279
|
validate_device(src, idx)
|
|
262
280
|
validate_sparsity(sparsity_block_size, (src, sparsity_layout_src), (idx, sparsity_layout_src))
|
|
263
281
|
validate_sparsity_block_size(sparsity_block_size, src, idx)
|
|
264
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
265
282
|
|
|
266
283
|
if reduce_op not in ["none", "sum"]:
|
|
267
284
|
raise ValueError(f"Reduction operation '{reduce_op}' is not supported")
|
|
268
285
|
|
|
269
286
|
adjusted_dim = dim % 3
|
|
270
287
|
|
|
271
|
-
lut =
|
|
272
|
-
|
|
273
|
-
return BlksprsTensor(
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
288
|
+
lut = scatter_reduce_build_lut(lut, sparsity_layout_src, sparsity_layout_tgt)
|
|
289
|
+
|
|
290
|
+
return BlksprsTensor(scatter_reduce_forward(src, sparsity_layout_src, lut["sparsity_lut_x"],
|
|
291
|
+
adjusted_dim, idx,
|
|
292
|
+
sparsity_layout_tgt, lut["sparsity_reverse_lut_o"],
|
|
293
|
+
sparsity_block_size, lut["n_sparse_blocks"],
|
|
294
|
+
reduce_op))
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
@triton_op("blksprs::scatter_reduce", mutates_args={})
|
|
298
|
+
def scatter_reduce_forward(x: Tensor, _: Tensor, sparsity_lut_x: Tensor,
|
|
299
|
+
dim: int, i: Tensor,
|
|
300
|
+
sparsity_layout_o: Tensor, sparsity_reverse_lut_o: Tensor,
|
|
301
|
+
sparsity_block_size: int, n_sparse_blocks: int,
|
|
302
|
+
reduce_op: str) -> Tensor:
|
|
303
|
+
output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
304
|
+
dtype=x.dtype, device=x.device)
|
|
305
|
+
|
|
306
|
+
x_b, x_r, x_c = x.size()
|
|
307
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
308
|
+
s_lut_x_r, s_lut_x_c = sparsity_lut_x.size()
|
|
309
|
+
s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut_x)
|
|
310
|
+
i_b, i_r, i_c = i.size()
|
|
311
|
+
i_b_s, i_r_s, i_c_s = stride(i)
|
|
312
|
+
o_b, o_r, o_c = output.size()
|
|
313
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
314
|
+
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
|
|
315
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
|
|
316
|
+
|
|
317
|
+
triton_grid = lambda meta: [x_b,
|
|
318
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
319
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
320
|
+
|
|
321
|
+
reduce_op_ind = 0
|
|
322
|
+
if reduce_op == "sum":
|
|
323
|
+
reduce_op_ind = 1
|
|
324
|
+
|
|
325
|
+
(wrap_triton(scatter_reduce_kernel)[triton_grid]
|
|
326
|
+
(x,
|
|
327
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
328
|
+
sparsity_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
329
|
+
dim,
|
|
330
|
+
i,
|
|
331
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
332
|
+
output,
|
|
333
|
+
o_b, o_b_s,
|
|
334
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
335
|
+
sparsity_reverse_lut_o,
|
|
336
|
+
reduce_op_ind,
|
|
337
|
+
sparsity_block_size))
|
|
338
|
+
|
|
339
|
+
return output
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def scatter_reduce_backward(ctx, grad_output):
|
|
343
|
+
sparsity_layout_x, i, sparsity_layout_o = ctx.saved_tensors
|
|
344
|
+
dim = ctx.dim
|
|
345
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
346
|
+
reduce_op = ctx.reduce_op
|
|
347
|
+
|
|
348
|
+
if reduce_op == "sum":
|
|
349
|
+
return gather(grad_output, sparsity_layout_o, dim, i, sparsity_layout_x,
|
|
350
|
+
sparsity_block_size), None, None, None, None, None, None, None, None, None
|
|
351
|
+
else:
|
|
352
|
+
raise ValueError(f"Reduction operation '{reduce_op}' does not support backward pass")
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
@triton.autotune(
|
|
356
|
+
configs=get_autotune_configs(),
|
|
357
|
+
key=[],
|
|
358
|
+
reset_to_zero=["o"]
|
|
359
|
+
)
|
|
360
|
+
@triton.jit
|
|
361
|
+
def scatter_reduce_kernel(x,
|
|
362
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
363
|
+
s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
364
|
+
dim,
|
|
365
|
+
i,
|
|
366
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
367
|
+
o,
|
|
368
|
+
o_b, o_b_s,
|
|
369
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
370
|
+
r_lut_o,
|
|
371
|
+
reduce_op_ind,
|
|
372
|
+
sparsity_block_size,
|
|
373
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
374
|
+
# Get triton block indices
|
|
375
|
+
pid_blk = tl.program_id(axis=0)
|
|
376
|
+
pid_row = tl.program_id(axis=1)
|
|
377
|
+
pid_col = tl.program_id(axis=2)
|
|
378
|
+
|
|
379
|
+
# Get valid triton block size
|
|
380
|
+
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
381
|
+
|
|
382
|
+
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
383
|
+
spa_bat_x_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
|
|
384
|
+
spa_bat_x_msk = (spa_bat_x_idx >= 0 and spa_bat_x_idx < s_lut_x_r * s_lut_x_r_s)
|
|
385
|
+
spa_bat_x = tl.load(s_lut_x + spa_bat_x_idx, mask=spa_bat_x_msk)
|
|
386
|
+
|
|
387
|
+
spa_row_x_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
|
|
388
|
+
spa_row_x_msk = (spa_row_x_idx >= 0 and spa_row_x_idx < s_lut_x_r * s_lut_x_r_s)
|
|
389
|
+
spa_row_x = tl.load(s_lut_x + spa_row_x_idx, mask=spa_row_x_msk)
|
|
390
|
+
|
|
391
|
+
spa_col_x_idx = (pid_blk * s_lut_x_r_s + 2 * s_lut_x_c_s)
|
|
392
|
+
spa_col_x_msk = (spa_col_x_idx >= 0 and spa_col_x_idx < s_lut_x_r * s_lut_x_r_s)
|
|
393
|
+
spa_col_x = tl.load(s_lut_x + spa_col_x_idx, mask=spa_col_x_msk)
|
|
394
|
+
|
|
395
|
+
# Load x values
|
|
396
|
+
blk_x_idx = ((pid_blk * x_b_s) +
|
|
397
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
398
|
+
((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
399
|
+
blk_x_msk = ((blk_x_idx >= 0 and
|
|
400
|
+
blk_x_idx < x_b * x_b_s) and
|
|
401
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
402
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
403
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
404
|
+
|
|
405
|
+
# Load index values
|
|
406
|
+
blk_i_idx = ((pid_blk * i_b_s) +
|
|
407
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
408
|
+
((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
409
|
+
blk_i_msk = ((blk_i_idx >= 0 and
|
|
410
|
+
blk_i_idx < i_b * i_b_s) and
|
|
411
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
412
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
413
|
+
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
|
|
414
|
+
|
|
415
|
+
# Get indices of sparsity blocks and positions within the blocks
|
|
416
|
+
pos_spa_blk_x = blk_i // sparsity_block_size
|
|
417
|
+
pos_spa_int_x = blk_i % sparsity_block_size
|
|
418
|
+
|
|
419
|
+
rev_dst_bat_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_x, dtype=tl.int32)
|
|
420
|
+
rev_dst_row_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_x, dtype=tl.int32)
|
|
421
|
+
rev_dst_col_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_x, dtype=tl.int32)
|
|
422
|
+
dst_row_o = (((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
|
|
423
|
+
.broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
|
|
424
|
+
dst_col_o = (((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
|
|
425
|
+
.broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
|
|
426
|
+
if dim == 0:
|
|
427
|
+
rev_dst_bat_o = blk_i
|
|
428
|
+
elif dim == 1:
|
|
429
|
+
rev_dst_row_o = pos_spa_blk_x
|
|
430
|
+
dst_row_o = pos_spa_int_x * x_r_s
|
|
431
|
+
elif dim == 2:
|
|
432
|
+
rev_dst_col_o = pos_spa_blk_x
|
|
433
|
+
dst_col_o = pos_spa_int_x * x_c_s
|
|
434
|
+
|
|
435
|
+
# Load reverse sparsity indices for o
|
|
436
|
+
rev_idx_spa_o_idx = ((rev_dst_bat_o * s_l_o_b_s) +
|
|
437
|
+
(rev_dst_row_o * s_l_o_r_s) +
|
|
438
|
+
(rev_dst_col_o * s_l_o_c_s))
|
|
439
|
+
rev_idx_spa_o_msk = ((rev_idx_spa_o_idx >= 0 and
|
|
440
|
+
rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s) and
|
|
441
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
442
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
443
|
+
rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
|
|
444
|
+
|
|
445
|
+
# Store output
|
|
446
|
+
blk_o_idx = ((rev_idx_spa_o * o_b_s) +
|
|
447
|
+
dst_row_o +
|
|
448
|
+
dst_col_o)
|
|
449
|
+
blk_o_msk = (((blk_o_idx >= 0 and
|
|
450
|
+
blk_o_idx < o_b * o_b_s) and
|
|
451
|
+
rev_idx_spa_o_msk != -1) and
|
|
452
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
453
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
454
|
+
|
|
455
|
+
if reduce_op_ind == 0:
|
|
456
|
+
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
457
|
+
elif reduce_op_ind == 1:
|
|
458
|
+
tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
def scatter_reduce_build_lut(lut: dict, sparsity_layout_src: Tensor, sparsity_layout_tgt: Tensor):
|
|
462
|
+
if lut is None:
|
|
463
|
+
lut = dict()
|
|
464
|
+
|
|
465
|
+
if "sparsity_lut_x" not in lut:
|
|
466
|
+
sparsity_lut_x = torch.nonzero(sparsity_layout_src).contiguous()
|
|
467
|
+
lut["sparsity_lut_x"] = sparsity_lut_x
|
|
468
|
+
|
|
469
|
+
if "sparsity_reverse_lut_o" not in lut:
|
|
470
|
+
sparsity_layout_o_flat = sparsity_layout_tgt.reshape(-1)
|
|
471
|
+
sparsity_reverse_lut_o = ((torch.cumsum(sparsity_layout_o_flat, dim=-1) - 1) *
|
|
472
|
+
(sparsity_layout_o_flat == 1) -
|
|
473
|
+
(1 * (sparsity_layout_o_flat == 0)))
|
|
474
|
+
lut["sparsity_reverse_lut_o"] = sparsity_reverse_lut_o
|
|
475
|
+
|
|
476
|
+
if "n_sparse_blocks" not in lut:
|
|
477
|
+
n_sparse_blocks = torch.sum(sparsity_layout_tgt.to(torch.int)).item()
|
|
478
|
+
lut["n_sparse_blocks"] = n_sparse_blocks
|
|
479
|
+
|
|
480
|
+
validate_contiguous(sparsity_layout_src, lut["sparsity_lut_x"],
|
|
481
|
+
sparsity_layout_tgt, lut["sparsity_reverse_lut_o"])
|
|
482
|
+
|
|
483
|
+
return lut
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
# noinspection PyUnusedLocal
|
|
487
|
+
def scatter_reduce_setup_context(ctx, inputs, output):
|
|
488
|
+
(_, sparsity_layout_x, _, dim, i, sparsity_layout_o, _, sparsity_block_size, _, reduce_op) = inputs
|
|
489
|
+
|
|
490
|
+
ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_o)
|
|
491
|
+
ctx.dim = dim
|
|
492
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
493
|
+
ctx.reduce_op = reduce_op
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
scatter_reduce_forward.register_autograd(scatter_reduce_backward, setup_context=scatter_reduce_setup_context)
|