blksprs 2.0rc8__tar.gz → 2.1.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {blksprs-2.0rc8 → blksprs-2.1.1}/PKG-INFO +3 -2
- {blksprs-2.0rc8 → blksprs-2.1.1}/README.md +1 -0
- {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs/__init__.py +3 -2
- {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs/ops/conversion.py +4 -4
- {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs/ops/flow.py +1 -1
- {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs/ops/softmax.py +252 -3
- {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs/utils/tools.py +2 -10
- {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs.egg-info/PKG-INFO +3 -2
- {blksprs-2.0rc8 → blksprs-2.1.1}/pyproject.toml +2 -2
- {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs/layouting/distribution_layout.py +0 -0
- {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs/layouting/sparsity_layout.py +0 -0
- {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs/ops/distribution.py +0 -0
- {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs/ops/matmul.py +0 -0
- {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs/ops/misc/broadcast_ops.py +0 -0
- {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs/ops/misc/row_wise.py +0 -0
- {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs/ops/partitioning.py +0 -0
- {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs/ops/repeat.py +0 -0
- {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs/ops/transpose.py +0 -0
- {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs/utils/autotuning.py +0 -0
- {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs/utils/benchmarking.py +0 -0
- {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs/utils/blksprs_tensor.py +0 -0
- {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs/utils/processing.py +0 -0
- {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs/utils/validation.py +0 -0
- {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs.egg-info/SOURCES.txt +0 -0
- {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs.egg-info/dependency_links.txt +0 -0
- {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs.egg-info/requires.txt +0 -0
- {blksprs-2.0rc8 → blksprs-2.1.1}/blksprs.egg-info/top_level.txt +0 -0
- {blksprs-2.0rc8 → blksprs-2.1.1}/setup.cfg +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 2.
|
|
4
|
-
Summary: A lightweight library for operations on
|
|
3
|
+
Version: 2.1.1
|
|
4
|
+
Summary: A lightweight library for operations on block-sparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
7
7
|
Project-URL: Bugtracker, https://github.com/FelixSchoen/blksprs/issues
|
|
@@ -197,6 +197,7 @@ def test_readme():
|
|
|
197
197
|
# Other available functions
|
|
198
198
|
bs.ops.transpose(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
199
199
|
bs.ops.softmax(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
200
|
+
bs.ops.softmax_fused(o_sparse, sparsity_layout_o, sparsity_block_size) # Significantly faster version that requires that rows of matrix fit into memory
|
|
200
201
|
bs.ops.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
201
202
|
bs.ops.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
202
203
|
|
|
@@ -178,6 +178,7 @@ def test_readme():
|
|
|
178
178
|
# Other available functions
|
|
179
179
|
bs.ops.transpose(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
180
180
|
bs.ops.softmax(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
181
|
+
bs.ops.softmax_fused(o_sparse, sparsity_layout_o, sparsity_block_size) # Significantly faster version that requires that rows of matrix fit into memory
|
|
181
182
|
bs.ops.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
182
183
|
bs.ops.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
183
184
|
|
|
@@ -1,12 +1,13 @@
|
|
|
1
|
-
from blksprs.utils.tools import version
|
|
2
1
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
3
2
|
|
|
3
|
+
__version__ = "2.1.1"
|
|
4
|
+
|
|
4
5
|
|
|
5
6
|
class ops:
|
|
6
7
|
from blksprs.ops.conversion import to_dense, to_sparse, from_blksprs, to_blksprs, adapt_layout
|
|
7
8
|
from blksprs.ops.distribution import gather, scatter, scatter_reduce
|
|
8
9
|
from blksprs.ops.matmul import matmul
|
|
9
|
-
from blksprs.ops.softmax import softmax
|
|
10
|
+
from blksprs.ops.softmax import softmax, softmax_fused
|
|
10
11
|
from blksprs.ops.transpose import transpose
|
|
11
12
|
from blksprs.ops.repeat import repeat, repeat_interleave
|
|
12
13
|
from blksprs.ops.partitioning import split, merge
|
|
@@ -13,7 +13,7 @@ from blksprs.utils.validation import validate_contiguous, validate_dimensions, v
|
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
def to_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int) -> BlksprsTensor:
|
|
16
|
-
"""Wrapper for
|
|
16
|
+
"""Wrapper for :func:`to_sparse`.
|
|
17
17
|
|
|
18
18
|
"""
|
|
19
19
|
return to_sparse(x, sparsity_layout, sparsity_block_size)
|
|
@@ -167,7 +167,7 @@ to_sparse_forward.register_autograd(to_sparse_wrapper_backward, setup_context=to
|
|
|
167
167
|
|
|
168
168
|
def from_blksprs(x: BlksprsTensor, sparsity_layout: Tensor,
|
|
169
169
|
sparsity_block_size: int, fill_value: float = 0, lut: dict = None) -> Tensor:
|
|
170
|
-
"""Wrapper for
|
|
170
|
+
"""Wrapper for :func:`to_dense`.
|
|
171
171
|
|
|
172
172
|
"""
|
|
173
173
|
return to_dense(x, sparsity_layout, sparsity_block_size, fill_value=fill_value, lut=lut)
|
|
@@ -204,8 +204,8 @@ def to_dense(x: BlksprsTensor, sparsity_layout: Tensor,
|
|
|
204
204
|
if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
|
|
205
205
|
return x
|
|
206
206
|
|
|
207
|
-
return to_dense_forward(x, sparsity_layout,
|
|
208
|
-
lut["sparsity_reverse_lut"], sparsity_block_size, fill_value)
|
|
207
|
+
return Tensor(to_dense_forward(x, sparsity_layout,
|
|
208
|
+
lut["sparsity_reverse_lut"], sparsity_block_size, fill_value))
|
|
209
209
|
|
|
210
210
|
|
|
211
211
|
@triton_op("blksprs::to_dense_forward", mutates_args={})
|
|
@@ -78,7 +78,7 @@ def flow_pull_kernel(x,
|
|
|
78
78
|
spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
|
|
79
79
|
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
80
80
|
|
|
81
|
-
#
|
|
81
|
+
# Load reverse sparsity index
|
|
82
82
|
rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
|
|
83
83
|
spa_row * s_l_o_r_s +
|
|
84
84
|
spa_col * s_l_o_c_s)
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
import pdb
|
|
2
|
+
|
|
1
3
|
import torch
|
|
2
4
|
import triton
|
|
3
5
|
from torch import Tensor
|
|
@@ -13,8 +15,20 @@ from blksprs.utils.validation import validate_contiguous, validate_dimensions, v
|
|
|
13
15
|
validate_sparsity, validate_sparsity_block_size, validate_dtype_float_32
|
|
14
16
|
|
|
15
17
|
|
|
18
|
+
def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, flag_fused: bool = True,
|
|
19
|
+
lut: dict = None) -> BlksprsTensor:
|
|
20
|
+
"""Wrapper for :func:`softmax_regular` and :func:`softmax_fused` based on the ``flag_fused`` parameter.
|
|
21
|
+
|
|
22
|
+
"""
|
|
23
|
+
if flag_fused:
|
|
24
|
+
return softmax_fused(x, sparsity_layout, sparsity_block_size, lut)
|
|
25
|
+
else:
|
|
26
|
+
return softmax_regular(x, sparsity_layout, sparsity_block_size, lut)
|
|
27
|
+
|
|
28
|
+
|
|
16
29
|
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
|
|
17
|
-
def
|
|
30
|
+
def softmax_regular(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
31
|
+
lut: dict = None) -> BlksprsTensor:
|
|
18
32
|
"""Computes the softmax of a block-sparse tensor in compressed form.
|
|
19
33
|
|
|
20
34
|
Note:
|
|
@@ -100,6 +114,8 @@ def softmax_backward_wrapper(ctx, grad_output):
|
|
|
100
114
|
def softmax_backward(grad_output: Tensor, o: Tensor, sparsity_lut: Tensor, sparsity_layout: Tensor,
|
|
101
115
|
sparsity_block_size: int) -> Tensor:
|
|
102
116
|
with torch.no_grad():
|
|
117
|
+
grad_x = torch.zeros_like(o, dtype=torch.float)
|
|
118
|
+
|
|
103
119
|
s, sparsity_layout_s = row_wise_sum(grad_output * o, sparsity_layout, sparsity_block_size, flag_slice_only=True)
|
|
104
120
|
|
|
105
121
|
sparsity_layout_s_flat = sparsity_layout_s.reshape(-1)
|
|
@@ -116,8 +132,6 @@ def softmax_backward(grad_output: Tensor, o: Tensor, sparsity_lut: Tensor, spars
|
|
|
116
132
|
s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
|
|
117
133
|
s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_s)
|
|
118
134
|
|
|
119
|
-
grad_x = torch.zeros_like(o, dtype=torch.float)
|
|
120
|
-
|
|
121
135
|
triton_grid = lambda meta: [o_b,
|
|
122
136
|
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
123
137
|
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
@@ -302,3 +316,238 @@ def softmax_setup_context(ctx, inputs, output):
|
|
|
302
316
|
|
|
303
317
|
|
|
304
318
|
softmax_forward.register_autograd(softmax_backward_wrapper, setup_context=softmax_setup_context)
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
|
|
322
|
+
def softmax_fused(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
323
|
+
lut: dict = None) -> BlksprsTensor:
|
|
324
|
+
"""Computes the softmax fused for each row of a block-sparse tensor in compressed form.
|
|
325
|
+
|
|
326
|
+
Note:
|
|
327
|
+
This softmax implementation is a fused version that loads the entire row of a block-sparse tensor into memory.
|
|
328
|
+
See :func:`softmax` for a true block-wise softmax implementation.
|
|
329
|
+
|
|
330
|
+
Args:
|
|
331
|
+
x (BlksprsTensor): A block-sparse tensor in compressed form.
|
|
332
|
+
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
333
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
334
|
+
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
335
|
+
|
|
336
|
+
Returns:
|
|
337
|
+
BlksprsTensor: The result of the softmax operation as a block-sparse tensor in compressed form.
|
|
338
|
+
|
|
339
|
+
"""
|
|
340
|
+
x = x.contiguous()
|
|
341
|
+
|
|
342
|
+
validate_dimensions(x)
|
|
343
|
+
validate_contiguous(x)
|
|
344
|
+
validate_dtype_float_32(x)
|
|
345
|
+
validate_device(x)
|
|
346
|
+
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
347
|
+
validate_sparsity_block_size(sparsity_block_size, x)
|
|
348
|
+
|
|
349
|
+
lut = softmax_fused_build_lut(lut, sparsity_layout)
|
|
350
|
+
|
|
351
|
+
return BlksprsTensor(softmax_fused_forward(x, sparsity_layout,
|
|
352
|
+
lut["sparsity_reverse_lut"],
|
|
353
|
+
sparsity_block_size))
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
@triton_op("blksprs::softmax_fused_forward", mutates_args={})
|
|
357
|
+
def softmax_fused_forward(x: Tensor, sparsity_layout: Tensor,
|
|
358
|
+
sparsity_reverse_lut: Tensor,
|
|
359
|
+
sparsity_block_size: int) -> Tensor:
|
|
360
|
+
output = torch.zeros_like(x)
|
|
361
|
+
|
|
362
|
+
x_b, x_r, x_c = x.size()
|
|
363
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
364
|
+
s_l_b, s_l_r, s_l_c = sparsity_layout.size()
|
|
365
|
+
s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout)
|
|
366
|
+
|
|
367
|
+
triton_grid = lambda meta: [s_l_b,
|
|
368
|
+
s_l_r,
|
|
369
|
+
sparsity_block_size]
|
|
370
|
+
|
|
371
|
+
(wrap_triton(softmax_fused_kernel)[triton_grid]
|
|
372
|
+
(x,
|
|
373
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
374
|
+
output,
|
|
375
|
+
s_l_b, s_l_b_s, s_l_r_s, s_l_c, s_l_c_s,
|
|
376
|
+
sparsity_reverse_lut,
|
|
377
|
+
sparsity_block_size))
|
|
378
|
+
|
|
379
|
+
return output
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
def softmax_fused_backward_wrapper(ctx, grad_output):
|
|
383
|
+
o, sparsity_layout, sparsity_reverse_lut = ctx.saved_tensors
|
|
384
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
385
|
+
|
|
386
|
+
return softmax_fused_backward(grad_output, o, sparsity_reverse_lut, sparsity_layout,
|
|
387
|
+
sparsity_block_size), None, None, None, None, None
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
@triton_op("blksprs::softmax_fused_backward", mutates_args={})
|
|
391
|
+
def softmax_fused_backward(grad_output: Tensor, o: Tensor, sparsity_reverse_lut: Tensor, sparsity_layout: Tensor,
|
|
392
|
+
sparsity_block_size: int) -> Tensor:
|
|
393
|
+
with torch.no_grad():
|
|
394
|
+
grad_x = torch.zeros_like(o)
|
|
395
|
+
|
|
396
|
+
g_b, g_r, g_c = grad_output.size()
|
|
397
|
+
g_b_s, g_r_s, g_c_s = stride(grad_output)
|
|
398
|
+
o_b, o_r, o_c = o.size()
|
|
399
|
+
o_b_s, o_r_s, o_c_s = stride(o)
|
|
400
|
+
s_l_b, s_l_r, s_l_c = sparsity_layout.size()
|
|
401
|
+
s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout)
|
|
402
|
+
|
|
403
|
+
triton_grid = lambda meta: [s_l_b,
|
|
404
|
+
s_l_r,
|
|
405
|
+
sparsity_block_size]
|
|
406
|
+
|
|
407
|
+
(wrap_triton(softmax_fused_kernel_grad)[triton_grid]
|
|
408
|
+
(grad_output,
|
|
409
|
+
g_b, g_b_s, g_r_s, g_c_s,
|
|
410
|
+
o,
|
|
411
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
412
|
+
s_l_b, s_l_b_s, s_l_r_s, s_l_c, s_l_c_s,
|
|
413
|
+
sparsity_reverse_lut,
|
|
414
|
+
grad_x,
|
|
415
|
+
sparsity_block_size))
|
|
416
|
+
|
|
417
|
+
return grad_x
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
# noinspection PyUnusedLocal
|
|
421
|
+
@triton.autotune(
|
|
422
|
+
configs=get_autotune_configs(),
|
|
423
|
+
key=["sparsity_block_size"],
|
|
424
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
425
|
+
reset_to_zero=["o"]
|
|
426
|
+
)
|
|
427
|
+
@triton.jit
|
|
428
|
+
def softmax_fused_kernel(x,
|
|
429
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
430
|
+
o,
|
|
431
|
+
s_l_b, s_l_b_s, s_l_r_s, s_l_c: tl.constexpr, s_l_c_s,
|
|
432
|
+
r_lut,
|
|
433
|
+
sparsity_block_size: tl.constexpr,
|
|
434
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
435
|
+
# Get triton block indices
|
|
436
|
+
pid_bat = tl.program_id(axis=0)
|
|
437
|
+
pid_row = tl.program_id(axis=1)
|
|
438
|
+
pid_lin = tl.program_id(axis=2)
|
|
439
|
+
|
|
440
|
+
# Load reverse sparsity indices of row
|
|
441
|
+
blk_rev_idx = (pid_bat * s_l_b_s +
|
|
442
|
+
pid_row * s_l_r_s +
|
|
443
|
+
(tl.arange(0, s_l_c) * s_l_c_s))
|
|
444
|
+
blk_rev_msk = (blk_rev_idx >= 0 and blk_rev_idx < s_l_b * s_l_b_s)
|
|
445
|
+
blk_rev = tl.load(r_lut + blk_rev_idx, mask=blk_rev_msk).to(tl.int32)
|
|
446
|
+
|
|
447
|
+
if (not (tl.min(blk_rev) == -1 and
|
|
448
|
+
tl.max(blk_rev) == -1)):
|
|
449
|
+
# Extend sparsity indices to cover sparsity blocks
|
|
450
|
+
blk_rev_ext = tl.expand_dims(blk_rev, -1)
|
|
451
|
+
blk_rev_ext = tl.broadcast_to(blk_rev_ext, (s_l_c, sparsity_block_size))
|
|
452
|
+
blk_rev_ext = tl.reshape(blk_rev_ext, (s_l_c * sparsity_block_size))
|
|
453
|
+
|
|
454
|
+
# Load line of x
|
|
455
|
+
blk_x_idx = (blk_rev_ext * x_b_s +
|
|
456
|
+
pid_lin * x_r_s +
|
|
457
|
+
(tl.arange(0, s_l_c * sparsity_block_size) % sparsity_block_size) * x_c_s)
|
|
458
|
+
blk_x_mask = ((blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
|
|
459
|
+
and blk_rev_ext != -1)
|
|
460
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_mask, other=float("-inf"))
|
|
461
|
+
|
|
462
|
+
# Compute softmax
|
|
463
|
+
blk_x_softmax = tl.softmax(blk_x)
|
|
464
|
+
|
|
465
|
+
# Store output
|
|
466
|
+
tl.store(o + blk_x_idx, blk_x_softmax, mask=blk_x_mask)
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
# noinspection PyUnusedLocal
|
|
470
|
+
@triton.autotune(
|
|
471
|
+
configs=get_autotune_configs(),
|
|
472
|
+
key=["sparsity_block_size"],
|
|
473
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
474
|
+
reset_to_zero=["o"]
|
|
475
|
+
)
|
|
476
|
+
@triton.jit
|
|
477
|
+
def softmax_fused_kernel_grad(g,
|
|
478
|
+
g_b, g_b_s, g_r_s, g_c_s,
|
|
479
|
+
x,
|
|
480
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
481
|
+
s_l_b, s_l_b_s, s_l_r_s, s_l_c: tl.constexpr, s_l_c_s,
|
|
482
|
+
r_lut,
|
|
483
|
+
o,
|
|
484
|
+
sparsity_block_size: tl.constexpr,
|
|
485
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
486
|
+
# Get triton block indices
|
|
487
|
+
pid_bat = tl.program_id(axis=0)
|
|
488
|
+
pid_row = tl.program_id(axis=1)
|
|
489
|
+
pid_lin = tl.program_id(axis=2)
|
|
490
|
+
|
|
491
|
+
# Load reverse sparsity indices of row
|
|
492
|
+
blk_rev_idx = (pid_bat * s_l_b_s +
|
|
493
|
+
pid_row * s_l_r_s +
|
|
494
|
+
(tl.arange(0, s_l_c) * s_l_c_s))
|
|
495
|
+
blk_rev_msk = (blk_rev_idx >= 0 and blk_rev_idx < s_l_b * s_l_b_s)
|
|
496
|
+
blk_rev = tl.load(r_lut + blk_rev_idx, mask=blk_rev_msk).to(tl.int32)
|
|
497
|
+
|
|
498
|
+
if (not (tl.min(blk_rev) == -1 and
|
|
499
|
+
tl.max(blk_rev) == -1)):
|
|
500
|
+
# Extend sparsity indices to cover sparsity blocks
|
|
501
|
+
blk_rev_ext = tl.expand_dims(blk_rev, -1)
|
|
502
|
+
blk_rev_ext = tl.broadcast_to(blk_rev_ext, (s_l_c, sparsity_block_size))
|
|
503
|
+
blk_rev_ext = tl.reshape(blk_rev_ext, (s_l_c * sparsity_block_size))
|
|
504
|
+
|
|
505
|
+
# Load line of g
|
|
506
|
+
blk_g_idx = (blk_rev_ext * g_b_s +
|
|
507
|
+
pid_lin * g_r_s +
|
|
508
|
+
(tl.arange(0, s_l_c * sparsity_block_size) % sparsity_block_size) * g_c_s)
|
|
509
|
+
blk_g_mask = ((blk_g_idx >= 0 and blk_g_idx < g_b * g_b_s)
|
|
510
|
+
and blk_rev_ext != -1)
|
|
511
|
+
blk_g = tl.load(g + blk_g_idx, mask=blk_g_mask)
|
|
512
|
+
|
|
513
|
+
# Load line of x
|
|
514
|
+
blk_x_idx = (blk_rev_ext * x_b_s +
|
|
515
|
+
pid_lin * x_r_s +
|
|
516
|
+
(tl.arange(0, s_l_c * sparsity_block_size) % sparsity_block_size) * x_c_s)
|
|
517
|
+
blk_x_mask = ((blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
|
|
518
|
+
and blk_rev_ext != -1)
|
|
519
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_mask)
|
|
520
|
+
|
|
521
|
+
# Compute gradients
|
|
522
|
+
blk_grad = blk_x * (blk_g - tl.sum(blk_x * blk_g))
|
|
523
|
+
|
|
524
|
+
tl.store(o + blk_x_idx, blk_grad, mask=blk_x_mask)
|
|
525
|
+
|
|
526
|
+
|
|
527
|
+
def softmax_fused_build_lut(lut: dict, sparsity_layout: Tensor):
|
|
528
|
+
if lut is None:
|
|
529
|
+
lut = dict()
|
|
530
|
+
|
|
531
|
+
if "sparsity_reverse_lut" not in lut:
|
|
532
|
+
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
|
533
|
+
sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
534
|
+
(sparsity_layout_flat == 1) -
|
|
535
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
536
|
+
.reshape(sparsity_layout.size())
|
|
537
|
+
.reshape(-1).contiguous())
|
|
538
|
+
lut["sparsity_reverse_lut"] = sparsity_reverse_lut
|
|
539
|
+
|
|
540
|
+
validate_contiguous(sparsity_layout, lut["sparsity_reverse_lut"])
|
|
541
|
+
|
|
542
|
+
return lut
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
# noinspection PyUnusedLocal
|
|
546
|
+
def softmax_fused_setup_context(ctx, inputs, output):
|
|
547
|
+
(_, sparsity_layout, sparsity_reverse_lut, sparsity_block_size) = inputs
|
|
548
|
+
|
|
549
|
+
ctx.save_for_backward(output, sparsity_layout, sparsity_reverse_lut)
|
|
550
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
551
|
+
|
|
552
|
+
|
|
553
|
+
softmax_fused_forward.register_autograd(softmax_fused_backward_wrapper, setup_context=softmax_fused_setup_context)
|
|
@@ -1,6 +1,3 @@
|
|
|
1
|
-
import tomllib
|
|
2
|
-
from pathlib import Path
|
|
3
|
-
|
|
4
1
|
import torch
|
|
5
2
|
from torch import Tensor, Size
|
|
6
3
|
|
|
@@ -8,19 +5,14 @@ from torch import Tensor, Size
|
|
|
8
5
|
torch._dynamo.config.capture_scalar_outputs = True
|
|
9
6
|
|
|
10
7
|
|
|
11
|
-
def
|
|
12
|
-
with open(Path(__file__).parent.parent.parent.joinpath("pyproject.toml"), "rb") as f:
|
|
13
|
-
return tomllib.load(f)["project"]["version"]
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def do_shape_blocksparse(x: Tensor):
|
|
8
|
+
def do_shape_blocksparse(x: Tensor) -> tuple[Tensor, Size]:
|
|
17
9
|
if x.dim() == 3:
|
|
18
10
|
return x.contiguous(), x.size()
|
|
19
11
|
|
|
20
12
|
return x.reshape(-1, x.size(-2), x.size(-1)).contiguous(), x.size()
|
|
21
13
|
|
|
22
14
|
|
|
23
|
-
def undo_shape_blocksparse(x: Tensor, shape: Size):
|
|
15
|
+
def undo_shape_blocksparse(x: Tensor, shape: Size | tuple[int, ...]) -> Tensor:
|
|
24
16
|
if x.shape[:-2] == shape[:-2]:
|
|
25
17
|
return x
|
|
26
18
|
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 2.
|
|
4
|
-
Summary: A lightweight library for operations on
|
|
3
|
+
Version: 2.1.1
|
|
4
|
+
Summary: A lightweight library for operations on block-sparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
7
7
|
Project-URL: Bugtracker, https://github.com/FelixSchoen/blksprs/issues
|
|
@@ -197,6 +197,7 @@ def test_readme():
|
|
|
197
197
|
# Other available functions
|
|
198
198
|
bs.ops.transpose(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
199
199
|
bs.ops.softmax(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
200
|
+
bs.ops.softmax_fused(o_sparse, sparsity_layout_o, sparsity_block_size) # Significantly faster version that requires that rows of matrix fit into memory
|
|
200
201
|
bs.ops.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
201
202
|
bs.ops.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size)
|
|
202
203
|
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "blksprs"
|
|
3
|
-
version = "2.
|
|
3
|
+
version = "2.1.1"
|
|
4
4
|
authors = [{ name = "Felix Schön", email = "schoen@kr.tuwien.ac.at" }]
|
|
5
|
-
description = "A lightweight library for operations on
|
|
5
|
+
description = "A lightweight library for operations on block-sparse matrices in PyTorch."
|
|
6
6
|
readme = "README.md"
|
|
7
7
|
requires-python = ">=3.11"
|
|
8
8
|
license = { file = "LICENSE.md" }
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|