blksprs 1.10.2__py3-none-any.whl → 2.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +2 -6
- blksprs/layouting/distribution_layout.py +32 -25
- blksprs/layouting/sparsity_layout.py +65 -52
- blksprs/ops/conversion.py +423 -374
- blksprs/ops/distribution.py +403 -335
- blksprs/ops/flow.py +135 -83
- blksprs/ops/matmul.py +221 -187
- blksprs/ops/misc/broadcast_ops.py +53 -35
- blksprs/ops/misc/row_wise.py +151 -91
- blksprs/ops/partitioning.py +136 -89
- blksprs/ops/repeat.py +115 -108
- blksprs/ops/softmax.py +244 -208
- blksprs/ops/transpose.py +69 -131
- blksprs/utils/benchmarking.py +3 -3
- blksprs/utils/tools.py +31 -4
- blksprs/utils/validation.py +0 -14
- {blksprs-1.10.2.dist-info → blksprs-2.0rc1.dist-info}/METADATA +42 -36
- blksprs-2.0rc1.dist-info/RECORD +22 -0
- {blksprs-1.10.2.dist-info → blksprs-2.0rc1.dist-info}/WHEEL +1 -1
- blksprs/ops/misc/exp.py +0 -104
- blksprs/utils/layout_utils.py +0 -17
- blksprs-1.10.2.dist-info/RECORD +0 -24
- {blksprs-1.10.2.dist-info → blksprs-2.0rc1.dist-info}/top_level.txt +0 -0
blksprs/ops/misc/row_wise.py
CHANGED
|
@@ -1,16 +1,17 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import triton
|
|
3
3
|
from torch import Tensor
|
|
4
|
+
from torch._library.triton import wrap_triton, triton_op
|
|
4
5
|
from triton import language as tl
|
|
5
6
|
|
|
6
7
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
7
|
-
from blksprs.utils.tools import
|
|
8
|
+
from blksprs.utils.tools import stride, get_autotune_configs
|
|
8
9
|
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, validate_sparsity, \
|
|
9
|
-
validate_sparsity_block_size
|
|
10
|
+
validate_sparsity_block_size
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
def row_wise_sum(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
13
|
-
flag_slice_only: bool = False
|
|
14
|
+
flag_slice_only: bool = False) -> (BlksprsTensor, Tensor):
|
|
14
15
|
"""Computes the row-wise sum of a block-sparse tensor.
|
|
15
16
|
|
|
16
17
|
Returns a block-sparse tensor in compressed form with only one block per row, where the first entry contains the sum
|
|
@@ -25,7 +26,6 @@ def row_wise_sum(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
|
|
|
25
26
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
26
27
|
flag_slice_only (bool, optional): If set the output will be of shape ``[x.size(0), x.size(1), 1]``
|
|
27
28
|
(default ``False``).
|
|
28
|
-
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
29
29
|
|
|
30
30
|
Returns:
|
|
31
31
|
tuple[BlksprsTensor, Tensor]: A tuple containing a block-sparse tensor in compressed form containing the row-wise sum
|
|
@@ -39,7 +39,6 @@ def row_wise_sum(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
|
|
|
39
39
|
validate_device(x)
|
|
40
40
|
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
41
41
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
42
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
43
42
|
|
|
44
43
|
sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
|
|
45
44
|
|
|
@@ -54,11 +53,19 @@ def row_wise_sum(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
|
|
|
54
53
|
validate_contiguous(sparsity_layout, sparsity_lut,
|
|
55
54
|
sparsity_layout_output, sparsity_reverse_lut_output)
|
|
56
55
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
56
|
+
return BlksprsTensor(row_wise_sum_forward(
|
|
57
|
+
x, sparsity_lut, sparsity_layout_output, sparsity_reverse_lut_output,
|
|
58
|
+
sparsity_block_size, n_sparse_blocks_output, flag_slice_only)), sparsity_layout_output
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@triton_op("blksprs::row_wise_sum", mutates_args={})
|
|
62
|
+
def row_wise_sum_forward(x: Tensor, sparsity_lut: Tensor,
|
|
63
|
+
sparsity_layout_output: Tensor, sparsity_reverse_lut_output: Tensor,
|
|
64
|
+
sparsity_block_size: int, n_sparse_blocks_output: int,
|
|
65
|
+
flag_slice_only: bool = False) -> Tensor:
|
|
66
|
+
output = torch.zeros(
|
|
67
|
+
size=(n_sparse_blocks_output, sparsity_block_size, 1 if flag_slice_only else sparsity_block_size),
|
|
68
|
+
dtype=x.dtype, device=x.device)
|
|
62
69
|
|
|
63
70
|
x_b, x_r, x_c = x.size()
|
|
64
71
|
x_b_s, x_r_s, x_c_s = stride(x)
|
|
@@ -69,14 +76,11 @@ def row_wise_sum(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
|
|
|
69
76
|
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
|
|
70
77
|
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
|
|
71
78
|
|
|
72
|
-
if triton_block_size is None:
|
|
73
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
74
|
-
|
|
75
79
|
triton_grid = lambda meta: [x_b,
|
|
76
80
|
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
77
81
|
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
78
82
|
|
|
79
|
-
(
|
|
83
|
+
(wrap_triton(row_wise_sum_kernel)[triton_grid]
|
|
80
84
|
(x,
|
|
81
85
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
82
86
|
sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
@@ -84,24 +88,34 @@ def row_wise_sum(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
|
|
|
84
88
|
o_b, o_b_s, o_r_s,
|
|
85
89
|
s_l_o_b, s_l_o_b_s, s_l_o_r_s,
|
|
86
90
|
sparsity_reverse_lut_output,
|
|
87
|
-
|
|
91
|
+
sparsity_block_size))
|
|
88
92
|
|
|
89
|
-
return
|
|
93
|
+
return output
|
|
90
94
|
|
|
91
95
|
|
|
96
|
+
@triton.autotune(
|
|
97
|
+
configs=get_autotune_configs(),
|
|
98
|
+
key=[],
|
|
99
|
+
reset_to_zero=["o"]
|
|
100
|
+
)
|
|
92
101
|
@triton.jit
|
|
93
|
-
def
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
102
|
+
def row_wise_sum_kernel(x,
|
|
103
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
104
|
+
s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
105
|
+
o,
|
|
106
|
+
o_b, o_b_s, o_r_s,
|
|
107
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s,
|
|
108
|
+
r_lut_o,
|
|
109
|
+
sparsity_block_size,
|
|
110
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
111
|
+
# Get triton block indices
|
|
101
112
|
pid_blk = tl.program_id(axis=0)
|
|
102
113
|
pid_row = tl.program_id(axis=1)
|
|
103
114
|
pid_col = tl.program_id(axis=2)
|
|
104
115
|
|
|
116
|
+
# Get valid triton block size
|
|
117
|
+
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
118
|
+
|
|
105
119
|
# Get position of current sparsity block consisting of its batch and row index
|
|
106
120
|
spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
|
|
107
121
|
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
|
|
@@ -122,22 +136,28 @@ def kernel_blocksparse_row_wise_sum(x,
|
|
|
122
136
|
return
|
|
123
137
|
|
|
124
138
|
blk_idx = ((pid_blk * x_b_s) +
|
|
125
|
-
((pid_row *
|
|
126
|
-
((pid_col *
|
|
127
|
-
blk_msk = (blk_idx >= 0 and
|
|
139
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
140
|
+
((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
141
|
+
blk_msk = ((blk_idx >= 0 and
|
|
142
|
+
blk_idx < x_b * x_b_s) and
|
|
143
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
144
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
128
145
|
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
129
146
|
|
|
130
147
|
buf = tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
|
|
131
148
|
|
|
132
149
|
o_idx = (rev_idx_spa * o_b_s +
|
|
133
|
-
((pid_row *
|
|
150
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
134
151
|
(tl.arange(0, 1))[None, :])
|
|
135
|
-
o_msk = (o_idx >= 0 and
|
|
152
|
+
o_msk = ((o_idx >= 0 and
|
|
153
|
+
o_idx < o_b * o_b_s) and
|
|
154
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
155
|
+
tl.arange(0, 1)[None, :] < val_tbs))
|
|
136
156
|
tl.atomic_add(o + o_idx, buf, o_msk)
|
|
137
157
|
|
|
138
158
|
|
|
139
159
|
def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
140
|
-
flag_slice_only: bool = False
|
|
160
|
+
flag_slice_only: bool = False) -> (BlksprsTensor, Tensor):
|
|
141
161
|
"""Computes the row-wise max of a block-sparse tensor.
|
|
142
162
|
|
|
143
163
|
Returns a block-sparse tensor in compressed form with only one block per row, where the first entry contains the
|
|
@@ -152,7 +172,6 @@ def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
|
|
|
152
172
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
153
173
|
flag_slice_only (bool, optional): If set the output will be of shape ``[x.size(0), x.size(1), 1]``
|
|
154
174
|
(default ``False``).
|
|
155
|
-
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
156
175
|
|
|
157
176
|
Returns:
|
|
158
177
|
tuple[BlksprsTensor, Tensor]: A tuple containing a block-sparse tensor in compressed form containing the row-wise max
|
|
@@ -166,7 +185,6 @@ def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
|
|
|
166
185
|
validate_device(x)
|
|
167
186
|
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
168
187
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
169
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
170
188
|
|
|
171
189
|
sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
|
|
172
190
|
|
|
@@ -181,6 +199,16 @@ def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
|
|
|
181
199
|
validate_contiguous(sparsity_layout, sparsity_lut,
|
|
182
200
|
sparsity_layout_output, sparsity_reverse_lut_output)
|
|
183
201
|
|
|
202
|
+
return BlksprsTensor(
|
|
203
|
+
row_wise_max_forward(x, sparsity_lut, sparsity_layout_output, sparsity_reverse_lut_output, sparsity_block_size,
|
|
204
|
+
n_sparse_blocks_output, flag_slice_only)), sparsity_layout_output
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
@triton_op("blksprs::row_wise_max", mutates_args={})
|
|
208
|
+
def row_wise_max_forward(x: Tensor, sparsity_lut: Tensor,
|
|
209
|
+
sparsity_layout_output: Tensor, sparsity_reverse_lut_output: Tensor,
|
|
210
|
+
sparsity_block_size: int, n_sparse_blocks_output: int,
|
|
211
|
+
flag_slice_only: bool = False) -> Tensor:
|
|
184
212
|
output = torch.full(size=(n_sparse_blocks_output,
|
|
185
213
|
sparsity_block_size,
|
|
186
214
|
1 if flag_slice_only else sparsity_block_size),
|
|
@@ -196,14 +224,11 @@ def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
|
|
|
196
224
|
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
|
|
197
225
|
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
|
|
198
226
|
|
|
199
|
-
if triton_block_size is None:
|
|
200
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
201
|
-
|
|
202
227
|
triton_grid = lambda meta: [x_b,
|
|
203
228
|
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
204
229
|
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
205
230
|
|
|
206
|
-
(
|
|
231
|
+
(wrap_triton(row_wise_max_kernel)[triton_grid]
|
|
207
232
|
(x,
|
|
208
233
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
209
234
|
sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
@@ -211,24 +236,34 @@ def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
|
|
|
211
236
|
o_b, o_b_s, o_r_s,
|
|
212
237
|
s_l_o_b, s_l_o_b_s, s_l_o_r_s,
|
|
213
238
|
sparsity_reverse_lut_output,
|
|
214
|
-
|
|
239
|
+
sparsity_block_size))
|
|
215
240
|
|
|
216
|
-
return
|
|
241
|
+
return output
|
|
217
242
|
|
|
218
243
|
|
|
244
|
+
@triton.autotune(
|
|
245
|
+
configs=get_autotune_configs(),
|
|
246
|
+
key=[],
|
|
247
|
+
restore_value=["o"]
|
|
248
|
+
)
|
|
219
249
|
@triton.jit
|
|
220
|
-
def
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
250
|
+
def row_wise_max_kernel(x,
|
|
251
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
252
|
+
s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
253
|
+
o,
|
|
254
|
+
o_b, o_b_s, o_r_s,
|
|
255
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s,
|
|
256
|
+
r_lut_o,
|
|
257
|
+
sparsity_block_size,
|
|
258
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
259
|
+
# Get triton block indices
|
|
228
260
|
pid_blk = tl.program_id(axis=0)
|
|
229
261
|
pid_row = tl.program_id(axis=1)
|
|
230
262
|
pid_col = tl.program_id(axis=2)
|
|
231
263
|
|
|
264
|
+
# Get valid triton block size
|
|
265
|
+
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
266
|
+
|
|
232
267
|
# Get position of current sparsity block consisting of its batch and row index
|
|
233
268
|
spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
|
|
234
269
|
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
|
|
@@ -249,22 +284,28 @@ def kernel_blocksparse_row_wise_max(x,
|
|
|
249
284
|
return
|
|
250
285
|
|
|
251
286
|
blk_idx = ((pid_blk * x_b_s) +
|
|
252
|
-
((pid_row *
|
|
253
|
-
((pid_col *
|
|
254
|
-
blk_msk = (blk_idx >= 0 and
|
|
287
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
288
|
+
((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
289
|
+
blk_msk = ((blk_idx >= 0 and
|
|
290
|
+
blk_idx < x_b * x_b_s) and
|
|
291
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
292
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
255
293
|
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
256
294
|
|
|
257
295
|
buf = tl.reshape(tl.max(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
|
|
258
296
|
|
|
259
297
|
o_idx = (rev_idx_spa * o_b_s +
|
|
260
|
-
((pid_row *
|
|
298
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
261
299
|
(tl.arange(0, 1))[None, :])
|
|
262
|
-
o_msk = (o_idx >= 0 and
|
|
300
|
+
o_msk = ((o_idx >= 0 and
|
|
301
|
+
o_idx < o_b * o_b_s) and
|
|
302
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
303
|
+
tl.arange(0, 1)[None, :] < val_tbs))
|
|
263
304
|
tl.atomic_max(o + o_idx, buf, o_msk)
|
|
264
305
|
|
|
265
306
|
|
|
266
307
|
def row_wise_add(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
|
|
267
|
-
sparsity_block_size: int
|
|
308
|
+
sparsity_block_size: int) -> BlksprsTensor:
|
|
268
309
|
"""For each row in ``y`` adds the value to each value in the corresponding row of the block-sparse tensor ``x``.
|
|
269
310
|
|
|
270
311
|
Args:
|
|
@@ -272,7 +313,6 @@ def row_wise_add(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
|
|
|
272
313
|
sparsity_layout_x (Tensor): The sparsity layout of the block-sparse tensor.
|
|
273
314
|
y (BlksprsTensor): A block-sparse tensor in compressed form with only one value per row and a single column of sparse blocks.
|
|
274
315
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
275
|
-
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
276
316
|
|
|
277
317
|
Returns:
|
|
278
318
|
BlksprsTensor: The values of ``x`` with the first value of ``y`` in each row added to them as a block-sparse tensor in
|
|
@@ -284,9 +324,8 @@ def row_wise_add(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
|
|
|
284
324
|
validate_device(x)
|
|
285
325
|
validate_sparsity(sparsity_block_size, (x, sparsity_layout_x))
|
|
286
326
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
287
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
288
327
|
|
|
289
|
-
|
|
328
|
+
sparsity_lut_x = torch.nonzero(sparsity_layout_x).contiguous()
|
|
290
329
|
|
|
291
330
|
sparsity_layout_rwm, _ = torch.max(sparsity_layout_x, dim=-1, keepdim=True)
|
|
292
331
|
sparsity_layout_rwm_flat = sparsity_layout_rwm.reshape(-1)
|
|
@@ -294,24 +333,37 @@ def row_wise_add(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
|
|
|
294
333
|
(sparsity_layout_rwm_flat == 1) -
|
|
295
334
|
(1 * (sparsity_layout_rwm_flat == 0)))
|
|
296
335
|
|
|
297
|
-
validate_contiguous(sparsity_layout_x,
|
|
336
|
+
validate_contiguous(sparsity_layout_x, sparsity_lut_x, sparsity_reverse_lut_rwm)
|
|
337
|
+
|
|
338
|
+
return BlksprsTensor(row_wise_add_forward(x, sparsity_lut_x, sparsity_layout_rwm,
|
|
339
|
+
sparsity_reverse_lut_rwm, y, sparsity_block_size))
|
|
298
340
|
|
|
341
|
+
|
|
342
|
+
def row_wise_sub(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
|
|
343
|
+
sparsity_block_size: int) -> BlksprsTensor:
|
|
344
|
+
"""Wrapper for ``row_wise_add`` with negated y.
|
|
345
|
+
|
|
346
|
+
"""
|
|
347
|
+
return row_wise_add(x, sparsity_layout_x, torch.neg(y), sparsity_block_size)
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
@triton_op("blksprs::row_wise_add", mutates_args={})
|
|
351
|
+
def row_wise_add_forward(x: Tensor, sparsity_lut_x: Tensor,
|
|
352
|
+
sparsity_layout_x_rwm: Tensor, sparsity_reverse_x_lut_rwm: Tensor,
|
|
353
|
+
y: Tensor, sparsity_block_size: int) -> Tensor:
|
|
299
354
|
output = torch.empty_like(x)
|
|
300
355
|
|
|
301
356
|
x_b, x_r, x_c = x.size()
|
|
302
357
|
x_b_s, x_r_s, x_c_s = stride(x)
|
|
303
|
-
s_lut_r, s_lut_c =
|
|
304
|
-
s_lut_r_s, s_lut_c_s = stride(
|
|
358
|
+
s_lut_r, s_lut_c = sparsity_lut_x.size()
|
|
359
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut_x)
|
|
305
360
|
y_b, y_r, y_c = y.size()
|
|
306
361
|
y_b_s, y_r_s, y_c_s = stride(y)
|
|
307
|
-
s_l_y_b, s_l_y_r, s_l_y_c =
|
|
308
|
-
s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = stride(
|
|
362
|
+
s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_x_rwm.size()
|
|
363
|
+
s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = stride(sparsity_layout_x_rwm)
|
|
309
364
|
o_b, o_r, o_c = output.size()
|
|
310
365
|
o_b_s, o_r_s, o_c_s = stride(output)
|
|
311
366
|
|
|
312
|
-
if triton_block_size is None:
|
|
313
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
314
|
-
|
|
315
367
|
triton_grid = lambda meta: [o_b,
|
|
316
368
|
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
317
369
|
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
@@ -319,49 +371,48 @@ def row_wise_add(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
|
|
|
319
371
|
(kernel_blocksparse_row_wise_add[triton_grid]
|
|
320
372
|
(x,
|
|
321
373
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
322
|
-
|
|
374
|
+
sparsity_lut_x, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
323
375
|
y, y_b, y_b_s, y_r_s, y_c_s,
|
|
324
376
|
s_l_y_b, s_l_y_b_s, s_l_y_r_s,
|
|
325
|
-
|
|
377
|
+
sparsity_reverse_x_lut_rwm,
|
|
326
378
|
output,
|
|
327
379
|
o_b, o_b_s, o_r_s, o_c_s,
|
|
328
|
-
|
|
329
|
-
))
|
|
330
|
-
|
|
331
|
-
return BlksprsTensor(output)
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
def row_wise_sub(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
|
|
335
|
-
sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
|
|
336
|
-
"""Wrapper for ``row_wise_add`` with negated y.
|
|
380
|
+
sparsity_block_size))
|
|
337
381
|
|
|
338
|
-
|
|
339
|
-
return row_wise_add(x, sparsity_layout_x, torch.neg(y), sparsity_block_size, triton_block_size)
|
|
382
|
+
return output
|
|
340
383
|
|
|
341
384
|
|
|
385
|
+
@triton.autotune(
|
|
386
|
+
configs=get_autotune_configs(),
|
|
387
|
+
key=[]
|
|
388
|
+
)
|
|
342
389
|
@triton.jit
|
|
343
390
|
def kernel_blocksparse_row_wise_add(x,
|
|
344
391
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
345
|
-
|
|
392
|
+
s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
346
393
|
y, y_b, y_b_s, y_r_s, y_c_s,
|
|
347
394
|
s_l_y_b, s_l_y_b_s, s_l_y_r_s,
|
|
348
395
|
r_lut_y,
|
|
349
396
|
o,
|
|
350
397
|
o_b, o_b_s, o_r_s, o_c_s,
|
|
398
|
+
sparsity_block_size,
|
|
351
399
|
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
352
400
|
# Get triton block indices
|
|
353
401
|
pid_blk = tl.program_id(axis=0)
|
|
354
402
|
pid_row = tl.program_id(axis=1)
|
|
355
403
|
pid_col = tl.program_id(axis=2)
|
|
356
404
|
|
|
405
|
+
# Get valid triton block size
|
|
406
|
+
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
407
|
+
|
|
357
408
|
# Get position of current sparsity block consisting of its batch and row index
|
|
358
|
-
spa_bat_idx = (pid_blk *
|
|
359
|
-
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx <
|
|
360
|
-
spa_bat = tl.load(
|
|
409
|
+
spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
|
|
410
|
+
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
|
|
411
|
+
spa_bat = tl.load(s_lut_x + spa_bat_idx, mask=spa_bat_msk)
|
|
361
412
|
|
|
362
|
-
spa_row_idx = (pid_blk *
|
|
363
|
-
spa_row_msk = (spa_row_idx >= 0 and spa_row_idx <
|
|
364
|
-
spa_row = tl.load(
|
|
413
|
+
spa_row_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
|
|
414
|
+
spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_x_r * s_lut_x_r_s)
|
|
415
|
+
spa_row = tl.load(s_lut_x + spa_row_idx, mask=spa_row_msk)
|
|
365
416
|
|
|
366
417
|
# Get reverse sparsity indices for s
|
|
367
418
|
rev_idx_spa_s_idx = (spa_bat * s_l_y_b_s +
|
|
@@ -375,16 +426,22 @@ def kernel_blocksparse_row_wise_add(x,
|
|
|
375
426
|
|
|
376
427
|
# Load x block
|
|
377
428
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
378
|
-
((pid_row *
|
|
379
|
-
((pid_col *
|
|
380
|
-
blk_x_msk = (blk_x_idx >= 0 and
|
|
429
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
430
|
+
((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
431
|
+
blk_x_msk = ((blk_x_idx >= 0 and
|
|
432
|
+
blk_x_idx < x_b * x_b_s) and
|
|
433
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
434
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
381
435
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
382
436
|
|
|
383
437
|
# Load sum block
|
|
384
438
|
blk_s_idx = (rev_idx_spa_s * y_b_s +
|
|
385
|
-
((pid_row *
|
|
439
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
|
|
386
440
|
(tl.arange(0, 1) * y_c_s)[None, :])
|
|
387
|
-
blk_s_msk = (blk_s_idx >= 0 and
|
|
441
|
+
blk_s_msk = ((blk_s_idx >= 0 and
|
|
442
|
+
blk_s_idx < y_b * y_b_s) and
|
|
443
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
444
|
+
tl.arange(0, 1)[None, :] < val_tbs))
|
|
388
445
|
blk_s = tl.load(y + blk_s_idx, mask=blk_s_msk)
|
|
389
446
|
|
|
390
447
|
# Compute exp
|
|
@@ -392,7 +449,10 @@ def kernel_blocksparse_row_wise_add(x,
|
|
|
392
449
|
|
|
393
450
|
# Store block
|
|
394
451
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
395
|
-
((pid_row *
|
|
396
|
-
((pid_col *
|
|
397
|
-
blk_o_msk = (blk_o_idx >= 0 and
|
|
452
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
453
|
+
((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
454
|
+
blk_o_msk = ((blk_o_idx >= 0 and
|
|
455
|
+
blk_o_idx < o_b * o_b_s) and
|
|
456
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
457
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
398
458
|
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|