blksprs 2.1.1__py3-none-any.whl → 2.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from blksprs.utils.blksprs_tensor import BlksprsTensor
2
2
 
3
- __version__ = "2.1.1"
3
+ __version__ = "2.1.2"
4
4
 
5
5
 
6
6
  class ops:
blksprs/ops/softmax.py CHANGED
@@ -9,7 +9,7 @@ from triton import language as tl
9
9
 
10
10
  from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
11
11
  from blksprs.utils.blksprs_tensor import BlksprsTensor
12
- from blksprs.utils.tools import stride
12
+ from blksprs.utils.tools import stride, ceil_pow2
13
13
  from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
14
14
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
15
15
  validate_sparsity, validate_sparsity_block_size, validate_dtype_float_32
@@ -349,13 +349,15 @@ def softmax_fused(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size
349
349
  lut = softmax_fused_build_lut(lut, sparsity_layout)
350
350
 
351
351
  return BlksprsTensor(softmax_fused_forward(x, sparsity_layout,
352
- lut["sparsity_reverse_lut"],
352
+ lut["sparsity_reverse_lut_sorted"],
353
+ lut["max_blocks_line"],
353
354
  sparsity_block_size))
354
355
 
355
356
 
356
357
  @triton_op("blksprs::softmax_fused_forward", mutates_args={})
357
358
  def softmax_fused_forward(x: Tensor, sparsity_layout: Tensor,
358
- sparsity_reverse_lut: Tensor,
359
+ sparsity_reverse_lut_sorted: Tensor,
360
+ max_blocks_line: int,
359
361
  sparsity_block_size: int) -> Tensor:
360
362
  output = torch.zeros_like(x)
361
363
 
@@ -372,23 +374,29 @@ def softmax_fused_forward(x: Tensor, sparsity_layout: Tensor,
372
374
  (x,
373
375
  x_b, x_b_s, x_r_s, x_c_s,
374
376
  output,
375
- s_l_b, s_l_b_s, s_l_r_s, s_l_c, s_l_c_s,
376
- sparsity_reverse_lut,
377
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
378
+ sparsity_reverse_lut_sorted,
379
+ max_blocks_line,
377
380
  sparsity_block_size))
378
381
 
379
382
  return output
380
383
 
381
384
 
382
385
  def softmax_fused_backward_wrapper(ctx, grad_output):
383
- o, sparsity_layout, sparsity_reverse_lut = ctx.saved_tensors
386
+ o, sparsity_layout, sparsity_reverse_lut_sorted = ctx.saved_tensors
387
+ max_blocks_line = ctx.max_blocks_line
384
388
  sparsity_block_size = ctx.sparsity_block_size
385
389
 
386
- return softmax_fused_backward(grad_output, o, sparsity_reverse_lut, sparsity_layout,
387
- sparsity_block_size), None, None, None, None, None
390
+ return softmax_fused_backward(grad_output, o, sparsity_reverse_lut_sorted, sparsity_layout,
391
+ max_blocks_line, sparsity_block_size), None, None, None, None
388
392
 
389
393
 
390
394
  @triton_op("blksprs::softmax_fused_backward", mutates_args={})
391
- def softmax_fused_backward(grad_output: Tensor, o: Tensor, sparsity_reverse_lut: Tensor, sparsity_layout: Tensor,
395
+ def softmax_fused_backward(grad_output: Tensor,
396
+ o: Tensor,
397
+ sparsity_reverse_lut_sorted: Tensor,
398
+ sparsity_layout: Tensor,
399
+ max_blocks_line: int,
392
400
  sparsity_block_size: int) -> Tensor:
393
401
  with torch.no_grad():
394
402
  grad_x = torch.zeros_like(o)
@@ -409,9 +417,10 @@ def softmax_fused_backward(grad_output: Tensor, o: Tensor, sparsity_reverse_lut:
409
417
  g_b, g_b_s, g_r_s, g_c_s,
410
418
  o,
411
419
  o_b, o_b_s, o_r_s, o_c_s,
412
- s_l_b, s_l_b_s, s_l_r_s, s_l_c, s_l_c_s,
413
- sparsity_reverse_lut,
420
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
421
+ sparsity_reverse_lut_sorted,
414
422
  grad_x,
423
+ max_blocks_line,
415
424
  sparsity_block_size))
416
425
 
417
426
  return grad_x
@@ -428,8 +437,9 @@ def softmax_fused_backward(grad_output: Tensor, o: Tensor, sparsity_reverse_lut:
428
437
  def softmax_fused_kernel(x,
429
438
  x_b, x_b_s, x_r_s, x_c_s,
430
439
  o,
431
- s_l_b, s_l_b_s, s_l_r_s, s_l_c: tl.constexpr, s_l_c_s,
432
- r_lut,
440
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
441
+ r_lut_s,
442
+ mbs: tl.constexpr,
433
443
  sparsity_block_size: tl.constexpr,
434
444
  TRITON_BLOCK_SIZE: tl.constexpr) -> None:
435
445
  # Get triton block indices
@@ -440,21 +450,21 @@ def softmax_fused_kernel(x,
440
450
  # Load reverse sparsity indices of row
441
451
  blk_rev_idx = (pid_bat * s_l_b_s +
442
452
  pid_row * s_l_r_s +
443
- (tl.arange(0, s_l_c) * s_l_c_s))
453
+ (tl.arange(0, mbs) * s_l_c_s))
444
454
  blk_rev_msk = (blk_rev_idx >= 0 and blk_rev_idx < s_l_b * s_l_b_s)
445
- blk_rev = tl.load(r_lut + blk_rev_idx, mask=blk_rev_msk).to(tl.int32)
455
+ blk_rev = tl.load(r_lut_s + blk_rev_idx, mask=blk_rev_msk).to(tl.int32)
446
456
 
447
457
  if (not (tl.min(blk_rev) == -1 and
448
458
  tl.max(blk_rev) == -1)):
449
459
  # Extend sparsity indices to cover sparsity blocks
450
460
  blk_rev_ext = tl.expand_dims(blk_rev, -1)
451
- blk_rev_ext = tl.broadcast_to(blk_rev_ext, (s_l_c, sparsity_block_size))
452
- blk_rev_ext = tl.reshape(blk_rev_ext, (s_l_c * sparsity_block_size))
461
+ blk_rev_ext = tl.broadcast_to(blk_rev_ext, (mbs, sparsity_block_size))
462
+ blk_rev_ext = tl.reshape(blk_rev_ext, (mbs * sparsity_block_size))
453
463
 
454
464
  # Load line of x
455
465
  blk_x_idx = (blk_rev_ext * x_b_s +
456
466
  pid_lin * x_r_s +
457
- (tl.arange(0, s_l_c * sparsity_block_size) % sparsity_block_size) * x_c_s)
467
+ (tl.arange(0, mbs * sparsity_block_size) % sparsity_block_size) * x_c_s)
458
468
  blk_x_mask = ((blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
459
469
  and blk_rev_ext != -1)
460
470
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_mask, other=float("-inf"))
@@ -478,9 +488,10 @@ def softmax_fused_kernel_grad(g,
478
488
  g_b, g_b_s, g_r_s, g_c_s,
479
489
  x,
480
490
  x_b, x_b_s, x_r_s, x_c_s,
481
- s_l_b, s_l_b_s, s_l_r_s, s_l_c: tl.constexpr, s_l_c_s,
482
- r_lut,
491
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
492
+ r_lut_s,
483
493
  o,
494
+ mbs: tl.constexpr,
484
495
  sparsity_block_size: tl.constexpr,
485
496
  TRITON_BLOCK_SIZE: tl.constexpr) -> None:
486
497
  # Get triton block indices
@@ -491,21 +502,21 @@ def softmax_fused_kernel_grad(g,
491
502
  # Load reverse sparsity indices of row
492
503
  blk_rev_idx = (pid_bat * s_l_b_s +
493
504
  pid_row * s_l_r_s +
494
- (tl.arange(0, s_l_c) * s_l_c_s))
505
+ (tl.arange(0, mbs) * s_l_c_s))
495
506
  blk_rev_msk = (blk_rev_idx >= 0 and blk_rev_idx < s_l_b * s_l_b_s)
496
- blk_rev = tl.load(r_lut + blk_rev_idx, mask=blk_rev_msk).to(tl.int32)
507
+ blk_rev = tl.load(r_lut_s + blk_rev_idx, mask=blk_rev_msk).to(tl.int32)
497
508
 
498
509
  if (not (tl.min(blk_rev) == -1 and
499
510
  tl.max(blk_rev) == -1)):
500
511
  # Extend sparsity indices to cover sparsity blocks
501
512
  blk_rev_ext = tl.expand_dims(blk_rev, -1)
502
- blk_rev_ext = tl.broadcast_to(blk_rev_ext, (s_l_c, sparsity_block_size))
503
- blk_rev_ext = tl.reshape(blk_rev_ext, (s_l_c * sparsity_block_size))
513
+ blk_rev_ext = tl.broadcast_to(blk_rev_ext, (mbs, sparsity_block_size))
514
+ blk_rev_ext = tl.reshape(blk_rev_ext, (mbs * sparsity_block_size))
504
515
 
505
516
  # Load line of g
506
517
  blk_g_idx = (blk_rev_ext * g_b_s +
507
518
  pid_lin * g_r_s +
508
- (tl.arange(0, s_l_c * sparsity_block_size) % sparsity_block_size) * g_c_s)
519
+ (tl.arange(0, mbs * sparsity_block_size) % sparsity_block_size) * g_c_s)
509
520
  blk_g_mask = ((blk_g_idx >= 0 and blk_g_idx < g_b * g_b_s)
510
521
  and blk_rev_ext != -1)
511
522
  blk_g = tl.load(g + blk_g_idx, mask=blk_g_mask)
@@ -513,7 +524,7 @@ def softmax_fused_kernel_grad(g,
513
524
  # Load line of x
514
525
  blk_x_idx = (blk_rev_ext * x_b_s +
515
526
  pid_lin * x_r_s +
516
- (tl.arange(0, s_l_c * sparsity_block_size) % sparsity_block_size) * x_c_s)
527
+ (tl.arange(0, mbs * sparsity_block_size) % sparsity_block_size) * x_c_s)
517
528
  blk_x_mask = ((blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
518
529
  and blk_rev_ext != -1)
519
530
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_mask)
@@ -521,6 +532,7 @@ def softmax_fused_kernel_grad(g,
521
532
  # Compute gradients
522
533
  blk_grad = blk_x * (blk_g - tl.sum(blk_x * blk_g))
523
534
 
535
+ # Store output
524
536
  tl.store(o + blk_x_idx, blk_grad, mask=blk_x_mask)
525
537
 
526
538
 
@@ -528,25 +540,36 @@ def softmax_fused_build_lut(lut: dict, sparsity_layout: Tensor):
528
540
  if lut is None:
529
541
  lut = dict()
530
542
 
531
- if "sparsity_reverse_lut" not in lut:
543
+ if "sparsity_reverse_lut_sorted" not in lut:
532
544
  sparsity_layout_flat = sparsity_layout.reshape(-1)
533
- sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
534
- (sparsity_layout_flat == 1) -
535
- (1 * (sparsity_layout_flat == 0)))
536
- .reshape(sparsity_layout.size())
537
- .reshape(-1).contiguous())
538
- lut["sparsity_reverse_lut"] = sparsity_reverse_lut
539
-
540
- validate_contiguous(sparsity_layout, lut["sparsity_reverse_lut"])
545
+ sparsity_reverse_lut_sorted = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
546
+ (sparsity_layout_flat == 1) -
547
+ (1 * (sparsity_layout_flat == 0)))
548
+ .reshape(sparsity_layout.size())
549
+ .sort(descending=True, dim=-1)[0]
550
+ .reshape(-1).contiguous())
551
+ lut["sparsity_reverse_lut_sorted"] = sparsity_reverse_lut_sorted
552
+
553
+ if "max_blocks_line" not in lut:
554
+ sparsity_reverse_lut_sorted = lut["sparsity_reverse_lut_sorted"]
555
+ max_blocks_line = ((torch.reshape(sparsity_reverse_lut_sorted, (-1, sparsity_layout.size(-1)))
556
+ != -1)
557
+ .sum(dim=-1)
558
+ .max()
559
+ .item())
560
+ lut["max_blocks_line"] = min(ceil_pow2(max(max_blocks_line, 2)), sparsity_layout.size(-1))
561
+
562
+ validate_contiguous(sparsity_layout, lut["sparsity_reverse_lut_sorted"])
541
563
 
542
564
  return lut
543
565
 
544
566
 
545
567
  # noinspection PyUnusedLocal
546
568
  def softmax_fused_setup_context(ctx, inputs, output):
547
- (_, sparsity_layout, sparsity_reverse_lut, sparsity_block_size) = inputs
569
+ (_, sparsity_layout, sparsity_reverse_lut_sorted, max_blocks_line, sparsity_block_size) = inputs
548
570
 
549
- ctx.save_for_backward(output, sparsity_layout, sparsity_reverse_lut)
571
+ ctx.save_for_backward(output, sparsity_layout, sparsity_reverse_lut_sorted)
572
+ ctx.max_blocks_line = max_blocks_line
550
573
  ctx.sparsity_block_size = sparsity_block_size
551
574
 
552
575
 
blksprs/utils/tools.py CHANGED
@@ -26,3 +26,8 @@ def stride(x: Tensor):
26
26
  return x.size(1) * x.size(2), x.size(2), 1
27
27
  else:
28
28
  raise NotImplementedError
29
+
30
+ def ceil_pow2(x: int) -> int:
31
+ if x <= 0:
32
+ raise ValueError("Input must be a positive integer.")
33
+ return 1 << (x - 1).bit_length()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: blksprs
3
- Version: 2.1.1
3
+ Version: 2.1.2
4
4
  Summary: A lightweight library for operations on block-sparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -1,4 +1,4 @@
1
- blksprs/__init__.py,sha256=KrLh_rkijisv0BXHY6hwCiGLQMVfw--jnAE-91f0C_k,1615
1
+ blksprs/__init__.py,sha256=NRxydw4i9jg7WeDuojfEePdtdbughV9AZsEcT9yywK4,1615
2
2
  blksprs/layouting/distribution_layout.py,sha256=ur1ty_2U-Hfj78hMWsLZvu7ZuGhzW3qGLKMc72DfTZM,5861
3
3
  blksprs/layouting/sparsity_layout.py,sha256=eXHmu2h7K5Q-YUpfOxocJoeP_5ZoQFZf_eHLxRZQbYU,11207
4
4
  blksprs/ops/conversion.py,sha256=RgVSyiULLwv8KWQqSyXpKwTr4Qp-lpDK9i-zKlN841I,21914
@@ -7,7 +7,7 @@ blksprs/ops/flow.py,sha256=qdWBCLDSkKaa8CAfkO1NgH-J5N7yMsILyR7qEpyrIUU,8246
7
7
  blksprs/ops/matmul.py,sha256=5tVBKU_lglUjaLDi6J_dscdqlmzRz38OGxqAxZxZXDs,11879
8
8
  blksprs/ops/partitioning.py,sha256=cfQmY9BZqGTvvJorIhtb-EyuGRJGPraWR-wTKdb47aI,9954
9
9
  blksprs/ops/repeat.py,sha256=TLYNxwPuT9y5K9xyM41WK5gnggAJF3lI61Q2K7zWjns,9035
10
- blksprs/ops/softmax.py,sha256=ByiEoM4dEt1IlRMkSDTJZh8CTk0OkBcyGbA_j1prkOw,22397
10
+ blksprs/ops/softmax.py,sha256=1UIovPrdE_zgAIPqjmOTFn8CMbd_2Z8tPP-vMBxU07I,23526
11
11
  blksprs/ops/transpose.py,sha256=U-VAyLRT6_NDv9qYSFzBqfVlDeIpTqAMEXkqto0VF6w,4072
12
12
  blksprs/ops/misc/broadcast_ops.py,sha256=-PrHiSJikZh8nXUmXxSCtFEP27TTxFr4wcrNxBjnimk,5987
13
13
  blksprs/ops/misc/row_wise.py,sha256=n5FJjAuOd8BHBJQx4bsQwr-HmXkR9PYVAqfk77wjOFU,19653
@@ -15,9 +15,9 @@ blksprs/utils/autotuning.py,sha256=a-kmWRjJ3eED2XbjkQeOJSyW8bdIs27HgKMPvAKqWeU,2
15
15
  blksprs/utils/benchmarking.py,sha256=dLabDscTFn5NkmOI1g7DnKeTneUYW3RIVv9MDF-8BKc,1271
16
16
  blksprs/utils/blksprs_tensor.py,sha256=pfoz59aJixj_fIoFx76ySiygwRQUemmgjMKepZ2c4j0,244
17
17
  blksprs/utils/processing.py,sha256=RNkEDc0g-sNHRuMPkRzNWU13d3_lIkXMJdoqES4yQTM,3738
18
- blksprs/utils/tools.py,sha256=CPf7viQ2OTcZFrB1aSL8_us4VE9M6YEfDz2dE30jr9I,715
18
+ blksprs/utils/tools.py,sha256=TKygEKge4wJtJnXXDg8BTL8vzBpqIJsQ_A3_5FmLpcE,859
19
19
  blksprs/utils/validation.py,sha256=G8eQlvJVMKfEX3k2AwBD0A6Ck-gFoRLpLNY6HXsB3fA,4348
20
- blksprs-2.1.1.dist-info/METADATA,sha256=dcEdCX15J2yUzUix6-dJyNQru35gxOY8t0GrY8pFT4w,9665
21
- blksprs-2.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
22
- blksprs-2.1.1.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
23
- blksprs-2.1.1.dist-info/RECORD,,
20
+ blksprs-2.1.2.dist-info/METADATA,sha256=U20ZL7XLhrgiMd_0QGFik0Ci43SDoCT8q876-1yCeNo,9665
21
+ blksprs-2.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
22
+ blksprs-2.1.2.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
23
+ blksprs-2.1.2.dist-info/RECORD,,