blksprs 0.2b4__tar.gz → 1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs-1.1/PKG-INFO ADDED
@@ -0,0 +1,164 @@
1
+ Metadata-Version: 2.1
2
+ Name: blksprs
3
+ Version: 1.1
4
+ Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
+ Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
+ Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
7
+ Project-URL: Bugtracker, https://github.com/FelixSchoen/blksprs/issues
8
+ Requires-Python: >=3.12
9
+ Description-Content-Type: text/markdown
10
+ Requires-Dist: torch
11
+ Provides-Extra: test
12
+ Requires-Dist: pytest; extra == "test"
13
+ Requires-Dist: pytest-xdist; extra == "test"
14
+ Requires-Dist: pytest-cov; extra == "test"
15
+ Requires-Dist: coverage; extra == "test"
16
+ Requires-Dist: matplotlib; extra == "test"
17
+ Provides-Extra: deploy
18
+ Requires-Dist: build; extra == "deploy"
19
+ Requires-Dist: twine; extra == "deploy"
20
+ Requires-Dist: pdoc3; extra == "deploy"
21
+
22
+ # blksprs
23
+
24
+ ## Overview
25
+
26
+ A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
27
+
28
+ Currently supported operations (includes gradient calculation):
29
+
30
+ - Sparse matrix multiplication (_supports any combination of sparse and dense matrices due to support
31
+ for `sparse = sparse @ sparse` matmul_)
32
+ - Softmax
33
+ - Transposition
34
+ - Gather
35
+ - Scatter (_supports either no reduction or summation, gradients are only available for summation_)
36
+ - Conversion from and to sparse form
37
+
38
+ As with this library sparse matrices are represented using a tuple of `(matrix, sparsity_layout, sparsity_block_size)`,
39
+ any element-wise operations can be applied in regular torch-like fashion.
40
+ These include, e.g.,
41
+
42
+ - Element-wise addition and subtraction
43
+ - Element-wise multiplication and division
44
+ - Element-wise exponentiation
45
+ - ...
46
+
47
+ Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to
48
+ match.
49
+
50
+ Furthermore, the library provides a set of utility functions for the creation of sparsity layouts based on existing
51
+ dense tensors.
52
+
53
+ ## Installation
54
+
55
+ Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is only compatible with
56
+ the Linux platform.
57
+
58
+ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) using pip:
59
+
60
+ ```pip install blksprs```
61
+
62
+ ## Changelog
63
+
64
+ See [`CHANGELOG.md`](https://github.com/FelixSchoen/blksprs/blob/main/CHANGELOG.md) for a detailed changelog.
65
+
66
+ ## Usage
67
+
68
+ We provide an example below to demonstrate the usage of the library.
69
+ For more detailed examples, please refer to
70
+ the [test cases](https://github.com/FelixSchoen/blksprs/blob/main/test/cases/test_blocksparse.py) which cover all
71
+ implemented operations and functions.
72
+ The example below can also be found in
73
+ the [test cases](https://github.com/FelixSchoen/blksprs/blob/main/test/cases/test_readme.py).
74
+
75
+ ```python
76
+ import torch
77
+
78
+ from blksprs.layouting.sparsity_layout import build_sparsity_layout
79
+ from blksprs.ops.conversion import to_sparse, to_dense
80
+ from blksprs.ops.matmul import matmul
81
+ from blksprs.ops.row_wise_sum import row_wise_sum
82
+ from blksprs.ops.softmax import softmax
83
+ from blksprs.ops.transpose import transpose
84
+ from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse
85
+
86
+
87
+ def test_readme():
88
+ # Set up parameters (batch size, number of heads, dimensions for matrices (m, k) and (n, k))
89
+ b, h, m, n, k = 2, 4, 64, 64, 16
90
+
91
+ # Percentage of blocks that will be sparse in the output for demonstration purposes
92
+ sparsity_percentage = 25
93
+
94
+ # Must be a power of two, greater than or equal to 16 for matmul, and divide m, n, and k
95
+ sparsity_block_size = 16
96
+
97
+ # Must be a power of two and smaller than or equal to sparsity_block_size
98
+ # If it is set to ``none`` a value will be chosen automatically
99
+ triton_block_size = None
100
+
101
+ # Initialise random (dense) tensors
102
+ x = torch.randn(size=(b, h, m, k), device="cuda")
103
+ y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()
104
+
105
+ # Convert tensors to three-dimensional (dense) tensors since Triton can only handle tensors of exactly three dimensions
106
+ x_dense, x_shape_original = do_shape_blocksparse(x)
107
+ y_dense, y_shape_original = do_shape_blocksparse(y)
108
+
109
+ # Create sparsity layouts from existing tensors
110
+ sparsity_layout_x = build_sparsity_layout(x_dense, sparsity_block_size, triton_block_size=triton_block_size)
111
+ sparsity_layout_y = build_sparsity_layout(y_dense, sparsity_block_size, triton_block_size=triton_block_size)
112
+
113
+ # Create random sparsity layout for output tensor
114
+ sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)
115
+
116
+ # Convert tensors to sparse tensors for matrix multiplication
117
+ x_sparse = to_sparse(x_dense, sparsity_layout_x, sparsity_block_size, triton_block_size=triton_block_size)
118
+ y_sparse = to_sparse(y_dense, sparsity_layout_y, sparsity_block_size, triton_block_size=triton_block_size)
119
+
120
+ # Perform matrix multiplication
121
+ o_sparse = matmul(x_sparse, sparsity_layout_x, y_sparse, sparsity_layout_y, sparsity_layout_o, sparsity_block_size,
122
+ triton_block_size=triton_block_size)
123
+ o_dense = to_dense(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
124
+
125
+ # Sanity check
126
+ o_torch = torch.matmul(x_dense, y_dense)
127
+
128
+ # Perform round trip to set sparse blocks to 0
129
+ o_torch_round_trip = to_dense(
130
+ to_sparse(o_torch, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size),
131
+ sparsity_layout_o, sparsity_block_size, fill_value=0, triton_block_size=triton_block_size)
132
+
133
+ # Assert that the output is correct
134
+ assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2) # Note that small numerical differences are expected
135
+
136
+ # Assert that the output has the correct sparsity layout
137
+ actual_sparsity_layout_o = build_sparsity_layout(o_dense, sparsity_block_size, triton_block_size=triton_block_size)
138
+ assert torch.allclose(actual_sparsity_layout_o, sparsity_layout_o)
139
+
140
+ # Convert output tensor back to original shape
141
+ o = undo_shape_blocksparse(o_dense, x_shape_original)
142
+
143
+ # Other available functions
144
+ transpose(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
145
+ softmax(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
146
+ row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
147
+
148
+
149
+ def _get_random_sparsity_layout(b, m, n, sparsity_block_size, sparsity_percentage):
150
+ """Helper function, creates a random sparsity layout for a given shape with a given percentage of blocks marked as sparse.
151
+
152
+ """
153
+ m_s = m // sparsity_block_size
154
+ n_s = n // sparsity_block_size
155
+
156
+ sparsity_layout = torch.ones(size=(b, m_s, n_s), device="cuda", dtype=torch.int)
157
+
158
+ num_zero_elements = int(m_s * n_s * (sparsity_percentage / 100))
159
+ for b_i in range(b):
160
+ indices = torch.randperm(m_s * n_s)[:num_zero_elements]
161
+ sparsity_layout[b_i, indices // n_s, indices % n_s] = 0
162
+
163
+ return sparsity_layout
164
+ ```
blksprs-1.1/README.md ADDED
@@ -0,0 +1,143 @@
1
+ # blksprs
2
+
3
+ ## Overview
4
+
5
+ A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
6
+
7
+ Currently supported operations (includes gradient calculation):
8
+
9
+ - Sparse matrix multiplication (_supports any combination of sparse and dense matrices due to support
10
+ for `sparse = sparse @ sparse` matmul_)
11
+ - Softmax
12
+ - Transposition
13
+ - Gather
14
+ - Scatter (_supports either no reduction or summation, gradients are only available for summation_)
15
+ - Conversion from and to sparse form
16
+
17
+ As with this library sparse matrices are represented using a tuple of `(matrix, sparsity_layout, sparsity_block_size)`,
18
+ any element-wise operations can be applied in regular torch-like fashion.
19
+ These include, e.g.,
20
+
21
+ - Element-wise addition and subtraction
22
+ - Element-wise multiplication and division
23
+ - Element-wise exponentiation
24
+ - ...
25
+
26
+ Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to
27
+ match.
28
+
29
+ Furthermore, the library provides a set of utility functions for the creation of sparsity layouts based on existing
30
+ dense tensors.
31
+
32
+ ## Installation
33
+
34
+ Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is only compatible with
35
+ the Linux platform.
36
+
37
+ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) using pip:
38
+
39
+ ```pip install blksprs```
40
+
41
+ ## Changelog
42
+
43
+ See [`CHANGELOG.md`](https://github.com/FelixSchoen/blksprs/blob/main/CHANGELOG.md) for a detailed changelog.
44
+
45
+ ## Usage
46
+
47
+ We provide an example below to demonstrate the usage of the library.
48
+ For more detailed examples, please refer to
49
+ the [test cases](https://github.com/FelixSchoen/blksprs/blob/main/test/cases/test_blocksparse.py) which cover all
50
+ implemented operations and functions.
51
+ The example below can also be found in
52
+ the [test cases](https://github.com/FelixSchoen/blksprs/blob/main/test/cases/test_readme.py).
53
+
54
+ ```python
55
+ import torch
56
+
57
+ from blksprs.layouting.sparsity_layout import build_sparsity_layout
58
+ from blksprs.ops.conversion import to_sparse, to_dense
59
+ from blksprs.ops.matmul import matmul
60
+ from blksprs.ops.row_wise_sum import row_wise_sum
61
+ from blksprs.ops.softmax import softmax
62
+ from blksprs.ops.transpose import transpose
63
+ from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse
64
+
65
+
66
+ def test_readme():
67
+ # Set up parameters (batch size, number of heads, dimensions for matrices (m, k) and (n, k))
68
+ b, h, m, n, k = 2, 4, 64, 64, 16
69
+
70
+ # Percentage of blocks that will be sparse in the output for demonstration purposes
71
+ sparsity_percentage = 25
72
+
73
+ # Must be a power of two, greater than or equal to 16 for matmul, and divide m, n, and k
74
+ sparsity_block_size = 16
75
+
76
+ # Must be a power of two and smaller than or equal to sparsity_block_size
77
+ # If it is set to ``none`` a value will be chosen automatically
78
+ triton_block_size = None
79
+
80
+ # Initialise random (dense) tensors
81
+ x = torch.randn(size=(b, h, m, k), device="cuda")
82
+ y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()
83
+
84
+ # Convert tensors to three-dimensional (dense) tensors since Triton can only handle tensors of exactly three dimensions
85
+ x_dense, x_shape_original = do_shape_blocksparse(x)
86
+ y_dense, y_shape_original = do_shape_blocksparse(y)
87
+
88
+ # Create sparsity layouts from existing tensors
89
+ sparsity_layout_x = build_sparsity_layout(x_dense, sparsity_block_size, triton_block_size=triton_block_size)
90
+ sparsity_layout_y = build_sparsity_layout(y_dense, sparsity_block_size, triton_block_size=triton_block_size)
91
+
92
+ # Create random sparsity layout for output tensor
93
+ sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)
94
+
95
+ # Convert tensors to sparse tensors for matrix multiplication
96
+ x_sparse = to_sparse(x_dense, sparsity_layout_x, sparsity_block_size, triton_block_size=triton_block_size)
97
+ y_sparse = to_sparse(y_dense, sparsity_layout_y, sparsity_block_size, triton_block_size=triton_block_size)
98
+
99
+ # Perform matrix multiplication
100
+ o_sparse = matmul(x_sparse, sparsity_layout_x, y_sparse, sparsity_layout_y, sparsity_layout_o, sparsity_block_size,
101
+ triton_block_size=triton_block_size)
102
+ o_dense = to_dense(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
103
+
104
+ # Sanity check
105
+ o_torch = torch.matmul(x_dense, y_dense)
106
+
107
+ # Perform round trip to set sparse blocks to 0
108
+ o_torch_round_trip = to_dense(
109
+ to_sparse(o_torch, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size),
110
+ sparsity_layout_o, sparsity_block_size, fill_value=0, triton_block_size=triton_block_size)
111
+
112
+ # Assert that the output is correct
113
+ assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2) # Note that small numerical differences are expected
114
+
115
+ # Assert that the output has the correct sparsity layout
116
+ actual_sparsity_layout_o = build_sparsity_layout(o_dense, sparsity_block_size, triton_block_size=triton_block_size)
117
+ assert torch.allclose(actual_sparsity_layout_o, sparsity_layout_o)
118
+
119
+ # Convert output tensor back to original shape
120
+ o = undo_shape_blocksparse(o_dense, x_shape_original)
121
+
122
+ # Other available functions
123
+ transpose(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
124
+ softmax(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
125
+ row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
126
+
127
+
128
+ def _get_random_sparsity_layout(b, m, n, sparsity_block_size, sparsity_percentage):
129
+ """Helper function, creates a random sparsity layout for a given shape with a given percentage of blocks marked as sparse.
130
+
131
+ """
132
+ m_s = m // sparsity_block_size
133
+ n_s = n // sparsity_block_size
134
+
135
+ sparsity_layout = torch.ones(size=(b, m_s, n_s), device="cuda", dtype=torch.int)
136
+
137
+ num_zero_elements = int(m_s * n_s * (sparsity_percentage / 100))
138
+ for b_i in range(b):
139
+ indices = torch.randperm(m_s * n_s)[:num_zero_elements]
140
+ sparsity_layout[b_i, indices // n_s, indices % n_s] = 0
141
+
142
+ return sparsity_layout
143
+ ```
@@ -0,0 +1,114 @@
1
+ import torch
2
+ import triton
3
+ from torch import Tensor
4
+ from triton import language as tl
5
+
6
+ from blksprs.utils.tools import get_triton_block_size
7
+ from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
8
+ validate_contiguous
9
+
10
+
11
+ def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
12
+ size_target: torch.Size,
13
+ sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
14
+ """Builds the sparsity layout of either the source of a gather or the target of a scatter operation.
15
+
16
+ Args:
17
+ indices (Tensor): The block-sparse indices tensor in compressed form used for the gather or scatter operation.
18
+ sparsity_layout_indices (Tensor): The sparsity layout of the indices block-sparse tensor.
19
+ size_target (torch.Size): The size of the block-sparse target tensor in regular form.
20
+ sparsity_block_size (int): The size of the sparsity blocks.
21
+ triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
22
+
23
+ Returns:
24
+ Tensor: The sparsity layout of the source or target tensor.
25
+
26
+ """
27
+ validate_dimensions(indices)
28
+ validate_contiguous(indices)
29
+ validate_device(indices)
30
+
31
+ sparsity_lut_i = torch.nonzero(sparsity_layout_indices).contiguous()
32
+
33
+ output = torch.zeros(size_target[0], size_target[1] // sparsity_block_size, size_target[2] // sparsity_block_size,
34
+ device=indices.device, dtype=torch.int32)
35
+
36
+ i_b, i_r, i_c = indices.size()
37
+ i_b_s, i_r_s, i_c_s = indices.stride()
38
+ s_l_i_b, s_l_i_r, s_l_i_c = sparsity_layout_indices.size()
39
+ s_l_i_b_s, s_l_i_r_s, s_l_i_c_s = sparsity_layout_indices.stride()
40
+ s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
41
+ s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
42
+ o_b, o_r, o_c = output.size()
43
+ o_b_s, o_r_s, o_c_s = output.stride()
44
+
45
+ if triton_block_size is None:
46
+ triton_block_size = get_triton_block_size(sparsity_block_size)
47
+
48
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
49
+
50
+ triton_grid = lambda meta: [i_b,
51
+ triton.cdiv(i_r, meta["TRITON_BLOCK_SIZE"]),
52
+ triton.cdiv(i_c, meta["TRITON_BLOCK_SIZE"])]
53
+
54
+ (kernel_distribution_layout[triton_grid]
55
+ (indices,
56
+ i_b, i_b_s, i_r_s, i_c_s,
57
+ sparsity_layout_indices,
58
+ s_l_i_b, s_l_i_b_s, s_l_i_r, s_l_i_r_s, s_l_i_c, s_l_i_c_s,
59
+ sparsity_lut_i,
60
+ s_lut_i_r, s_lut_i_r_s, s_lut_i_c, s_lut_i_c_s,
61
+ output,
62
+ o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,
63
+ sparsity_block_size,
64
+ triton_block_size))
65
+
66
+ return output
67
+
68
+
69
+ @triton.jit
70
+ def kernel_distribution_layout(i,
71
+ i_b, i_b_s, i_r_s, i_c_s,
72
+ s_l_i,
73
+ s_l_i_b, s_l_i_b_s, s_l_i_r, s_l_i_r_s, s_l_i_c, s_l_i_c_s,
74
+ s_lut_i,
75
+ s_lut_i_r, s_lut_i_r_s, s_lut_i_c, s_lut_i_c_s,
76
+ o,
77
+ o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,
78
+ sparsity_block_size,
79
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
80
+ # Get triton block indices
81
+ pid_blk = tl.program_id(axis=0)
82
+ pid_row = tl.program_id(axis=1)
83
+ pid_col = tl.program_id(axis=2)
84
+
85
+ # Get position of current sparsity block consisting of its batch, row, and column index
86
+ spa_bat_i_idx = (pid_blk * s_lut_i_r_s + 0 * s_lut_i_c_s)
87
+ spa_bat_i_msk = (spa_bat_i_idx < s_lut_i_r * s_lut_i_r_s)
88
+ spa_bat_i = tl.load(s_lut_i + spa_bat_i_idx, mask=spa_bat_i_msk)
89
+
90
+ spa_row_i_idx = (pid_blk * s_lut_i_r_s + 1 * s_lut_i_c_s)
91
+ spa_row_i_msk = (spa_row_i_idx < s_lut_i_r * s_lut_i_r_s)
92
+ spa_row_i = tl.load(s_lut_i + spa_row_i_idx, mask=spa_row_i_msk)
93
+
94
+ blk_i_idx = (pid_blk * i_b_s +
95
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
96
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
97
+ blk_i_msk = (blk_i_idx < i_b * i_b_s)
98
+ blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk)
99
+
100
+ blk_i = blk_i // sparsity_block_size
101
+ blk_v = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), 1, dtype=tl.int32)
102
+
103
+ blk_o_idx = ((spa_bat_i * o_b_s) +
104
+ (spa_row_i * o_r_s) +
105
+ (blk_i * o_c_s))
106
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
107
+ tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
108
+
109
+ # if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
110
+ # blk_o_idx = (pid_bat * o_b_s +
111
+ # (((pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_r_s +
112
+ # ((pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_c_s))
113
+ # blk_o_msk = (blk_o_idx < o_b * o_b_s)
114
+ # tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
@@ -0,0 +1,78 @@
1
+ import torch
2
+ import triton
3
+ from torch import Tensor
4
+ from triton import language as tl
5
+
6
+ from blksprs.utils.tools import get_triton_block_size
7
+ from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
8
+ validate_contiguous
9
+
10
+
11
+ def build_sparsity_layout(x: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
12
+ """Builds the sparsity layout of a dense tensor covering its sparse blocks.
13
+
14
+ Args:
15
+ x (Tensor): A block-sparse (or dense) tensor in regular form.
16
+ sparsity_block_size (int): The size of the sparsity blocks.
17
+ triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
18
+
19
+ Returns:
20
+ Tensor: The sparsity layout of the input block-sparse (or dense) tensor.
21
+
22
+ """
23
+ validate_dimensions(x)
24
+ validate_contiguous(x)
25
+ validate_device(x)
26
+
27
+ output = torch.zeros(x.size(0), x.size(1) // sparsity_block_size, x.size(2) // sparsity_block_size,
28
+ device=x.device, dtype=torch.int32)
29
+
30
+ x_b, x_r, x_c = x.size()
31
+ x_b_s, x_r_s, x_c_s = x.stride()
32
+ o_b, o_r, o_c = output.size()
33
+ o_b_s, o_r_s, o_c_s = output.stride()
34
+
35
+ if triton_block_size is None:
36
+ triton_block_size = get_triton_block_size(sparsity_block_size)
37
+
38
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
39
+
40
+ triton_grid = lambda meta: [x_b,
41
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
42
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
43
+
44
+ (kernel_sparsity_layout[triton_grid]
45
+ (x,
46
+ x_b, x_b_s, x_r_s, x_c_s,
47
+ output,
48
+ o_b, o_b_s, o_r_s, o_c_s,
49
+ sparsity_block_size,
50
+ triton_block_size))
51
+
52
+ return output
53
+
54
+
55
+ @triton.jit
56
+ def kernel_sparsity_layout(x,
57
+ x_b, x_b_s, x_r_s, x_c_s,
58
+ o,
59
+ o_b, o_b_s, o_r_s, o_c_s,
60
+ sparsity_block_size,
61
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
62
+ # Get triton block indices
63
+ pid_bat = tl.program_id(axis=0)
64
+ pid_row = tl.program_id(axis=1)
65
+ pid_col = tl.program_id(axis=2)
66
+
67
+ blk_x_idx = (pid_bat * x_b_s +
68
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
69
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
70
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
71
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
72
+
73
+ if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
74
+ blk_o_idx = (pid_bat * o_b_s +
75
+ (((pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_r_s +
76
+ ((pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_c_s))
77
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
78
+ tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
@@ -0,0 +1,132 @@
1
+ import torch
2
+ import triton
3
+ from torch import Tensor
4
+ from triton import language as tl
5
+
6
+ from blksprs.utils.tools import get_triton_block_size
7
+ from blksprs.utils.validation import validate_contiguous, validate_device, \
8
+ validate_sparsity_block_size, validate_triton_block_size
9
+
10
+
11
+ def broadcast_addition(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
12
+ sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
13
+ """Performs a broadcast and subsequent addition of two dense tensors x and y. Returns a block-sparse tensor in
14
+ compressed form.
15
+
16
+ Args:
17
+ x (Tensor): A dense input tensor.
18
+ y (Tensor): A dense input tensor.
19
+ sparsity_layout_output (Tensor): The sparsity layout of the output tensor.
20
+ sparsity_block_size (int): The size of the sparsity blocks.
21
+ triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
22
+
23
+ Returns:
24
+ Tensor: The result of the operation as a block-sparse tensor in compressed form. Each element o(i, j) of the
25
+ output tensor corresponds to x(i) + y(j).
26
+
27
+ """
28
+ validate_device(x, y)
29
+ validate_contiguous(x, y)
30
+ if x.size(-1) != y.size(-1):
31
+ raise ValueError("Dimensions of tensors must match")
32
+ validate_sparsity_block_size(sparsity_block_size)
33
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
34
+
35
+ sparsity_lut_o = torch.nonzero(sparsity_layout_output).contiguous()
36
+
37
+ n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
38
+
39
+ validate_contiguous(sparsity_layout_output, sparsity_lut_o)
40
+
41
+ output = torch.zeros(n_sparse_blocks, sparsity_block_size, sparsity_block_size, device=x.device)
42
+
43
+ x_b, x_c = x.size()
44
+ x_b_s, x_c_s = x.stride()
45
+ y_b, y_c = y.size()
46
+ y_b_s, y_c_s = y.stride()
47
+ o_b, o_r, o_c = output.size()
48
+ o_b_s, o_r_s, o_c_s = output.stride()
49
+ s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
50
+ s_lut_o_r_s, s_lut_o_c_s = sparsity_lut_o.stride()
51
+
52
+ if triton_block_size is None:
53
+ triton_block_size = get_triton_block_size(sparsity_block_size)
54
+
55
+ triton_grid = lambda meta: [o_b,
56
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
57
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
58
+
59
+ (kernel_broadcast_addition[triton_grid]
60
+ (x,
61
+ x_b, x_b_s, x_c_s,
62
+ y,
63
+ y_b, y_b_s, y_c_s,
64
+ output,
65
+ o_b, o_b_s, o_r_s, o_c_s,
66
+ sparsity_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
67
+ sparsity_block_size,
68
+ triton_block_size))
69
+
70
+ return output
71
+
72
+
73
+ def broadcast_subtraction(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
74
+ sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
75
+ """Wrapper for ``broadcast_addition`` with negated y.
76
+
77
+ """
78
+ return broadcast_addition(x, torch.neg(y), sparsity_layout_output, sparsity_block_size, triton_block_size)
79
+
80
+
81
+ @triton.jit
82
+ def kernel_broadcast_addition(x,
83
+ x_b, x_b_s, x_c_s,
84
+ y,
85
+ y_b, y_b_s, y_c_s,
86
+ o,
87
+ o_b, o_b_s, o_r_s, o_c_s,
88
+ s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
89
+ sparsity_block_size,
90
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
91
+ # Get triton block indices
92
+ pid_blk = tl.program_id(axis=0)
93
+ pid_row = tl.program_id(axis=1)
94
+ pid_col = tl.program_id(axis=2)
95
+
96
+ # Get position of current sparsity block consisting of its batch, row, and column index
97
+ spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
98
+ spa_bat_o_msk = (spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
99
+ spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
100
+
101
+ spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
102
+ spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
103
+ spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
104
+
105
+ spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
106
+ spa_col_o_msk = (spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
107
+ spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
108
+
109
+ # Load x block
110
+ blk_x_idx = (spa_bat_o * x_b_s +
111
+ ((spa_row_o * sparsity_block_size + pid_row * TRITON_BLOCK_SIZE +
112
+ tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
113
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
114
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
115
+
116
+ # Load y block
117
+ blk_y_idx = (spa_bat_o * y_b_s +
118
+ ((spa_col_o * sparsity_block_size + pid_col * TRITON_BLOCK_SIZE +
119
+ tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
120
+ blk_y_msk = (blk_y_idx < y_b * y_b_s)
121
+ blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
122
+
123
+ # Compute sum
124
+ blk_x, blk_y = tl.broadcast(tl.trans(blk_x), blk_y)
125
+ buf = blk_x + blk_y
126
+
127
+ # Store result
128
+ blk_o_idx = ((pid_blk * o_b_s) +
129
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
130
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
131
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
132
+ tl.store(o + blk_o_idx, buf, mask=blk_o_msk)