blksprs 2.1__py3-none-any.whl → 2.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from blksprs.utils.blksprs_tensor import BlksprsTensor
2
2
 
3
- __version__ = "2.1"
3
+ __version__ = "2.1.1"
4
4
 
5
5
 
6
6
  class ops:
blksprs/ops/conversion.py CHANGED
@@ -13,7 +13,7 @@ from blksprs.utils.validation import validate_contiguous, validate_dimensions, v
13
13
 
14
14
 
15
15
  def to_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int) -> BlksprsTensor:
16
- """Wrapper for ``to_sparse``.
16
+ """Wrapper for :func:`to_sparse`.
17
17
 
18
18
  """
19
19
  return to_sparse(x, sparsity_layout, sparsity_block_size)
@@ -167,7 +167,7 @@ to_sparse_forward.register_autograd(to_sparse_wrapper_backward, setup_context=to
167
167
 
168
168
  def from_blksprs(x: BlksprsTensor, sparsity_layout: Tensor,
169
169
  sparsity_block_size: int, fill_value: float = 0, lut: dict = None) -> Tensor:
170
- """Wrapper for ``to_dense``.
170
+ """Wrapper for :func:`to_dense`.
171
171
 
172
172
  """
173
173
  return to_dense(x, sparsity_layout, sparsity_block_size, fill_value=fill_value, lut=lut)
blksprs/ops/softmax.py CHANGED
@@ -9,15 +9,26 @@ from triton import language as tl
9
9
 
10
10
  from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
11
11
  from blksprs.utils.blksprs_tensor import BlksprsTensor
12
- from blksprs.utils.debugging import dbg_tensor_full
13
12
  from blksprs.utils.tools import stride
14
13
  from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
15
14
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
16
15
  validate_sparsity, validate_sparsity_block_size, validate_dtype_float_32
17
16
 
18
17
 
18
+ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, flag_fused: bool = True,
19
+ lut: dict = None) -> BlksprsTensor:
20
+ """Wrapper for :func:`softmax_regular` and :func:`softmax_fused` based on the ``flag_fused`` parameter.
21
+
22
+ """
23
+ if flag_fused:
24
+ return softmax_fused(x, sparsity_layout, sparsity_block_size, lut)
25
+ else:
26
+ return softmax_regular(x, sparsity_layout, sparsity_block_size, lut)
27
+
28
+
19
29
  @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
20
- def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
30
+ def softmax_regular(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
31
+ lut: dict = None) -> BlksprsTensor:
21
32
  """Computes the softmax of a block-sparse tensor in compressed form.
22
33
 
23
34
  Note:
@@ -1,7 +1,7 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: blksprs
3
- Version: 2.1
4
- Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
3
+ Version: 2.1.1
4
+ Summary: A lightweight library for operations on block-sparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
7
7
  Project-URL: Bugtracker, https://github.com/FelixSchoen/blksprs/issues
@@ -197,6 +197,7 @@ def test_readme():
197
197
  # Other available functions
198
198
  bs.ops.transpose(o_sparse, sparsity_layout_o, sparsity_block_size)
199
199
  bs.ops.softmax(o_sparse, sparsity_layout_o, sparsity_block_size)
200
+ bs.ops.softmax_fused(o_sparse, sparsity_layout_o, sparsity_block_size) # Significantly faster version that requires that rows of matrix fit into memory
200
201
  bs.ops.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size)
201
202
  bs.ops.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size)
202
203
 
@@ -1,13 +1,13 @@
1
- blksprs/__init__.py,sha256=o_Rj7fz_70vbMGLePihczVIVcM8E28vY3ah-d1q4ZO0,1613
1
+ blksprs/__init__.py,sha256=KrLh_rkijisv0BXHY6hwCiGLQMVfw--jnAE-91f0C_k,1615
2
2
  blksprs/layouting/distribution_layout.py,sha256=ur1ty_2U-Hfj78hMWsLZvu7ZuGhzW3qGLKMc72DfTZM,5861
3
3
  blksprs/layouting/sparsity_layout.py,sha256=eXHmu2h7K5Q-YUpfOxocJoeP_5ZoQFZf_eHLxRZQbYU,11207
4
- blksprs/ops/conversion.py,sha256=kf5HKofZ4nVeHCIqQoYKiIlgsAhq33Tnmnr1c17Fkqs,21906
4
+ blksprs/ops/conversion.py,sha256=RgVSyiULLwv8KWQqSyXpKwTr4Qp-lpDK9i-zKlN841I,21914
5
5
  blksprs/ops/distribution.py,sha256=0tPldv0ARzmCV1CU2jvfqpHBgOuHPrDFiCtqsLs7CZc,20789
6
6
  blksprs/ops/flow.py,sha256=qdWBCLDSkKaa8CAfkO1NgH-J5N7yMsILyR7qEpyrIUU,8246
7
7
  blksprs/ops/matmul.py,sha256=5tVBKU_lglUjaLDi6J_dscdqlmzRz38OGxqAxZxZXDs,11879
8
8
  blksprs/ops/partitioning.py,sha256=cfQmY9BZqGTvvJorIhtb-EyuGRJGPraWR-wTKdb47aI,9954
9
9
  blksprs/ops/repeat.py,sha256=TLYNxwPuT9y5K9xyM41WK5gnggAJF3lI61Q2K7zWjns,9035
10
- blksprs/ops/softmax.py,sha256=H0OxST_XX1QLa7HDTDHznzibVHAxnp5sVbMU32HLxf0,21967
10
+ blksprs/ops/softmax.py,sha256=ByiEoM4dEt1IlRMkSDTJZh8CTk0OkBcyGbA_j1prkOw,22397
11
11
  blksprs/ops/transpose.py,sha256=U-VAyLRT6_NDv9qYSFzBqfVlDeIpTqAMEXkqto0VF6w,4072
12
12
  blksprs/ops/misc/broadcast_ops.py,sha256=-PrHiSJikZh8nXUmXxSCtFEP27TTxFr4wcrNxBjnimk,5987
13
13
  blksprs/ops/misc/row_wise.py,sha256=n5FJjAuOd8BHBJQx4bsQwr-HmXkR9PYVAqfk77wjOFU,19653
@@ -17,7 +17,7 @@ blksprs/utils/blksprs_tensor.py,sha256=pfoz59aJixj_fIoFx76ySiygwRQUemmgjMKepZ2c4
17
17
  blksprs/utils/processing.py,sha256=RNkEDc0g-sNHRuMPkRzNWU13d3_lIkXMJdoqES4yQTM,3738
18
18
  blksprs/utils/tools.py,sha256=CPf7viQ2OTcZFrB1aSL8_us4VE9M6YEfDz2dE30jr9I,715
19
19
  blksprs/utils/validation.py,sha256=G8eQlvJVMKfEX3k2AwBD0A6Ck-gFoRLpLNY6HXsB3fA,4348
20
- blksprs-2.1.dist-info/METADATA,sha256=uPVm8Y7fX5iModz6j3hNAftdtauCsJ-iYrMa-Pv3xnU,9506
21
- blksprs-2.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
22
- blksprs-2.1.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
23
- blksprs-2.1.dist-info/RECORD,,
20
+ blksprs-2.1.1.dist-info/METADATA,sha256=dcEdCX15J2yUzUix6-dJyNQru35gxOY8t0GrY8pFT4w,9665
21
+ blksprs-2.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
22
+ blksprs-2.1.1.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
23
+ blksprs-2.1.1.dist-info/RECORD,,
File without changes