blksprs 1.10.2__py3-none-any.whl → 2.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +2 -6
- blksprs/layouting/distribution_layout.py +32 -25
- blksprs/layouting/sparsity_layout.py +65 -52
- blksprs/ops/conversion.py +423 -374
- blksprs/ops/distribution.py +403 -335
- blksprs/ops/flow.py +135 -83
- blksprs/ops/matmul.py +221 -187
- blksprs/ops/misc/broadcast_ops.py +53 -35
- blksprs/ops/misc/row_wise.py +151 -91
- blksprs/ops/partitioning.py +136 -89
- blksprs/ops/repeat.py +115 -108
- blksprs/ops/softmax.py +244 -208
- blksprs/ops/transpose.py +69 -131
- blksprs/utils/benchmarking.py +3 -3
- blksprs/utils/tools.py +31 -4
- blksprs/utils/validation.py +0 -14
- {blksprs-1.10.2.dist-info → blksprs-2.0rc1.dist-info}/METADATA +42 -36
- blksprs-2.0rc1.dist-info/RECORD +22 -0
- {blksprs-1.10.2.dist-info → blksprs-2.0rc1.dist-info}/WHEEL +1 -1
- blksprs/ops/misc/exp.py +0 -104
- blksprs/utils/layout_utils.py +0 -17
- blksprs-1.10.2.dist-info/RECORD +0 -24
- {blksprs-1.10.2.dist-info → blksprs-2.0rc1.dist-info}/top_level.txt +0 -0
blksprs/ops/distribution.py
CHANGED
|
@@ -1,19 +1,20 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import triton
|
|
3
3
|
from torch import Tensor
|
|
4
|
+
from torch._library import triton_op
|
|
5
|
+
from torch._library.triton import wrap_triton
|
|
4
6
|
from triton import language as tl
|
|
5
7
|
|
|
6
|
-
from blksprs.ops.conversion import to_dense
|
|
7
8
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
8
|
-
from blksprs.utils.tools import
|
|
9
|
+
from blksprs.utils.tools import stride, get_autotune_configs
|
|
9
10
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
10
|
-
validate_sparsity, validate_dtype_int, validate_sparsity_block_size
|
|
11
|
+
validate_sparsity, validate_dtype_int, validate_sparsity_block_size
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
14
15
|
dim: int,
|
|
15
16
|
idx: BlksprsTensor, sparsity_layout_idx: Tensor,
|
|
16
|
-
sparsity_block_size: int,
|
|
17
|
+
sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
|
|
17
18
|
"""Applies a gather operation on a block-sparse tensor in compressed form.
|
|
18
19
|
|
|
19
20
|
Args:
|
|
@@ -23,7 +24,7 @@ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
23
24
|
idx (BlksprsTensor): The block-sparse indices tensor in compressed form specifying how to gather from the source tensor.
|
|
24
25
|
sparsity_layout_idx (Tensor): The sparsity layout of the indices block-sparse tensor.
|
|
25
26
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
26
|
-
|
|
27
|
+
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
27
28
|
|
|
28
29
|
Returns:
|
|
29
30
|
BlksprsTensor: The result of the gather operation as a block-sparse tensor in compressed form.
|
|
@@ -38,171 +39,203 @@ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
38
39
|
validate_device(src, idx)
|
|
39
40
|
validate_sparsity(sparsity_block_size, (src, sparsity_layout_src), (idx, sparsity_layout_idx))
|
|
40
41
|
validate_sparsity_block_size(sparsity_block_size, src, idx)
|
|
41
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
42
|
-
|
|
43
|
-
sparsity_layout_x_flat = sparsity_layout_src.reshape(-1)
|
|
44
|
-
sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
|
|
45
|
-
(sparsity_layout_x_flat == 1) -
|
|
46
|
-
(1 * (sparsity_layout_x_flat == 0)))
|
|
47
|
-
|
|
48
|
-
sparsity_lut_i = torch.nonzero(sparsity_layout_idx).contiguous()
|
|
49
|
-
|
|
50
|
-
validate_contiguous(sparsity_layout_src, sparsity_reverse_lut_x,
|
|
51
|
-
sparsity_layout_idx, sparsity_lut_i)
|
|
52
42
|
|
|
53
43
|
adjusted_dim = dim % 3
|
|
54
44
|
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
45
|
+
lut = gather_build_lut(lut, sparsity_layout_src, sparsity_layout_idx)
|
|
46
|
+
|
|
47
|
+
return BlksprsTensor(gather_forward(src, sparsity_layout_src, lut["sparsity_reverse_lut_x"],
|
|
48
|
+
adjusted_dim, idx, sparsity_layout_idx, lut["sparsity_lut_i"],
|
|
49
|
+
sparsity_block_size))
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@triton_op("blksprs::gather", mutates_args={})
|
|
53
|
+
def gather_forward(x: Tensor, sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
|
|
54
|
+
dim: int, i: Tensor, _: Tensor, sparsity_lut_i: Tensor,
|
|
55
|
+
sparsity_block_size: int) -> Tensor:
|
|
56
|
+
output = torch.empty_like(i, dtype=x.dtype)
|
|
57
|
+
|
|
58
|
+
x_b, x_r, x_c = x.size()
|
|
59
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
60
|
+
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
|
|
61
|
+
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
|
|
62
|
+
i_b, i_r, i_c = i.size()
|
|
63
|
+
i_b_s, i_r_s, i_c_s = stride(i)
|
|
64
|
+
s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
|
|
65
|
+
s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
|
|
66
|
+
o_b, o_r, o_c = output.size()
|
|
67
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
68
|
+
|
|
69
|
+
triton_grid = lambda meta: [o_b,
|
|
70
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
71
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
72
|
+
|
|
73
|
+
(wrap_triton(gather_kernel)[triton_grid]
|
|
74
|
+
(x,
|
|
75
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
76
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
77
|
+
sparsity_reverse_lut_x,
|
|
78
|
+
dim,
|
|
79
|
+
i,
|
|
80
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
81
|
+
output,
|
|
82
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
83
|
+
sparsity_lut_i, s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
|
|
84
|
+
sparsity_block_size))
|
|
85
|
+
|
|
86
|
+
return output
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def gather_backward(ctx, grad_output):
|
|
90
|
+
sparsity_layout_x, i, sparsity_layout_i = ctx.saved_tensors
|
|
91
|
+
dim = ctx.dim
|
|
92
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
93
|
+
|
|
94
|
+
return scatter_reduce(grad_output, sparsity_layout_i,
|
|
95
|
+
dim, i,
|
|
96
|
+
sparsity_layout_x, sparsity_block_size,
|
|
97
|
+
reduce_op="sum"), None, None, None, None, None, None, None
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@triton.autotune(
|
|
101
|
+
configs=get_autotune_configs(),
|
|
102
|
+
key=[],
|
|
103
|
+
)
|
|
104
|
+
@triton.jit
|
|
105
|
+
def gather_kernel(x,
|
|
106
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
107
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
108
|
+
r_lut_x,
|
|
109
|
+
dim,
|
|
110
|
+
i,
|
|
111
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
112
|
+
o,
|
|
113
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
114
|
+
s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
115
|
+
sparsity_block_size,
|
|
116
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
117
|
+
# Get triton block indices
|
|
118
|
+
pid_blk = tl.program_id(axis=0)
|
|
119
|
+
pid_row = tl.program_id(axis=1)
|
|
120
|
+
pid_col = tl.program_id(axis=2)
|
|
121
|
+
|
|
122
|
+
# Get valid triton block size
|
|
123
|
+
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
124
|
+
|
|
125
|
+
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
126
|
+
spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
|
|
127
|
+
spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
128
|
+
spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
|
|
129
|
+
|
|
130
|
+
spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
|
|
131
|
+
spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
132
|
+
spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
|
|
133
|
+
|
|
134
|
+
spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
|
|
135
|
+
spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
136
|
+
spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
|
|
137
|
+
|
|
138
|
+
# Load index values
|
|
139
|
+
blk_i_idx = ((pid_blk * i_b_s) +
|
|
140
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
141
|
+
((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
142
|
+
blk_i_msk = ((blk_i_idx >= 0 and
|
|
143
|
+
blk_i_idx < i_b * i_b_s) and
|
|
144
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
145
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
146
|
+
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
|
|
147
|
+
|
|
148
|
+
# Get indices of sparsity blocks and positions within the blocks
|
|
149
|
+
pos_spa_blk_x = blk_i // sparsity_block_size
|
|
150
|
+
pos_spa_int_x = blk_i % sparsity_block_size
|
|
151
|
+
|
|
152
|
+
rev_dst_bat_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_o, dtype=tl.int32)
|
|
153
|
+
rev_dst_row_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_o, dtype=tl.int32)
|
|
154
|
+
rev_dst_col_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_o, dtype=tl.int32)
|
|
155
|
+
dst_row_x = (((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
|
|
156
|
+
.broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
|
|
157
|
+
dst_col_x = (((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
|
|
158
|
+
.broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
|
|
159
|
+
if dim == 0:
|
|
160
|
+
rev_dst_bat_x = blk_i
|
|
161
|
+
elif dim == 1:
|
|
162
|
+
rev_dst_row_x = pos_spa_blk_x
|
|
163
|
+
dst_row_x = pos_spa_int_x * x_r_s
|
|
164
|
+
elif dim == 2:
|
|
165
|
+
rev_dst_col_x = pos_spa_blk_x
|
|
166
|
+
dst_col_x = pos_spa_int_x * x_c_s
|
|
167
|
+
|
|
168
|
+
# Load reverse sparsity indices for x
|
|
169
|
+
rev_idx_spa_x_idx = ((rev_dst_bat_x * s_l_x_b_s) +
|
|
170
|
+
(rev_dst_row_x * s_l_x_r_s) +
|
|
171
|
+
(rev_dst_col_x * s_l_x_c_s))
|
|
172
|
+
rev_idx_spa_x_msk = ((rev_idx_spa_x_idx >= 0 and
|
|
173
|
+
rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s) and
|
|
174
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
175
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
176
|
+
rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
|
|
177
|
+
|
|
178
|
+
# Load x values
|
|
179
|
+
blk_x_idx = ((rev_idx_spa_x * x_b_s) +
|
|
180
|
+
dst_row_x +
|
|
181
|
+
dst_col_x)
|
|
182
|
+
blk_x_msk = (((blk_x_idx >= 0 and
|
|
183
|
+
blk_x_idx < x_b * x_b_s) and
|
|
184
|
+
rev_idx_spa_x_msk != -1) and
|
|
185
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
186
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
187
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
188
|
+
|
|
189
|
+
# Store output
|
|
190
|
+
blk_o_idx = ((pid_blk * o_b_s) +
|
|
191
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
192
|
+
((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
193
|
+
blk_o_msk = (((blk_o_idx >= 0 and
|
|
194
|
+
blk_o_idx < o_b * o_b_s) and
|
|
195
|
+
rev_idx_spa_x_msk != -1) and
|
|
196
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
197
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
198
|
+
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def gather_build_lut(lut: dict, sparsity_layout_src: Tensor, sparsity_layout_idx: Tensor):
|
|
202
|
+
if lut is None:
|
|
203
|
+
lut = dict()
|
|
204
|
+
|
|
205
|
+
if "sparsity_reverse_lut_x" not in lut:
|
|
206
|
+
sparsity_layout_x_flat = sparsity_layout_src.reshape(-1)
|
|
207
|
+
sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
|
|
208
|
+
(sparsity_layout_x_flat == 1) -
|
|
209
|
+
(1 * (sparsity_layout_x_flat == 0)))
|
|
210
|
+
lut["sparsity_reverse_lut_x"] = sparsity_reverse_lut_x
|
|
211
|
+
|
|
212
|
+
if "sparsity_lut_i" not in lut:
|
|
213
|
+
sparsity_lut_i = torch.nonzero(sparsity_layout_idx).contiguous()
|
|
214
|
+
lut["sparsity_lut_i"] = sparsity_lut_i
|
|
215
|
+
|
|
216
|
+
validate_contiguous(sparsity_layout_src, lut["sparsity_reverse_lut_x"],
|
|
217
|
+
sparsity_layout_idx, lut["sparsity_lut_i"])
|
|
218
|
+
|
|
219
|
+
return lut
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
# noinspection PyUnusedLocal
|
|
223
|
+
def gather_setup_context(ctx, inputs, output):
|
|
224
|
+
(_, sparsity_layout_x, _, dim, i, sparsity_layout_i, _, sparsity_block_size) = inputs
|
|
225
|
+
|
|
226
|
+
ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_i)
|
|
227
|
+
ctx.dim = dim
|
|
228
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
gather_forward.register_autograd(gather_backward, setup_context=gather_setup_context)
|
|
199
232
|
|
|
200
233
|
|
|
201
234
|
def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
202
235
|
dim: int,
|
|
203
236
|
idx: BlksprsTensor,
|
|
204
237
|
sparsity_layout_tgt: Tensor,
|
|
205
|
-
sparsity_block_size: int,
|
|
238
|
+
sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
|
|
206
239
|
"""Wrapper for ``scatter_reduce`` with ``reduce_op="none"``.
|
|
207
240
|
|
|
208
241
|
"""
|
|
@@ -211,7 +244,7 @@ def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
211
244
|
idx,
|
|
212
245
|
sparsity_layout_tgt,
|
|
213
246
|
sparsity_block_size,
|
|
214
|
-
reduce_op="none",
|
|
247
|
+
reduce_op="none", lut=lut)
|
|
215
248
|
|
|
216
249
|
|
|
217
250
|
def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
@@ -219,7 +252,7 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
219
252
|
idx: BlksprsTensor,
|
|
220
253
|
sparsity_layout_tgt: Tensor,
|
|
221
254
|
sparsity_block_size: int,
|
|
222
|
-
reduce_op: str = "sum",
|
|
255
|
+
reduce_op: str = "sum", lut: dict = None) -> BlksprsTensor:
|
|
223
256
|
"""Applies a scatter operation on a block-sparse tensor in compressed form.
|
|
224
257
|
|
|
225
258
|
Args:
|
|
@@ -231,7 +264,7 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
231
264
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
232
265
|
reduce_op (str, optional): The reduction operation to apply during the scatter operation (default ``"sum"``).
|
|
233
266
|
Supported operations are ``"none"`` and ``"sum"``.
|
|
234
|
-
|
|
267
|
+
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
235
268
|
|
|
236
269
|
Returns:
|
|
237
270
|
BlksprsTensor: The result of the scatter operation as a block-sparse tensor in compressed form.
|
|
@@ -246,183 +279,218 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
246
279
|
validate_device(src, idx)
|
|
247
280
|
validate_sparsity(sparsity_block_size, (src, sparsity_layout_src), (idx, sparsity_layout_src))
|
|
248
281
|
validate_sparsity_block_size(sparsity_block_size, src, idx)
|
|
249
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
250
282
|
|
|
251
283
|
if reduce_op not in ["none", "sum"]:
|
|
252
284
|
raise ValueError(f"Reduction operation '{reduce_op}' is not supported")
|
|
253
285
|
|
|
254
|
-
|
|
286
|
+
adjusted_dim = dim % 3
|
|
255
287
|
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
288
|
+
lut = scatter_reduce_build_lut(lut, sparsity_layout_src, sparsity_layout_tgt)
|
|
289
|
+
|
|
290
|
+
return BlksprsTensor(scatter_reduce_forward(src, sparsity_layout_src, lut["sparsity_lut_x"],
|
|
291
|
+
adjusted_dim, idx,
|
|
292
|
+
sparsity_layout_tgt, lut["sparsity_reverse_lut_o"],
|
|
293
|
+
sparsity_block_size, lut["n_sparse_blocks"],
|
|
294
|
+
reduce_op))
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
@triton_op("blksprs::scatter_reduce", mutates_args={})
|
|
298
|
+
def scatter_reduce_forward(x: Tensor, _: Tensor, sparsity_lut_x: Tensor,
|
|
299
|
+
dim: int, i: Tensor,
|
|
300
|
+
sparsity_layout_o: Tensor, sparsity_reverse_lut_o: Tensor,
|
|
301
|
+
sparsity_block_size: int, n_sparse_blocks: int,
|
|
302
|
+
reduce_op: str) -> Tensor:
|
|
303
|
+
output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
304
|
+
dtype=x.dtype, device=x.device)
|
|
305
|
+
|
|
306
|
+
x_b, x_r, x_c = x.size()
|
|
307
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
308
|
+
s_lut_x_r, s_lut_x_c = sparsity_lut_x.size()
|
|
309
|
+
s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut_x)
|
|
310
|
+
i_b, i_r, i_c = i.size()
|
|
311
|
+
i_b_s, i_r_s, i_c_s = stride(i)
|
|
312
|
+
o_b, o_r, o_c = output.size()
|
|
313
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
314
|
+
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
|
|
315
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
|
|
316
|
+
|
|
317
|
+
triton_grid = lambda meta: [x_b,
|
|
318
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
319
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
320
|
+
|
|
321
|
+
reduce_op_ind = 0
|
|
322
|
+
if reduce_op == "sum":
|
|
323
|
+
reduce_op_ind = 1
|
|
324
|
+
|
|
325
|
+
(wrap_triton(scatter_reduce_kernel)[triton_grid]
|
|
326
|
+
(x,
|
|
327
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
328
|
+
sparsity_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
329
|
+
dim,
|
|
330
|
+
i,
|
|
331
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
332
|
+
output,
|
|
333
|
+
o_b, o_b_s,
|
|
334
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
335
|
+
sparsity_reverse_lut_o,
|
|
336
|
+
reduce_op_ind,
|
|
337
|
+
sparsity_block_size))
|
|
338
|
+
|
|
339
|
+
return output
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def scatter_reduce_backward(ctx, grad_output):
|
|
343
|
+
sparsity_layout_x, i, sparsity_layout_o = ctx.saved_tensors
|
|
344
|
+
dim = ctx.dim
|
|
345
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
346
|
+
reduce_op = ctx.reduce_op
|
|
347
|
+
|
|
348
|
+
if reduce_op == "sum":
|
|
349
|
+
return gather(grad_output, sparsity_layout_o, dim, i, sparsity_layout_x,
|
|
350
|
+
sparsity_block_size), None, None, None, None, None, None, None, None, None
|
|
351
|
+
else:
|
|
352
|
+
raise ValueError(f"Reduction operation '{reduce_op}' does not support backward pass")
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
@triton.autotune(
|
|
356
|
+
configs=get_autotune_configs(),
|
|
357
|
+
key=[],
|
|
358
|
+
reset_to_zero=["o"]
|
|
359
|
+
)
|
|
360
|
+
@triton.jit
|
|
361
|
+
def scatter_reduce_kernel(x,
|
|
362
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
363
|
+
s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
364
|
+
dim,
|
|
365
|
+
i,
|
|
366
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
367
|
+
o,
|
|
368
|
+
o_b, o_b_s,
|
|
369
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
370
|
+
r_lut_o,
|
|
371
|
+
reduce_op_ind,
|
|
372
|
+
sparsity_block_size,
|
|
373
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
374
|
+
# Get triton block indices
|
|
375
|
+
pid_blk = tl.program_id(axis=0)
|
|
376
|
+
pid_row = tl.program_id(axis=1)
|
|
377
|
+
pid_col = tl.program_id(axis=2)
|
|
378
|
+
|
|
379
|
+
# Get valid triton block size
|
|
380
|
+
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
381
|
+
|
|
382
|
+
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
383
|
+
spa_bat_x_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
|
|
384
|
+
spa_bat_x_msk = (spa_bat_x_idx >= 0 and spa_bat_x_idx < s_lut_x_r * s_lut_x_r_s)
|
|
385
|
+
spa_bat_x = tl.load(s_lut_x + spa_bat_x_idx, mask=spa_bat_x_msk)
|
|
386
|
+
|
|
387
|
+
spa_row_x_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
|
|
388
|
+
spa_row_x_msk = (spa_row_x_idx >= 0 and spa_row_x_idx < s_lut_x_r * s_lut_x_r_s)
|
|
389
|
+
spa_row_x = tl.load(s_lut_x + spa_row_x_idx, mask=spa_row_x_msk)
|
|
390
|
+
|
|
391
|
+
spa_col_x_idx = (pid_blk * s_lut_x_r_s + 2 * s_lut_x_c_s)
|
|
392
|
+
spa_col_x_msk = (spa_col_x_idx >= 0 and spa_col_x_idx < s_lut_x_r * s_lut_x_r_s)
|
|
393
|
+
spa_col_x = tl.load(s_lut_x + spa_col_x_idx, mask=spa_col_x_msk)
|
|
394
|
+
|
|
395
|
+
# Load x values
|
|
396
|
+
blk_x_idx = ((pid_blk * x_b_s) +
|
|
397
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
398
|
+
((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
399
|
+
blk_x_msk = ((blk_x_idx >= 0 and
|
|
400
|
+
blk_x_idx < x_b * x_b_s) and
|
|
401
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
402
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
403
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
404
|
+
|
|
405
|
+
# Load index values
|
|
406
|
+
blk_i_idx = ((pid_blk * i_b_s) +
|
|
407
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
408
|
+
((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
409
|
+
blk_i_msk = ((blk_i_idx >= 0 and
|
|
410
|
+
blk_i_idx < i_b * i_b_s) and
|
|
411
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
412
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
413
|
+
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
|
|
414
|
+
|
|
415
|
+
# Get indices of sparsity blocks and positions within the blocks
|
|
416
|
+
pos_spa_blk_x = blk_i // sparsity_block_size
|
|
417
|
+
pos_spa_int_x = blk_i % sparsity_block_size
|
|
418
|
+
|
|
419
|
+
rev_dst_bat_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_x, dtype=tl.int32)
|
|
420
|
+
rev_dst_row_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_x, dtype=tl.int32)
|
|
421
|
+
rev_dst_col_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_x, dtype=tl.int32)
|
|
422
|
+
dst_row_o = (((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
|
|
423
|
+
.broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
|
|
424
|
+
dst_col_o = (((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
|
|
425
|
+
.broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
|
|
426
|
+
if dim == 0:
|
|
427
|
+
rev_dst_bat_o = blk_i
|
|
428
|
+
elif dim == 1:
|
|
429
|
+
rev_dst_row_o = pos_spa_blk_x
|
|
430
|
+
dst_row_o = pos_spa_int_x * x_r_s
|
|
431
|
+
elif dim == 2:
|
|
432
|
+
rev_dst_col_o = pos_spa_blk_x
|
|
433
|
+
dst_col_o = pos_spa_int_x * x_c_s
|
|
434
|
+
|
|
435
|
+
# Load reverse sparsity indices for o
|
|
436
|
+
rev_idx_spa_o_idx = ((rev_dst_bat_o * s_l_o_b_s) +
|
|
437
|
+
(rev_dst_row_o * s_l_o_r_s) +
|
|
438
|
+
(rev_dst_col_o * s_l_o_c_s))
|
|
439
|
+
rev_idx_spa_o_msk = ((rev_idx_spa_o_idx >= 0 and
|
|
440
|
+
rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s) and
|
|
441
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
442
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
443
|
+
rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
|
|
444
|
+
|
|
445
|
+
# Store output
|
|
446
|
+
blk_o_idx = ((rev_idx_spa_o * o_b_s) +
|
|
447
|
+
dst_row_o +
|
|
448
|
+
dst_col_o)
|
|
449
|
+
blk_o_msk = (((blk_o_idx >= 0 and
|
|
450
|
+
blk_o_idx < o_b * o_b_s) and
|
|
451
|
+
rev_idx_spa_o_msk != -1) and
|
|
452
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
453
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
454
|
+
|
|
455
|
+
if reduce_op_ind == 0:
|
|
456
|
+
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
457
|
+
elif reduce_op_ind == 1:
|
|
458
|
+
tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
260
459
|
|
|
261
|
-
n_sparse_blocks = torch.sum(sparsity_layout_tgt.to(torch.int)).item()
|
|
262
460
|
|
|
263
|
-
|
|
264
|
-
|
|
461
|
+
def scatter_reduce_build_lut(lut: dict, sparsity_layout_src: Tensor, sparsity_layout_tgt: Tensor):
|
|
462
|
+
if lut is None:
|
|
463
|
+
lut = dict()
|
|
464
|
+
|
|
465
|
+
if "sparsity_lut_x" not in lut:
|
|
466
|
+
sparsity_lut_x = torch.nonzero(sparsity_layout_src).contiguous()
|
|
467
|
+
lut["sparsity_lut_x"] = sparsity_lut_x
|
|
468
|
+
|
|
469
|
+
if "sparsity_reverse_lut_o" not in lut:
|
|
470
|
+
sparsity_layout_o_flat = sparsity_layout_tgt.reshape(-1)
|
|
471
|
+
sparsity_reverse_lut_o = ((torch.cumsum(sparsity_layout_o_flat, dim=-1) - 1) *
|
|
472
|
+
(sparsity_layout_o_flat == 1) -
|
|
473
|
+
(1 * (sparsity_layout_o_flat == 0)))
|
|
474
|
+
lut["sparsity_reverse_lut_o"] = sparsity_reverse_lut_o
|
|
475
|
+
|
|
476
|
+
if "n_sparse_blocks" not in lut:
|
|
477
|
+
n_sparse_blocks = torch.sum(sparsity_layout_tgt.to(torch.int)).item()
|
|
478
|
+
lut["n_sparse_blocks"] = n_sparse_blocks
|
|
479
|
+
|
|
480
|
+
validate_contiguous(sparsity_layout_src, lut["sparsity_lut_x"],
|
|
481
|
+
sparsity_layout_tgt, lut["sparsity_reverse_lut_o"])
|
|
482
|
+
|
|
483
|
+
return lut
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
# noinspection PyUnusedLocal
|
|
487
|
+
def scatter_reduce_setup_context(ctx, inputs, output):
|
|
488
|
+
(_, sparsity_layout_x, _, dim, i, sparsity_layout_o, _, sparsity_block_size, _, reduce_op) = inputs
|
|
489
|
+
|
|
490
|
+
ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_o)
|
|
491
|
+
ctx.dim = dim
|
|
492
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
493
|
+
ctx.reduce_op = reduce_op
|
|
265
494
|
|
|
266
|
-
adjusted_dim = dim % 3
|
|
267
495
|
|
|
268
|
-
|
|
269
|
-
adjusted_dim, idx,
|
|
270
|
-
sparsity_layout_tgt, sparsity_reverse_lut_o,
|
|
271
|
-
sparsity_block_size, n_sparse_blocks,
|
|
272
|
-
reduce_op, triton_block_size))
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
class _BlocksparseScatterReduce(torch.autograd.Function):
|
|
276
|
-
|
|
277
|
-
@staticmethod
|
|
278
|
-
def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_lut_x: Tensor,
|
|
279
|
-
dim: int, i: Tensor,
|
|
280
|
-
sparsity_layout_o: Tensor, sparsity_reverse_lut_o: Tensor,
|
|
281
|
-
sparsity_block_size: int, n_sparse_blocks: int,
|
|
282
|
-
reduce_op: str, triton_block_size: int) -> Tensor:
|
|
283
|
-
output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
284
|
-
dtype=x.dtype, device=x.device)
|
|
285
|
-
|
|
286
|
-
x_b, x_r, x_c = x.size()
|
|
287
|
-
x_b_s, x_r_s, x_c_s = stride(x)
|
|
288
|
-
s_lut_x_r, s_lut_x_c = sparsity_lut_x.size()
|
|
289
|
-
s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut_x)
|
|
290
|
-
i_b, i_r, i_c = i.size()
|
|
291
|
-
i_b_s, i_r_s, i_c_s = stride(i)
|
|
292
|
-
o_b, o_r, o_c = output.size()
|
|
293
|
-
o_b_s, o_r_s, o_c_s = stride(output)
|
|
294
|
-
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
|
|
295
|
-
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
|
|
296
|
-
|
|
297
|
-
if triton_block_size is None:
|
|
298
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
299
|
-
|
|
300
|
-
triton_grid = lambda meta: [x_b,
|
|
301
|
-
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
302
|
-
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
303
|
-
|
|
304
|
-
reduce_op_ind = 0
|
|
305
|
-
if reduce_op == "sum":
|
|
306
|
-
reduce_op_ind = 1
|
|
307
|
-
|
|
308
|
-
(_BlocksparseScatterReduce.kernel_blocksparse_scatter[triton_grid]
|
|
309
|
-
(x,
|
|
310
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
311
|
-
sparsity_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
312
|
-
dim,
|
|
313
|
-
i,
|
|
314
|
-
i_b, i_b_s, i_r_s, i_c_s,
|
|
315
|
-
output,
|
|
316
|
-
o_b, o_b_s,
|
|
317
|
-
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
318
|
-
sparsity_reverse_lut_o,
|
|
319
|
-
reduce_op_ind,
|
|
320
|
-
sparsity_block_size,
|
|
321
|
-
triton_block_size))
|
|
322
|
-
|
|
323
|
-
ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_o)
|
|
324
|
-
ctx.dim = dim
|
|
325
|
-
ctx.sparsity_block_size = sparsity_block_size
|
|
326
|
-
ctx.reduce_op = reduce_op
|
|
327
|
-
ctx.triton_block_size = triton_block_size
|
|
328
|
-
|
|
329
|
-
return output
|
|
330
|
-
|
|
331
|
-
@staticmethod
|
|
332
|
-
def backward(ctx, grad_output):
|
|
333
|
-
sparsity_layout_x, i, sparsity_layout_o = ctx.saved_tensors
|
|
334
|
-
dim = ctx.dim
|
|
335
|
-
sparsity_block_size = ctx.sparsity_block_size
|
|
336
|
-
reduce_op = ctx.reduce_op
|
|
337
|
-
triton_block_size = ctx.triton_block_size
|
|
338
|
-
|
|
339
|
-
if reduce_op == "sum":
|
|
340
|
-
return gather(grad_output, sparsity_layout_o, dim, i, sparsity_layout_x, sparsity_block_size,
|
|
341
|
-
triton_block_size=triton_block_size), None, None, None, None, None, None, None, None, None, None
|
|
342
|
-
else:
|
|
343
|
-
raise ValueError(f"Reduction operation '{reduce_op}' does not support backward pass")
|
|
344
|
-
|
|
345
|
-
@staticmethod
|
|
346
|
-
@triton.jit
|
|
347
|
-
def kernel_blocksparse_scatter(x,
|
|
348
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
349
|
-
s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
350
|
-
dim,
|
|
351
|
-
i,
|
|
352
|
-
i_b, i_b_s, i_r_s, i_c_s,
|
|
353
|
-
o,
|
|
354
|
-
o_b, o_b_s,
|
|
355
|
-
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
356
|
-
r_lut_o,
|
|
357
|
-
reduce_op_ind,
|
|
358
|
-
sparsity_block_size,
|
|
359
|
-
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
360
|
-
# Get triton block indices
|
|
361
|
-
pid_blk = tl.program_id(axis=0)
|
|
362
|
-
pid_row = tl.program_id(axis=1)
|
|
363
|
-
pid_col = tl.program_id(axis=2)
|
|
364
|
-
|
|
365
|
-
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
366
|
-
spa_bat_x_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
|
|
367
|
-
spa_bat_x_msk = (spa_bat_x_idx >= 0 and spa_bat_x_idx < s_lut_x_r * s_lut_x_r_s)
|
|
368
|
-
spa_bat_x = tl.load(s_lut_x + spa_bat_x_idx, mask=spa_bat_x_msk)
|
|
369
|
-
|
|
370
|
-
spa_row_x_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
|
|
371
|
-
spa_row_x_msk = (spa_row_x_idx >= 0 and spa_row_x_idx < s_lut_x_r * s_lut_x_r_s)
|
|
372
|
-
spa_row_x = tl.load(s_lut_x + spa_row_x_idx, mask=spa_row_x_msk)
|
|
373
|
-
|
|
374
|
-
spa_col_x_idx = (pid_blk * s_lut_x_r_s + 2 * s_lut_x_c_s)
|
|
375
|
-
spa_col_x_msk = (spa_col_x_idx >= 0 and spa_col_x_idx < s_lut_x_r * s_lut_x_r_s)
|
|
376
|
-
spa_col_x = tl.load(s_lut_x + spa_col_x_idx, mask=spa_col_x_msk)
|
|
377
|
-
|
|
378
|
-
# Load x values
|
|
379
|
-
blk_x_idx = ((pid_blk * x_b_s) +
|
|
380
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
381
|
-
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
382
|
-
blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
|
|
383
|
-
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
384
|
-
|
|
385
|
-
# Load index values
|
|
386
|
-
blk_i_idx = ((pid_blk * i_b_s) +
|
|
387
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
388
|
-
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
389
|
-
blk_i_msk = (blk_i_idx >= 0 and blk_i_idx < i_b * i_b_s)
|
|
390
|
-
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
|
|
391
|
-
|
|
392
|
-
# Get indices of sparsity blocks and positions within the blocks
|
|
393
|
-
pos_spa_blk_x = blk_i // sparsity_block_size
|
|
394
|
-
pos_spa_int_x = blk_i % sparsity_block_size
|
|
395
|
-
|
|
396
|
-
rev_dst_bat_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_x, dtype=tl.int32)
|
|
397
|
-
rev_dst_row_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_x, dtype=tl.int32)
|
|
398
|
-
rev_dst_col_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_x, dtype=tl.int32)
|
|
399
|
-
dst_row_o = (((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
|
|
400
|
-
.broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
|
|
401
|
-
dst_col_o = (((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
|
|
402
|
-
.broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
|
|
403
|
-
if dim == 0:
|
|
404
|
-
rev_dst_bat_o = blk_i
|
|
405
|
-
elif dim == 1:
|
|
406
|
-
rev_dst_row_o = pos_spa_blk_x
|
|
407
|
-
dst_row_o = pos_spa_int_x * x_r_s
|
|
408
|
-
elif dim == 2:
|
|
409
|
-
rev_dst_col_o = pos_spa_blk_x
|
|
410
|
-
dst_col_o = pos_spa_int_x * x_c_s
|
|
411
|
-
|
|
412
|
-
# Load reverse sparsity indices for o
|
|
413
|
-
rev_idx_spa_o_idx = ((rev_dst_bat_o * s_l_o_b_s) +
|
|
414
|
-
(rev_dst_row_o * s_l_o_r_s) +
|
|
415
|
-
(rev_dst_col_o * s_l_o_c_s))
|
|
416
|
-
rev_idx_spa_o_msk = (rev_idx_spa_o_idx >= 0 and rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
|
|
417
|
-
rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
|
|
418
|
-
|
|
419
|
-
# Store output
|
|
420
|
-
blk_o_idx = ((rev_idx_spa_o * o_b_s) +
|
|
421
|
-
dst_row_o +
|
|
422
|
-
dst_col_o)
|
|
423
|
-
blk_o_msk = ((blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s) and rev_idx_spa_o_msk != -1)
|
|
424
|
-
|
|
425
|
-
if reduce_op_ind == 0:
|
|
426
|
-
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
427
|
-
elif reduce_op_ind == 1:
|
|
428
|
-
tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
496
|
+
scatter_reduce_forward.register_autograd(scatter_reduce_backward, setup_context=scatter_reduce_setup_context)
|