blksprs 1.0__tar.gz → 1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: blksprs
3
- Version: 1.0
3
+ Version: 1.2
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -23,16 +23,21 @@ Requires-Dist: pdoc3; extra == "deploy"
23
23
 
24
24
  ## Overview
25
25
 
26
- A lightweight library for operations on blocksparse matrices in PyTorch.
26
+ A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
27
27
 
28
28
  Currently supported operations (includes gradient calculation):
29
29
 
30
- - Sparse matrix multiplication (_supports any combination of sparse and dense matrices due to support for `sparse = sparse @ sparse` matmul_)
30
+ - Sparse matrix multiplication (_supports any combination of sparse and dense matrices due to support
31
+ for `sparse = sparse @ sparse` matmul_)
31
32
  - Softmax
32
33
  - Transposition
33
- - Conversion from and to sparse form
34
+ - Gather
35
+ - Scatter (_supports either no reduction or summation, gradients are only available for summation_)
36
+ - Conversion to and from sparse form
37
+ - Conversion to different sparsity layouts and different sparsity block sizes
34
38
 
35
- As with this library sparse matrices are represented using a tuple of `(matrix, sparsity_layout, sparsity_block_size)`, any element-wise operations can be applied in regular torch-like fashion.
39
+ As with this library sparse matrices are represented using a tuple of `(matrix, sparsity_layout, sparsity_block_size)`,
40
+ any element-wise operations can be applied in regular torch-like fashion.
36
41
  These include, e.g.,
37
42
 
38
43
  - Element-wise addition and subtraction
@@ -40,24 +45,45 @@ These include, e.g.,
40
45
  - Element-wise exponentiation
41
46
  - ...
42
47
 
48
+ Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to
49
+ match.
50
+
51
+ Furthermore, the library provides a set of utility functions for the creation of sparsity layouts based on existing
52
+ dense tensors.
53
+
43
54
  ## Installation
44
55
 
56
+ Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is only compatible with
57
+ the Linux platform.
58
+
45
59
  We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) using pip:
46
60
 
47
61
  ```pip install blksprs```
48
62
 
63
+ ### Dependencies
64
+
65
+ - [PyTorch](https://pytorch.org/) (built with v2.4.0)
66
+ - _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
67
+
49
68
  ## Changelog
50
69
 
51
70
  See [`CHANGELOG.md`](https://github.com/FelixSchoen/blksprs/blob/main/CHANGELOG.md) for a detailed changelog.
52
71
 
53
72
  ## Usage
54
73
 
74
+ We provide an example below to demonstrate the usage of the library.
75
+ For more detailed examples, please refer to
76
+ the [test cases](https://github.com/FelixSchoen/blksprs/blob/main/test/cases/test_blocksparse.py) which cover all
77
+ implemented operations and functions.
78
+ The example below can also be found in
79
+ the [test cases](https://github.com/FelixSchoen/blksprs/blob/main/test/cases/test_readme.py).
80
+
55
81
  ```python
56
82
  import torch
57
83
 
58
- from blksprs.layouting.sparsity_layout import create_sparsity_layout
84
+ from blksprs.layouting.sparsity_layout import build_sparsity_layout
59
85
  from blksprs.ops.conversion import to_sparse, to_dense
60
- from blksprs.ops.matmul_sss import matmul_sss
86
+ from blksprs.ops.matmul import matmul
61
87
  from blksprs.ops.row_wise_sum import row_wise_sum
62
88
  from blksprs.ops.softmax import softmax
63
89
  from blksprs.ops.transpose import transpose
@@ -65,7 +91,7 @@ from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse
65
91
 
66
92
 
67
93
  def test_readme():
68
- # Set up parameters
94
+ # Set up parameters (batch size, number of heads, dimensions for matrices (m, k) and (n, k))
69
95
  b, h, m, n, k = 2, 4, 64, 64, 16
70
96
 
71
97
  # Percentage of blocks that will be sparse in the output for demonstration purposes
@@ -78,7 +104,6 @@ def test_readme():
78
104
  # If it is set to ``none`` a value will be chosen automatically
79
105
  triton_block_size = None
80
106
 
81
-
82
107
  # Initialise random (dense) tensors
83
108
  x = torch.randn(size=(b, h, m, k), device="cuda")
84
109
  y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()
@@ -88,8 +113,8 @@ def test_readme():
88
113
  y_dense, y_shape_original = do_shape_blocksparse(y)
89
114
 
90
115
  # Create sparsity layouts from existing tensors
91
- sparsity_layout_x = create_sparsity_layout(x_dense, sparsity_block_size, triton_block_size=triton_block_size)
92
- sparsity_layout_y = create_sparsity_layout(y_dense, sparsity_block_size, triton_block_size=triton_block_size)
116
+ sparsity_layout_x = build_sparsity_layout(x_dense, sparsity_block_size, triton_block_size=triton_block_size)
117
+ sparsity_layout_y = build_sparsity_layout(y_dense, sparsity_block_size, triton_block_size=triton_block_size)
93
118
 
94
119
  # Create random sparsity layout for output tensor
95
120
  sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)
@@ -99,8 +124,8 @@ def test_readme():
99
124
  y_sparse = to_sparse(y_dense, sparsity_layout_y, sparsity_block_size, triton_block_size=triton_block_size)
100
125
 
101
126
  # Perform matrix multiplication
102
- o_sparse = matmul_sss(x_sparse, y_sparse, sparsity_layout_x, sparsity_layout_y, sparsity_layout_o,
103
- sparsity_block_size, triton_block_size=triton_block_size)
127
+ o_sparse = matmul(x_sparse, sparsity_layout_x, y_sparse, sparsity_layout_y, sparsity_layout_o, sparsity_block_size,
128
+ triton_block_size=triton_block_size)
104
129
  o_dense = to_dense(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
105
130
 
106
131
  # Sanity check
@@ -115,7 +140,7 @@ def test_readme():
115
140
  assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2) # Note that small numerical differences are expected
116
141
 
117
142
  # Assert that the output has the correct sparsity layout
118
- actual_sparsity_layout_o = create_sparsity_layout(o_dense, sparsity_block_size, triton_block_size=triton_block_size)
143
+ actual_sparsity_layout_o = build_sparsity_layout(o_dense, sparsity_block_size, triton_block_size=triton_block_size)
119
144
  assert torch.allclose(actual_sparsity_layout_o, sparsity_layout_o)
120
145
 
121
146
  # Convert output tensor back to original shape
@@ -2,16 +2,21 @@
2
2
 
3
3
  ## Overview
4
4
 
5
- A lightweight library for operations on blocksparse matrices in PyTorch.
5
+ A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
6
6
 
7
7
  Currently supported operations (includes gradient calculation):
8
8
 
9
- - Sparse matrix multiplication (_supports any combination of sparse and dense matrices due to support for `sparse = sparse @ sparse` matmul_)
9
+ - Sparse matrix multiplication (_supports any combination of sparse and dense matrices due to support
10
+ for `sparse = sparse @ sparse` matmul_)
10
11
  - Softmax
11
12
  - Transposition
12
- - Conversion from and to sparse form
13
+ - Gather
14
+ - Scatter (_supports either no reduction or summation, gradients are only available for summation_)
15
+ - Conversion to and from sparse form
16
+ - Conversion to different sparsity layouts and different sparsity block sizes
13
17
 
14
- As with this library sparse matrices are represented using a tuple of `(matrix, sparsity_layout, sparsity_block_size)`, any element-wise operations can be applied in regular torch-like fashion.
18
+ As with this library sparse matrices are represented using a tuple of `(matrix, sparsity_layout, sparsity_block_size)`,
19
+ any element-wise operations can be applied in regular torch-like fashion.
15
20
  These include, e.g.,
16
21
 
17
22
  - Element-wise addition and subtraction
@@ -19,24 +24,45 @@ These include, e.g.,
19
24
  - Element-wise exponentiation
20
25
  - ...
21
26
 
27
+ Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to
28
+ match.
29
+
30
+ Furthermore, the library provides a set of utility functions for the creation of sparsity layouts based on existing
31
+ dense tensors.
32
+
22
33
  ## Installation
23
34
 
35
+ Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is only compatible with
36
+ the Linux platform.
37
+
24
38
  We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) using pip:
25
39
 
26
40
  ```pip install blksprs```
27
41
 
42
+ ### Dependencies
43
+
44
+ - [PyTorch](https://pytorch.org/) (built with v2.4.0)
45
+ - _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
46
+
28
47
  ## Changelog
29
48
 
30
49
  See [`CHANGELOG.md`](https://github.com/FelixSchoen/blksprs/blob/main/CHANGELOG.md) for a detailed changelog.
31
50
 
32
51
  ## Usage
33
52
 
53
+ We provide an example below to demonstrate the usage of the library.
54
+ For more detailed examples, please refer to
55
+ the [test cases](https://github.com/FelixSchoen/blksprs/blob/main/test/cases/test_blocksparse.py) which cover all
56
+ implemented operations and functions.
57
+ The example below can also be found in
58
+ the [test cases](https://github.com/FelixSchoen/blksprs/blob/main/test/cases/test_readme.py).
59
+
34
60
  ```python
35
61
  import torch
36
62
 
37
- from blksprs.layouting.sparsity_layout import create_sparsity_layout
63
+ from blksprs.layouting.sparsity_layout import build_sparsity_layout
38
64
  from blksprs.ops.conversion import to_sparse, to_dense
39
- from blksprs.ops.matmul_sss import matmul_sss
65
+ from blksprs.ops.matmul import matmul
40
66
  from blksprs.ops.row_wise_sum import row_wise_sum
41
67
  from blksprs.ops.softmax import softmax
42
68
  from blksprs.ops.transpose import transpose
@@ -44,7 +70,7 @@ from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse
44
70
 
45
71
 
46
72
  def test_readme():
47
- # Set up parameters
73
+ # Set up parameters (batch size, number of heads, dimensions for matrices (m, k) and (n, k))
48
74
  b, h, m, n, k = 2, 4, 64, 64, 16
49
75
 
50
76
  # Percentage of blocks that will be sparse in the output for demonstration purposes
@@ -57,7 +83,6 @@ def test_readme():
57
83
  # If it is set to ``none`` a value will be chosen automatically
58
84
  triton_block_size = None
59
85
 
60
-
61
86
  # Initialise random (dense) tensors
62
87
  x = torch.randn(size=(b, h, m, k), device="cuda")
63
88
  y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()
@@ -67,8 +92,8 @@ def test_readme():
67
92
  y_dense, y_shape_original = do_shape_blocksparse(y)
68
93
 
69
94
  # Create sparsity layouts from existing tensors
70
- sparsity_layout_x = create_sparsity_layout(x_dense, sparsity_block_size, triton_block_size=triton_block_size)
71
- sparsity_layout_y = create_sparsity_layout(y_dense, sparsity_block_size, triton_block_size=triton_block_size)
95
+ sparsity_layout_x = build_sparsity_layout(x_dense, sparsity_block_size, triton_block_size=triton_block_size)
96
+ sparsity_layout_y = build_sparsity_layout(y_dense, sparsity_block_size, triton_block_size=triton_block_size)
72
97
 
73
98
  # Create random sparsity layout for output tensor
74
99
  sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)
@@ -78,8 +103,8 @@ def test_readme():
78
103
  y_sparse = to_sparse(y_dense, sparsity_layout_y, sparsity_block_size, triton_block_size=triton_block_size)
79
104
 
80
105
  # Perform matrix multiplication
81
- o_sparse = matmul_sss(x_sparse, y_sparse, sparsity_layout_x, sparsity_layout_y, sparsity_layout_o,
82
- sparsity_block_size, triton_block_size=triton_block_size)
106
+ o_sparse = matmul(x_sparse, sparsity_layout_x, y_sparse, sparsity_layout_y, sparsity_layout_o, sparsity_block_size,
107
+ triton_block_size=triton_block_size)
83
108
  o_dense = to_dense(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
84
109
 
85
110
  # Sanity check
@@ -94,7 +119,7 @@ def test_readme():
94
119
  assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2) # Note that small numerical differences are expected
95
120
 
96
121
  # Assert that the output has the correct sparsity layout
97
- actual_sparsity_layout_o = create_sparsity_layout(o_dense, sparsity_block_size, triton_block_size=triton_block_size)
122
+ actual_sparsity_layout_o = build_sparsity_layout(o_dense, sparsity_block_size, triton_block_size=triton_block_size)
98
123
  assert torch.allclose(actual_sparsity_layout_o, sparsity_layout_o)
99
124
 
100
125
  # Convert output tensor back to original shape
@@ -0,0 +1,114 @@
1
+ import torch
2
+ import triton
3
+ from torch import Tensor
4
+ from triton import language as tl
5
+
6
+ from blksprs.utils.tools import get_triton_block_size
7
+ from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
8
+ validate_contiguous
9
+
10
+
11
+ def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
12
+ size_target: torch.Size,
13
+ sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
14
+ """Builds the sparsity layout of either the source of a gather or the target of a scatter operation.
15
+
16
+ Args:
17
+ indices (Tensor): The block-sparse indices tensor in compressed form used for the gather or scatter operation.
18
+ sparsity_layout_indices (Tensor): The sparsity layout of the indices block-sparse tensor.
19
+ size_target (torch.Size): The size of the block-sparse target tensor in regular form.
20
+ sparsity_block_size (int): The size of the sparsity blocks.
21
+ triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
22
+
23
+ Returns:
24
+ Tensor: The sparsity layout of the source or target tensor.
25
+
26
+ """
27
+ validate_dimensions(indices)
28
+ validate_contiguous(indices)
29
+ validate_device(indices)
30
+
31
+ sparsity_lut_i = torch.nonzero(sparsity_layout_indices).contiguous()
32
+
33
+ output = torch.zeros(size_target[0], size_target[1] // sparsity_block_size, size_target[2] // sparsity_block_size,
34
+ device=indices.device, dtype=torch.int32)
35
+
36
+ i_b, i_r, i_c = indices.size()
37
+ i_b_s, i_r_s, i_c_s = indices.stride()
38
+ s_l_i_b, s_l_i_r, s_l_i_c = sparsity_layout_indices.size()
39
+ s_l_i_b_s, s_l_i_r_s, s_l_i_c_s = sparsity_layout_indices.stride()
40
+ s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
41
+ s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
42
+ o_b, o_r, o_c = output.size()
43
+ o_b_s, o_r_s, o_c_s = output.stride()
44
+
45
+ if triton_block_size is None:
46
+ triton_block_size = get_triton_block_size(sparsity_block_size)
47
+
48
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
49
+
50
+ triton_grid = lambda meta: [i_b,
51
+ triton.cdiv(i_r, meta["TRITON_BLOCK_SIZE"]),
52
+ triton.cdiv(i_c, meta["TRITON_BLOCK_SIZE"])]
53
+
54
+ (kernel_distribution_layout[triton_grid]
55
+ (indices,
56
+ i_b, i_b_s, i_r_s, i_c_s,
57
+ sparsity_layout_indices,
58
+ s_l_i_b, s_l_i_b_s, s_l_i_r, s_l_i_r_s, s_l_i_c, s_l_i_c_s,
59
+ sparsity_lut_i,
60
+ s_lut_i_r, s_lut_i_r_s, s_lut_i_c, s_lut_i_c_s,
61
+ output,
62
+ o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,
63
+ sparsity_block_size,
64
+ triton_block_size))
65
+
66
+ return output
67
+
68
+
69
+ @triton.jit
70
+ def kernel_distribution_layout(i,
71
+ i_b, i_b_s, i_r_s, i_c_s,
72
+ s_l_i,
73
+ s_l_i_b, s_l_i_b_s, s_l_i_r, s_l_i_r_s, s_l_i_c, s_l_i_c_s,
74
+ s_lut_i,
75
+ s_lut_i_r, s_lut_i_r_s, s_lut_i_c, s_lut_i_c_s,
76
+ o,
77
+ o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,
78
+ sparsity_block_size,
79
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
80
+ # Get triton block indices
81
+ pid_blk = tl.program_id(axis=0)
82
+ pid_row = tl.program_id(axis=1)
83
+ pid_col = tl.program_id(axis=2)
84
+
85
+ # Get position of current sparsity block consisting of its batch, row, and column index
86
+ spa_bat_i_idx = (pid_blk * s_lut_i_r_s + 0 * s_lut_i_c_s)
87
+ spa_bat_i_msk = (spa_bat_i_idx < s_lut_i_r * s_lut_i_r_s)
88
+ spa_bat_i = tl.load(s_lut_i + spa_bat_i_idx, mask=spa_bat_i_msk)
89
+
90
+ spa_row_i_idx = (pid_blk * s_lut_i_r_s + 1 * s_lut_i_c_s)
91
+ spa_row_i_msk = (spa_row_i_idx < s_lut_i_r * s_lut_i_r_s)
92
+ spa_row_i = tl.load(s_lut_i + spa_row_i_idx, mask=spa_row_i_msk)
93
+
94
+ blk_i_idx = (pid_blk * i_b_s +
95
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
96
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
97
+ blk_i_msk = (blk_i_idx < i_b * i_b_s)
98
+ blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk)
99
+
100
+ blk_i = blk_i // sparsity_block_size
101
+ blk_v = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), 1, dtype=tl.int32)
102
+
103
+ blk_o_idx = ((spa_bat_i * o_b_s) +
104
+ (spa_row_i * o_r_s) +
105
+ (blk_i * o_c_s))
106
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
107
+ tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
108
+
109
+ # if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
110
+ # blk_o_idx = (pid_bat * o_b_s +
111
+ # (((pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_r_s +
112
+ # ((pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_c_s))
113
+ # blk_o_msk = (blk_o_idx < o_b * o_b_s)
114
+ # tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
@@ -0,0 +1,190 @@
1
+ import math
2
+
3
+ import torch
4
+ import triton
5
+ from torch import Tensor
6
+ from triton import language as tl
7
+
8
+ from blksprs.utils.tools import get_triton_block_size
9
+ from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
10
+ validate_contiguous, validate_sparsity, validate_sparsity_block_size
11
+
12
+
13
+ def build_sparsity_layout(x: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
14
+ """Builds the sparsity layout of a dense tensor in regular form covering its sparse blocks.
15
+
16
+ Args:
17
+ x (Tensor): A block-sparse (or dense) tensor in regular form.
18
+ sparsity_block_size (int): The size of the sparsity blocks.
19
+ triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
20
+
21
+ Returns:
22
+ Tensor: The sparsity layout of the input block-sparse (or dense) tensor.
23
+
24
+ """
25
+ validate_dimensions(x)
26
+ validate_contiguous(x)
27
+ validate_device(x)
28
+
29
+ output = torch.zeros(x.size(0), x.size(1) // sparsity_block_size, x.size(2) // sparsity_block_size,
30
+ device=x.device, dtype=torch.int32)
31
+
32
+ x_b, x_r, x_c = x.size()
33
+ x_b_s, x_r_s, x_c_s = x.stride()
34
+ o_b, o_r, o_c = output.size()
35
+ o_b_s, o_r_s, o_c_s = output.stride()
36
+
37
+ if triton_block_size is None:
38
+ triton_block_size = get_triton_block_size(sparsity_block_size)
39
+
40
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
41
+
42
+ triton_grid = lambda meta: [x_b,
43
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
44
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
45
+
46
+ (kernel_sparsity_layout[triton_grid]
47
+ (x,
48
+ x_b, x_b_s, x_r_s, x_c_s,
49
+ output,
50
+ o_b, o_b_s, o_r_s, o_c_s,
51
+ sparsity_block_size,
52
+ triton_block_size))
53
+
54
+ return output
55
+
56
+
57
+ @triton.jit
58
+ def kernel_sparsity_layout(x,
59
+ x_b, x_b_s, x_r_s, x_c_s,
60
+ o,
61
+ o_b, o_b_s, o_r_s, o_c_s,
62
+ sparsity_block_size,
63
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
64
+ # Get triton block indices
65
+ pid_bat = tl.program_id(axis=0)
66
+ pid_row = tl.program_id(axis=1)
67
+ pid_col = tl.program_id(axis=2)
68
+
69
+ # Load x values
70
+ blk_x_idx = (pid_bat * x_b_s +
71
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
72
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
73
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
74
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
75
+
76
+ # Store sparsity layout value
77
+ if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
78
+ blk_o_idx = (pid_bat * o_b_s +
79
+ (((pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_r_s +
80
+ ((pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_c_s))
81
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
82
+ tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
83
+
84
+
85
+ def build_sparsity_layout_adaption(x: Tensor, sparsity_layout_from: Tensor,
86
+ sparsity_block_size_from: int, sparsity_block_size_to: int,
87
+ triton_block_size: int = None) -> Tensor:
88
+ """Builds the sparsity layout of a block-sparse tensor in compressed form if a different sparsity block size were
89
+ used.
90
+
91
+ Args:
92
+ x (Tensor): A block-sparse tensor in compressed form.
93
+ sparsity_layout_from (Tensor): The sparsity layout of the input block-sparse tensor.
94
+ sparsity_block_size_from (int): The size of the sparsity blocks of the input tensor.
95
+ sparsity_block_size_to (int): The desired size of the sparsity blocks for the resulting layout.
96
+ triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
97
+
98
+ Returns:
99
+ Tensor: The sparsity layout in regular form using the new sparsity block size of the input block-sparse tensor
100
+ in compressed form.
101
+
102
+ """
103
+ validate_dimensions(x)
104
+ validate_contiguous(x, sparsity_layout_from)
105
+ validate_device(x)
106
+ validate_sparsity(sparsity_block_size_from, (x, sparsity_layout_from))
107
+ validate_sparsity_block_size(sparsity_block_size_from, x)
108
+ validate_sparsity_block_size(sparsity_block_size_to)
109
+ min_sparsity_block_size = min(sparsity_block_size_from, sparsity_block_size_to)
110
+ validate_triton_block_size(triton_block_size, min_sparsity_block_size)
111
+
112
+ sparsity_lut = torch.nonzero(sparsity_layout_from).contiguous()
113
+
114
+ validate_contiguous(sparsity_layout_from, sparsity_lut)
115
+
116
+ o_b = sparsity_layout_from.size(0)
117
+ o_r = math.ceil(sparsity_layout_from.size(1) * sparsity_block_size_from // sparsity_block_size_to)
118
+ o_c = math.ceil(sparsity_layout_from.size(2) * sparsity_block_size_from // sparsity_block_size_to)
119
+
120
+ output = torch.zeros(o_b, o_r, o_c, device=x.device, dtype=torch.int32)
121
+
122
+ x_b, x_r, x_c = x.size()
123
+ x_b_s, x_r_s, x_c_s = x.stride()
124
+ s_lut_r, s_lut_c = sparsity_lut.size()
125
+ s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
126
+ o_b_s, o_r_s, o_c_s = output.stride()
127
+
128
+ if triton_block_size is None:
129
+ triton_block_size = get_triton_block_size(sparsity_block_size_from)
130
+
131
+ triton_grid = lambda meta: [x_b,
132
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
133
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
134
+
135
+ (kernel_sparsity_layout_adaption[triton_grid]
136
+ (x,
137
+ x_b, x_b_s, x_r_s, x_c_s,
138
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
139
+ output,
140
+ o_b, o_b_s, o_r_s, o_c_s,
141
+ sparsity_block_size_from,
142
+ sparsity_block_size_to,
143
+ triton_block_size))
144
+
145
+ return output
146
+
147
+
148
+ @triton.jit
149
+ def kernel_sparsity_layout_adaption(x,
150
+ x_b, x_b_s, x_r_s, x_c_s,
151
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
152
+ o,
153
+ o_b, o_b_s, o_r_s, o_c_s,
154
+ sparsity_block_size_from,
155
+ sparsity_block_size_to,
156
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
157
+ # Get triton block indices
158
+ pid_blk = tl.program_id(axis=0)
159
+ pid_row = tl.program_id(axis=1)
160
+ pid_col = tl.program_id(axis=2)
161
+
162
+ # Get sparsity index of current output block consisting of its batch, row, and column index
163
+ spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
164
+ spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
165
+ spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
166
+
167
+ spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
168
+ spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
169
+ spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
170
+
171
+ spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
172
+ spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
173
+ spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
174
+
175
+ # Load x values
176
+ blk_x_idx = ((pid_blk * x_b_s) +
177
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
178
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
179
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
180
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
181
+
182
+ # Store sparsity layout value
183
+ if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
184
+ blk_o_idx = ((spa_bat * o_b_s) +
185
+ (((spa_row * sparsity_block_size_from + pid_row * TRITON_BLOCK_SIZE)
186
+ // sparsity_block_size_to) * o_r_s) +
187
+ (((spa_col * sparsity_block_size_from + pid_col * TRITON_BLOCK_SIZE)
188
+ // sparsity_block_size_to) * o_c_s))
189
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
190
+ tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
@@ -0,0 +1,132 @@
1
+ import torch
2
+ import triton
3
+ from torch import Tensor
4
+ from triton import language as tl
5
+
6
+ from blksprs.utils.tools import get_triton_block_size
7
+ from blksprs.utils.validation import validate_contiguous, validate_device, \
8
+ validate_sparsity_block_size, validate_triton_block_size
9
+
10
+
11
+ def broadcast_addition(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
12
+ sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
13
+ """Performs a broadcast and subsequent addition of two dense tensors x and y. Returns a block-sparse tensor in
14
+ compressed form.
15
+
16
+ Args:
17
+ x (Tensor): A dense input tensor.
18
+ y (Tensor): A dense input tensor.
19
+ sparsity_layout_output (Tensor): The sparsity layout of the output tensor.
20
+ sparsity_block_size (int): The size of the sparsity blocks.
21
+ triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
22
+
23
+ Returns:
24
+ Tensor: The result of the operation as a block-sparse tensor in compressed form. Each element o(i, j) of the
25
+ output tensor corresponds to x(i) + y(j).
26
+
27
+ """
28
+ validate_device(x, y)
29
+ validate_contiguous(x, y)
30
+ if x.size(-1) != y.size(-1):
31
+ raise ValueError("Dimensions of tensors must match")
32
+ validate_sparsity_block_size(sparsity_block_size)
33
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
34
+
35
+ sparsity_lut_o = torch.nonzero(sparsity_layout_output).contiguous()
36
+
37
+ n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
38
+
39
+ validate_contiguous(sparsity_layout_output, sparsity_lut_o)
40
+
41
+ output = torch.zeros(n_sparse_blocks, sparsity_block_size, sparsity_block_size, device=x.device)
42
+
43
+ x_b, x_c = x.size()
44
+ x_b_s, x_c_s = x.stride()
45
+ y_b, y_c = y.size()
46
+ y_b_s, y_c_s = y.stride()
47
+ o_b, o_r, o_c = output.size()
48
+ o_b_s, o_r_s, o_c_s = output.stride()
49
+ s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
50
+ s_lut_o_r_s, s_lut_o_c_s = sparsity_lut_o.stride()
51
+
52
+ if triton_block_size is None:
53
+ triton_block_size = get_triton_block_size(sparsity_block_size)
54
+
55
+ triton_grid = lambda meta: [o_b,
56
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
57
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
58
+
59
+ (kernel_broadcast_addition[triton_grid]
60
+ (x,
61
+ x_b, x_b_s, x_c_s,
62
+ y,
63
+ y_b, y_b_s, y_c_s,
64
+ output,
65
+ o_b, o_b_s, o_r_s, o_c_s,
66
+ sparsity_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
67
+ sparsity_block_size,
68
+ triton_block_size))
69
+
70
+ return output
71
+
72
+
73
+ def broadcast_subtraction(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
74
+ sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
75
+ """Wrapper for ``broadcast_addition`` with negated y.
76
+
77
+ """
78
+ return broadcast_addition(x, torch.neg(y), sparsity_layout_output, sparsity_block_size, triton_block_size)
79
+
80
+
81
+ @triton.jit
82
+ def kernel_broadcast_addition(x,
83
+ x_b, x_b_s, x_c_s,
84
+ y,
85
+ y_b, y_b_s, y_c_s,
86
+ o,
87
+ o_b, o_b_s, o_r_s, o_c_s,
88
+ s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
89
+ sparsity_block_size,
90
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
91
+ # Get triton block indices
92
+ pid_blk = tl.program_id(axis=0)
93
+ pid_row = tl.program_id(axis=1)
94
+ pid_col = tl.program_id(axis=2)
95
+
96
+ # Get position of current sparsity block consisting of its batch, row, and column index
97
+ spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
98
+ spa_bat_o_msk = (spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
99
+ spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
100
+
101
+ spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
102
+ spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
103
+ spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
104
+
105
+ spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
106
+ spa_col_o_msk = (spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
107
+ spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
108
+
109
+ # Load x block
110
+ blk_x_idx = (spa_bat_o * x_b_s +
111
+ ((spa_row_o * sparsity_block_size + pid_row * TRITON_BLOCK_SIZE +
112
+ tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
113
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
114
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
115
+
116
+ # Load y block
117
+ blk_y_idx = (spa_bat_o * y_b_s +
118
+ ((spa_col_o * sparsity_block_size + pid_col * TRITON_BLOCK_SIZE +
119
+ tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
120
+ blk_y_msk = (blk_y_idx < y_b * y_b_s)
121
+ blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
122
+
123
+ # Compute sum
124
+ blk_x, blk_y = tl.broadcast(tl.trans(blk_x), blk_y)
125
+ buf = blk_x + blk_y
126
+
127
+ # Store result
128
+ blk_o_idx = ((pid_blk * o_b_s) +
129
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
130
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
131
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
132
+ tl.store(o + blk_o_idx, buf, mask=blk_o_msk)