blksprs 2.1.1__py3-none-any.whl → 2.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +1 -1
- blksprs/ops/softmax.py +60 -37
- blksprs/utils/tools.py +5 -0
- {blksprs-2.1.1.dist-info → blksprs-2.1.2.dist-info}/METADATA +1 -1
- {blksprs-2.1.1.dist-info → blksprs-2.1.2.dist-info}/RECORD +7 -7
- {blksprs-2.1.1.dist-info → blksprs-2.1.2.dist-info}/WHEEL +0 -0
- {blksprs-2.1.1.dist-info → blksprs-2.1.2.dist-info}/top_level.txt +0 -0
blksprs/__init__.py
CHANGED
blksprs/ops/softmax.py
CHANGED
|
@@ -9,7 +9,7 @@ from triton import language as tl
|
|
|
9
9
|
|
|
10
10
|
from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
|
|
11
11
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
12
|
-
from blksprs.utils.tools import stride
|
|
12
|
+
from blksprs.utils.tools import stride, ceil_pow2
|
|
13
13
|
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
14
14
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
15
15
|
validate_sparsity, validate_sparsity_block_size, validate_dtype_float_32
|
|
@@ -349,13 +349,15 @@ def softmax_fused(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size
|
|
|
349
349
|
lut = softmax_fused_build_lut(lut, sparsity_layout)
|
|
350
350
|
|
|
351
351
|
return BlksprsTensor(softmax_fused_forward(x, sparsity_layout,
|
|
352
|
-
lut["
|
|
352
|
+
lut["sparsity_reverse_lut_sorted"],
|
|
353
|
+
lut["max_blocks_line"],
|
|
353
354
|
sparsity_block_size))
|
|
354
355
|
|
|
355
356
|
|
|
356
357
|
@triton_op("blksprs::softmax_fused_forward", mutates_args={})
|
|
357
358
|
def softmax_fused_forward(x: Tensor, sparsity_layout: Tensor,
|
|
358
|
-
|
|
359
|
+
sparsity_reverse_lut_sorted: Tensor,
|
|
360
|
+
max_blocks_line: int,
|
|
359
361
|
sparsity_block_size: int) -> Tensor:
|
|
360
362
|
output = torch.zeros_like(x)
|
|
361
363
|
|
|
@@ -372,23 +374,29 @@ def softmax_fused_forward(x: Tensor, sparsity_layout: Tensor,
|
|
|
372
374
|
(x,
|
|
373
375
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
374
376
|
output,
|
|
375
|
-
s_l_b, s_l_b_s, s_l_r_s,
|
|
376
|
-
|
|
377
|
+
s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
|
|
378
|
+
sparsity_reverse_lut_sorted,
|
|
379
|
+
max_blocks_line,
|
|
377
380
|
sparsity_block_size))
|
|
378
381
|
|
|
379
382
|
return output
|
|
380
383
|
|
|
381
384
|
|
|
382
385
|
def softmax_fused_backward_wrapper(ctx, grad_output):
|
|
383
|
-
o, sparsity_layout,
|
|
386
|
+
o, sparsity_layout, sparsity_reverse_lut_sorted = ctx.saved_tensors
|
|
387
|
+
max_blocks_line = ctx.max_blocks_line
|
|
384
388
|
sparsity_block_size = ctx.sparsity_block_size
|
|
385
389
|
|
|
386
|
-
return softmax_fused_backward(grad_output, o,
|
|
387
|
-
sparsity_block_size), None, None, None, None
|
|
390
|
+
return softmax_fused_backward(grad_output, o, sparsity_reverse_lut_sorted, sparsity_layout,
|
|
391
|
+
max_blocks_line, sparsity_block_size), None, None, None, None
|
|
388
392
|
|
|
389
393
|
|
|
390
394
|
@triton_op("blksprs::softmax_fused_backward", mutates_args={})
|
|
391
|
-
def softmax_fused_backward(grad_output: Tensor,
|
|
395
|
+
def softmax_fused_backward(grad_output: Tensor,
|
|
396
|
+
o: Tensor,
|
|
397
|
+
sparsity_reverse_lut_sorted: Tensor,
|
|
398
|
+
sparsity_layout: Tensor,
|
|
399
|
+
max_blocks_line: int,
|
|
392
400
|
sparsity_block_size: int) -> Tensor:
|
|
393
401
|
with torch.no_grad():
|
|
394
402
|
grad_x = torch.zeros_like(o)
|
|
@@ -409,9 +417,10 @@ def softmax_fused_backward(grad_output: Tensor, o: Tensor, sparsity_reverse_lut:
|
|
|
409
417
|
g_b, g_b_s, g_r_s, g_c_s,
|
|
410
418
|
o,
|
|
411
419
|
o_b, o_b_s, o_r_s, o_c_s,
|
|
412
|
-
s_l_b, s_l_b_s, s_l_r_s,
|
|
413
|
-
|
|
420
|
+
s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
|
|
421
|
+
sparsity_reverse_lut_sorted,
|
|
414
422
|
grad_x,
|
|
423
|
+
max_blocks_line,
|
|
415
424
|
sparsity_block_size))
|
|
416
425
|
|
|
417
426
|
return grad_x
|
|
@@ -428,8 +437,9 @@ def softmax_fused_backward(grad_output: Tensor, o: Tensor, sparsity_reverse_lut:
|
|
|
428
437
|
def softmax_fused_kernel(x,
|
|
429
438
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
430
439
|
o,
|
|
431
|
-
s_l_b, s_l_b_s, s_l_r_s,
|
|
432
|
-
|
|
440
|
+
s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
|
|
441
|
+
r_lut_s,
|
|
442
|
+
mbs: tl.constexpr,
|
|
433
443
|
sparsity_block_size: tl.constexpr,
|
|
434
444
|
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
435
445
|
# Get triton block indices
|
|
@@ -440,21 +450,21 @@ def softmax_fused_kernel(x,
|
|
|
440
450
|
# Load reverse sparsity indices of row
|
|
441
451
|
blk_rev_idx = (pid_bat * s_l_b_s +
|
|
442
452
|
pid_row * s_l_r_s +
|
|
443
|
-
(tl.arange(0,
|
|
453
|
+
(tl.arange(0, mbs) * s_l_c_s))
|
|
444
454
|
blk_rev_msk = (blk_rev_idx >= 0 and blk_rev_idx < s_l_b * s_l_b_s)
|
|
445
|
-
blk_rev = tl.load(
|
|
455
|
+
blk_rev = tl.load(r_lut_s + blk_rev_idx, mask=blk_rev_msk).to(tl.int32)
|
|
446
456
|
|
|
447
457
|
if (not (tl.min(blk_rev) == -1 and
|
|
448
458
|
tl.max(blk_rev) == -1)):
|
|
449
459
|
# Extend sparsity indices to cover sparsity blocks
|
|
450
460
|
blk_rev_ext = tl.expand_dims(blk_rev, -1)
|
|
451
|
-
blk_rev_ext = tl.broadcast_to(blk_rev_ext, (
|
|
452
|
-
blk_rev_ext = tl.reshape(blk_rev_ext, (
|
|
461
|
+
blk_rev_ext = tl.broadcast_to(blk_rev_ext, (mbs, sparsity_block_size))
|
|
462
|
+
blk_rev_ext = tl.reshape(blk_rev_ext, (mbs * sparsity_block_size))
|
|
453
463
|
|
|
454
464
|
# Load line of x
|
|
455
465
|
blk_x_idx = (blk_rev_ext * x_b_s +
|
|
456
466
|
pid_lin * x_r_s +
|
|
457
|
-
(tl.arange(0,
|
|
467
|
+
(tl.arange(0, mbs * sparsity_block_size) % sparsity_block_size) * x_c_s)
|
|
458
468
|
blk_x_mask = ((blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
|
|
459
469
|
and blk_rev_ext != -1)
|
|
460
470
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_mask, other=float("-inf"))
|
|
@@ -478,9 +488,10 @@ def softmax_fused_kernel_grad(g,
|
|
|
478
488
|
g_b, g_b_s, g_r_s, g_c_s,
|
|
479
489
|
x,
|
|
480
490
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
481
|
-
s_l_b, s_l_b_s, s_l_r_s,
|
|
482
|
-
|
|
491
|
+
s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
|
|
492
|
+
r_lut_s,
|
|
483
493
|
o,
|
|
494
|
+
mbs: tl.constexpr,
|
|
484
495
|
sparsity_block_size: tl.constexpr,
|
|
485
496
|
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
486
497
|
# Get triton block indices
|
|
@@ -491,21 +502,21 @@ def softmax_fused_kernel_grad(g,
|
|
|
491
502
|
# Load reverse sparsity indices of row
|
|
492
503
|
blk_rev_idx = (pid_bat * s_l_b_s +
|
|
493
504
|
pid_row * s_l_r_s +
|
|
494
|
-
(tl.arange(0,
|
|
505
|
+
(tl.arange(0, mbs) * s_l_c_s))
|
|
495
506
|
blk_rev_msk = (blk_rev_idx >= 0 and blk_rev_idx < s_l_b * s_l_b_s)
|
|
496
|
-
blk_rev = tl.load(
|
|
507
|
+
blk_rev = tl.load(r_lut_s + blk_rev_idx, mask=blk_rev_msk).to(tl.int32)
|
|
497
508
|
|
|
498
509
|
if (not (tl.min(blk_rev) == -1 and
|
|
499
510
|
tl.max(blk_rev) == -1)):
|
|
500
511
|
# Extend sparsity indices to cover sparsity blocks
|
|
501
512
|
blk_rev_ext = tl.expand_dims(blk_rev, -1)
|
|
502
|
-
blk_rev_ext = tl.broadcast_to(blk_rev_ext, (
|
|
503
|
-
blk_rev_ext = tl.reshape(blk_rev_ext, (
|
|
513
|
+
blk_rev_ext = tl.broadcast_to(blk_rev_ext, (mbs, sparsity_block_size))
|
|
514
|
+
blk_rev_ext = tl.reshape(blk_rev_ext, (mbs * sparsity_block_size))
|
|
504
515
|
|
|
505
516
|
# Load line of g
|
|
506
517
|
blk_g_idx = (blk_rev_ext * g_b_s +
|
|
507
518
|
pid_lin * g_r_s +
|
|
508
|
-
(tl.arange(0,
|
|
519
|
+
(tl.arange(0, mbs * sparsity_block_size) % sparsity_block_size) * g_c_s)
|
|
509
520
|
blk_g_mask = ((blk_g_idx >= 0 and blk_g_idx < g_b * g_b_s)
|
|
510
521
|
and blk_rev_ext != -1)
|
|
511
522
|
blk_g = tl.load(g + blk_g_idx, mask=blk_g_mask)
|
|
@@ -513,7 +524,7 @@ def softmax_fused_kernel_grad(g,
|
|
|
513
524
|
# Load line of x
|
|
514
525
|
blk_x_idx = (blk_rev_ext * x_b_s +
|
|
515
526
|
pid_lin * x_r_s +
|
|
516
|
-
(tl.arange(0,
|
|
527
|
+
(tl.arange(0, mbs * sparsity_block_size) % sparsity_block_size) * x_c_s)
|
|
517
528
|
blk_x_mask = ((blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
|
|
518
529
|
and blk_rev_ext != -1)
|
|
519
530
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_mask)
|
|
@@ -521,6 +532,7 @@ def softmax_fused_kernel_grad(g,
|
|
|
521
532
|
# Compute gradients
|
|
522
533
|
blk_grad = blk_x * (blk_g - tl.sum(blk_x * blk_g))
|
|
523
534
|
|
|
535
|
+
# Store output
|
|
524
536
|
tl.store(o + blk_x_idx, blk_grad, mask=blk_x_mask)
|
|
525
537
|
|
|
526
538
|
|
|
@@ -528,25 +540,36 @@ def softmax_fused_build_lut(lut: dict, sparsity_layout: Tensor):
|
|
|
528
540
|
if lut is None:
|
|
529
541
|
lut = dict()
|
|
530
542
|
|
|
531
|
-
if "
|
|
543
|
+
if "sparsity_reverse_lut_sorted" not in lut:
|
|
532
544
|
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
545
|
+
sparsity_reverse_lut_sorted = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
546
|
+
(sparsity_layout_flat == 1) -
|
|
547
|
+
(1 * (sparsity_layout_flat == 0)))
|
|
548
|
+
.reshape(sparsity_layout.size())
|
|
549
|
+
.sort(descending=True, dim=-1)[0]
|
|
550
|
+
.reshape(-1).contiguous())
|
|
551
|
+
lut["sparsity_reverse_lut_sorted"] = sparsity_reverse_lut_sorted
|
|
552
|
+
|
|
553
|
+
if "max_blocks_line" not in lut:
|
|
554
|
+
sparsity_reverse_lut_sorted = lut["sparsity_reverse_lut_sorted"]
|
|
555
|
+
max_blocks_line = ((torch.reshape(sparsity_reverse_lut_sorted, (-1, sparsity_layout.size(-1)))
|
|
556
|
+
!= -1)
|
|
557
|
+
.sum(dim=-1)
|
|
558
|
+
.max()
|
|
559
|
+
.item())
|
|
560
|
+
lut["max_blocks_line"] = min(ceil_pow2(max(max_blocks_line, 2)), sparsity_layout.size(-1))
|
|
561
|
+
|
|
562
|
+
validate_contiguous(sparsity_layout, lut["sparsity_reverse_lut_sorted"])
|
|
541
563
|
|
|
542
564
|
return lut
|
|
543
565
|
|
|
544
566
|
|
|
545
567
|
# noinspection PyUnusedLocal
|
|
546
568
|
def softmax_fused_setup_context(ctx, inputs, output):
|
|
547
|
-
(_, sparsity_layout,
|
|
569
|
+
(_, sparsity_layout, sparsity_reverse_lut_sorted, max_blocks_line, sparsity_block_size) = inputs
|
|
548
570
|
|
|
549
|
-
ctx.save_for_backward(output, sparsity_layout,
|
|
571
|
+
ctx.save_for_backward(output, sparsity_layout, sparsity_reverse_lut_sorted)
|
|
572
|
+
ctx.max_blocks_line = max_blocks_line
|
|
550
573
|
ctx.sparsity_block_size = sparsity_block_size
|
|
551
574
|
|
|
552
575
|
|
blksprs/utils/tools.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 2.1.
|
|
3
|
+
Version: 2.1.2
|
|
4
4
|
Summary: A lightweight library for operations on block-sparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
blksprs/__init__.py,sha256=
|
|
1
|
+
blksprs/__init__.py,sha256=NRxydw4i9jg7WeDuojfEePdtdbughV9AZsEcT9yywK4,1615
|
|
2
2
|
blksprs/layouting/distribution_layout.py,sha256=ur1ty_2U-Hfj78hMWsLZvu7ZuGhzW3qGLKMc72DfTZM,5861
|
|
3
3
|
blksprs/layouting/sparsity_layout.py,sha256=eXHmu2h7K5Q-YUpfOxocJoeP_5ZoQFZf_eHLxRZQbYU,11207
|
|
4
4
|
blksprs/ops/conversion.py,sha256=RgVSyiULLwv8KWQqSyXpKwTr4Qp-lpDK9i-zKlN841I,21914
|
|
@@ -7,7 +7,7 @@ blksprs/ops/flow.py,sha256=qdWBCLDSkKaa8CAfkO1NgH-J5N7yMsILyR7qEpyrIUU,8246
|
|
|
7
7
|
blksprs/ops/matmul.py,sha256=5tVBKU_lglUjaLDi6J_dscdqlmzRz38OGxqAxZxZXDs,11879
|
|
8
8
|
blksprs/ops/partitioning.py,sha256=cfQmY9BZqGTvvJorIhtb-EyuGRJGPraWR-wTKdb47aI,9954
|
|
9
9
|
blksprs/ops/repeat.py,sha256=TLYNxwPuT9y5K9xyM41WK5gnggAJF3lI61Q2K7zWjns,9035
|
|
10
|
-
blksprs/ops/softmax.py,sha256=
|
|
10
|
+
blksprs/ops/softmax.py,sha256=1UIovPrdE_zgAIPqjmOTFn8CMbd_2Z8tPP-vMBxU07I,23526
|
|
11
11
|
blksprs/ops/transpose.py,sha256=U-VAyLRT6_NDv9qYSFzBqfVlDeIpTqAMEXkqto0VF6w,4072
|
|
12
12
|
blksprs/ops/misc/broadcast_ops.py,sha256=-PrHiSJikZh8nXUmXxSCtFEP27TTxFr4wcrNxBjnimk,5987
|
|
13
13
|
blksprs/ops/misc/row_wise.py,sha256=n5FJjAuOd8BHBJQx4bsQwr-HmXkR9PYVAqfk77wjOFU,19653
|
|
@@ -15,9 +15,9 @@ blksprs/utils/autotuning.py,sha256=a-kmWRjJ3eED2XbjkQeOJSyW8bdIs27HgKMPvAKqWeU,2
|
|
|
15
15
|
blksprs/utils/benchmarking.py,sha256=dLabDscTFn5NkmOI1g7DnKeTneUYW3RIVv9MDF-8BKc,1271
|
|
16
16
|
blksprs/utils/blksprs_tensor.py,sha256=pfoz59aJixj_fIoFx76ySiygwRQUemmgjMKepZ2c4j0,244
|
|
17
17
|
blksprs/utils/processing.py,sha256=RNkEDc0g-sNHRuMPkRzNWU13d3_lIkXMJdoqES4yQTM,3738
|
|
18
|
-
blksprs/utils/tools.py,sha256=
|
|
18
|
+
blksprs/utils/tools.py,sha256=TKygEKge4wJtJnXXDg8BTL8vzBpqIJsQ_A3_5FmLpcE,859
|
|
19
19
|
blksprs/utils/validation.py,sha256=G8eQlvJVMKfEX3k2AwBD0A6Ck-gFoRLpLNY6HXsB3fA,4348
|
|
20
|
-
blksprs-2.1.
|
|
21
|
-
blksprs-2.1.
|
|
22
|
-
blksprs-2.1.
|
|
23
|
-
blksprs-2.1.
|
|
20
|
+
blksprs-2.1.2.dist-info/METADATA,sha256=U20ZL7XLhrgiMd_0QGFik0Ci43SDoCT8q876-1yCeNo,9665
|
|
21
|
+
blksprs-2.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
22
|
+
blksprs-2.1.2.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
|
|
23
|
+
blksprs-2.1.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|