blksprs 1.8.2__py3-none-any.whl → 1.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/__init__.py CHANGED
@@ -1,27 +1,40 @@
1
- from blksprs.ops.conversion import to_dense, to_sparse, from_blksprs, to_blksprs
2
- from blksprs.ops.distribution import gather, scatter, scatter_reduce
3
- from blksprs.ops.matmul import matmul
4
- from blksprs.ops.softmax import softmax
5
- from blksprs.ops.transpose import transpose
6
- from blksprs.ops.repeat import repeat, repeat_interleave
7
- from blksprs.misc.partitioning import split, merge
1
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
8
2
 
3
+ class ops:
4
+ from blksprs.ops.conversion import to_dense, to_sparse, from_blksprs, to_blksprs, adapt_layout
5
+ from blksprs.ops.distribution import gather, scatter, scatter_reduce
6
+ from blksprs.ops.matmul import matmul
7
+ from blksprs.ops.softmax import softmax
8
+ from blksprs.ops.transpose import transpose
9
+ from blksprs.ops.repeat import repeat, repeat_interleave
10
+ from blksprs.ops.partitioning import split, merge
9
11
 
10
- class layout:
12
+ class misc:
13
+ from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_add, row_wise_sub
14
+ from blksprs.ops.misc.broadcast_ops import broadcast_add, broadcast_sub
15
+ from blksprs.ops.misc.exp import exp
16
+
17
+ class experimental:
18
+ from blksprs.ops.experimental.distribution_mdi import gather_mdi, scatter_reduce_mdi
19
+
20
+
21
+ class layouting:
11
22
  from blksprs.layouting.distribution_layout import build_distribution_layout
12
23
  from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption, \
13
24
  build_sparsity_layout_matmul, build_sparsity_layout_matmul_fast
14
25
 
26
+ class experimental:
27
+ from blksprs.ops.experimental.distribution_mdi import build_distribution_layout_mdi
15
28
 
16
- class misc:
17
- from blksprs.misc.broadcast_ops import broadcast_add, broadcast_sub
18
- from blksprs.misc.exp import exp
19
- from blksprs.misc.row_wise import row_wise_sum, row_wise_max, row_wise_add, row_wise_sub
20
-
21
-
22
- class util:
23
- from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse, disable_validation
24
29
 
30
+ class utils:
31
+ from blksprs.utils.processing import apply_torch_linear
32
+ from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse
33
+ from blksprs.utils.validation import disable_validation
25
34
 
26
- class experimental:
27
- from blksprs.experimental.distribution_mdi import gather_mdi, scatter_reduce_mdi
35
+ class validation:
36
+ from blksprs.utils.validation import disable_validation
37
+ from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_dtype_float, \
38
+ validate_dtype_int, validate_device, validate_sparsity, validate_sparsity_dense, \
39
+ validate_sparsity_block_size, \
40
+ validate_triton_block_size
blksprs/ops/softmax.py CHANGED
@@ -3,8 +3,8 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
- from blksprs.misc.exp import exp
7
- from blksprs.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
6
+ from blksprs.ops.misc.exp import exp
7
+ from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
8
8
  from blksprs.utils.blksprs_tensor import BlksprsTensor
9
9
  from blksprs.utils.tools import get_triton_block_size, stride
10
10
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
@@ -0,0 +1,41 @@
1
+ import torch
2
+ from torch import Tensor, nn
3
+ from triton.language import dtype
4
+
5
+ from blksprs.layouting.sparsity_layout import build_sparsity_layout_matmul_fast
6
+ from blksprs.ops.conversion import to_sparse
7
+ from blksprs.ops.matmul import matmul
8
+ from blksprs.ops.repeat import repeat
9
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
10
+
11
+
12
+ def apply_torch_linear(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
13
+ linear: nn.Linear) -> (BlksprsTensor, Tensor):
14
+ # Extract weight and bias
15
+ w = linear.weight
16
+ b = linear.bias
17
+
18
+ # Convert w to block-sparse representation
19
+ sparsity_layout_w_t = torch.ones(size=(sparsity_layout.size(0), w.size(1) // sparsity_block_size,
20
+ w.size(0) // sparsity_block_size), dtype=torch.bool, device=x.device)
21
+ w_t_bs = to_sparse(w.transpose(-1, -2).unsqueeze(0).repeat(sparsity_layout.size(0), 1, 1),
22
+ sparsity_layout_w_t, sparsity_block_size)
23
+
24
+ # Apply weights
25
+ sparsity_layout_xw = build_sparsity_layout_matmul_fast(sparsity_layout, sparsity_layout_w_t)
26
+ xw = matmul(x, sparsity_layout, w_t_bs, sparsity_layout_w_t, sparsity_layout_xw, sparsity_block_size)
27
+ interim = xw
28
+
29
+ # Apply bias
30
+ if b is not None:
31
+ b_slice = b.unsqueeze(0).unsqueeze(0).repeat(1, sparsity_block_size, 1)
32
+ sparsity_layout_b_slice = torch.ones(size=(1, b_slice.size(1) // sparsity_block_size,
33
+ b_slice.size(2) // sparsity_block_size), dtype=torch.bool,
34
+ device=x.device)
35
+ b_slice_bs = to_sparse(b_slice, sparsity_layout_b_slice, sparsity_block_size)
36
+ b_bs, sparsity_layout_b = repeat(b_slice_bs, sparsity_layout_b_slice,
37
+ (sparsity_layout.size(0), sparsity_layout_xw.size(1), 1), sparsity_block_size,
38
+ sparsity_layout_output=sparsity_layout_xw)
39
+ interim = interim + b_bs
40
+
41
+ return interim, sparsity_layout_xw
blksprs/utils/tools.py CHANGED
@@ -1,7 +1,5 @@
1
1
  from torch import Tensor, Size
2
2
 
3
- from blksprs.utils.validation import _set_skip_validation
4
-
5
3
 
6
4
  def do_shape_blocksparse(x: Tensor):
7
5
  if x.dim() == 3:
@@ -21,8 +19,5 @@ def get_triton_block_size(sparsity_block_size: int, limit: int = 128):
21
19
  return min(sparsity_block_size, limit)
22
20
 
23
21
 
24
- def disable_validation():
25
- _set_skip_validation(True)
26
-
27
22
  def stride(x: Tensor):
28
- return x.view(x.shape).stride()
23
+ return x.view(x.shape).stride()
@@ -124,3 +124,7 @@ def _check_skip_validation():
124
124
  def _set_skip_validation(skip_validation: bool):
125
125
  global VALIDATION
126
126
  VALIDATION = not skip_validation
127
+
128
+
129
+ def disable_validation():
130
+ _set_skip_validation(True)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: blksprs
3
- Version: 1.8.2
3
+ Version: 1.8.3
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -22,6 +22,14 @@ Requires-Dist: matplotlib; extra == "test"
22
22
  [![GitHub Release](https://img.shields.io/github/v/release/FelixSchoen/blksprs?include_prereleases&label=Latest%20Release)](https://github.com/FelixSchoen/blksprs/releases)
23
23
  [![Python Version](https://img.shields.io/badge/Python%20Version-3.11-blue)](https://www.python.org/downloads/release/python-3119/)
24
24
 
25
+ ## Important Notice
26
+
27
+ 🚨 **Non-Final API** 🚨
28
+
29
+ Although it already supports a wide variety of functions, this library is still under active development and the API is
30
+ subject to change. For feature requests or bug reports, please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
31
+ We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
32
+
25
33
  ## Overview
26
34
 
27
35
  A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
@@ -51,14 +59,14 @@ These include, e.g.,
51
59
  Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to
52
60
  match.
53
61
 
54
- Further helpful operations (included in the ``bs.misc`` module) that do **not** support gradient calculation include:
62
+ Further helpful operations (included in the ``bs.ops.misc`` module) that do **not** support gradient calculation include:
55
63
 
56
64
  - Row-wise sum, max, addition, and subtraction
57
65
  - Broadcast addition and subtraction between slices
58
66
 
59
67
  Furthermore, the library provides a set of utility functions for the creation of sparsity layouts based on existing
60
- dense tensors and for the scatter operation (module ``bs.layout``), as well as utility functions to ensure correct input
61
- dimensionality (module ``bs.util``).
68
+ dense tensors and for the scatter operation (module ``bs.layouting``), as well as utility functions to apply linear layers,
69
+ ensure correct input dimensionality, and validate input (module ``bs.utils``).
62
70
 
63
71
  ## Installation
64
72
 
@@ -111,14 +119,14 @@ def test_readme():
111
119
  y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()
112
120
 
113
121
  # Convert tensors to three-dimensional (dense) tensors since Triton can only handle tensors of exactly three dimensions
114
- x_dense, x_shape_original = bs.util.do_shape_blocksparse(x)
115
- y_dense, y_shape_original = bs.util.do_shape_blocksparse(y)
122
+ x_dense, x_shape_original = bs.utils.do_shape_blocksparse(x)
123
+ y_dense, y_shape_original = bs.utils.do_shape_blocksparse(y)
116
124
 
117
125
  # Create sparsity layouts from existing tensors
118
- sparsity_layout_x = bs.layout.build_sparsity_layout(x_dense, sparsity_block_size,
119
- triton_block_size=triton_block_size)
120
- sparsity_layout_y = bs.layout.build_sparsity_layout(y_dense, sparsity_block_size,
121
- triton_block_size=triton_block_size)
126
+ sparsity_layout_x = bs.layouting.build_sparsity_layout(x_dense, sparsity_block_size,
127
+ triton_block_size=triton_block_size)
128
+ sparsity_layout_y = bs.layouting.build_sparsity_layout(y_dense, sparsity_block_size,
129
+ triton_block_size=triton_block_size)
122
130
 
123
131
  # Create random sparsity layout for output tensor
124
132
  sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)
@@ -150,12 +158,12 @@ def test_readme():
150
158
  assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2) # Note that small numerical differences are expected
151
159
 
152
160
  # Assert that the output has the correct sparsity layout
153
- actual_sparsity_layout_o = bs.layout.build_sparsity_layout(o_dense, sparsity_block_size,
154
- triton_block_size=triton_block_size)
161
+ actual_sparsity_layout_o = bs.layouting.build_sparsity_layout(o_dense, sparsity_block_size,
162
+ triton_block_size=triton_block_size)
155
163
  assert torch.allclose(actual_sparsity_layout_o.to(torch.int), sparsity_layout_o)
156
164
 
157
165
  # Convert output tensor back to original shape
158
- o = bs.util.undo_shape_blocksparse(o_dense, x_shape_original)
166
+ o = bs.utils.undo_shape_blocksparse(o_dense, x_shape_original)
159
167
 
160
168
  # Other available functions
161
169
  bs.transpose(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
@@ -0,0 +1,23 @@
1
+ blksprs/__init__.py,sha256=YMrERuEf1hTv5vVdOvPEzh9rESn4uqOB7WHB12Qs5lU,1836
2
+ blksprs/layouting/distribution_layout.py,sha256=wmj1SwWyY_fhbvMmh6AXrR77LoSp6xLwUWCCyO9i5lk,4239
3
+ blksprs/layouting/sparsity_layout.py,sha256=-sScIn4hhG35j9BXytrojEzp8jnFkMargJjtivPV1fc,9755
4
+ blksprs/ops/conversion.py,sha256=ol-iV45wDzp9G1dJEkY53EdrvnmHzcl7QQmPJ-xqQTs,22410
5
+ blksprs/ops/distribution.py,sha256=fXZV6UegCVpIwzh-A825OSYClHWu5k0UMYdO2UGDUpM,17067
6
+ blksprs/ops/matmul.py,sha256=yh2ZnO0ZltT1AgadiFP0vX28YJ4n74xO-I_5vFUmOmA,11452
7
+ blksprs/ops/partitioning.py,sha256=K0ExR2a3W62d_9xxCJzsdJDLgtbxTI6P8loOOBdhPzE,7674
8
+ blksprs/ops/repeat.py,sha256=IvSIRbuyFn0b57LObymLgup0LqlWQ3ndIw-QuiYQcaU,14564
9
+ blksprs/ops/softmax.py,sha256=CDQT2KnwkJ4hGIgT0EUp6P92uiYpCdJQ9zxcdgSAAJA,12102
10
+ blksprs/ops/transpose.py,sha256=jxzFFffrj4S_9tiCrwwUMdz6EA98o1dziWXjlqb64a4,6859
11
+ blksprs/ops/experimental/distribution_mdi.py,sha256=HaRUu6LTWATzjuHWgddIUE-0fgY-O87STpJO4JY7k_8,20357
12
+ blksprs/ops/misc/broadcast_ops.py,sha256=cPtRJa3pkZfY1QG51CJ-zDn4SK-CRpX5LEXoKGGMvRU,5418
13
+ blksprs/ops/misc/exp.py,sha256=FnSFosBfJHuiEbD0MD-i4axLghRn4a0f8KvHXrKBB6M,3802
14
+ blksprs/ops/misc/row_wise.py,sha256=SvJuNww-_QoVKTyTjMvjmzHlBuUlTKamkuq_rKzwAqs,17081
15
+ blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
16
+ blksprs/utils/blksprs_tensor.py,sha256=VjplBgDhnf9sxf-1R5feA0xp5FDCDdaeZmCeoIRdCnc,151
17
+ blksprs/utils/processing.py,sha256=hYsFxEbQKcbqU4WtZWusPnWMHg8ZAZF1SKZJYjez9aU,2060
18
+ blksprs/utils/tools.py,sha256=r7Y4C37vfSWUyQTGwa8NyRqgovmsq9hMufkenqYHOxo,539
19
+ blksprs/utils/validation.py,sha256=IZxH2HZpePmv7lRqLsSwV_6FwsdnTXv9q4j98vCMSsQ,4195
20
+ blksprs-1.8.3.dist-info/METADATA,sha256=DZkJ_HeetF1V6-_F6GeG0uXT-QmttMFOq4ao8fiSMgQ,8458
21
+ blksprs-1.8.3.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
22
+ blksprs-1.8.3.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
23
+ blksprs-1.8.3.dist-info/RECORD,,
@@ -1,22 +0,0 @@
1
- blksprs/__init__.py,sha256=np0msosWMaZNVVfuFGt8rE6HZURyIald391dKAs1dSQ,1093
2
- blksprs/experimental/distribution_mdi.py,sha256=HaRUu6LTWATzjuHWgddIUE-0fgY-O87STpJO4JY7k_8,20357
3
- blksprs/layouting/distribution_layout.py,sha256=wmj1SwWyY_fhbvMmh6AXrR77LoSp6xLwUWCCyO9i5lk,4239
4
- blksprs/layouting/sparsity_layout.py,sha256=-sScIn4hhG35j9BXytrojEzp8jnFkMargJjtivPV1fc,9755
5
- blksprs/misc/broadcast_ops.py,sha256=cPtRJa3pkZfY1QG51CJ-zDn4SK-CRpX5LEXoKGGMvRU,5418
6
- blksprs/misc/exp.py,sha256=FnSFosBfJHuiEbD0MD-i4axLghRn4a0f8KvHXrKBB6M,3802
7
- blksprs/misc/partitioning.py,sha256=K0ExR2a3W62d_9xxCJzsdJDLgtbxTI6P8loOOBdhPzE,7674
8
- blksprs/misc/row_wise.py,sha256=SvJuNww-_QoVKTyTjMvjmzHlBuUlTKamkuq_rKzwAqs,17081
9
- blksprs/ops/conversion.py,sha256=ol-iV45wDzp9G1dJEkY53EdrvnmHzcl7QQmPJ-xqQTs,22410
10
- blksprs/ops/distribution.py,sha256=fXZV6UegCVpIwzh-A825OSYClHWu5k0UMYdO2UGDUpM,17067
11
- blksprs/ops/matmul.py,sha256=yh2ZnO0ZltT1AgadiFP0vX28YJ4n74xO-I_5vFUmOmA,11452
12
- blksprs/ops/repeat.py,sha256=IvSIRbuyFn0b57LObymLgup0LqlWQ3ndIw-QuiYQcaU,14564
13
- blksprs/ops/softmax.py,sha256=D9wITz3KB24QXGGjgn_RLQ0Iiq_SjX0bTbUyv9479uU,12094
14
- blksprs/ops/transpose.py,sha256=jxzFFffrj4S_9tiCrwwUMdz6EA98o1dziWXjlqb64a4,6859
15
- blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
16
- blksprs/utils/blksprs_tensor.py,sha256=VjplBgDhnf9sxf-1R5feA0xp5FDCDdaeZmCeoIRdCnc,151
17
- blksprs/utils/tools.py,sha256=S3836Zuc-BMigv-5mLTjRznCzuaF6oYW-Ir9zzUnr3o,655
18
- blksprs/utils/validation.py,sha256=WzihRPibXYzss3PMkhDt5_d3Q3NHA_d1TzTz3CoGPGg,4136
19
- blksprs-1.8.2.dist-info/METADATA,sha256=Zoc860mYmFss7v5ChNoi9407v1qDo_ecc6JUWCvaesg,8009
20
- blksprs-1.8.2.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
21
- blksprs-1.8.2.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
22
- blksprs-1.8.2.dist-info/RECORD,,
File without changes
File without changes
File without changes
File without changes