blksprs 2.0rc6__py3-none-any.whl → 2.0rc8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +1 -0
- blksprs/layouting/distribution_layout.py +39 -26
- blksprs/layouting/sparsity_layout.py +58 -45
- blksprs/ops/conversion.py +86 -84
- blksprs/ops/distribution.py +81 -79
- blksprs/ops/flow.py +64 -60
- blksprs/ops/matmul.py +50 -55
- blksprs/ops/misc/broadcast_ops.py +29 -27
- blksprs/ops/misc/row_wise.py +134 -132
- blksprs/ops/partitioning.py +12 -10
- blksprs/ops/repeat.py +6 -5
- blksprs/ops/softmax.py +55 -47
- blksprs/ops/transpose.py +8 -7
- blksprs/utils/autotuning.py +10 -10
- blksprs/utils/processing.py +0 -1
- blksprs/utils/tools.py +8 -9
- {blksprs-2.0rc6.dist-info → blksprs-2.0rc8.dist-info}/METADATA +7 -3
- blksprs-2.0rc8.dist-info/RECORD +23 -0
- {blksprs-2.0rc6.dist-info → blksprs-2.0rc8.dist-info}/WHEEL +1 -1
- blksprs-2.0rc6.dist-info/RECORD +0 -23
- {blksprs-2.0rc6.dist-info → blksprs-2.0rc8.dist-info}/top_level.txt +0 -0
blksprs/ops/misc/row_wise.py
CHANGED
|
@@ -4,9 +4,9 @@ from torch import Tensor
|
|
|
4
4
|
from torch._library.triton import wrap_triton, triton_op
|
|
5
5
|
from triton import language as tl
|
|
6
6
|
|
|
7
|
-
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
8
|
-
from blksprs.utils.tools import stride, get_autocast_min_val
|
|
9
7
|
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
8
|
+
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
9
|
+
from blksprs.utils.tools import stride
|
|
10
10
|
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, validate_sparsity, \
|
|
11
11
|
validate_sparsity_block_size
|
|
12
12
|
|
|
@@ -60,41 +60,43 @@ def row_wise_sum(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
|
|
|
60
60
|
sparsity_block_size, n_sparse_blocks_output, flag_slice_only)), sparsity_layout_output
|
|
61
61
|
|
|
62
62
|
|
|
63
|
-
@triton_op("blksprs::
|
|
63
|
+
@triton_op("blksprs::row_wise_sum_forward", mutates_args={})
|
|
64
64
|
def row_wise_sum_forward(x: Tensor, sparsity_lut: Tensor,
|
|
65
65
|
sparsity_layout_output: Tensor, sparsity_reverse_lut_output: Tensor,
|
|
66
66
|
sparsity_block_size: int, n_sparse_blocks_output: int,
|
|
67
67
|
flag_slice_only: bool = False) -> Tensor:
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
68
|
+
with torch.no_grad():
|
|
69
|
+
output = torch.zeros(
|
|
70
|
+
size=(n_sparse_blocks_output, sparsity_block_size, 1 if flag_slice_only else sparsity_block_size),
|
|
71
|
+
dtype=x.dtype, device=x.device)
|
|
72
|
+
|
|
73
|
+
x_b, x_r, x_c = x.size()
|
|
74
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
75
|
+
s_lut_x_r, s_lut_x_c = sparsity_lut.size()
|
|
76
|
+
s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut)
|
|
77
|
+
o_b, o_r, o_c = output.size()
|
|
78
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
79
|
+
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
|
|
80
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
|
|
81
|
+
|
|
82
|
+
triton_grid = lambda meta: [x_b,
|
|
83
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
84
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
85
|
+
|
|
86
|
+
(wrap_triton(row_wise_sum_kernel)[triton_grid]
|
|
87
|
+
(x,
|
|
88
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
89
|
+
sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
90
|
+
output,
|
|
91
|
+
o_b, o_b_s, o_r_s,
|
|
92
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s,
|
|
93
|
+
sparsity_reverse_lut_output,
|
|
94
|
+
sparsity_block_size))
|
|
95
|
+
|
|
96
|
+
return output
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
# noinspection PyUnusedLocal
|
|
98
100
|
@triton.autotune(
|
|
99
101
|
configs=get_autotune_configs(),
|
|
100
102
|
key=["sparsity_block_size"],
|
|
@@ -131,25 +133,22 @@ def row_wise_sum_kernel(x,
|
|
|
131
133
|
rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
|
|
132
134
|
rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
133
135
|
|
|
134
|
-
if rev_idx_spa
|
|
135
|
-
|
|
136
|
-
|
|
136
|
+
if rev_idx_spa >= 0:
|
|
137
|
+
blk_idx = ((pid_blk * x_b_s) +
|
|
138
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
139
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
140
|
+
blk_msk = (blk_idx >= 0 and
|
|
141
|
+
blk_idx < x_b * x_b_s)
|
|
142
|
+
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
137
143
|
|
|
138
|
-
|
|
139
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
140
|
-
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
141
|
-
blk_msk = (blk_idx >= 0 and
|
|
142
|
-
blk_idx < x_b * x_b_s)
|
|
143
|
-
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
144
|
+
buf = tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
|
|
144
145
|
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
o_idx < o_b * o_b_s)
|
|
152
|
-
tl.atomic_add(o + o_idx, buf, o_msk)
|
|
146
|
+
o_idx = (rev_idx_spa * o_b_s +
|
|
147
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
148
|
+
(tl.arange(0, 1))[None, :])
|
|
149
|
+
o_msk = (o_idx >= 0 and
|
|
150
|
+
o_idx < o_b * o_b_s)
|
|
151
|
+
tl.atomic_add(o + o_idx, buf, o_msk)
|
|
153
152
|
|
|
154
153
|
|
|
155
154
|
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
@@ -175,6 +174,8 @@ def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
|
|
|
175
174
|
of the input and the sparsity layout of the output tensor.
|
|
176
175
|
|
|
177
176
|
"""
|
|
177
|
+
# TODO Fix for triton bug, see https://github.com/triton-lang/triton/issues/6376, should be fixed with the upcoming 3.4.0 release
|
|
178
|
+
x = torch.where(x == -0.0, torch.tensor(0.0), x)
|
|
178
179
|
x = x.contiguous()
|
|
179
180
|
|
|
180
181
|
validate_dimensions(x)
|
|
@@ -201,43 +202,45 @@ def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
|
|
|
201
202
|
n_sparse_blocks_output, flag_slice_only)), sparsity_layout_output
|
|
202
203
|
|
|
203
204
|
|
|
204
|
-
@triton_op("blksprs::
|
|
205
|
+
@triton_op("blksprs::row_wise_max_forward", mutates_args={})
|
|
205
206
|
def row_wise_max_forward(x: Tensor, sparsity_lut: Tensor,
|
|
206
207
|
sparsity_layout_output: Tensor, sparsity_reverse_lut_output: Tensor,
|
|
207
208
|
sparsity_block_size: int, n_sparse_blocks_output: int,
|
|
208
209
|
flag_slice_only: bool = False) -> Tensor:
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
210
|
+
with torch.no_grad():
|
|
211
|
+
output = torch.full(size=(n_sparse_blocks_output,
|
|
212
|
+
sparsity_block_size,
|
|
213
|
+
1 if flag_slice_only else sparsity_block_size),
|
|
214
|
+
fill_value=torch.finfo(x.dtype).min,
|
|
215
|
+
device=x.device)
|
|
216
|
+
|
|
217
|
+
x_b, x_r, x_c = x.size()
|
|
218
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
219
|
+
s_lut_x_r, s_lut_x_c = sparsity_lut.size()
|
|
220
|
+
s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut)
|
|
221
|
+
o_b, o_r, o_c = output.size()
|
|
222
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
223
|
+
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
|
|
224
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
|
|
225
|
+
|
|
226
|
+
triton_grid = lambda meta: [x_b,
|
|
227
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
228
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
229
|
+
|
|
230
|
+
(wrap_triton(row_wise_max_kernel)[triton_grid]
|
|
231
|
+
(x,
|
|
232
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
233
|
+
sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
234
|
+
output,
|
|
235
|
+
o_b, o_b_s, o_r_s,
|
|
236
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s,
|
|
237
|
+
sparsity_reverse_lut_output,
|
|
238
|
+
sparsity_block_size))
|
|
239
|
+
|
|
240
|
+
return output
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
# noinspection PyUnusedLocal
|
|
241
244
|
@triton.autotune(
|
|
242
245
|
configs=get_autotune_configs(),
|
|
243
246
|
key=["sparsity_block_size"],
|
|
@@ -274,25 +277,22 @@ def row_wise_max_kernel(x,
|
|
|
274
277
|
rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
|
|
275
278
|
rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
276
279
|
|
|
277
|
-
if rev_idx_spa
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
blk_msk = (blk_idx >= 0 and
|
|
285
|
-
blk_idx < x_b * x_b_s)
|
|
286
|
-
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
280
|
+
if rev_idx_spa >= 0:
|
|
281
|
+
blk_idx = ((pid_blk * x_b_s) +
|
|
282
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
283
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
284
|
+
blk_msk = (blk_idx >= 0 and
|
|
285
|
+
blk_idx < x_b * x_b_s)
|
|
286
|
+
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
287
287
|
|
|
288
|
-
|
|
288
|
+
buf = tl.reshape(tl.max(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
|
|
289
289
|
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
290
|
+
o_idx = (rev_idx_spa * o_b_s +
|
|
291
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
292
|
+
(tl.arange(0, 1))[None, :])
|
|
293
|
+
o_msk = (o_idx >= 0 and
|
|
294
|
+
o_idx < o_b * o_b_s)
|
|
295
|
+
tl.atomic_max(o + o_idx, buf, o_msk)
|
|
296
296
|
|
|
297
297
|
|
|
298
298
|
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
@@ -339,41 +339,43 @@ def row_wise_sub(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
|
|
|
339
339
|
return row_wise_add(x, sparsity_layout_x, torch.neg(y), sparsity_block_size)
|
|
340
340
|
|
|
341
341
|
|
|
342
|
-
@triton_op("blksprs::
|
|
342
|
+
@triton_op("blksprs::row_wise_add_forward", mutates_args={})
|
|
343
343
|
def row_wise_add_forward(x: Tensor, sparsity_lut_x: Tensor,
|
|
344
344
|
sparsity_layout_x_rwm: Tensor, sparsity_reverse_x_lut_rwm: Tensor,
|
|
345
345
|
y: Tensor, sparsity_block_size: int) -> Tensor:
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
346
|
+
with torch.no_grad():
|
|
347
|
+
output = torch.zeros_like(x)
|
|
348
|
+
|
|
349
|
+
x_b, x_r, x_c = x.size()
|
|
350
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
351
|
+
s_lut_r, s_lut_c = sparsity_lut_x.size()
|
|
352
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut_x)
|
|
353
|
+
y_b, y_r, y_c = y.size()
|
|
354
|
+
y_b_s, y_r_s, y_c_s = stride(y)
|
|
355
|
+
s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_x_rwm.size()
|
|
356
|
+
s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = stride(sparsity_layout_x_rwm)
|
|
357
|
+
o_b, o_r, o_c = output.size()
|
|
358
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
359
|
+
|
|
360
|
+
triton_grid = lambda meta: [o_b,
|
|
361
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
362
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
363
|
+
|
|
364
|
+
(wrap_triton(kernel_blocksparse_row_wise_add)[triton_grid]
|
|
365
|
+
(x,
|
|
366
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
367
|
+
sparsity_lut_x, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
368
|
+
y, y_b, y_b_s, y_r_s, y_c_s,
|
|
369
|
+
s_l_y_b, s_l_y_b_s, s_l_y_r_s,
|
|
370
|
+
sparsity_reverse_x_lut_rwm,
|
|
371
|
+
output,
|
|
372
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
373
|
+
sparsity_block_size))
|
|
374
|
+
|
|
375
|
+
return output
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
# noinspection PyUnusedLocal
|
|
377
379
|
@triton.autotune(
|
|
378
380
|
configs=get_autotune_configs(),
|
|
379
381
|
key=["sparsity_block_size"],
|
blksprs/ops/partitioning.py
CHANGED
|
@@ -46,14 +46,15 @@ def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
|
46
46
|
partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_output"]
|
|
47
47
|
|
|
48
48
|
|
|
49
|
-
@triton_op("blksprs::
|
|
49
|
+
@triton_op("blksprs::split_forward", mutates_args={})
|
|
50
50
|
def split_forward(x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
51
51
|
_: int, __: int, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
52
|
-
|
|
53
|
-
|
|
52
|
+
with torch.no_grad():
|
|
53
|
+
return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
54
|
+
n_sparse_blocks)
|
|
54
55
|
|
|
55
56
|
|
|
56
|
-
def
|
|
57
|
+
def split_wrapper_backward(ctx, grad_output):
|
|
57
58
|
sparsity_layout = ctx.saved_tensors[0]
|
|
58
59
|
num_partitions = ctx.num_partitions
|
|
59
60
|
dim = ctx.dim
|
|
@@ -109,7 +110,7 @@ def split_setup_context(ctx, inputs, output):
|
|
|
109
110
|
ctx.sparsity_block_size = sparsity_block_size
|
|
110
111
|
|
|
111
112
|
|
|
112
|
-
split_forward.register_autograd(
|
|
113
|
+
split_forward.register_autograd(split_wrapper_backward, setup_context=split_setup_context)
|
|
113
114
|
|
|
114
115
|
|
|
115
116
|
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
@@ -150,14 +151,15 @@ def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
|
150
151
|
partitions, adjusted_dim, sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_output"]
|
|
151
152
|
|
|
152
153
|
|
|
153
|
-
@triton_op("blksprs::
|
|
154
|
+
@triton_op("blksprs::merge_forward", mutates_args={})
|
|
154
155
|
def merge_forward(x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
155
156
|
_: int, __: int, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
156
|
-
|
|
157
|
-
|
|
157
|
+
with torch.no_grad():
|
|
158
|
+
return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
159
|
+
n_sparse_blocks)
|
|
158
160
|
|
|
159
161
|
|
|
160
|
-
def
|
|
162
|
+
def merge_wrapper_backward(ctx, grad_output):
|
|
161
163
|
sparsity_layout = ctx.saved_tensors[0]
|
|
162
164
|
num_partitions = ctx.num_partitions
|
|
163
165
|
dim = ctx.dim
|
|
@@ -216,4 +218,4 @@ def merge_setup_context(ctx, inputs, output):
|
|
|
216
218
|
ctx.sparsity_block_size = sparsity_block_size
|
|
217
219
|
|
|
218
220
|
|
|
219
|
-
merge_forward.register_autograd(
|
|
221
|
+
merge_forward.register_autograd(merge_wrapper_backward, setup_context=merge_setup_context)
|
blksprs/ops/repeat.py
CHANGED
|
@@ -92,15 +92,16 @@ def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
|
|
|
92
92
|
lut["sparsity_reverse_lut"], sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_o"]
|
|
93
93
|
|
|
94
94
|
|
|
95
|
-
@triton_op("blksprs::
|
|
95
|
+
@triton_op("blksprs::repeat_forward", mutates_args={})
|
|
96
96
|
def repeat_forward(x: Tensor, _: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor,
|
|
97
97
|
sparsity_reverse_lut: Tensor,
|
|
98
98
|
sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
99
|
-
|
|
100
|
-
|
|
99
|
+
with torch.no_grad():
|
|
100
|
+
return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
101
|
+
n_sparse_blocks)
|
|
101
102
|
|
|
102
103
|
|
|
103
|
-
def
|
|
104
|
+
def repeat_wrapper_backward(ctx, grad_output):
|
|
104
105
|
sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut = ctx.saved_tensors
|
|
105
106
|
sparsity_block_size = ctx.sparsity_block_size
|
|
106
107
|
n_sparse_blocks = torch.sum(sparsity_layout_x.to(torch.int)).item()
|
|
@@ -190,4 +191,4 @@ def repeat_setup_context(ctx, inputs, output):
|
|
|
190
191
|
ctx.sparsity_block_size = sparsity_block_size
|
|
191
192
|
|
|
192
193
|
|
|
193
|
-
repeat_forward.register_autograd(
|
|
194
|
+
repeat_forward.register_autograd(repeat_wrapper_backward, setup_context=repeat_setup_context)
|
blksprs/ops/softmax.py
CHANGED
|
@@ -47,19 +47,13 @@ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
|
47
47
|
sparsity_block_size))
|
|
48
48
|
|
|
49
49
|
|
|
50
|
-
@triton_op("blksprs::
|
|
50
|
+
@triton_op("blksprs::softmax_forward", mutates_args={})
|
|
51
51
|
def softmax_forward(x: Tensor, sparsity_layout: Tensor,
|
|
52
52
|
sparsity_lut: Tensor,
|
|
53
53
|
sparsity_reverse_lut_rws: Tensor,
|
|
54
54
|
sparsity_block_size: int) -> Tensor:
|
|
55
55
|
output = torch.zeros_like(x)
|
|
56
56
|
|
|
57
|
-
x_b, x_r, x_c = x.size()
|
|
58
|
-
x_b_s, x_r_s, x_c_s = stride(x)
|
|
59
|
-
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
60
|
-
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
61
|
-
o_b, o_r, o_c = output.size()
|
|
62
|
-
|
|
63
57
|
x_row_wise_max, sparsity_layout_rwm = row_wise_max(x, sparsity_layout, sparsity_block_size,
|
|
64
58
|
flag_slice_only=True)
|
|
65
59
|
x_scaled = row_wise_sub(x, sparsity_layout, x_row_wise_max, sparsity_block_size)
|
|
@@ -67,6 +61,11 @@ def softmax_forward(x: Tensor, sparsity_layout: Tensor,
|
|
|
67
61
|
x_exp_row_wise_sum, sparsity_layout_rws = row_wise_sum(x_exp, sparsity_layout, sparsity_block_size,
|
|
68
62
|
flag_slice_only=True)
|
|
69
63
|
|
|
64
|
+
x_b, x_r, x_c = x.size()
|
|
65
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
66
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
67
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
68
|
+
o_b, o_r, o_c = output.size()
|
|
70
69
|
s_b, s_r, s_c = x_exp_row_wise_sum.shape
|
|
71
70
|
s_b_s, s_r_s, s_c_s = stride(x_exp_row_wise_sum)
|
|
72
71
|
s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_rws.shape
|
|
@@ -89,50 +88,58 @@ def softmax_forward(x: Tensor, sparsity_layout: Tensor,
|
|
|
89
88
|
return output
|
|
90
89
|
|
|
91
90
|
|
|
92
|
-
def
|
|
91
|
+
def softmax_backward_wrapper(ctx, grad_output):
|
|
93
92
|
o, sparsity_layout, sparsity_lut = ctx.saved_tensors
|
|
94
93
|
sparsity_block_size = ctx.sparsity_block_size
|
|
95
94
|
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
95
|
+
return softmax_backward(grad_output, o, sparsity_lut, sparsity_layout,
|
|
96
|
+
sparsity_block_size), None, None, None, None, None
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@triton_op("blksprs::softmax_backward", mutates_args={})
|
|
100
|
+
def softmax_backward(grad_output: Tensor, o: Tensor, sparsity_lut: Tensor, sparsity_layout: Tensor,
|
|
101
|
+
sparsity_block_size: int) -> Tensor:
|
|
102
|
+
with torch.no_grad():
|
|
103
|
+
s, sparsity_layout_s = row_wise_sum(grad_output * o, sparsity_layout, sparsity_block_size, flag_slice_only=True)
|
|
104
|
+
|
|
105
|
+
sparsity_layout_s_flat = sparsity_layout_s.reshape(-1)
|
|
106
|
+
sparsity_reverse_lut_s = ((torch.cumsum(sparsity_layout_s_flat, dim=-1) - 1) *
|
|
107
|
+
(sparsity_layout_s_flat == 1) -
|
|
108
|
+
(1 * (sparsity_layout_s_flat == 0)))
|
|
109
|
+
|
|
110
|
+
o_b, o_r, o_c = o.size()
|
|
111
|
+
o_b_s, o_r_s, o_c_s = stride(o)
|
|
112
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
113
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
114
|
+
s_b, s_r, s_c = s.size()
|
|
115
|
+
s_b_s, s_r_s, s_c_s = stride(s)
|
|
116
|
+
s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
|
|
117
|
+
s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_s)
|
|
118
|
+
|
|
119
|
+
grad_x = torch.zeros_like(o, dtype=torch.float)
|
|
120
|
+
|
|
121
|
+
triton_grid = lambda meta: [o_b,
|
|
122
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
123
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
124
|
+
|
|
125
|
+
(wrap_triton(softmax_kernel_grad)[triton_grid]
|
|
126
|
+
(grad_output,
|
|
127
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
128
|
+
o,
|
|
129
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
130
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
131
|
+
s,
|
|
132
|
+
s_b, s_b_s, s_r_s, s_c_s,
|
|
133
|
+
s_l_s_b, s_l_s_b_s, s_l_s_r_s,
|
|
134
|
+
sparsity_reverse_lut_s,
|
|
135
|
+
grad_x,
|
|
136
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
137
|
+
sparsity_block_size))
|
|
138
|
+
|
|
139
|
+
return grad_x
|
|
134
140
|
|
|
135
141
|
|
|
142
|
+
# noinspection PyUnusedLocal
|
|
136
143
|
@triton.autotune(
|
|
137
144
|
configs=get_autotune_configs(),
|
|
138
145
|
key=["sparsity_block_size"],
|
|
@@ -193,6 +200,7 @@ def softmax_kernel(x,
|
|
|
193
200
|
tl.store(o + blk_x_idx, buf, mask=blk_x_msk)
|
|
194
201
|
|
|
195
202
|
|
|
203
|
+
# noinspection PyUnusedLocal
|
|
196
204
|
@triton.autotune(
|
|
197
205
|
configs=get_autotune_configs(),
|
|
198
206
|
key=["sparsity_block_size"],
|
|
@@ -293,4 +301,4 @@ def softmax_setup_context(ctx, inputs, output):
|
|
|
293
301
|
ctx.sparsity_block_size = sparsity_block_size
|
|
294
302
|
|
|
295
303
|
|
|
296
|
-
softmax_forward.register_autograd(
|
|
304
|
+
softmax_forward.register_autograd(softmax_backward_wrapper, setup_context=softmax_setup_context)
|
blksprs/ops/transpose.py
CHANGED
|
@@ -28,7 +28,6 @@ def transpose(x: BlksprsTensor, sparsity_layout: Tensor,
|
|
|
28
28
|
|
|
29
29
|
"""
|
|
30
30
|
x = x.contiguous()
|
|
31
|
-
x_t = x.transpose(-1, -2).contiguous()
|
|
32
31
|
|
|
33
32
|
validate_dimensions(x)
|
|
34
33
|
validate_contiguous(x)
|
|
@@ -38,20 +37,22 @@ def transpose(x: BlksprsTensor, sparsity_layout: Tensor,
|
|
|
38
37
|
|
|
39
38
|
lut = transpose_build_lut(lut, sparsity_layout)
|
|
40
39
|
|
|
41
|
-
return BlksprsTensor(transpose_forward(
|
|
40
|
+
return BlksprsTensor(transpose_forward(x, lut["sparsity_layout_t"],
|
|
42
41
|
lut["sparsity_lut"], lut["sparsity_reverse_lut"],
|
|
43
42
|
sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_t"]
|
|
44
43
|
|
|
45
44
|
|
|
46
|
-
@triton_op("blksprs::
|
|
45
|
+
@triton_op("blksprs::transpose_forward", mutates_args={})
|
|
47
46
|
def transpose_forward(x: Tensor, sparsity_layout_o: Tensor,
|
|
48
47
|
sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
49
48
|
sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
50
|
-
|
|
51
|
-
|
|
49
|
+
with torch.no_grad():
|
|
50
|
+
x_t = x.transpose(-1, -2).contiguous()
|
|
51
|
+
return flow_pull_forward(x_t, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
|
|
52
|
+
sparsity_block_size, n_sparse_blocks)
|
|
52
53
|
|
|
53
54
|
|
|
54
|
-
def
|
|
55
|
+
def transpose_wrapper_backward(ctx, grad_output):
|
|
55
56
|
sparsity_layout = ctx.saved_tensors[0]
|
|
56
57
|
sparsity_block_size = ctx.sparsity_block_size
|
|
57
58
|
|
|
@@ -96,4 +97,4 @@ def transpose_setup_context(ctx, inputs, output):
|
|
|
96
97
|
ctx.sparsity_block_size = sparsity_block_size
|
|
97
98
|
|
|
98
99
|
|
|
99
|
-
transpose_forward.register_autograd(
|
|
100
|
+
transpose_forward.register_autograd(transpose_wrapper_backward, setup_context=transpose_setup_context)
|
blksprs/utils/autotuning.py
CHANGED
|
@@ -2,15 +2,7 @@ import os
|
|
|
2
2
|
|
|
3
3
|
blksprs_autotune_mode = os.getenv("BLKSPRS_AUTOTUNE", "DEFAULT")
|
|
4
4
|
|
|
5
|
-
if blksprs_autotune_mode == "
|
|
6
|
-
autotune_parameters = [
|
|
7
|
-
(16, 3, 8),
|
|
8
|
-
|
|
9
|
-
(32, 3, 8),
|
|
10
|
-
|
|
11
|
-
(64, 3, 8),
|
|
12
|
-
]
|
|
13
|
-
elif blksprs_autotune_mode == "DEFAULT":
|
|
5
|
+
if blksprs_autotune_mode == "DEFAULT":
|
|
14
6
|
autotune_parameters = [
|
|
15
7
|
(16, 3, 8),
|
|
16
8
|
(16, 4, 4),
|
|
@@ -28,6 +20,14 @@ elif blksprs_autotune_mode == "DEFAULT":
|
|
|
28
20
|
(128, 4, 4),
|
|
29
21
|
(128, 5, 2),
|
|
30
22
|
]
|
|
23
|
+
elif blksprs_autotune_mode == "TEST":
|
|
24
|
+
autotune_parameters = [
|
|
25
|
+
(16, 3, 8),
|
|
26
|
+
|
|
27
|
+
(32, 3, 8),
|
|
28
|
+
|
|
29
|
+
(64, 3, 8),
|
|
30
|
+
]
|
|
31
31
|
else:
|
|
32
32
|
raise NotImplementedError(f"Unknown autotune mode: {blksprs_autotune_mode}")
|
|
33
33
|
|
|
@@ -75,4 +75,4 @@ def get_autotune_configs():
|
|
|
75
75
|
autotune_configs.append(
|
|
76
76
|
triton.Config({"TRITON_BLOCK_SIZE": block_size}, num_stages=num_stages, num_warps=num_warps))
|
|
77
77
|
|
|
78
|
-
return autotune_configs
|
|
78
|
+
return autotune_configs
|
blksprs/utils/processing.py
CHANGED
|
@@ -26,7 +26,6 @@ def apply_torch_linear(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block
|
|
|
26
26
|
|
|
27
27
|
# Apply weights
|
|
28
28
|
sparsity_layout_xw = build_sparsity_layout_matmul_fast(sparsity_layout, sparsity_layout_w_t)
|
|
29
|
-
# TODO At the moment, manual cast is needed. Bug with custom_fwd?
|
|
30
29
|
xw = matmul(x, sparsity_layout, BlksprsTensor(w_t_bs.to(x.dtype)), sparsity_layout_w_t, sparsity_layout_xw, sparsity_block_size)
|
|
31
30
|
interim = xw
|
|
32
31
|
|