blksprs 1.11__py3-none-any.whl → 2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +4 -5
- blksprs/layouting/distribution_layout.py +64 -48
- blksprs/layouting/sparsity_layout.py +96 -72
- blksprs/ops/conversion.py +349 -338
- blksprs/ops/distribution.py +318 -294
- blksprs/ops/flow.py +113 -100
- blksprs/ops/matmul.py +187 -172
- blksprs/ops/misc/broadcast_ops.py +68 -53
- blksprs/ops/misc/row_wise.py +223 -176
- blksprs/ops/partitioning.py +140 -132
- blksprs/ops/repeat.py +118 -120
- blksprs/ops/softmax.py +240 -214
- blksprs/ops/transpose.py +55 -52
- blksprs/utils/autotuning.py +78 -0
- blksprs/utils/benchmarking.py +3 -3
- blksprs/utils/processing.py +2 -1
- blksprs/utils/tools.py +5 -6
- blksprs/utils/validation.py +22 -16
- {blksprs-1.11.dist-info → blksprs-2.0.dist-info}/METADATA +55 -36
- blksprs-2.0.dist-info/RECORD +23 -0
- {blksprs-1.11.dist-info → blksprs-2.0.dist-info}/WHEEL +1 -1
- blksprs/utils/layout_utils.py +0 -17
- blksprs-1.11.dist-info/RECORD +0 -23
- {blksprs-1.11.dist-info → blksprs-2.0.dist-info}/top_level.txt +0 -0
blksprs/ops/misc/row_wise.py
CHANGED
|
@@ -1,16 +1,19 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import triton
|
|
3
3
|
from torch import Tensor
|
|
4
|
+
from torch._library.triton import wrap_triton, triton_op
|
|
4
5
|
from triton import language as tl
|
|
5
6
|
|
|
7
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
6
8
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
7
|
-
from blksprs.utils.tools import
|
|
9
|
+
from blksprs.utils.tools import stride
|
|
8
10
|
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, validate_sparsity, \
|
|
9
|
-
validate_sparsity_block_size
|
|
11
|
+
validate_sparsity_block_size
|
|
10
12
|
|
|
11
13
|
|
|
14
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
|
|
12
15
|
def row_wise_sum(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
13
|
-
flag_slice_only: bool = False
|
|
16
|
+
flag_slice_only: bool = False) -> (BlksprsTensor, Tensor):
|
|
14
17
|
"""Computes the row-wise sum of a block-sparse tensor.
|
|
15
18
|
|
|
16
19
|
Returns a block-sparse tensor in compressed form with only one block per row, where the first entry contains the sum
|
|
@@ -25,7 +28,6 @@ def row_wise_sum(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
|
|
|
25
28
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
26
29
|
flag_slice_only (bool, optional): If set the output will be of shape ``[x.size(0), x.size(1), 1]``
|
|
27
30
|
(default ``False``).
|
|
28
|
-
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
29
31
|
|
|
30
32
|
Returns:
|
|
31
33
|
tuple[BlksprsTensor, Tensor]: A tuple containing a block-sparse tensor in compressed form containing the row-wise sum
|
|
@@ -39,7 +41,6 @@ def row_wise_sum(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
|
|
|
39
41
|
validate_device(x)
|
|
40
42
|
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
41
43
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
42
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
43
44
|
|
|
44
45
|
sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
|
|
45
46
|
|
|
@@ -54,50 +55,65 @@ def row_wise_sum(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
|
|
|
54
55
|
validate_contiguous(sparsity_layout, sparsity_lut,
|
|
55
56
|
sparsity_layout_output, sparsity_reverse_lut_output)
|
|
56
57
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
58
|
+
return BlksprsTensor(row_wise_sum_forward(
|
|
59
|
+
x, sparsity_lut, sparsity_layout_output, sparsity_reverse_lut_output,
|
|
60
|
+
sparsity_block_size, n_sparse_blocks_output, flag_slice_only)), sparsity_layout_output
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@triton_op("blksprs::row_wise_sum_forward", mutates_args={})
|
|
64
|
+
def row_wise_sum_forward(x: Tensor, sparsity_lut: Tensor,
|
|
65
|
+
sparsity_layout_output: Tensor, sparsity_reverse_lut_output: Tensor,
|
|
66
|
+
sparsity_block_size: int, n_sparse_blocks_output: int,
|
|
67
|
+
flag_slice_only: bool = False) -> Tensor:
|
|
68
|
+
with torch.no_grad():
|
|
69
|
+
output = torch.zeros(
|
|
70
|
+
size=(n_sparse_blocks_output, sparsity_block_size, 1 if flag_slice_only else sparsity_block_size),
|
|
71
|
+
dtype=x.dtype, device=x.device)
|
|
72
|
+
|
|
73
|
+
x_b, x_r, x_c = x.size()
|
|
74
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
75
|
+
s_lut_x_r, s_lut_x_c = sparsity_lut.size()
|
|
76
|
+
s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut)
|
|
77
|
+
o_b, o_r, o_c = output.size()
|
|
78
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
79
|
+
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
|
|
80
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
|
|
81
|
+
|
|
82
|
+
triton_grid = lambda meta: [x_b,
|
|
83
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
84
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
85
|
+
|
|
86
|
+
(wrap_triton(row_wise_sum_kernel)[triton_grid]
|
|
87
|
+
(x,
|
|
88
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
89
|
+
sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
90
|
+
output,
|
|
91
|
+
o_b, o_b_s, o_r_s,
|
|
92
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s,
|
|
93
|
+
sparsity_reverse_lut_output,
|
|
94
|
+
sparsity_block_size))
|
|
95
|
+
|
|
96
|
+
return output
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
# noinspection PyUnusedLocal
|
|
100
|
+
@triton.autotune(
|
|
101
|
+
configs=get_autotune_configs(),
|
|
102
|
+
key=["sparsity_block_size"],
|
|
103
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
104
|
+
reset_to_zero=["o"]
|
|
105
|
+
)
|
|
92
106
|
@triton.jit
|
|
93
|
-
def
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
107
|
+
def row_wise_sum_kernel(x,
|
|
108
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
109
|
+
s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
110
|
+
o,
|
|
111
|
+
o_b, o_b_s, o_r_s,
|
|
112
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s,
|
|
113
|
+
r_lut_o,
|
|
114
|
+
sparsity_block_size,
|
|
115
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
116
|
+
# Get triton block indices
|
|
101
117
|
pid_blk = tl.program_id(axis=0)
|
|
102
118
|
pid_row = tl.program_id(axis=1)
|
|
103
119
|
pid_col = tl.program_id(axis=2)
|
|
@@ -117,27 +133,27 @@ def kernel_blocksparse_row_wise_sum(x,
|
|
|
117
133
|
rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
|
|
118
134
|
rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
119
135
|
|
|
120
|
-
if rev_idx_spa
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
blk_msk = (blk_idx >= 0 and blk_idx < x_b * x_b_s)
|
|
128
|
-
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
136
|
+
if rev_idx_spa >= 0:
|
|
137
|
+
blk_idx = ((pid_blk * x_b_s) +
|
|
138
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
139
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
140
|
+
blk_msk = (blk_idx >= 0 and
|
|
141
|
+
blk_idx < x_b * x_b_s)
|
|
142
|
+
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
129
143
|
|
|
130
|
-
|
|
144
|
+
buf = tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
|
|
131
145
|
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
146
|
+
o_idx = (rev_idx_spa * o_b_s +
|
|
147
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
148
|
+
(tl.arange(0, 1))[None, :])
|
|
149
|
+
o_msk = (o_idx >= 0 and
|
|
150
|
+
o_idx < o_b * o_b_s)
|
|
151
|
+
tl.atomic_add(o + o_idx, buf, o_msk)
|
|
137
152
|
|
|
138
153
|
|
|
154
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
139
155
|
def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
140
|
-
flag_slice_only: bool = False
|
|
156
|
+
flag_slice_only: bool = False) -> (BlksprsTensor, Tensor):
|
|
141
157
|
"""Computes the row-wise max of a block-sparse tensor.
|
|
142
158
|
|
|
143
159
|
Returns a block-sparse tensor in compressed form with only one block per row, where the first entry contains the
|
|
@@ -152,13 +168,14 @@ def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
|
|
|
152
168
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
153
169
|
flag_slice_only (bool, optional): If set the output will be of shape ``[x.size(0), x.size(1), 1]``
|
|
154
170
|
(default ``False``).
|
|
155
|
-
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
156
171
|
|
|
157
172
|
Returns:
|
|
158
173
|
tuple[BlksprsTensor, Tensor]: A tuple containing a block-sparse tensor in compressed form containing the row-wise max
|
|
159
174
|
of the input and the sparsity layout of the output tensor.
|
|
160
175
|
|
|
161
176
|
"""
|
|
177
|
+
# TODO Fix for triton bug, see https://github.com/triton-lang/triton/issues/6376, should be fixed with the upcoming 3.4.0 release
|
|
178
|
+
x = torch.where(x == -0.0, torch.tensor(0.0), x)
|
|
162
179
|
x = x.contiguous()
|
|
163
180
|
|
|
164
181
|
validate_dimensions(x)
|
|
@@ -166,7 +183,6 @@ def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
|
|
|
166
183
|
validate_device(x)
|
|
167
184
|
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
168
185
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
169
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
170
186
|
|
|
171
187
|
sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
|
|
172
188
|
|
|
@@ -181,50 +197,67 @@ def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
|
|
|
181
197
|
validate_contiguous(sparsity_layout, sparsity_lut,
|
|
182
198
|
sparsity_layout_output, sparsity_reverse_lut_output)
|
|
183
199
|
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
200
|
+
return BlksprsTensor(
|
|
201
|
+
row_wise_max_forward(x, sparsity_lut, sparsity_layout_output, sparsity_reverse_lut_output, sparsity_block_size,
|
|
202
|
+
n_sparse_blocks_output, flag_slice_only)), sparsity_layout_output
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
@triton_op("blksprs::row_wise_max_forward", mutates_args={})
|
|
206
|
+
def row_wise_max_forward(x: Tensor, sparsity_lut: Tensor,
|
|
207
|
+
sparsity_layout_output: Tensor, sparsity_reverse_lut_output: Tensor,
|
|
208
|
+
sparsity_block_size: int, n_sparse_blocks_output: int,
|
|
209
|
+
flag_slice_only: bool = False) -> Tensor:
|
|
210
|
+
with torch.no_grad():
|
|
211
|
+
output = torch.full(size=(n_sparse_blocks_output,
|
|
212
|
+
sparsity_block_size,
|
|
213
|
+
1 if flag_slice_only else sparsity_block_size),
|
|
214
|
+
fill_value=torch.finfo(x.dtype).min,
|
|
215
|
+
device=x.device)
|
|
216
|
+
|
|
217
|
+
x_b, x_r, x_c = x.size()
|
|
218
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
219
|
+
s_lut_x_r, s_lut_x_c = sparsity_lut.size()
|
|
220
|
+
s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut)
|
|
221
|
+
o_b, o_r, o_c = output.size()
|
|
222
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
223
|
+
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
|
|
224
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
|
|
225
|
+
|
|
226
|
+
triton_grid = lambda meta: [x_b,
|
|
227
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
228
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
229
|
+
|
|
230
|
+
(wrap_triton(row_wise_max_kernel)[triton_grid]
|
|
231
|
+
(x,
|
|
232
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
233
|
+
sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
234
|
+
output,
|
|
235
|
+
o_b, o_b_s, o_r_s,
|
|
236
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s,
|
|
237
|
+
sparsity_reverse_lut_output,
|
|
238
|
+
sparsity_block_size))
|
|
239
|
+
|
|
240
|
+
return output
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
# noinspection PyUnusedLocal
|
|
244
|
+
@triton.autotune(
|
|
245
|
+
configs=get_autotune_configs(),
|
|
246
|
+
key=["sparsity_block_size"],
|
|
247
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
248
|
+
restore_value=["o"]
|
|
249
|
+
)
|
|
219
250
|
@triton.jit
|
|
220
|
-
def
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
251
|
+
def row_wise_max_kernel(x,
|
|
252
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
253
|
+
s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
254
|
+
o,
|
|
255
|
+
o_b, o_b_s, o_r_s,
|
|
256
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s,
|
|
257
|
+
r_lut_o,
|
|
258
|
+
sparsity_block_size,
|
|
259
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
260
|
+
# Get triton block indices
|
|
228
261
|
pid_blk = tl.program_id(axis=0)
|
|
229
262
|
pid_row = tl.program_id(axis=1)
|
|
230
263
|
pid_col = tl.program_id(axis=2)
|
|
@@ -244,27 +277,27 @@ def kernel_blocksparse_row_wise_max(x,
|
|
|
244
277
|
rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
|
|
245
278
|
rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
246
279
|
|
|
247
|
-
if rev_idx_spa
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
blk_msk = (blk_idx >= 0 and blk_idx < x_b * x_b_s)
|
|
255
|
-
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
280
|
+
if rev_idx_spa >= 0:
|
|
281
|
+
blk_idx = ((pid_blk * x_b_s) +
|
|
282
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
283
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
284
|
+
blk_msk = (blk_idx >= 0 and
|
|
285
|
+
blk_idx < x_b * x_b_s)
|
|
286
|
+
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
256
287
|
|
|
257
|
-
|
|
288
|
+
buf = tl.reshape(tl.max(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
|
|
258
289
|
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
290
|
+
o_idx = (rev_idx_spa * o_b_s +
|
|
291
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
292
|
+
(tl.arange(0, 1))[None, :])
|
|
293
|
+
o_msk = (o_idx >= 0 and
|
|
294
|
+
o_idx < o_b * o_b_s)
|
|
295
|
+
tl.atomic_max(o + o_idx, buf, o_msk)
|
|
264
296
|
|
|
265
297
|
|
|
298
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
266
299
|
def row_wise_add(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
|
|
267
|
-
sparsity_block_size: int
|
|
300
|
+
sparsity_block_size: int) -> BlksprsTensor:
|
|
268
301
|
"""For each row in ``y`` adds the value to each value in the corresponding row of the block-sparse tensor ``x``.
|
|
269
302
|
|
|
270
303
|
Args:
|
|
@@ -272,7 +305,6 @@ def row_wise_add(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
|
|
|
272
305
|
sparsity_layout_x (Tensor): The sparsity layout of the block-sparse tensor.
|
|
273
306
|
y (BlksprsTensor): A block-sparse tensor in compressed form with only one value per row and a single column of sparse blocks.
|
|
274
307
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
275
|
-
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
276
308
|
|
|
277
309
|
Returns:
|
|
278
310
|
BlksprsTensor: The values of ``x`` with the first value of ``y`` in each row added to them as a block-sparse tensor in
|
|
@@ -284,9 +316,8 @@ def row_wise_add(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
|
|
|
284
316
|
validate_device(x)
|
|
285
317
|
validate_sparsity(sparsity_block_size, (x, sparsity_layout_x))
|
|
286
318
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
287
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
288
319
|
|
|
289
|
-
|
|
320
|
+
sparsity_lut_x = torch.nonzero(sparsity_layout_x).contiguous()
|
|
290
321
|
|
|
291
322
|
sparsity_layout_rwm, _ = torch.max(sparsity_layout_x, dim=-1, keepdim=True)
|
|
292
323
|
sparsity_layout_rwm_flat = sparsity_layout_rwm.reshape(-1)
|
|
@@ -294,60 +325,73 @@ def row_wise_add(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
|
|
|
294
325
|
(sparsity_layout_rwm_flat == 1) -
|
|
295
326
|
(1 * (sparsity_layout_rwm_flat == 0)))
|
|
296
327
|
|
|
297
|
-
validate_contiguous(sparsity_layout_x,
|
|
298
|
-
|
|
299
|
-
output = torch.empty_like(x)
|
|
300
|
-
|
|
301
|
-
x_b, x_r, x_c = x.size()
|
|
302
|
-
x_b_s, x_r_s, x_c_s = stride(x)
|
|
303
|
-
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
304
|
-
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
305
|
-
y_b, y_r, y_c = y.size()
|
|
306
|
-
y_b_s, y_r_s, y_c_s = stride(y)
|
|
307
|
-
s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_rwm.size()
|
|
308
|
-
s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = stride(sparsity_layout_rwm)
|
|
309
|
-
o_b, o_r, o_c = output.size()
|
|
310
|
-
o_b_s, o_r_s, o_c_s = stride(output)
|
|
311
|
-
|
|
312
|
-
if triton_block_size is None:
|
|
313
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
314
|
-
|
|
315
|
-
triton_grid = lambda meta: [o_b,
|
|
316
|
-
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
317
|
-
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
318
|
-
|
|
319
|
-
(kernel_blocksparse_row_wise_add[triton_grid]
|
|
320
|
-
(x,
|
|
321
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
322
|
-
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
323
|
-
y, y_b, y_b_s, y_r_s, y_c_s,
|
|
324
|
-
s_l_y_b, s_l_y_b_s, s_l_y_r_s,
|
|
325
|
-
sparsity_reverse_lut_rwm,
|
|
326
|
-
output,
|
|
327
|
-
o_b, o_b_s, o_r_s, o_c_s,
|
|
328
|
-
triton_block_size
|
|
329
|
-
))
|
|
328
|
+
validate_contiguous(sparsity_layout_x, sparsity_lut_x, sparsity_reverse_lut_rwm)
|
|
330
329
|
|
|
331
|
-
return BlksprsTensor(
|
|
330
|
+
return BlksprsTensor(row_wise_add_forward(x, sparsity_lut_x, sparsity_layout_rwm,
|
|
331
|
+
sparsity_reverse_lut_rwm, y, sparsity_block_size))
|
|
332
332
|
|
|
333
333
|
|
|
334
334
|
def row_wise_sub(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
|
|
335
|
-
sparsity_block_size: int
|
|
335
|
+
sparsity_block_size: int) -> BlksprsTensor:
|
|
336
336
|
"""Wrapper for ``row_wise_add`` with negated y.
|
|
337
337
|
|
|
338
338
|
"""
|
|
339
|
-
return row_wise_add(x, sparsity_layout_x, torch.neg(y), sparsity_block_size
|
|
340
|
-
|
|
341
|
-
|
|
339
|
+
return row_wise_add(x, sparsity_layout_x, torch.neg(y), sparsity_block_size)
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
@triton_op("blksprs::row_wise_add_forward", mutates_args={})
|
|
343
|
+
def row_wise_add_forward(x: Tensor, sparsity_lut_x: Tensor,
|
|
344
|
+
sparsity_layout_x_rwm: Tensor, sparsity_reverse_x_lut_rwm: Tensor,
|
|
345
|
+
y: Tensor, sparsity_block_size: int) -> Tensor:
|
|
346
|
+
with torch.no_grad():
|
|
347
|
+
output = torch.zeros_like(x)
|
|
348
|
+
|
|
349
|
+
x_b, x_r, x_c = x.size()
|
|
350
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
351
|
+
s_lut_r, s_lut_c = sparsity_lut_x.size()
|
|
352
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut_x)
|
|
353
|
+
y_b, y_r, y_c = y.size()
|
|
354
|
+
y_b_s, y_r_s, y_c_s = stride(y)
|
|
355
|
+
s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_x_rwm.size()
|
|
356
|
+
s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = stride(sparsity_layout_x_rwm)
|
|
357
|
+
o_b, o_r, o_c = output.size()
|
|
358
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
359
|
+
|
|
360
|
+
triton_grid = lambda meta: [o_b,
|
|
361
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
362
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
363
|
+
|
|
364
|
+
(wrap_triton(kernel_blocksparse_row_wise_add)[triton_grid]
|
|
365
|
+
(x,
|
|
366
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
367
|
+
sparsity_lut_x, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
368
|
+
y, y_b, y_b_s, y_r_s, y_c_s,
|
|
369
|
+
s_l_y_b, s_l_y_b_s, s_l_y_r_s,
|
|
370
|
+
sparsity_reverse_x_lut_rwm,
|
|
371
|
+
output,
|
|
372
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
373
|
+
sparsity_block_size))
|
|
374
|
+
|
|
375
|
+
return output
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
# noinspection PyUnusedLocal
|
|
379
|
+
@triton.autotune(
|
|
380
|
+
configs=get_autotune_configs(),
|
|
381
|
+
key=["sparsity_block_size"],
|
|
382
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
383
|
+
reset_to_zero=["o"]
|
|
384
|
+
)
|
|
342
385
|
@triton.jit
|
|
343
386
|
def kernel_blocksparse_row_wise_add(x,
|
|
344
387
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
345
|
-
|
|
388
|
+
s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
346
389
|
y, y_b, y_b_s, y_r_s, y_c_s,
|
|
347
390
|
s_l_y_b, s_l_y_b_s, s_l_y_r_s,
|
|
348
391
|
r_lut_y,
|
|
349
392
|
o,
|
|
350
393
|
o_b, o_b_s, o_r_s, o_c_s,
|
|
394
|
+
sparsity_block_size,
|
|
351
395
|
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
352
396
|
# Get triton block indices
|
|
353
397
|
pid_blk = tl.program_id(axis=0)
|
|
@@ -355,13 +399,13 @@ def kernel_blocksparse_row_wise_add(x,
|
|
|
355
399
|
pid_col = tl.program_id(axis=2)
|
|
356
400
|
|
|
357
401
|
# Get position of current sparsity block consisting of its batch and row index
|
|
358
|
-
spa_bat_idx = (pid_blk *
|
|
359
|
-
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx <
|
|
360
|
-
spa_bat = tl.load(
|
|
402
|
+
spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
|
|
403
|
+
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
|
|
404
|
+
spa_bat = tl.load(s_lut_x + spa_bat_idx, mask=spa_bat_msk)
|
|
361
405
|
|
|
362
|
-
spa_row_idx = (pid_blk *
|
|
363
|
-
spa_row_msk = (spa_row_idx >= 0 and spa_row_idx <
|
|
364
|
-
spa_row = tl.load(
|
|
406
|
+
spa_row_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
|
|
407
|
+
spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_x_r * s_lut_x_r_s)
|
|
408
|
+
spa_row = tl.load(s_lut_x + spa_row_idx, mask=spa_row_msk)
|
|
365
409
|
|
|
366
410
|
# Get reverse sparsity indices for s
|
|
367
411
|
rev_idx_spa_s_idx = (spa_bat * s_l_y_b_s +
|
|
@@ -377,14 +421,16 @@ def kernel_blocksparse_row_wise_add(x,
|
|
|
377
421
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
378
422
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
379
423
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
380
|
-
blk_x_msk = (blk_x_idx >= 0 and
|
|
424
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
425
|
+
blk_x_idx < x_b * x_b_s)
|
|
381
426
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
382
427
|
|
|
383
428
|
# Load sum block
|
|
384
429
|
blk_s_idx = (rev_idx_spa_s * y_b_s +
|
|
385
430
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
|
|
386
431
|
(tl.arange(0, 1) * y_c_s)[None, :])
|
|
387
|
-
blk_s_msk = (blk_s_idx >= 0 and
|
|
432
|
+
blk_s_msk = (blk_s_idx >= 0 and
|
|
433
|
+
blk_s_idx < y_b * y_b_s)
|
|
388
434
|
blk_s = tl.load(y + blk_s_idx, mask=blk_s_msk)
|
|
389
435
|
|
|
390
436
|
# Compute exp
|
|
@@ -394,5 +440,6 @@ def kernel_blocksparse_row_wise_add(x,
|
|
|
394
440
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
395
441
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
396
442
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
397
|
-
blk_o_msk = (blk_o_idx >= 0 and
|
|
443
|
+
blk_o_msk = (blk_o_idx >= 0 and
|
|
444
|
+
blk_o_idx < o_b * o_b_s)
|
|
398
445
|
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|