blksprs 1.0__tar.gz → 1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: blksprs
3
- Version: 1.0
3
+ Version: 1.1
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -23,16 +23,20 @@ Requires-Dist: pdoc3; extra == "deploy"
23
23
 
24
24
  ## Overview
25
25
 
26
- A lightweight library for operations on blocksparse matrices in PyTorch.
26
+ A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
27
27
 
28
28
  Currently supported operations (includes gradient calculation):
29
29
 
30
- - Sparse matrix multiplication (_supports any combination of sparse and dense matrices due to support for `sparse = sparse @ sparse` matmul_)
30
+ - Sparse matrix multiplication (_supports any combination of sparse and dense matrices due to support
31
+ for `sparse = sparse @ sparse` matmul_)
31
32
  - Softmax
32
33
  - Transposition
34
+ - Gather
35
+ - Scatter (_supports either no reduction or summation, gradients are only available for summation_)
33
36
  - Conversion from and to sparse form
34
37
 
35
- As with this library sparse matrices are represented using a tuple of `(matrix, sparsity_layout, sparsity_block_size)`, any element-wise operations can be applied in regular torch-like fashion.
38
+ As with this library sparse matrices are represented using a tuple of `(matrix, sparsity_layout, sparsity_block_size)`,
39
+ any element-wise operations can be applied in regular torch-like fashion.
36
40
  These include, e.g.,
37
41
 
38
42
  - Element-wise addition and subtraction
@@ -40,8 +44,17 @@ These include, e.g.,
40
44
  - Element-wise exponentiation
41
45
  - ...
42
46
 
47
+ Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to
48
+ match.
49
+
50
+ Furthermore, the library provides a set of utility functions for the creation of sparsity layouts based on existing
51
+ dense tensors.
52
+
43
53
  ## Installation
44
54
 
55
+ Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is only compatible with
56
+ the Linux platform.
57
+
45
58
  We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) using pip:
46
59
 
47
60
  ```pip install blksprs```
@@ -52,12 +65,19 @@ See [`CHANGELOG.md`](https://github.com/FelixSchoen/blksprs/blob/main/CHANGELOG.
52
65
 
53
66
  ## Usage
54
67
 
68
+ We provide an example below to demonstrate the usage of the library.
69
+ For more detailed examples, please refer to
70
+ the [test cases](https://github.com/FelixSchoen/blksprs/blob/main/test/cases/test_blocksparse.py) which cover all
71
+ implemented operations and functions.
72
+ The example below can also be found in
73
+ the [test cases](https://github.com/FelixSchoen/blksprs/blob/main/test/cases/test_readme.py).
74
+
55
75
  ```python
56
76
  import torch
57
77
 
58
- from blksprs.layouting.sparsity_layout import create_sparsity_layout
78
+ from blksprs.layouting.sparsity_layout import build_sparsity_layout
59
79
  from blksprs.ops.conversion import to_sparse, to_dense
60
- from blksprs.ops.matmul_sss import matmul_sss
80
+ from blksprs.ops.matmul import matmul
61
81
  from blksprs.ops.row_wise_sum import row_wise_sum
62
82
  from blksprs.ops.softmax import softmax
63
83
  from blksprs.ops.transpose import transpose
@@ -65,7 +85,7 @@ from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse
65
85
 
66
86
 
67
87
  def test_readme():
68
- # Set up parameters
88
+ # Set up parameters (batch size, number of heads, dimensions for matrices (m, k) and (n, k))
69
89
  b, h, m, n, k = 2, 4, 64, 64, 16
70
90
 
71
91
  # Percentage of blocks that will be sparse in the output for demonstration purposes
@@ -78,7 +98,6 @@ def test_readme():
78
98
  # If it is set to ``none`` a value will be chosen automatically
79
99
  triton_block_size = None
80
100
 
81
-
82
101
  # Initialise random (dense) tensors
83
102
  x = torch.randn(size=(b, h, m, k), device="cuda")
84
103
  y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()
@@ -88,8 +107,8 @@ def test_readme():
88
107
  y_dense, y_shape_original = do_shape_blocksparse(y)
89
108
 
90
109
  # Create sparsity layouts from existing tensors
91
- sparsity_layout_x = create_sparsity_layout(x_dense, sparsity_block_size, triton_block_size=triton_block_size)
92
- sparsity_layout_y = create_sparsity_layout(y_dense, sparsity_block_size, triton_block_size=triton_block_size)
110
+ sparsity_layout_x = build_sparsity_layout(x_dense, sparsity_block_size, triton_block_size=triton_block_size)
111
+ sparsity_layout_y = build_sparsity_layout(y_dense, sparsity_block_size, triton_block_size=triton_block_size)
93
112
 
94
113
  # Create random sparsity layout for output tensor
95
114
  sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)
@@ -99,8 +118,8 @@ def test_readme():
99
118
  y_sparse = to_sparse(y_dense, sparsity_layout_y, sparsity_block_size, triton_block_size=triton_block_size)
100
119
 
101
120
  # Perform matrix multiplication
102
- o_sparse = matmul_sss(x_sparse, y_sparse, sparsity_layout_x, sparsity_layout_y, sparsity_layout_o,
103
- sparsity_block_size, triton_block_size=triton_block_size)
121
+ o_sparse = matmul(x_sparse, sparsity_layout_x, y_sparse, sparsity_layout_y, sparsity_layout_o, sparsity_block_size,
122
+ triton_block_size=triton_block_size)
104
123
  o_dense = to_dense(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
105
124
 
106
125
  # Sanity check
@@ -115,7 +134,7 @@ def test_readme():
115
134
  assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2) # Note that small numerical differences are expected
116
135
 
117
136
  # Assert that the output has the correct sparsity layout
118
- actual_sparsity_layout_o = create_sparsity_layout(o_dense, sparsity_block_size, triton_block_size=triton_block_size)
137
+ actual_sparsity_layout_o = build_sparsity_layout(o_dense, sparsity_block_size, triton_block_size=triton_block_size)
119
138
  assert torch.allclose(actual_sparsity_layout_o, sparsity_layout_o)
120
139
 
121
140
  # Convert output tensor back to original shape
@@ -2,16 +2,20 @@
2
2
 
3
3
  ## Overview
4
4
 
5
- A lightweight library for operations on blocksparse matrices in PyTorch.
5
+ A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
6
6
 
7
7
  Currently supported operations (includes gradient calculation):
8
8
 
9
- - Sparse matrix multiplication (_supports any combination of sparse and dense matrices due to support for `sparse = sparse @ sparse` matmul_)
9
+ - Sparse matrix multiplication (_supports any combination of sparse and dense matrices due to support
10
+ for `sparse = sparse @ sparse` matmul_)
10
11
  - Softmax
11
12
  - Transposition
13
+ - Gather
14
+ - Scatter (_supports either no reduction or summation, gradients are only available for summation_)
12
15
  - Conversion from and to sparse form
13
16
 
14
- As with this library sparse matrices are represented using a tuple of `(matrix, sparsity_layout, sparsity_block_size)`, any element-wise operations can be applied in regular torch-like fashion.
17
+ As with this library sparse matrices are represented using a tuple of `(matrix, sparsity_layout, sparsity_block_size)`,
18
+ any element-wise operations can be applied in regular torch-like fashion.
15
19
  These include, e.g.,
16
20
 
17
21
  - Element-wise addition and subtraction
@@ -19,8 +23,17 @@ These include, e.g.,
19
23
  - Element-wise exponentiation
20
24
  - ...
21
25
 
26
+ Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to
27
+ match.
28
+
29
+ Furthermore, the library provides a set of utility functions for the creation of sparsity layouts based on existing
30
+ dense tensors.
31
+
22
32
  ## Installation
23
33
 
34
+ Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is only compatible with
35
+ the Linux platform.
36
+
24
37
  We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) using pip:
25
38
 
26
39
  ```pip install blksprs```
@@ -31,12 +44,19 @@ See [`CHANGELOG.md`](https://github.com/FelixSchoen/blksprs/blob/main/CHANGELOG.
31
44
 
32
45
  ## Usage
33
46
 
47
+ We provide an example below to demonstrate the usage of the library.
48
+ For more detailed examples, please refer to
49
+ the [test cases](https://github.com/FelixSchoen/blksprs/blob/main/test/cases/test_blocksparse.py) which cover all
50
+ implemented operations and functions.
51
+ The example below can also be found in
52
+ the [test cases](https://github.com/FelixSchoen/blksprs/blob/main/test/cases/test_readme.py).
53
+
34
54
  ```python
35
55
  import torch
36
56
 
37
- from blksprs.layouting.sparsity_layout import create_sparsity_layout
57
+ from blksprs.layouting.sparsity_layout import build_sparsity_layout
38
58
  from blksprs.ops.conversion import to_sparse, to_dense
39
- from blksprs.ops.matmul_sss import matmul_sss
59
+ from blksprs.ops.matmul import matmul
40
60
  from blksprs.ops.row_wise_sum import row_wise_sum
41
61
  from blksprs.ops.softmax import softmax
42
62
  from blksprs.ops.transpose import transpose
@@ -44,7 +64,7 @@ from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse
44
64
 
45
65
 
46
66
  def test_readme():
47
- # Set up parameters
67
+ # Set up parameters (batch size, number of heads, dimensions for matrices (m, k) and (n, k))
48
68
  b, h, m, n, k = 2, 4, 64, 64, 16
49
69
 
50
70
  # Percentage of blocks that will be sparse in the output for demonstration purposes
@@ -57,7 +77,6 @@ def test_readme():
57
77
  # If it is set to ``none`` a value will be chosen automatically
58
78
  triton_block_size = None
59
79
 
60
-
61
80
  # Initialise random (dense) tensors
62
81
  x = torch.randn(size=(b, h, m, k), device="cuda")
63
82
  y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()
@@ -67,8 +86,8 @@ def test_readme():
67
86
  y_dense, y_shape_original = do_shape_blocksparse(y)
68
87
 
69
88
  # Create sparsity layouts from existing tensors
70
- sparsity_layout_x = create_sparsity_layout(x_dense, sparsity_block_size, triton_block_size=triton_block_size)
71
- sparsity_layout_y = create_sparsity_layout(y_dense, sparsity_block_size, triton_block_size=triton_block_size)
89
+ sparsity_layout_x = build_sparsity_layout(x_dense, sparsity_block_size, triton_block_size=triton_block_size)
90
+ sparsity_layout_y = build_sparsity_layout(y_dense, sparsity_block_size, triton_block_size=triton_block_size)
72
91
 
73
92
  # Create random sparsity layout for output tensor
74
93
  sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)
@@ -78,8 +97,8 @@ def test_readme():
78
97
  y_sparse = to_sparse(y_dense, sparsity_layout_y, sparsity_block_size, triton_block_size=triton_block_size)
79
98
 
80
99
  # Perform matrix multiplication
81
- o_sparse = matmul_sss(x_sparse, y_sparse, sparsity_layout_x, sparsity_layout_y, sparsity_layout_o,
82
- sparsity_block_size, triton_block_size=triton_block_size)
100
+ o_sparse = matmul(x_sparse, sparsity_layout_x, y_sparse, sparsity_layout_y, sparsity_layout_o, sparsity_block_size,
101
+ triton_block_size=triton_block_size)
83
102
  o_dense = to_dense(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
84
103
 
85
104
  # Sanity check
@@ -94,7 +113,7 @@ def test_readme():
94
113
  assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2) # Note that small numerical differences are expected
95
114
 
96
115
  # Assert that the output has the correct sparsity layout
97
- actual_sparsity_layout_o = create_sparsity_layout(o_dense, sparsity_block_size, triton_block_size=triton_block_size)
116
+ actual_sparsity_layout_o = build_sparsity_layout(o_dense, sparsity_block_size, triton_block_size=triton_block_size)
98
117
  assert torch.allclose(actual_sparsity_layout_o, sparsity_layout_o)
99
118
 
100
119
  # Convert output tensor back to original shape
@@ -0,0 +1,114 @@
1
+ import torch
2
+ import triton
3
+ from torch import Tensor
4
+ from triton import language as tl
5
+
6
+ from blksprs.utils.tools import get_triton_block_size
7
+ from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
8
+ validate_contiguous
9
+
10
+
11
+ def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
12
+ size_target: torch.Size,
13
+ sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
14
+ """Builds the sparsity layout of either the source of a gather or the target of a scatter operation.
15
+
16
+ Args:
17
+ indices (Tensor): The block-sparse indices tensor in compressed form used for the gather or scatter operation.
18
+ sparsity_layout_indices (Tensor): The sparsity layout of the indices block-sparse tensor.
19
+ size_target (torch.Size): The size of the block-sparse target tensor in regular form.
20
+ sparsity_block_size (int): The size of the sparsity blocks.
21
+ triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
22
+
23
+ Returns:
24
+ Tensor: The sparsity layout of the source or target tensor.
25
+
26
+ """
27
+ validate_dimensions(indices)
28
+ validate_contiguous(indices)
29
+ validate_device(indices)
30
+
31
+ sparsity_lut_i = torch.nonzero(sparsity_layout_indices).contiguous()
32
+
33
+ output = torch.zeros(size_target[0], size_target[1] // sparsity_block_size, size_target[2] // sparsity_block_size,
34
+ device=indices.device, dtype=torch.int32)
35
+
36
+ i_b, i_r, i_c = indices.size()
37
+ i_b_s, i_r_s, i_c_s = indices.stride()
38
+ s_l_i_b, s_l_i_r, s_l_i_c = sparsity_layout_indices.size()
39
+ s_l_i_b_s, s_l_i_r_s, s_l_i_c_s = sparsity_layout_indices.stride()
40
+ s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
41
+ s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
42
+ o_b, o_r, o_c = output.size()
43
+ o_b_s, o_r_s, o_c_s = output.stride()
44
+
45
+ if triton_block_size is None:
46
+ triton_block_size = get_triton_block_size(sparsity_block_size)
47
+
48
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
49
+
50
+ triton_grid = lambda meta: [i_b,
51
+ triton.cdiv(i_r, meta["TRITON_BLOCK_SIZE"]),
52
+ triton.cdiv(i_c, meta["TRITON_BLOCK_SIZE"])]
53
+
54
+ (kernel_distribution_layout[triton_grid]
55
+ (indices,
56
+ i_b, i_b_s, i_r_s, i_c_s,
57
+ sparsity_layout_indices,
58
+ s_l_i_b, s_l_i_b_s, s_l_i_r, s_l_i_r_s, s_l_i_c, s_l_i_c_s,
59
+ sparsity_lut_i,
60
+ s_lut_i_r, s_lut_i_r_s, s_lut_i_c, s_lut_i_c_s,
61
+ output,
62
+ o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,
63
+ sparsity_block_size,
64
+ triton_block_size))
65
+
66
+ return output
67
+
68
+
69
+ @triton.jit
70
+ def kernel_distribution_layout(i,
71
+ i_b, i_b_s, i_r_s, i_c_s,
72
+ s_l_i,
73
+ s_l_i_b, s_l_i_b_s, s_l_i_r, s_l_i_r_s, s_l_i_c, s_l_i_c_s,
74
+ s_lut_i,
75
+ s_lut_i_r, s_lut_i_r_s, s_lut_i_c, s_lut_i_c_s,
76
+ o,
77
+ o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,
78
+ sparsity_block_size,
79
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
80
+ # Get triton block indices
81
+ pid_blk = tl.program_id(axis=0)
82
+ pid_row = tl.program_id(axis=1)
83
+ pid_col = tl.program_id(axis=2)
84
+
85
+ # Get position of current sparsity block consisting of its batch, row, and column index
86
+ spa_bat_i_idx = (pid_blk * s_lut_i_r_s + 0 * s_lut_i_c_s)
87
+ spa_bat_i_msk = (spa_bat_i_idx < s_lut_i_r * s_lut_i_r_s)
88
+ spa_bat_i = tl.load(s_lut_i + spa_bat_i_idx, mask=spa_bat_i_msk)
89
+
90
+ spa_row_i_idx = (pid_blk * s_lut_i_r_s + 1 * s_lut_i_c_s)
91
+ spa_row_i_msk = (spa_row_i_idx < s_lut_i_r * s_lut_i_r_s)
92
+ spa_row_i = tl.load(s_lut_i + spa_row_i_idx, mask=spa_row_i_msk)
93
+
94
+ blk_i_idx = (pid_blk * i_b_s +
95
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
96
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
97
+ blk_i_msk = (blk_i_idx < i_b * i_b_s)
98
+ blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk)
99
+
100
+ blk_i = blk_i // sparsity_block_size
101
+ blk_v = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), 1, dtype=tl.int32)
102
+
103
+ blk_o_idx = ((spa_bat_i * o_b_s) +
104
+ (spa_row_i * o_r_s) +
105
+ (blk_i * o_c_s))
106
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
107
+ tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
108
+
109
+ # if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
110
+ # blk_o_idx = (pid_bat * o_b_s +
111
+ # (((pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_r_s +
112
+ # ((pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_c_s))
113
+ # blk_o_msk = (blk_o_idx < o_b * o_b_s)
114
+ # tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
@@ -5,13 +5,23 @@ from triton import language as tl
5
5
 
6
6
  from blksprs.utils.tools import get_triton_block_size
7
7
  from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
8
- validate_dtype_float, validate_contiguous
8
+ validate_contiguous
9
9
 
10
10
 
11
- def create_sparsity_layout(x: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
11
+ def build_sparsity_layout(x: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
12
+ """Builds the sparsity layout of a dense tensor covering its sparse blocks.
13
+
14
+ Args:
15
+ x (Tensor): A block-sparse (or dense) tensor in regular form.
16
+ sparsity_block_size (int): The size of the sparsity blocks.
17
+ triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
18
+
19
+ Returns:
20
+ Tensor: The sparsity layout of the input block-sparse (or dense) tensor.
21
+
22
+ """
12
23
  validate_dimensions(x)
13
24
  validate_contiguous(x)
14
- validate_dtype_float(x)
15
25
  validate_device(x)
16
26
 
17
27
  output = torch.zeros(x.size(0), x.size(1) // sparsity_block_size, x.size(2) // sparsity_block_size,
@@ -33,9 +43,9 @@ def create_sparsity_layout(x: Tensor, sparsity_block_size: int, triton_block_siz
33
43
 
34
44
  (kernel_sparsity_layout[triton_grid]
35
45
  (x,
36
- x_b, x_b_s, x_r, x_r_s, x_c, x_c_s,
46
+ x_b, x_b_s, x_r_s, x_c_s,
37
47
  output,
38
- o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,
48
+ o_b, o_b_s, o_r_s, o_c_s,
39
49
  sparsity_block_size,
40
50
  triton_block_size))
41
51
 
@@ -44,9 +54,9 @@ def create_sparsity_layout(x: Tensor, sparsity_block_size: int, triton_block_siz
44
54
 
45
55
  @triton.jit
46
56
  def kernel_sparsity_layout(x,
47
- x_b, x_b_s, x_r, x_r_s, x_c, x_c_s,
57
+ x_b, x_b_s, x_r_s, x_c_s,
48
58
  o,
49
- o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,
59
+ o_b, o_b_s, o_r_s, o_c_s,
50
60
  sparsity_block_size,
51
61
  TRITON_BLOCK_SIZE: tl.constexpr) -> None:
52
62
  # Get triton block indices
@@ -0,0 +1,132 @@
1
+ import torch
2
+ import triton
3
+ from torch import Tensor
4
+ from triton import language as tl
5
+
6
+ from blksprs.utils.tools import get_triton_block_size
7
+ from blksprs.utils.validation import validate_contiguous, validate_device, \
8
+ validate_sparsity_block_size, validate_triton_block_size
9
+
10
+
11
+ def broadcast_addition(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
12
+ sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
13
+ """Performs a broadcast and subsequent addition of two dense tensors x and y. Returns a block-sparse tensor in
14
+ compressed form.
15
+
16
+ Args:
17
+ x (Tensor): A dense input tensor.
18
+ y (Tensor): A dense input tensor.
19
+ sparsity_layout_output (Tensor): The sparsity layout of the output tensor.
20
+ sparsity_block_size (int): The size of the sparsity blocks.
21
+ triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
22
+
23
+ Returns:
24
+ Tensor: The result of the operation as a block-sparse tensor in compressed form. Each element o(i, j) of the
25
+ output tensor corresponds to x(i) + y(j).
26
+
27
+ """
28
+ validate_device(x, y)
29
+ validate_contiguous(x, y)
30
+ if x.size(-1) != y.size(-1):
31
+ raise ValueError("Dimensions of tensors must match")
32
+ validate_sparsity_block_size(sparsity_block_size)
33
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
34
+
35
+ sparsity_lut_o = torch.nonzero(sparsity_layout_output).contiguous()
36
+
37
+ n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
38
+
39
+ validate_contiguous(sparsity_layout_output, sparsity_lut_o)
40
+
41
+ output = torch.zeros(n_sparse_blocks, sparsity_block_size, sparsity_block_size, device=x.device)
42
+
43
+ x_b, x_c = x.size()
44
+ x_b_s, x_c_s = x.stride()
45
+ y_b, y_c = y.size()
46
+ y_b_s, y_c_s = y.stride()
47
+ o_b, o_r, o_c = output.size()
48
+ o_b_s, o_r_s, o_c_s = output.stride()
49
+ s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
50
+ s_lut_o_r_s, s_lut_o_c_s = sparsity_lut_o.stride()
51
+
52
+ if triton_block_size is None:
53
+ triton_block_size = get_triton_block_size(sparsity_block_size)
54
+
55
+ triton_grid = lambda meta: [o_b,
56
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
57
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
58
+
59
+ (kernel_broadcast_addition[triton_grid]
60
+ (x,
61
+ x_b, x_b_s, x_c_s,
62
+ y,
63
+ y_b, y_b_s, y_c_s,
64
+ output,
65
+ o_b, o_b_s, o_r_s, o_c_s,
66
+ sparsity_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
67
+ sparsity_block_size,
68
+ triton_block_size))
69
+
70
+ return output
71
+
72
+
73
+ def broadcast_subtraction(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
74
+ sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
75
+ """Wrapper for ``broadcast_addition`` with negated y.
76
+
77
+ """
78
+ return broadcast_addition(x, torch.neg(y), sparsity_layout_output, sparsity_block_size, triton_block_size)
79
+
80
+
81
+ @triton.jit
82
+ def kernel_broadcast_addition(x,
83
+ x_b, x_b_s, x_c_s,
84
+ y,
85
+ y_b, y_b_s, y_c_s,
86
+ o,
87
+ o_b, o_b_s, o_r_s, o_c_s,
88
+ s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
89
+ sparsity_block_size,
90
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
91
+ # Get triton block indices
92
+ pid_blk = tl.program_id(axis=0)
93
+ pid_row = tl.program_id(axis=1)
94
+ pid_col = tl.program_id(axis=2)
95
+
96
+ # Get position of current sparsity block consisting of its batch, row, and column index
97
+ spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
98
+ spa_bat_o_msk = (spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
99
+ spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
100
+
101
+ spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
102
+ spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
103
+ spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
104
+
105
+ spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
106
+ spa_col_o_msk = (spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
107
+ spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
108
+
109
+ # Load x block
110
+ blk_x_idx = (spa_bat_o * x_b_s +
111
+ ((spa_row_o * sparsity_block_size + pid_row * TRITON_BLOCK_SIZE +
112
+ tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
113
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
114
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
115
+
116
+ # Load y block
117
+ blk_y_idx = (spa_bat_o * y_b_s +
118
+ ((spa_col_o * sparsity_block_size + pid_col * TRITON_BLOCK_SIZE +
119
+ tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
120
+ blk_y_msk = (blk_y_idx < y_b * y_b_s)
121
+ blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
122
+
123
+ # Compute sum
124
+ blk_x, blk_y = tl.broadcast(tl.trans(blk_x), blk_y)
125
+ buf = blk_x + blk_y
126
+
127
+ # Store result
128
+ blk_o_idx = ((pid_blk * o_b_s) +
129
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
130
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
131
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
132
+ tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
@@ -4,21 +4,33 @@ from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
6
  from blksprs.utils.tools import get_triton_block_size
7
- from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_dtype_float, validate_device
7
+ from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
8
+ validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
8
9
 
9
10
 
10
11
  def to_dense(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
11
12
  triton_block_size: int = None) -> Tensor:
12
- """Converts a blocksparse tensor to a dense tensor based on the given sparsity layout.
13
+ """Converts a block-sparse tensor in compressed form to a block-sparse tensor in regular form based on the given
14
+ sparsity layout.
13
15
 
14
- The ``fill_value`` is used to fill the resulting dense tensor with a specific value (default ``0``) where the
15
- blocksparse tensor is not present.
16
+ Args:
17
+ x (Tensor): A block-sparse tensor in compressed form.
18
+ sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
19
+ sparsity_block_size (int): The size of the sparsity blocks.
20
+ fill_value (float): The value to fill the resulting dense tensor with where the block-sparse tensor is not
21
+ present (default ``0``).
22
+ triton_block_size (int): The block size to use for the triton kernel (default ``None``).
23
+
24
+ Returns:
25
+ Tensor: The block-sparse tensor converted to regular form.
16
26
 
17
27
  """
18
28
  validate_dimensions(x)
19
29
  validate_contiguous(x, sparsity_layout)
20
- validate_dtype_float(x)
21
30
  validate_device(x)
31
+ validate_sparsity(sparsity_block_size, (x, sparsity_layout))
32
+ validate_sparsity_block_size(sparsity_block_size, x)
33
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
22
34
 
23
35
  sparsity_layout_flat = sparsity_layout.reshape(-1)
24
36
  sparsity_reverse_lut = ((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
@@ -68,7 +80,7 @@ class _BlocksparseToDense(torch.autograd.Function):
68
80
  sparsity_block_size,
69
81
  triton_block_size))
70
82
 
71
- ctx.sparsity_layout = sparsity_layout
83
+ ctx.save_for_backward(sparsity_layout)
72
84
  ctx.sparsity_block_size = sparsity_block_size
73
85
  ctx.triton_block_size = triton_block_size
74
86
 
@@ -76,11 +88,12 @@ class _BlocksparseToDense(torch.autograd.Function):
76
88
 
77
89
  @staticmethod
78
90
  def backward(ctx, grad_output):
79
- sparsity_layout = ctx.sparsity_layout
91
+ sparsity_layout = ctx.saved_tensors[0]
80
92
  sparsity_block_size = ctx.sparsity_block_size
81
93
  triton_block_size = ctx.triton_block_size
82
94
 
83
- return to_sparse(grad_output, sparsity_layout, sparsity_block_size, triton_block_size), None, None, None, None, None
95
+ return to_sparse(grad_output, sparsity_layout, sparsity_block_size,
96
+ triton_block_size), None, None, None, None, None
84
97
 
85
98
  @staticmethod
86
99
  @triton.jit
@@ -124,18 +137,29 @@ class _BlocksparseToDense(torch.autograd.Function):
124
137
 
125
138
 
126
139
  def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
127
- """Converts a dense tensor to a blocksparse tensor based on the given sparsity layout.
140
+ """Converts a block-sparse tensor in regular form to a block-sparse tensor in compressed form based on the given
141
+ sparsity layout.
142
+
143
+ Args:
144
+ x (Tensor): A block-sparse tensor in regular form.
145
+ sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
146
+ sparsity_block_size (int): The size of the sparsity blocks.
147
+ triton_block_size (int): The block size to use for the triton kernel (default ``None``).
148
+
149
+ Returns:
150
+ Tensor: The block-sparse tensor converted to compressed form.
128
151
 
129
152
  """
130
153
  validate_dimensions(x)
131
- validate_contiguous(x, sparsity_layout)
132
- validate_dtype_float(x)
154
+ validate_contiguous(x)
133
155
  validate_device(x)
156
+ validate_sparsity_block_size(sparsity_block_size, x)
157
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
134
158
 
135
159
  sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
136
160
  n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
137
161
 
138
- validate_contiguous(sparsity_lut)
162
+ validate_contiguous(sparsity_layout, sparsity_lut)
139
163
 
140
164
  return _BlocksparseToSparse.apply(x,
141
165
  sparsity_layout, sparsity_lut,
@@ -149,7 +173,8 @@ class _BlocksparseToSparse(torch.autograd.Function):
149
173
  def forward(ctx, x: Tensor,
150
174
  sparsity_layout: Tensor, sparsity_lut: Tensor,
151
175
  sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
152
- output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size), device=x.device)
176
+ output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size), dtype=x.dtype,
177
+ device=x.device)
153
178
 
154
179
  x_b, x_r, x_c = x.size()
155
180
  x_b_s, x_r_s, x_c_s = x.stride()
@@ -172,7 +197,7 @@ class _BlocksparseToSparse(torch.autograd.Function):
172
197
  sparsity_block_size,
173
198
  triton_block_size))
174
199
 
175
- ctx.sparsity_layout = sparsity_layout
200
+ ctx.save_for_backward(sparsity_layout)
176
201
  ctx.sparsity_block_size = sparsity_block_size
177
202
  ctx.triton_block_size = triton_block_size
178
203
 
@@ -180,7 +205,7 @@ class _BlocksparseToSparse(torch.autograd.Function):
180
205
 
181
206
  @staticmethod
182
207
  def backward(ctx, grad_output):
183
- sparsity_layout = ctx.sparsity_layout
208
+ sparsity_layout = ctx.saved_tensors[0]
184
209
  sparsity_block_size = ctx.sparsity_block_size
185
210
  triton_block_size = ctx.triton_block_size
186
211