blksprs 1.0__py3-none-any.whl → 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/utils/tools.py CHANGED
@@ -15,33 +15,6 @@ def undo_shape_blocksparse(x: Tensor, shape: Size):
15
15
 
16
16
  return x.reshape((*shape[:-2], *x.shape[-2:]))
17
17
 
18
+
18
19
  def get_triton_block_size(sparsity_block_size: int, limit: int = 128):
19
20
  return min(sparsity_block_size, limit)
20
-
21
- #
22
-
23
- def slow_to_dense(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int):
24
- output = torch.zeros(size=(sparsity_layout.size(0), sparsity_layout.size(1) * sparsity_block_size,
25
- sparsity_layout.size(2) * sparsity_block_size), device=x.device)
26
- indices_sparse_blocks = sparsity_layout.nonzero(as_tuple=True)
27
-
28
- for idx, (b, r, c) in enumerate(zip(*indices_sparse_blocks)):
29
- t_r = r * sparsity_block_size
30
- t_c = c * sparsity_block_size
31
- to_insert = x[idx]
32
- output[b, t_r:t_r + sparsity_block_size, t_c:t_c + sparsity_block_size] = to_insert
33
-
34
- return output
35
-
36
- def slow_to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int):
37
- num_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
38
- output = torch.zeros(size=(num_sparse_blocks, sparsity_block_size, sparsity_block_size), device=x.device)
39
- indices_sparse_blocks = sparsity_layout.nonzero(as_tuple=True)
40
-
41
- for idx, (b, r, c) in enumerate(zip(*indices_sparse_blocks)):
42
- t_r = r * sparsity_block_size
43
- t_c = c * sparsity_block_size
44
- to_insert = x[b, t_r:t_r + sparsity_block_size, t_c:t_c + sparsity_block_size]
45
- output[idx] = to_insert
46
-
47
- return output
@@ -3,24 +3,45 @@ from torch import Tensor
3
3
 
4
4
 
5
5
  def validate_dimensions(*tensors: Tensor) -> None:
6
+ if _skip_validation():
7
+ return
8
+
6
9
  for tensor in tensors:
7
10
  if tensor.dim() != 3:
8
11
  raise ValueError("Tensor must have 3 dimensions")
9
12
 
10
13
 
11
14
  def validate_contiguous(*tensors: Tensor) -> None:
15
+ if _skip_validation():
16
+ return
17
+
12
18
  for tensor in tensors:
13
19
  if not tensor.is_contiguous():
14
20
  raise ValueError("Tensor must be contiguous")
15
21
 
16
22
 
17
23
  def validate_dtype_float(*tensors: Tensor) -> None:
24
+ if _skip_validation():
25
+ return
26
+
18
27
  for tensor in tensors:
19
28
  if tensor.dtype != torch.float32:
20
29
  raise ValueError("Tensor must have float32 dtype")
21
30
 
22
31
 
32
+ def validate_dtype_int(*tensors: Tensor) -> None:
33
+ if _skip_validation():
34
+ return
35
+
36
+ for tensor in tensors:
37
+ if tensor.dtype != torch.int32 and tensor.dtype != torch.int64:
38
+ raise ValueError("Tensor must have int32 or int64 dtype")
39
+
40
+
23
41
  def validate_device(*tensors: Tensor) -> None:
42
+ if _skip_validation():
43
+ return
44
+
24
45
  device = None
25
46
 
26
47
  for i, tensor in enumerate(tensors):
@@ -33,13 +54,44 @@ def validate_device(*tensors: Tensor) -> None:
33
54
  if tensor.device != device:
34
55
  raise ValueError("Tensors must be on same device")
35
56
 
57
+
36
58
  def validate_sparsity(sparsity_block_size: int, *tensor_sparsity_layout_tuples: tuple[Tensor, Tensor]) -> None:
59
+ if _skip_validation():
60
+ return
61
+
37
62
  for (tensor, sparsity_layout) in tensor_sparsity_layout_tuples:
63
+ _validate_sparsity_layout_values(sparsity_layout)
64
+
38
65
  if not (tensor.size(-1) == tensor.size(-2) == sparsity_block_size):
39
66
  raise ValueError("Blocks not conforming to sparsity block size")
40
67
  if not tensor.size(0) == torch.sum(sparsity_layout.reshape(-1)):
41
68
  raise ValueError("Mismatch between sparsity layout and blocks")
42
69
 
70
+
71
+ def _validate_sparsity_layout_values(sparsity_layout: Tensor):
72
+ if not torch.all(torch.logical_or(sparsity_layout == 0, sparsity_layout == 1)):
73
+ raise ValueError("Sparsity layout values must be either 0 or 1")
74
+
75
+ def validate_sparsity_block_size(sparsity_block_size: int, *tensors):
76
+ if _skip_validation():
77
+ return
78
+
79
+ if not (sparsity_block_size & (sparsity_block_size - 1)) == 0:
80
+ raise ValueError("Sparsity block size must be a power of 2")
81
+
82
+ for tensor in tensors:
83
+ if not (tensor.size(-1) % sparsity_block_size == 0 and tensor.size(-2) % sparsity_block_size == 0):
84
+ raise ValueError("Tensor sizes must be divisible by sparsity block size")
85
+
43
86
  def validate_triton_block_size(triton_block_size: int, sparsity_block_size: int):
87
+ if _skip_validation():
88
+ return
89
+
90
+ if triton_block_size is None:
91
+ return
92
+
44
93
  if triton_block_size > sparsity_block_size:
45
- raise ValueError("Triton block size cannot be larger than sparsity block size")
94
+ raise ValueError("Triton block size cannot be larger than sparsity block size")
95
+
96
+ def _skip_validation():
97
+ return False
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: blksprs
3
- Version: 1.0
3
+ Version: 1.1
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -23,16 +23,20 @@ Requires-Dist: matplotlib; extra == "test"
23
23
 
24
24
  ## Overview
25
25
 
26
- A lightweight library for operations on blocksparse matrices in PyTorch.
26
+ A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
27
27
 
28
28
  Currently supported operations (includes gradient calculation):
29
29
 
30
- - Sparse matrix multiplication (_supports any combination of sparse and dense matrices due to support for `sparse = sparse @ sparse` matmul_)
30
+ - Sparse matrix multiplication (_supports any combination of sparse and dense matrices due to support
31
+ for `sparse = sparse @ sparse` matmul_)
31
32
  - Softmax
32
33
  - Transposition
34
+ - Gather
35
+ - Scatter (_supports either no reduction or summation, gradients are only available for summation_)
33
36
  - Conversion from and to sparse form
34
37
 
35
- As with this library sparse matrices are represented using a tuple of `(matrix, sparsity_layout, sparsity_block_size)`, any element-wise operations can be applied in regular torch-like fashion.
38
+ As with this library sparse matrices are represented using a tuple of `(matrix, sparsity_layout, sparsity_block_size)`,
39
+ any element-wise operations can be applied in regular torch-like fashion.
36
40
  These include, e.g.,
37
41
 
38
42
  - Element-wise addition and subtraction
@@ -40,8 +44,17 @@ These include, e.g.,
40
44
  - Element-wise exponentiation
41
45
  - ...
42
46
 
47
+ Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to
48
+ match.
49
+
50
+ Furthermore, the library provides a set of utility functions for the creation of sparsity layouts based on existing
51
+ dense tensors.
52
+
43
53
  ## Installation
44
54
 
55
+ Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is only compatible with
56
+ the Linux platform.
57
+
45
58
  We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) using pip:
46
59
 
47
60
  ```pip install blksprs```
@@ -52,12 +65,19 @@ See [`CHANGELOG.md`](https://github.com/FelixSchoen/blksprs/blob/main/CHANGELOG.
52
65
 
53
66
  ## Usage
54
67
 
68
+ We provide an example below to demonstrate the usage of the library.
69
+ For more detailed examples, please refer to
70
+ the [test cases](https://github.com/FelixSchoen/blksprs/blob/main/test/cases/test_blocksparse.py) which cover all
71
+ implemented operations and functions.
72
+ The example below can also be found in
73
+ the [test cases](https://github.com/FelixSchoen/blksprs/blob/main/test/cases/test_readme.py).
74
+
55
75
  ```python
56
76
  import torch
57
77
 
58
- from blksprs.layouting.sparsity_layout import create_sparsity_layout
78
+ from blksprs.layouting.sparsity_layout import build_sparsity_layout
59
79
  from blksprs.ops.conversion import to_sparse, to_dense
60
- from blksprs.ops.matmul_sss import matmul_sss
80
+ from blksprs.ops.matmul import matmul
61
81
  from blksprs.ops.row_wise_sum import row_wise_sum
62
82
  from blksprs.ops.softmax import softmax
63
83
  from blksprs.ops.transpose import transpose
@@ -65,7 +85,7 @@ from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse
65
85
 
66
86
 
67
87
  def test_readme():
68
- # Set up parameters
88
+ # Set up parameters (batch size, number of heads, dimensions for matrices (m, k) and (n, k))
69
89
  b, h, m, n, k = 2, 4, 64, 64, 16
70
90
 
71
91
  # Percentage of blocks that will be sparse in the output for demonstration purposes
@@ -78,7 +98,6 @@ def test_readme():
78
98
  # If it is set to ``none`` a value will be chosen automatically
79
99
  triton_block_size = None
80
100
 
81
-
82
101
  # Initialise random (dense) tensors
83
102
  x = torch.randn(size=(b, h, m, k), device="cuda")
84
103
  y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()
@@ -88,8 +107,8 @@ def test_readme():
88
107
  y_dense, y_shape_original = do_shape_blocksparse(y)
89
108
 
90
109
  # Create sparsity layouts from existing tensors
91
- sparsity_layout_x = create_sparsity_layout(x_dense, sparsity_block_size, triton_block_size=triton_block_size)
92
- sparsity_layout_y = create_sparsity_layout(y_dense, sparsity_block_size, triton_block_size=triton_block_size)
110
+ sparsity_layout_x = build_sparsity_layout(x_dense, sparsity_block_size, triton_block_size=triton_block_size)
111
+ sparsity_layout_y = build_sparsity_layout(y_dense, sparsity_block_size, triton_block_size=triton_block_size)
93
112
 
94
113
  # Create random sparsity layout for output tensor
95
114
  sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)
@@ -99,8 +118,8 @@ def test_readme():
99
118
  y_sparse = to_sparse(y_dense, sparsity_layout_y, sparsity_block_size, triton_block_size=triton_block_size)
100
119
 
101
120
  # Perform matrix multiplication
102
- o_sparse = matmul_sss(x_sparse, y_sparse, sparsity_layout_x, sparsity_layout_y, sparsity_layout_o,
103
- sparsity_block_size, triton_block_size=triton_block_size)
121
+ o_sparse = matmul(x_sparse, sparsity_layout_x, y_sparse, sparsity_layout_y, sparsity_layout_o, sparsity_block_size,
122
+ triton_block_size=triton_block_size)
104
123
  o_dense = to_dense(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
105
124
 
106
125
  # Sanity check
@@ -115,7 +134,7 @@ def test_readme():
115
134
  assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2) # Note that small numerical differences are expected
116
135
 
117
136
  # Assert that the output has the correct sparsity layout
118
- actual_sparsity_layout_o = create_sparsity_layout(o_dense, sparsity_block_size, triton_block_size=triton_block_size)
137
+ actual_sparsity_layout_o = build_sparsity_layout(o_dense, sparsity_block_size, triton_block_size=triton_block_size)
119
138
  assert torch.allclose(actual_sparsity_layout_o, sparsity_layout_o)
120
139
 
121
140
  # Convert output tensor back to original shape
@@ -0,0 +1,17 @@
1
+ blksprs/layouting/distribution_layout.py,sha256=GQ-ZRXbeImiLcbaqnL2FuUZ6DoFwmB0naT_YrOpD84Q,4940
2
+ blksprs/layouting/sparsity_layout.py,sha256=Z9Ac88kQZVaPp27ymlLwyGN14ZfMyljXp6oM_gSsFMQ,2902
3
+ blksprs/misc/broadcast_addition.py,sha256=vf1Hdqz9Uyqykto3DCjmdyepMzpMXL238SpANQqRAwI,5297
4
+ blksprs/ops/conversion.py,sha256=COhHE5KvwhrtdUTLZX1wmxFe0kDNMY97iIhnkMmztBA,11362
5
+ blksprs/ops/distribution.py,sha256=_fQb6fWpLxocAh86D74ATahChi0EK0eBb4eUOUEBVps,16769
6
+ blksprs/ops/exp.py,sha256=qs8fVtCzxl4CKT4GepaqurjEL62jyi8VjMY12JFrFAU,3674
7
+ blksprs/ops/matmul.py,sha256=x3lrYg4g8fIf5PeMtZY_SEpi11kP9RFcRoemCIxcSDE,11086
8
+ blksprs/ops/row_wise_sum.py,sha256=ojuSejV37cLtRNS3lBfknA5KY3TEg8EHxOqVT6JZzoM,11387
9
+ blksprs/ops/softmax.py,sha256=ZyeAVqmG_VzJ72FArGrpUSFfoSM4GPxyubrmNKERVIA,11654
10
+ blksprs/ops/transpose.py,sha256=DVEXoxo2MoTNL3NZrjxsukMDrzk2vnEXL1uRnKFWkn0,6722
11
+ blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
12
+ blksprs/utils/tools.py,sha256=P2UALvccRjJJ7w05YGuaxB3qmNObgct4idfM0jlE2wg,465
13
+ blksprs/utils/validation.py,sha256=gJYZO5C48YUrXV3Fy_Z_lCaOpiFj951FT-Od7sKfprg,3007
14
+ blksprs-1.1.dist-info/METADATA,sha256=NIdEtqxj4SBUOP1eMlBz2RoOppwlQx9sJnRmDicWvp4,6982
15
+ blksprs-1.1.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
16
+ blksprs-1.1.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
17
+ blksprs-1.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (74.1.2)
2
+ Generator: setuptools (75.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,14 +0,0 @@
1
- blksprs/layouting/sparsity_layout.py,sha256=fmzp0vasDDJuNwM-sPk3IBg1svCB2-ItELscppoSLPE,2553
2
- blksprs/ops/conversion.py,sha256=AYceIsv_g7xswoBqP3TGR2vjOjiLjTpwWG0vZ7XKCa8,10062
3
- blksprs/ops/exp.py,sha256=b0IuUVA_UoKNYDNT4Q3EFuXm7EEv_J2-DR7hfgCeT1Q,3222
4
- blksprs/ops/matmul_sss.py,sha256=34JSkO_9OOnQXB4KZHraElGDjbCx8p0dr9J5JebVdhY,10639
5
- blksprs/ops/row_wise_sum.py,sha256=ltoZpGVIApQBt_rbmknhsd-7MnibZRX6lkIzDnAC9k8,10462
6
- blksprs/ops/softmax.py,sha256=fJGKFshFMIVxaYJ_pgPD7EC0ooG-31XH9-qzRQ-xY5A,11018
7
- blksprs/ops/transpose.py,sha256=WEsXWRYDTWk2U36mt0aEQeReR3TG9TFrWAwVh3NNVYk,5985
8
- blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
9
- blksprs/utils/tools.py,sha256=bqJtXXUKmcUxQsQ4ZkrpDZ5P7gnCburSW5hVd9U2M3E,1708
10
- blksprs/utils/validation.py,sha256=d4BFxzX-zVa5mUv_t3IW_bZbbP3vzSWan_KC1lyw7bs,1639
11
- blksprs-1.0.dist-info/METADATA,sha256=hrQQ8iK3-F2b38ogT8gp0O-sxm7UFn3i6QNv6vsS2so,5991
12
- blksprs-1.0.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
13
- blksprs-1.0.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
14
- blksprs-1.0.dist-info/RECORD,,