blksprs 1.11__py3-none-any.whl → 2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +4 -5
- blksprs/layouting/distribution_layout.py +64 -48
- blksprs/layouting/sparsity_layout.py +96 -72
- blksprs/ops/conversion.py +349 -338
- blksprs/ops/distribution.py +318 -294
- blksprs/ops/flow.py +113 -100
- blksprs/ops/matmul.py +187 -172
- blksprs/ops/misc/broadcast_ops.py +68 -53
- blksprs/ops/misc/row_wise.py +223 -176
- blksprs/ops/partitioning.py +140 -132
- blksprs/ops/repeat.py +118 -120
- blksprs/ops/softmax.py +240 -214
- blksprs/ops/transpose.py +55 -52
- blksprs/utils/autotuning.py +78 -0
- blksprs/utils/benchmarking.py +3 -3
- blksprs/utils/processing.py +2 -1
- blksprs/utils/tools.py +5 -6
- blksprs/utils/validation.py +22 -16
- {blksprs-1.11.dist-info → blksprs-2.0.dist-info}/METADATA +55 -36
- blksprs-2.0.dist-info/RECORD +23 -0
- {blksprs-1.11.dist-info → blksprs-2.0.dist-info}/WHEEL +1 -1
- blksprs/utils/layout_utils.py +0 -17
- blksprs-1.11.dist-info/RECORD +0 -23
- {blksprs-1.11.dist-info → blksprs-2.0.dist-info}/top_level.txt +0 -0
blksprs/ops/distribution.py
CHANGED
|
@@ -1,19 +1,22 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import triton
|
|
3
3
|
from torch import Tensor
|
|
4
|
+
from torch._library import triton_op
|
|
5
|
+
from torch._library.triton import wrap_triton
|
|
4
6
|
from triton import language as tl
|
|
5
7
|
|
|
6
|
-
from blksprs.ops.conversion import to_dense
|
|
7
8
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
8
|
-
from blksprs.utils.tools import
|
|
9
|
+
from blksprs.utils.tools import stride
|
|
10
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
9
11
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
10
|
-
validate_sparsity, validate_dtype_int, validate_sparsity_block_size
|
|
12
|
+
validate_sparsity, validate_dtype_int, validate_sparsity_block_size
|
|
11
13
|
|
|
12
14
|
|
|
15
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
13
16
|
def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
14
17
|
dim: int,
|
|
15
18
|
idx: BlksprsTensor, sparsity_layout_idx: Tensor,
|
|
16
|
-
sparsity_block_size: int,
|
|
19
|
+
sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
|
|
17
20
|
"""Applies a gather operation on a block-sparse tensor in compressed form.
|
|
18
21
|
|
|
19
22
|
Args:
|
|
@@ -23,7 +26,6 @@ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
23
26
|
idx (BlksprsTensor): The block-sparse indices tensor in compressed form specifying how to gather from the source tensor.
|
|
24
27
|
sparsity_layout_idx (Tensor): The sparsity layout of the indices block-sparse tensor.
|
|
25
28
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
26
|
-
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
27
29
|
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
28
30
|
|
|
29
31
|
Returns:
|
|
@@ -39,45 +41,22 @@ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
39
41
|
validate_device(src, idx)
|
|
40
42
|
validate_sparsity(sparsity_block_size, (src, sparsity_layout_src), (idx, sparsity_layout_idx))
|
|
41
43
|
validate_sparsity_block_size(sparsity_block_size, src, idx)
|
|
42
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
43
44
|
|
|
44
45
|
adjusted_dim = dim % 3
|
|
45
46
|
|
|
46
|
-
lut =
|
|
47
|
+
lut = gather_build_lut(lut, sparsity_layout_src, sparsity_layout_idx)
|
|
47
48
|
|
|
48
|
-
return BlksprsTensor(
|
|
49
|
-
|
|
50
|
-
|
|
49
|
+
return BlksprsTensor(gather_forward(src, sparsity_layout_src, lut["sparsity_reverse_lut_x"],
|
|
50
|
+
adjusted_dim, idx, sparsity_layout_idx, lut["sparsity_lut_i"],
|
|
51
|
+
sparsity_block_size))
|
|
51
52
|
|
|
52
53
|
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
if "sparsity_reverse_lut_x" not in lut:
|
|
61
|
-
sparsity_layout_x_flat = sparsity_layout_src.reshape(-1)
|
|
62
|
-
sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
|
|
63
|
-
(sparsity_layout_x_flat == 1) -
|
|
64
|
-
(1 * (sparsity_layout_x_flat == 0)))
|
|
65
|
-
lut["sparsity_reverse_lut_x"] = sparsity_reverse_lut_x
|
|
66
|
-
|
|
67
|
-
if "sparsity_lut_i" not in lut:
|
|
68
|
-
sparsity_lut_i = torch.nonzero(sparsity_layout_idx).contiguous()
|
|
69
|
-
lut["sparsity_lut_i"] = sparsity_lut_i
|
|
70
|
-
|
|
71
|
-
validate_contiguous(sparsity_layout_src, lut["sparsity_reverse_lut_x"],
|
|
72
|
-
sparsity_layout_idx, lut["sparsity_lut_i"])
|
|
73
|
-
|
|
74
|
-
return lut
|
|
75
|
-
|
|
76
|
-
@staticmethod
|
|
77
|
-
def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
|
|
78
|
-
dim: int, i: Tensor, sparsity_layout_i: Tensor, sparsity_lut_i: Tensor,
|
|
79
|
-
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
80
|
-
output = torch.empty_like(i, dtype=x.dtype)
|
|
54
|
+
@triton_op("blksprs::gather_forward", mutates_args={})
|
|
55
|
+
def gather_forward(x: Tensor, sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
|
|
56
|
+
dim: int, i: Tensor, _: Tensor, sparsity_lut_i: Tensor,
|
|
57
|
+
sparsity_block_size: int) -> Tensor:
|
|
58
|
+
with torch.no_grad():
|
|
59
|
+
output = torch.zeros_like(i, dtype=x.dtype)
|
|
81
60
|
|
|
82
61
|
x_b, x_r, x_c = x.size()
|
|
83
62
|
x_b_s, x_r_s, x_c_s = stride(x)
|
|
@@ -90,14 +69,11 @@ class _BlocksparseGather(torch.autograd.Function):
|
|
|
90
69
|
o_b, o_r, o_c = output.size()
|
|
91
70
|
o_b_s, o_r_s, o_c_s = stride(output)
|
|
92
71
|
|
|
93
|
-
if triton_block_size is None:
|
|
94
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
95
|
-
|
|
96
72
|
triton_grid = lambda meta: [o_b,
|
|
97
73
|
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
98
74
|
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
99
75
|
|
|
100
|
-
(
|
|
76
|
+
(wrap_triton(gather_kernel)[triton_grid]
|
|
101
77
|
(x,
|
|
102
78
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
103
79
|
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
@@ -108,115 +84,152 @@ class _BlocksparseGather(torch.autograd.Function):
|
|
|
108
84
|
output,
|
|
109
85
|
o_b, o_b_s, o_r_s, o_c_s,
|
|
110
86
|
sparsity_lut_i, s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
|
|
111
|
-
sparsity_block_size
|
|
112
|
-
triton_block_size))
|
|
113
|
-
|
|
114
|
-
ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_i)
|
|
115
|
-
ctx.dim = dim
|
|
116
|
-
ctx.sparsity_block_size = sparsity_block_size
|
|
117
|
-
ctx.triton_block_size = triton_block_size
|
|
87
|
+
sparsity_block_size))
|
|
118
88
|
|
|
119
89
|
return output
|
|
120
90
|
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
91
|
+
|
|
92
|
+
def gather_wrapper_backward(ctx, grad_output):
|
|
93
|
+
sparsity_layout_x, i, sparsity_layout_i = ctx.saved_tensors
|
|
94
|
+
dim = ctx.dim
|
|
95
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
96
|
+
|
|
97
|
+
return scatter_reduce(grad_output, sparsity_layout_i,
|
|
98
|
+
dim, i,
|
|
99
|
+
sparsity_layout_x, sparsity_block_size,
|
|
100
|
+
reduce_op="sum"), None, None, None, None, None, None, None
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@triton.autotune(
|
|
104
|
+
configs=get_autotune_configs(),
|
|
105
|
+
key=["sparsity_block_size"],
|
|
106
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
107
|
+
reset_to_zero=["o"]
|
|
108
|
+
)
|
|
109
|
+
@triton.jit
|
|
110
|
+
def gather_kernel(x,
|
|
111
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
112
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
113
|
+
r_lut_x,
|
|
114
|
+
dim,
|
|
115
|
+
i,
|
|
116
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
117
|
+
o,
|
|
118
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
119
|
+
s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
120
|
+
sparsity_block_size,
|
|
121
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
122
|
+
# Get triton block indices
|
|
123
|
+
pid_blk = tl.program_id(axis=0)
|
|
124
|
+
pid_row = tl.program_id(axis=1)
|
|
125
|
+
pid_col = tl.program_id(axis=2)
|
|
126
|
+
|
|
127
|
+
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
128
|
+
spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
|
|
129
|
+
spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
130
|
+
spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
|
|
131
|
+
|
|
132
|
+
spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
|
|
133
|
+
spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
134
|
+
spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
|
|
135
|
+
|
|
136
|
+
spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
|
|
137
|
+
spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
138
|
+
spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
|
|
139
|
+
|
|
140
|
+
# Load index values
|
|
141
|
+
blk_i_idx = ((pid_blk * i_b_s) +
|
|
142
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
143
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
144
|
+
blk_i_msk = (blk_i_idx >= 0 and
|
|
145
|
+
blk_i_idx < i_b * i_b_s)
|
|
146
|
+
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
|
|
147
|
+
|
|
148
|
+
# Get indices of sparsity blocks and positions within the blocks
|
|
149
|
+
pos_spa_blk_x = blk_i // sparsity_block_size
|
|
150
|
+
pos_spa_int_x = blk_i % sparsity_block_size
|
|
151
|
+
|
|
152
|
+
rev_dst_bat_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_o, dtype=tl.int32)
|
|
153
|
+
rev_dst_row_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_o, dtype=tl.int32)
|
|
154
|
+
rev_dst_col_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_o, dtype=tl.int32)
|
|
155
|
+
dst_row_x = (((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
|
|
156
|
+
.broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
|
|
157
|
+
dst_col_x = (((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
|
|
158
|
+
.broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
|
|
159
|
+
if dim == 0:
|
|
160
|
+
rev_dst_bat_x = blk_i
|
|
161
|
+
elif dim == 1:
|
|
162
|
+
rev_dst_row_x = pos_spa_blk_x
|
|
163
|
+
dst_row_x = pos_spa_int_x * x_r_s
|
|
164
|
+
elif dim == 2:
|
|
165
|
+
rev_dst_col_x = pos_spa_blk_x
|
|
166
|
+
dst_col_x = pos_spa_int_x * x_c_s
|
|
167
|
+
|
|
168
|
+
# Load reverse sparsity indices for x
|
|
169
|
+
rev_idx_spa_x_idx = ((rev_dst_bat_x * s_l_x_b_s) +
|
|
170
|
+
(rev_dst_row_x * s_l_x_r_s) +
|
|
171
|
+
(rev_dst_col_x * s_l_x_c_s))
|
|
172
|
+
rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and
|
|
173
|
+
rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
|
|
174
|
+
rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
|
|
175
|
+
|
|
176
|
+
# Load x values
|
|
177
|
+
blk_x_idx = ((rev_idx_spa_x * x_b_s) +
|
|
178
|
+
dst_row_x +
|
|
179
|
+
dst_col_x)
|
|
180
|
+
blk_x_msk = ((blk_x_idx >= 0 and
|
|
181
|
+
blk_x_idx < x_b * x_b_s) and
|
|
182
|
+
rev_idx_spa_x_msk != -1)
|
|
183
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
184
|
+
|
|
185
|
+
# Store output
|
|
186
|
+
blk_o_idx = ((pid_blk * o_b_s) +
|
|
187
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
188
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
189
|
+
blk_o_msk = ((blk_o_idx >= 0 and
|
|
190
|
+
blk_o_idx < o_b * o_b_s) and
|
|
191
|
+
rev_idx_spa_x_msk != -1)
|
|
192
|
+
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def gather_build_lut(lut: dict, sparsity_layout_src: Tensor, sparsity_layout_idx: Tensor):
|
|
196
|
+
if lut is None:
|
|
197
|
+
lut = dict()
|
|
198
|
+
|
|
199
|
+
if "sparsity_reverse_lut_x" not in lut:
|
|
200
|
+
sparsity_layout_x_flat = sparsity_layout_src.reshape(-1)
|
|
201
|
+
sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
|
|
202
|
+
(sparsity_layout_x_flat == 1) -
|
|
203
|
+
(1 * (sparsity_layout_x_flat == 0)))
|
|
204
|
+
lut["sparsity_reverse_lut_x"] = sparsity_reverse_lut_x
|
|
205
|
+
|
|
206
|
+
if "sparsity_lut_i" not in lut:
|
|
207
|
+
sparsity_lut_i = torch.nonzero(sparsity_layout_idx).contiguous()
|
|
208
|
+
lut["sparsity_lut_i"] = sparsity_lut_i
|
|
209
|
+
|
|
210
|
+
validate_contiguous(sparsity_layout_src, lut["sparsity_reverse_lut_x"],
|
|
211
|
+
sparsity_layout_idx, lut["sparsity_lut_i"])
|
|
212
|
+
|
|
213
|
+
return lut
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
# noinspection PyUnusedLocal
|
|
217
|
+
def gather_setup_context(ctx, inputs, output):
|
|
218
|
+
(_, sparsity_layout_x, _, dim, i, sparsity_layout_i, _, sparsity_block_size) = inputs
|
|
219
|
+
|
|
220
|
+
ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_i)
|
|
221
|
+
ctx.dim = dim
|
|
222
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
gather_forward.register_autograd(gather_wrapper_backward, setup_context=gather_setup_context)
|
|
213
226
|
|
|
214
227
|
|
|
215
228
|
def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
216
229
|
dim: int,
|
|
217
230
|
idx: BlksprsTensor,
|
|
218
231
|
sparsity_layout_tgt: Tensor,
|
|
219
|
-
sparsity_block_size: int,
|
|
232
|
+
sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
|
|
220
233
|
"""Wrapper for ``scatter_reduce`` with ``reduce_op="none"``.
|
|
221
234
|
|
|
222
235
|
"""
|
|
@@ -225,15 +238,16 @@ def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
225
238
|
idx,
|
|
226
239
|
sparsity_layout_tgt,
|
|
227
240
|
sparsity_block_size,
|
|
228
|
-
reduce_op="none",
|
|
241
|
+
reduce_op="none", lut=lut)
|
|
229
242
|
|
|
230
243
|
|
|
244
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
|
|
231
245
|
def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
232
246
|
dim: int,
|
|
233
247
|
idx: BlksprsTensor,
|
|
234
248
|
sparsity_layout_tgt: Tensor,
|
|
235
249
|
sparsity_block_size: int,
|
|
236
|
-
reduce_op: str = "sum",
|
|
250
|
+
reduce_op: str = "sum", lut: dict = None) -> BlksprsTensor:
|
|
237
251
|
"""Applies a scatter operation on a block-sparse tensor in compressed form.
|
|
238
252
|
|
|
239
253
|
Args:
|
|
@@ -245,7 +259,6 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
245
259
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
246
260
|
reduce_op (str, optional): The reduction operation to apply during the scatter operation (default ``"sum"``).
|
|
247
261
|
Supported operations are ``"none"`` and ``"sum"``.
|
|
248
|
-
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
249
262
|
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
250
263
|
|
|
251
264
|
Returns:
|
|
@@ -261,55 +274,28 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
261
274
|
validate_device(src, idx)
|
|
262
275
|
validate_sparsity(sparsity_block_size, (src, sparsity_layout_src), (idx, sparsity_layout_src))
|
|
263
276
|
validate_sparsity_block_size(sparsity_block_size, src, idx)
|
|
264
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
265
277
|
|
|
266
278
|
if reduce_op not in ["none", "sum"]:
|
|
267
279
|
raise ValueError(f"Reduction operation '{reduce_op}' is not supported")
|
|
268
280
|
|
|
269
281
|
adjusted_dim = dim % 3
|
|
270
282
|
|
|
271
|
-
lut =
|
|
272
|
-
|
|
273
|
-
return BlksprsTensor(_BlocksparseScatterReduce.apply(src, sparsity_layout_src, lut["sparsity_lut_x"],
|
|
274
|
-
adjusted_dim, idx,
|
|
275
|
-
sparsity_layout_tgt, lut["sparsity_reverse_lut_o"],
|
|
276
|
-
sparsity_block_size, lut["n_sparse_blocks"],
|
|
277
|
-
reduce_op, triton_block_size))
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
class _BlocksparseScatterReduce(torch.autograd.Function):
|
|
281
|
-
|
|
282
|
-
@staticmethod
|
|
283
|
-
def build_lut(lut: dict, sparsity_layout_src: Tensor, sparsity_layout_tgt: Tensor):
|
|
284
|
-
if lut is None:
|
|
285
|
-
lut = dict()
|
|
283
|
+
lut = scatter_reduce_build_lut(lut, sparsity_layout_src, sparsity_layout_tgt)
|
|
286
284
|
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
285
|
+
return BlksprsTensor(scatter_reduce_forward(src, sparsity_layout_src, lut["sparsity_lut_x"],
|
|
286
|
+
adjusted_dim, idx,
|
|
287
|
+
sparsity_layout_tgt, lut["sparsity_reverse_lut_o"],
|
|
288
|
+
sparsity_block_size, lut["n_sparse_blocks"],
|
|
289
|
+
reduce_op))
|
|
290
290
|
|
|
291
|
-
if "sparsity_reverse_lut_o" not in lut:
|
|
292
|
-
sparsity_layout_o_flat = sparsity_layout_tgt.reshape(-1)
|
|
293
|
-
sparsity_reverse_lut_o = ((torch.cumsum(sparsity_layout_o_flat, dim=-1) - 1) *
|
|
294
|
-
(sparsity_layout_o_flat == 1) -
|
|
295
|
-
(1 * (sparsity_layout_o_flat == 0)))
|
|
296
|
-
lut["sparsity_reverse_lut_o"] = sparsity_reverse_lut_o
|
|
297
291
|
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
return lut
|
|
306
|
-
|
|
307
|
-
@staticmethod
|
|
308
|
-
def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_lut_x: Tensor,
|
|
309
|
-
dim: int, i: Tensor,
|
|
310
|
-
sparsity_layout_o: Tensor, sparsity_reverse_lut_o: Tensor,
|
|
311
|
-
sparsity_block_size: int, n_sparse_blocks: int,
|
|
312
|
-
reduce_op: str, triton_block_size: int) -> Tensor:
|
|
292
|
+
@triton_op("blksprs::scatter_reduce_forward", mutates_args={})
|
|
293
|
+
def scatter_reduce_forward(x: Tensor, _: Tensor, sparsity_lut_x: Tensor,
|
|
294
|
+
dim: int, i: Tensor,
|
|
295
|
+
sparsity_layout_o: Tensor, sparsity_reverse_lut_o: Tensor,
|
|
296
|
+
sparsity_block_size: int, n_sparse_blocks: int,
|
|
297
|
+
reduce_op: str) -> Tensor:
|
|
298
|
+
with torch.no_grad():
|
|
313
299
|
output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
314
300
|
dtype=x.dtype, device=x.device)
|
|
315
301
|
|
|
@@ -324,9 +310,6 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
|
|
|
324
310
|
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
|
|
325
311
|
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
|
|
326
312
|
|
|
327
|
-
if triton_block_size is None:
|
|
328
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
329
|
-
|
|
330
313
|
triton_grid = lambda meta: [x_b,
|
|
331
314
|
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
332
315
|
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
@@ -335,7 +318,7 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
|
|
|
335
318
|
if reduce_op == "sum":
|
|
336
319
|
reduce_op_ind = 1
|
|
337
320
|
|
|
338
|
-
(
|
|
321
|
+
(wrap_triton(scatter_reduce_kernel)[triton_grid]
|
|
339
322
|
(x,
|
|
340
323
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
341
324
|
sparsity_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
@@ -347,112 +330,153 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
|
|
|
347
330
|
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
348
331
|
sparsity_reverse_lut_o,
|
|
349
332
|
reduce_op_ind,
|
|
350
|
-
sparsity_block_size
|
|
351
|
-
triton_block_size))
|
|
352
|
-
|
|
353
|
-
ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_o)
|
|
354
|
-
ctx.dim = dim
|
|
355
|
-
ctx.sparsity_block_size = sparsity_block_size
|
|
356
|
-
ctx.reduce_op = reduce_op
|
|
357
|
-
ctx.triton_block_size = triton_block_size
|
|
333
|
+
sparsity_block_size))
|
|
358
334
|
|
|
359
335
|
return output
|
|
360
336
|
|
|
361
|
-
@staticmethod
|
|
362
|
-
def backward(ctx, grad_output):
|
|
363
|
-
sparsity_layout_x, i, sparsity_layout_o = ctx.saved_tensors
|
|
364
|
-
dim = ctx.dim
|
|
365
|
-
sparsity_block_size = ctx.sparsity_block_size
|
|
366
|
-
reduce_op = ctx.reduce_op
|
|
367
|
-
triton_block_size = ctx.triton_block_size
|
|
368
337
|
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
338
|
+
def scatter_reduce_wrapper_backward(ctx, grad_output):
|
|
339
|
+
sparsity_layout_x, i, sparsity_layout_o = ctx.saved_tensors
|
|
340
|
+
dim = ctx.dim
|
|
341
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
342
|
+
reduce_op = ctx.reduce_op
|
|
343
|
+
|
|
344
|
+
if reduce_op == "sum":
|
|
345
|
+
return gather(grad_output, sparsity_layout_o, dim, i, sparsity_layout_x,
|
|
346
|
+
sparsity_block_size), None, None, None, None, None, None, None, None, None
|
|
347
|
+
else:
|
|
348
|
+
raise ValueError(f"Reduction operation '{reduce_op}' does not support backward pass")
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
@triton.autotune(
|
|
352
|
+
configs=get_autotune_configs(),
|
|
353
|
+
key=["sparsity_block_size"],
|
|
354
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
355
|
+
reset_to_zero=["o"]
|
|
356
|
+
)
|
|
357
|
+
@triton.jit
|
|
358
|
+
def scatter_reduce_kernel(x,
|
|
359
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
360
|
+
s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
361
|
+
dim,
|
|
362
|
+
i,
|
|
363
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
364
|
+
o,
|
|
365
|
+
o_b, o_b_s,
|
|
366
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
367
|
+
r_lut_o,
|
|
368
|
+
reduce_op_ind,
|
|
369
|
+
sparsity_block_size,
|
|
370
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
371
|
+
# Get triton block indices
|
|
372
|
+
pid_blk = tl.program_id(axis=0)
|
|
373
|
+
pid_row = tl.program_id(axis=1)
|
|
374
|
+
pid_col = tl.program_id(axis=2)
|
|
375
|
+
|
|
376
|
+
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
377
|
+
spa_bat_x_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
|
|
378
|
+
spa_bat_x_msk = (spa_bat_x_idx >= 0 and spa_bat_x_idx < s_lut_x_r * s_lut_x_r_s)
|
|
379
|
+
spa_bat_x = tl.load(s_lut_x + spa_bat_x_idx, mask=spa_bat_x_msk)
|
|
380
|
+
|
|
381
|
+
spa_row_x_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
|
|
382
|
+
spa_row_x_msk = (spa_row_x_idx >= 0 and spa_row_x_idx < s_lut_x_r * s_lut_x_r_s)
|
|
383
|
+
spa_row_x = tl.load(s_lut_x + spa_row_x_idx, mask=spa_row_x_msk)
|
|
384
|
+
|
|
385
|
+
spa_col_x_idx = (pid_blk * s_lut_x_r_s + 2 * s_lut_x_c_s)
|
|
386
|
+
spa_col_x_msk = (spa_col_x_idx >= 0 and spa_col_x_idx < s_lut_x_r * s_lut_x_r_s)
|
|
387
|
+
spa_col_x = tl.load(s_lut_x + spa_col_x_idx, mask=spa_col_x_msk)
|
|
388
|
+
|
|
389
|
+
# Load x values
|
|
390
|
+
blk_x_idx = ((pid_blk * x_b_s) +
|
|
391
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
392
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
393
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
394
|
+
blk_x_idx < x_b * x_b_s)
|
|
395
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
396
|
+
|
|
397
|
+
# Load index values
|
|
398
|
+
blk_i_idx = ((pid_blk * i_b_s) +
|
|
399
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
400
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
401
|
+
blk_i_msk = (blk_i_idx >= 0 and
|
|
402
|
+
blk_i_idx < i_b * i_b_s)
|
|
403
|
+
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
|
|
404
|
+
|
|
405
|
+
# Get indices of sparsity blocks and positions within the blocks
|
|
406
|
+
pos_spa_blk_x = blk_i // sparsity_block_size
|
|
407
|
+
pos_spa_int_x = blk_i % sparsity_block_size
|
|
408
|
+
|
|
409
|
+
rev_dst_bat_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_x, dtype=tl.int32)
|
|
410
|
+
rev_dst_row_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_x, dtype=tl.int32)
|
|
411
|
+
rev_dst_col_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_x, dtype=tl.int32)
|
|
412
|
+
dst_row_o = (((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
|
|
413
|
+
.broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
|
|
414
|
+
dst_col_o = (((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
|
|
415
|
+
.broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
|
|
416
|
+
if dim == 0:
|
|
417
|
+
rev_dst_bat_o = blk_i
|
|
418
|
+
elif dim == 1:
|
|
419
|
+
rev_dst_row_o = pos_spa_blk_x
|
|
420
|
+
dst_row_o = pos_spa_int_x * x_r_s
|
|
421
|
+
elif dim == 2:
|
|
422
|
+
rev_dst_col_o = pos_spa_blk_x
|
|
423
|
+
dst_col_o = pos_spa_int_x * x_c_s
|
|
424
|
+
|
|
425
|
+
# Load reverse sparsity indices for o
|
|
426
|
+
rev_idx_spa_o_idx = ((rev_dst_bat_o * s_l_o_b_s) +
|
|
427
|
+
(rev_dst_row_o * s_l_o_r_s) +
|
|
428
|
+
(rev_dst_col_o * s_l_o_c_s))
|
|
429
|
+
rev_idx_spa_o_msk = (rev_idx_spa_o_idx >= 0 and
|
|
430
|
+
rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
|
|
431
|
+
rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
|
|
432
|
+
|
|
433
|
+
# Store output
|
|
434
|
+
blk_o_idx = ((rev_idx_spa_o * o_b_s) +
|
|
435
|
+
dst_row_o +
|
|
436
|
+
dst_col_o)
|
|
437
|
+
blk_o_msk = ((blk_o_idx >= 0 and
|
|
438
|
+
blk_o_idx < o_b * o_b_s) and
|
|
439
|
+
rev_idx_spa_o_msk != -1)
|
|
440
|
+
|
|
441
|
+
if reduce_op_ind == 0:
|
|
442
|
+
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
443
|
+
elif reduce_op_ind == 1:
|
|
444
|
+
tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
def scatter_reduce_build_lut(lut: dict, sparsity_layout_src: Tensor, sparsity_layout_tgt: Tensor):
|
|
448
|
+
if lut is None:
|
|
449
|
+
lut = dict()
|
|
450
|
+
|
|
451
|
+
if "sparsity_lut_x" not in lut:
|
|
452
|
+
sparsity_lut_x = torch.nonzero(sparsity_layout_src).contiguous()
|
|
453
|
+
lut["sparsity_lut_x"] = sparsity_lut_x
|
|
454
|
+
|
|
455
|
+
if "sparsity_reverse_lut_o" not in lut:
|
|
456
|
+
sparsity_layout_o_flat = sparsity_layout_tgt.reshape(-1)
|
|
457
|
+
sparsity_reverse_lut_o = ((torch.cumsum(sparsity_layout_o_flat, dim=-1) - 1) *
|
|
458
|
+
(sparsity_layout_o_flat == 1) -
|
|
459
|
+
(1 * (sparsity_layout_o_flat == 0)))
|
|
460
|
+
lut["sparsity_reverse_lut_o"] = sparsity_reverse_lut_o
|
|
461
|
+
|
|
462
|
+
if "n_sparse_blocks" not in lut:
|
|
463
|
+
n_sparse_blocks = torch.sum(sparsity_layout_tgt.to(torch.int)).item()
|
|
464
|
+
lut["n_sparse_blocks"] = n_sparse_blocks
|
|
465
|
+
|
|
466
|
+
validate_contiguous(sparsity_layout_src, lut["sparsity_lut_x"],
|
|
467
|
+
sparsity_layout_tgt, lut["sparsity_reverse_lut_o"])
|
|
468
|
+
|
|
469
|
+
return lut
|
|
470
|
+
|
|
471
|
+
|
|
472
|
+
# noinspection PyUnusedLocal
|
|
473
|
+
def scatter_reduce_setup_context(ctx, inputs, output):
|
|
474
|
+
(_, sparsity_layout_x, _, dim, i, sparsity_layout_o, _, sparsity_block_size, _, reduce_op) = inputs
|
|
475
|
+
|
|
476
|
+
ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_o)
|
|
477
|
+
ctx.dim = dim
|
|
478
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
479
|
+
ctx.reduce_op = reduce_op
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
scatter_reduce_forward.register_autograd(scatter_reduce_wrapper_backward, setup_context=scatter_reduce_setup_context)
|