blksprs 2.1__py3-none-any.whl → 2.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from blksprs.utils.blksprs_tensor import BlksprsTensor
2
2
 
3
- __version__ = "2.1"
3
+ __version__ = "2.1.2"
4
4
 
5
5
 
6
6
  class ops:
blksprs/ops/conversion.py CHANGED
@@ -13,7 +13,7 @@ from blksprs.utils.validation import validate_contiguous, validate_dimensions, v
13
13
 
14
14
 
15
15
  def to_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int) -> BlksprsTensor:
16
- """Wrapper for ``to_sparse``.
16
+ """Wrapper for :func:`to_sparse`.
17
17
 
18
18
  """
19
19
  return to_sparse(x, sparsity_layout, sparsity_block_size)
@@ -167,7 +167,7 @@ to_sparse_forward.register_autograd(to_sparse_wrapper_backward, setup_context=to
167
167
 
168
168
  def from_blksprs(x: BlksprsTensor, sparsity_layout: Tensor,
169
169
  sparsity_block_size: int, fill_value: float = 0, lut: dict = None) -> Tensor:
170
- """Wrapper for ``to_dense``.
170
+ """Wrapper for :func:`to_dense`.
171
171
 
172
172
  """
173
173
  return to_dense(x, sparsity_layout, sparsity_block_size, fill_value=fill_value, lut=lut)
blksprs/ops/softmax.py CHANGED
@@ -9,15 +9,26 @@ from triton import language as tl
9
9
 
10
10
  from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
11
11
  from blksprs.utils.blksprs_tensor import BlksprsTensor
12
- from blksprs.utils.debugging import dbg_tensor_full
13
- from blksprs.utils.tools import stride
12
+ from blksprs.utils.tools import stride, ceil_pow2
14
13
  from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
15
14
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
16
15
  validate_sparsity, validate_sparsity_block_size, validate_dtype_float_32
17
16
 
18
17
 
18
+ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, flag_fused: bool = True,
19
+ lut: dict = None) -> BlksprsTensor:
20
+ """Wrapper for :func:`softmax_regular` and :func:`softmax_fused` based on the ``flag_fused`` parameter.
21
+
22
+ """
23
+ if flag_fused:
24
+ return softmax_fused(x, sparsity_layout, sparsity_block_size, lut)
25
+ else:
26
+ return softmax_regular(x, sparsity_layout, sparsity_block_size, lut)
27
+
28
+
19
29
  @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
20
- def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
30
+ def softmax_regular(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
31
+ lut: dict = None) -> BlksprsTensor:
21
32
  """Computes the softmax of a block-sparse tensor in compressed form.
22
33
 
23
34
  Note:
@@ -338,13 +349,15 @@ def softmax_fused(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size
338
349
  lut = softmax_fused_build_lut(lut, sparsity_layout)
339
350
 
340
351
  return BlksprsTensor(softmax_fused_forward(x, sparsity_layout,
341
- lut["sparsity_reverse_lut"],
352
+ lut["sparsity_reverse_lut_sorted"],
353
+ lut["max_blocks_line"],
342
354
  sparsity_block_size))
343
355
 
344
356
 
345
357
  @triton_op("blksprs::softmax_fused_forward", mutates_args={})
346
358
  def softmax_fused_forward(x: Tensor, sparsity_layout: Tensor,
347
- sparsity_reverse_lut: Tensor,
359
+ sparsity_reverse_lut_sorted: Tensor,
360
+ max_blocks_line: int,
348
361
  sparsity_block_size: int) -> Tensor:
349
362
  output = torch.zeros_like(x)
350
363
 
@@ -361,23 +374,29 @@ def softmax_fused_forward(x: Tensor, sparsity_layout: Tensor,
361
374
  (x,
362
375
  x_b, x_b_s, x_r_s, x_c_s,
363
376
  output,
364
- s_l_b, s_l_b_s, s_l_r_s, s_l_c, s_l_c_s,
365
- sparsity_reverse_lut,
377
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
378
+ sparsity_reverse_lut_sorted,
379
+ max_blocks_line,
366
380
  sparsity_block_size))
367
381
 
368
382
  return output
369
383
 
370
384
 
371
385
  def softmax_fused_backward_wrapper(ctx, grad_output):
372
- o, sparsity_layout, sparsity_reverse_lut = ctx.saved_tensors
386
+ o, sparsity_layout, sparsity_reverse_lut_sorted = ctx.saved_tensors
387
+ max_blocks_line = ctx.max_blocks_line
373
388
  sparsity_block_size = ctx.sparsity_block_size
374
389
 
375
- return softmax_fused_backward(grad_output, o, sparsity_reverse_lut, sparsity_layout,
376
- sparsity_block_size), None, None, None, None, None
390
+ return softmax_fused_backward(grad_output, o, sparsity_reverse_lut_sorted, sparsity_layout,
391
+ max_blocks_line, sparsity_block_size), None, None, None, None
377
392
 
378
393
 
379
394
  @triton_op("blksprs::softmax_fused_backward", mutates_args={})
380
- def softmax_fused_backward(grad_output: Tensor, o: Tensor, sparsity_reverse_lut: Tensor, sparsity_layout: Tensor,
395
+ def softmax_fused_backward(grad_output: Tensor,
396
+ o: Tensor,
397
+ sparsity_reverse_lut_sorted: Tensor,
398
+ sparsity_layout: Tensor,
399
+ max_blocks_line: int,
381
400
  sparsity_block_size: int) -> Tensor:
382
401
  with torch.no_grad():
383
402
  grad_x = torch.zeros_like(o)
@@ -398,9 +417,10 @@ def softmax_fused_backward(grad_output: Tensor, o: Tensor, sparsity_reverse_lut:
398
417
  g_b, g_b_s, g_r_s, g_c_s,
399
418
  o,
400
419
  o_b, o_b_s, o_r_s, o_c_s,
401
- s_l_b, s_l_b_s, s_l_r_s, s_l_c, s_l_c_s,
402
- sparsity_reverse_lut,
420
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
421
+ sparsity_reverse_lut_sorted,
403
422
  grad_x,
423
+ max_blocks_line,
404
424
  sparsity_block_size))
405
425
 
406
426
  return grad_x
@@ -417,8 +437,9 @@ def softmax_fused_backward(grad_output: Tensor, o: Tensor, sparsity_reverse_lut:
417
437
  def softmax_fused_kernel(x,
418
438
  x_b, x_b_s, x_r_s, x_c_s,
419
439
  o,
420
- s_l_b, s_l_b_s, s_l_r_s, s_l_c: tl.constexpr, s_l_c_s,
421
- r_lut,
440
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
441
+ r_lut_s,
442
+ mbs: tl.constexpr,
422
443
  sparsity_block_size: tl.constexpr,
423
444
  TRITON_BLOCK_SIZE: tl.constexpr) -> None:
424
445
  # Get triton block indices
@@ -429,21 +450,21 @@ def softmax_fused_kernel(x,
429
450
  # Load reverse sparsity indices of row
430
451
  blk_rev_idx = (pid_bat * s_l_b_s +
431
452
  pid_row * s_l_r_s +
432
- (tl.arange(0, s_l_c) * s_l_c_s))
453
+ (tl.arange(0, mbs) * s_l_c_s))
433
454
  blk_rev_msk = (blk_rev_idx >= 0 and blk_rev_idx < s_l_b * s_l_b_s)
434
- blk_rev = tl.load(r_lut + blk_rev_idx, mask=blk_rev_msk).to(tl.int32)
455
+ blk_rev = tl.load(r_lut_s + blk_rev_idx, mask=blk_rev_msk).to(tl.int32)
435
456
 
436
457
  if (not (tl.min(blk_rev) == -1 and
437
458
  tl.max(blk_rev) == -1)):
438
459
  # Extend sparsity indices to cover sparsity blocks
439
460
  blk_rev_ext = tl.expand_dims(blk_rev, -1)
440
- blk_rev_ext = tl.broadcast_to(blk_rev_ext, (s_l_c, sparsity_block_size))
441
- blk_rev_ext = tl.reshape(blk_rev_ext, (s_l_c * sparsity_block_size))
461
+ blk_rev_ext = tl.broadcast_to(blk_rev_ext, (mbs, sparsity_block_size))
462
+ blk_rev_ext = tl.reshape(blk_rev_ext, (mbs * sparsity_block_size))
442
463
 
443
464
  # Load line of x
444
465
  blk_x_idx = (blk_rev_ext * x_b_s +
445
466
  pid_lin * x_r_s +
446
- (tl.arange(0, s_l_c * sparsity_block_size) % sparsity_block_size) * x_c_s)
467
+ (tl.arange(0, mbs * sparsity_block_size) % sparsity_block_size) * x_c_s)
447
468
  blk_x_mask = ((blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
448
469
  and blk_rev_ext != -1)
449
470
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_mask, other=float("-inf"))
@@ -467,9 +488,10 @@ def softmax_fused_kernel_grad(g,
467
488
  g_b, g_b_s, g_r_s, g_c_s,
468
489
  x,
469
490
  x_b, x_b_s, x_r_s, x_c_s,
470
- s_l_b, s_l_b_s, s_l_r_s, s_l_c: tl.constexpr, s_l_c_s,
471
- r_lut,
491
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
492
+ r_lut_s,
472
493
  o,
494
+ mbs: tl.constexpr,
473
495
  sparsity_block_size: tl.constexpr,
474
496
  TRITON_BLOCK_SIZE: tl.constexpr) -> None:
475
497
  # Get triton block indices
@@ -480,21 +502,21 @@ def softmax_fused_kernel_grad(g,
480
502
  # Load reverse sparsity indices of row
481
503
  blk_rev_idx = (pid_bat * s_l_b_s +
482
504
  pid_row * s_l_r_s +
483
- (tl.arange(0, s_l_c) * s_l_c_s))
505
+ (tl.arange(0, mbs) * s_l_c_s))
484
506
  blk_rev_msk = (blk_rev_idx >= 0 and blk_rev_idx < s_l_b * s_l_b_s)
485
- blk_rev = tl.load(r_lut + blk_rev_idx, mask=blk_rev_msk).to(tl.int32)
507
+ blk_rev = tl.load(r_lut_s + blk_rev_idx, mask=blk_rev_msk).to(tl.int32)
486
508
 
487
509
  if (not (tl.min(blk_rev) == -1 and
488
510
  tl.max(blk_rev) == -1)):
489
511
  # Extend sparsity indices to cover sparsity blocks
490
512
  blk_rev_ext = tl.expand_dims(blk_rev, -1)
491
- blk_rev_ext = tl.broadcast_to(blk_rev_ext, (s_l_c, sparsity_block_size))
492
- blk_rev_ext = tl.reshape(blk_rev_ext, (s_l_c * sparsity_block_size))
513
+ blk_rev_ext = tl.broadcast_to(blk_rev_ext, (mbs, sparsity_block_size))
514
+ blk_rev_ext = tl.reshape(blk_rev_ext, (mbs * sparsity_block_size))
493
515
 
494
516
  # Load line of g
495
517
  blk_g_idx = (blk_rev_ext * g_b_s +
496
518
  pid_lin * g_r_s +
497
- (tl.arange(0, s_l_c * sparsity_block_size) % sparsity_block_size) * g_c_s)
519
+ (tl.arange(0, mbs * sparsity_block_size) % sparsity_block_size) * g_c_s)
498
520
  blk_g_mask = ((blk_g_idx >= 0 and blk_g_idx < g_b * g_b_s)
499
521
  and blk_rev_ext != -1)
500
522
  blk_g = tl.load(g + blk_g_idx, mask=blk_g_mask)
@@ -502,7 +524,7 @@ def softmax_fused_kernel_grad(g,
502
524
  # Load line of x
503
525
  blk_x_idx = (blk_rev_ext * x_b_s +
504
526
  pid_lin * x_r_s +
505
- (tl.arange(0, s_l_c * sparsity_block_size) % sparsity_block_size) * x_c_s)
527
+ (tl.arange(0, mbs * sparsity_block_size) % sparsity_block_size) * x_c_s)
506
528
  blk_x_mask = ((blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
507
529
  and blk_rev_ext != -1)
508
530
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_mask)
@@ -510,6 +532,7 @@ def softmax_fused_kernel_grad(g,
510
532
  # Compute gradients
511
533
  blk_grad = blk_x * (blk_g - tl.sum(blk_x * blk_g))
512
534
 
535
+ # Store output
513
536
  tl.store(o + blk_x_idx, blk_grad, mask=blk_x_mask)
514
537
 
515
538
 
@@ -517,25 +540,36 @@ def softmax_fused_build_lut(lut: dict, sparsity_layout: Tensor):
517
540
  if lut is None:
518
541
  lut = dict()
519
542
 
520
- if "sparsity_reverse_lut" not in lut:
543
+ if "sparsity_reverse_lut_sorted" not in lut:
521
544
  sparsity_layout_flat = sparsity_layout.reshape(-1)
522
- sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
523
- (sparsity_layout_flat == 1) -
524
- (1 * (sparsity_layout_flat == 0)))
525
- .reshape(sparsity_layout.size())
526
- .reshape(-1).contiguous())
527
- lut["sparsity_reverse_lut"] = sparsity_reverse_lut
528
-
529
- validate_contiguous(sparsity_layout, lut["sparsity_reverse_lut"])
545
+ sparsity_reverse_lut_sorted = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
546
+ (sparsity_layout_flat == 1) -
547
+ (1 * (sparsity_layout_flat == 0)))
548
+ .reshape(sparsity_layout.size())
549
+ .sort(descending=True, dim=-1)[0]
550
+ .reshape(-1).contiguous())
551
+ lut["sparsity_reverse_lut_sorted"] = sparsity_reverse_lut_sorted
552
+
553
+ if "max_blocks_line" not in lut:
554
+ sparsity_reverse_lut_sorted = lut["sparsity_reverse_lut_sorted"]
555
+ max_blocks_line = ((torch.reshape(sparsity_reverse_lut_sorted, (-1, sparsity_layout.size(-1)))
556
+ != -1)
557
+ .sum(dim=-1)
558
+ .max()
559
+ .item())
560
+ lut["max_blocks_line"] = min(ceil_pow2(max(max_blocks_line, 2)), sparsity_layout.size(-1))
561
+
562
+ validate_contiguous(sparsity_layout, lut["sparsity_reverse_lut_sorted"])
530
563
 
531
564
  return lut
532
565
 
533
566
 
534
567
  # noinspection PyUnusedLocal
535
568
  def softmax_fused_setup_context(ctx, inputs, output):
536
- (_, sparsity_layout, sparsity_reverse_lut, sparsity_block_size) = inputs
569
+ (_, sparsity_layout, sparsity_reverse_lut_sorted, max_blocks_line, sparsity_block_size) = inputs
537
570
 
538
- ctx.save_for_backward(output, sparsity_layout, sparsity_reverse_lut)
571
+ ctx.save_for_backward(output, sparsity_layout, sparsity_reverse_lut_sorted)
572
+ ctx.max_blocks_line = max_blocks_line
539
573
  ctx.sparsity_block_size = sparsity_block_size
540
574
 
541
575
 
blksprs/utils/tools.py CHANGED
@@ -26,3 +26,8 @@ def stride(x: Tensor):
26
26
  return x.size(1) * x.size(2), x.size(2), 1
27
27
  else:
28
28
  raise NotImplementedError
29
+
30
+ def ceil_pow2(x: int) -> int:
31
+ if x <= 0:
32
+ raise ValueError("Input must be a positive integer.")
33
+ return 1 << (x - 1).bit_length()
@@ -1,7 +1,7 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: blksprs
3
- Version: 2.1
4
- Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
3
+ Version: 2.1.2
4
+ Summary: A lightweight library for operations on block-sparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
7
7
  Project-URL: Bugtracker, https://github.com/FelixSchoen/blksprs/issues
@@ -197,6 +197,7 @@ def test_readme():
197
197
  # Other available functions
198
198
  bs.ops.transpose(o_sparse, sparsity_layout_o, sparsity_block_size)
199
199
  bs.ops.softmax(o_sparse, sparsity_layout_o, sparsity_block_size)
200
+ bs.ops.softmax_fused(o_sparse, sparsity_layout_o, sparsity_block_size) # Significantly faster version that requires that rows of matrix fit into memory
200
201
  bs.ops.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size)
201
202
  bs.ops.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size)
202
203
 
@@ -1,13 +1,13 @@
1
- blksprs/__init__.py,sha256=o_Rj7fz_70vbMGLePihczVIVcM8E28vY3ah-d1q4ZO0,1613
1
+ blksprs/__init__.py,sha256=NRxydw4i9jg7WeDuojfEePdtdbughV9AZsEcT9yywK4,1615
2
2
  blksprs/layouting/distribution_layout.py,sha256=ur1ty_2U-Hfj78hMWsLZvu7ZuGhzW3qGLKMc72DfTZM,5861
3
3
  blksprs/layouting/sparsity_layout.py,sha256=eXHmu2h7K5Q-YUpfOxocJoeP_5ZoQFZf_eHLxRZQbYU,11207
4
- blksprs/ops/conversion.py,sha256=kf5HKofZ4nVeHCIqQoYKiIlgsAhq33Tnmnr1c17Fkqs,21906
4
+ blksprs/ops/conversion.py,sha256=RgVSyiULLwv8KWQqSyXpKwTr4Qp-lpDK9i-zKlN841I,21914
5
5
  blksprs/ops/distribution.py,sha256=0tPldv0ARzmCV1CU2jvfqpHBgOuHPrDFiCtqsLs7CZc,20789
6
6
  blksprs/ops/flow.py,sha256=qdWBCLDSkKaa8CAfkO1NgH-J5N7yMsILyR7qEpyrIUU,8246
7
7
  blksprs/ops/matmul.py,sha256=5tVBKU_lglUjaLDi6J_dscdqlmzRz38OGxqAxZxZXDs,11879
8
8
  blksprs/ops/partitioning.py,sha256=cfQmY9BZqGTvvJorIhtb-EyuGRJGPraWR-wTKdb47aI,9954
9
9
  blksprs/ops/repeat.py,sha256=TLYNxwPuT9y5K9xyM41WK5gnggAJF3lI61Q2K7zWjns,9035
10
- blksprs/ops/softmax.py,sha256=H0OxST_XX1QLa7HDTDHznzibVHAxnp5sVbMU32HLxf0,21967
10
+ blksprs/ops/softmax.py,sha256=1UIovPrdE_zgAIPqjmOTFn8CMbd_2Z8tPP-vMBxU07I,23526
11
11
  blksprs/ops/transpose.py,sha256=U-VAyLRT6_NDv9qYSFzBqfVlDeIpTqAMEXkqto0VF6w,4072
12
12
  blksprs/ops/misc/broadcast_ops.py,sha256=-PrHiSJikZh8nXUmXxSCtFEP27TTxFr4wcrNxBjnimk,5987
13
13
  blksprs/ops/misc/row_wise.py,sha256=n5FJjAuOd8BHBJQx4bsQwr-HmXkR9PYVAqfk77wjOFU,19653
@@ -15,9 +15,9 @@ blksprs/utils/autotuning.py,sha256=a-kmWRjJ3eED2XbjkQeOJSyW8bdIs27HgKMPvAKqWeU,2
15
15
  blksprs/utils/benchmarking.py,sha256=dLabDscTFn5NkmOI1g7DnKeTneUYW3RIVv9MDF-8BKc,1271
16
16
  blksprs/utils/blksprs_tensor.py,sha256=pfoz59aJixj_fIoFx76ySiygwRQUemmgjMKepZ2c4j0,244
17
17
  blksprs/utils/processing.py,sha256=RNkEDc0g-sNHRuMPkRzNWU13d3_lIkXMJdoqES4yQTM,3738
18
- blksprs/utils/tools.py,sha256=CPf7viQ2OTcZFrB1aSL8_us4VE9M6YEfDz2dE30jr9I,715
18
+ blksprs/utils/tools.py,sha256=TKygEKge4wJtJnXXDg8BTL8vzBpqIJsQ_A3_5FmLpcE,859
19
19
  blksprs/utils/validation.py,sha256=G8eQlvJVMKfEX3k2AwBD0A6Ck-gFoRLpLNY6HXsB3fA,4348
20
- blksprs-2.1.dist-info/METADATA,sha256=uPVm8Y7fX5iModz6j3hNAftdtauCsJ-iYrMa-Pv3xnU,9506
21
- blksprs-2.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
22
- blksprs-2.1.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
23
- blksprs-2.1.dist-info/RECORD,,
20
+ blksprs-2.1.2.dist-info/METADATA,sha256=U20ZL7XLhrgiMd_0QGFik0Ci43SDoCT8q876-1yCeNo,9665
21
+ blksprs-2.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
22
+ blksprs-2.1.2.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
23
+ blksprs-2.1.2.dist-info/RECORD,,
File without changes