blksprs 1.2__tar.gz → 1.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,11 +1,11 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: blksprs
3
- Version: 1.2
3
+ Version: 1.3
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
7
7
  Project-URL: Bugtracker, https://github.com/FelixSchoen/blksprs/issues
8
- Requires-Python: >=3.12
8
+ Requires-Python: >=3.11
9
9
  Description-Content-Type: text/markdown
10
10
  Requires-Dist: torch
11
11
  Provides-Extra: test
@@ -21,6 +21,9 @@ Requires-Dist: pdoc3; extra == "deploy"
21
21
 
22
22
  # blksprs
23
23
 
24
+ [![GitHub Release](https://img.shields.io/github/v/release/FelixSchoen/blksprs?include_prereleases&label=Latest%20Release)](https://github.com/FelixSchoen/blksprs/releases)
25
+ [![Python Version](https://img.shields.io/badge/Python%20Version-3.11-blue)](https://www.python.org/downloads/release/python-3119/)
26
+
24
27
  ## Overview
25
28
 
26
29
  A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
@@ -1,5 +1,8 @@
1
1
  # blksprs
2
2
 
3
+ [![GitHub Release](https://img.shields.io/github/v/release/FelixSchoen/blksprs?include_prereleases&label=Latest%20Release)](https://github.com/FelixSchoen/blksprs/releases)
4
+ [![Python Version](https://img.shields.io/badge/Python%20Version-3.11-blue)](https://www.python.org/downloads/release/python-3119/)
5
+
3
6
  ## Overview
4
7
 
5
8
  A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
@@ -0,0 +1,130 @@
1
+ import torch
2
+ import triton
3
+ from torch import Tensor
4
+ from triton import language as tl
5
+
6
+ from blksprs.utils.tools import get_triton_block_size
7
+ from blksprs.utils.validation import validate_contiguous, validate_device, \
8
+ validate_sparsity_block_size, validate_triton_block_size, validate_dimensions
9
+
10
+
11
+ def repeat_interleave(x: Tensor, sparsity_layout: Tensor, repeats: int,
12
+ sparsity_block_size: int, triton_block_size: int = None) -> tuple[Tensor, Tensor]:
13
+ """Repeats and interleaves the block-sparse tensor in compressed form.
14
+
15
+ Repeats each matrix contained in the tensors by ``repeats`` amount and places them consecutively in the output
16
+ tensor.
17
+
18
+ Args:
19
+ x (Tensor): A block-sparse tensor in compressed form.
20
+ sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
21
+ repeats (int): The number of times to repeat the matrices.
22
+ sparsity_block_size (int): The size of the sparsity blocks.
23
+ triton_block_size (int): The block size to use for the triton kernel (default ``None``).
24
+
25
+ Returns:
26
+ Tensor: A block-sparse tensor in compressed form containing the repeated and interleaved matrices.
27
+ Tensor: The sparsity layout of the resulting output tensor.
28
+
29
+ """
30
+ validate_dimensions(x)
31
+ validate_contiguous(x)
32
+ validate_device(x)
33
+ validate_sparsity_block_size(sparsity_block_size, x)
34
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
35
+
36
+ sparsity_layout_output = torch.repeat_interleave(sparsity_layout, 3, dim=0).contiguous()
37
+
38
+ sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
39
+
40
+ sparsity_layout_output_flat = sparsity_layout_output.reshape(-1)
41
+ sparsity_output_reverse_lut = ((torch.cumsum(sparsity_layout_output_flat, dim=-1) - 1) *
42
+ (sparsity_layout_output_flat == 1) -
43
+ (1 * (sparsity_layout_output_flat == 0)))
44
+
45
+ n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
46
+
47
+ validate_contiguous(sparsity_layout, sparsity_lut, sparsity_layout_output, sparsity_output_reverse_lut)
48
+
49
+ output = torch.empty(n_sparse_blocks * repeats, sparsity_block_size, sparsity_block_size,
50
+ dtype=x.dtype, device=x.device)
51
+
52
+ x_b, x_r, x_c = x.size()
53
+ x_b_s, x_r_s, x_c_s = x.stride()
54
+ s_lut_r, s_lut_c = sparsity_lut.size()
55
+ s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
56
+ o_b, o_r, o_c = output.size()
57
+ o_b_s, o_r_s, o_c_s = output.stride()
58
+ s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
59
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_output.stride()
60
+
61
+ if triton_block_size is None:
62
+ triton_block_size = get_triton_block_size(sparsity_block_size)
63
+
64
+ triton_grid = lambda meta: [x_b,
65
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
66
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
67
+
68
+ (kernel_repeat_interleave[triton_grid]
69
+ (x,
70
+ x_b, x_b_s, x_r_s, x_c_s,
71
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
72
+ output,
73
+ o_b, o_b_s, o_r_s, o_c_s,
74
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
75
+ sparsity_output_reverse_lut,
76
+ repeats,
77
+ triton_block_size))
78
+
79
+ return output, sparsity_layout_output
80
+
81
+
82
+ @triton.jit
83
+ def kernel_repeat_interleave(x,
84
+ x_b, x_b_s, x_r_s, x_c_s,
85
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
86
+ o,
87
+ o_b, o_b_s, o_r_s, o_c_s,
88
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
89
+ r_lut_o,
90
+ repeats,
91
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
92
+ # Get triton block indices
93
+ pid_blk = tl.program_id(axis=0)
94
+ pid_row = tl.program_id(axis=1)
95
+ pid_col = tl.program_id(axis=2)
96
+
97
+ # Get sparsity index of current output block consisting of its batch, row, and column index
98
+ spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
99
+ spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
100
+ spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
101
+
102
+ spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
103
+ spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
104
+ spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
105
+
106
+ spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
107
+ spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
108
+ spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
109
+
110
+ # Load block
111
+ blk_x_idx = ((pid_blk * x_b_s) +
112
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
113
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
114
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
115
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
116
+
117
+ for repeat in range(repeats):
118
+ # Get reverse sparsity index
119
+ rev_idx_spa_idx = ((spa_bat * repeats + repeat) * s_l_o_b_s +
120
+ spa_row * s_l_o_r_s +
121
+ spa_col * s_l_o_c_s)
122
+ rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
123
+ rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
124
+
125
+ # Store block
126
+ blk_o_idx = ((rev_idx_spa * o_b_s) +
127
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
128
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
129
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
130
+ tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
@@ -129,7 +129,7 @@ class _BlocksparseTranspose(torch.autograd.Function):
129
129
  spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
130
130
  spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
131
131
 
132
- # Get reverse sparsity indices
132
+ # Get reverse sparsity index
133
133
  rev_idx_spa_idx = (spa_bat * s_l_b_s +
134
134
  spa_row * s_l_r_s +
135
135
  spa_col * s_l_c_s)
@@ -10,7 +10,7 @@ def do_shape_blocksparse(x: Tensor):
10
10
 
11
11
 
12
12
  def undo_shape_blocksparse(x: Tensor, shape: Size):
13
- if x.shape[-2:] == shape[-2:]:
13
+ if x.shape[:-2] == shape[:-2]:
14
14
  return x
15
15
 
16
16
  return x.reshape((*shape[:-2], *x.shape[-2:]))
@@ -1,11 +1,11 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: blksprs
3
- Version: 1.2
3
+ Version: 1.3
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
7
7
  Project-URL: Bugtracker, https://github.com/FelixSchoen/blksprs/issues
8
- Requires-Python: >=3.12
8
+ Requires-Python: >=3.11
9
9
  Description-Content-Type: text/markdown
10
10
  Requires-Dist: torch
11
11
  Provides-Extra: test
@@ -21,6 +21,9 @@ Requires-Dist: pdoc3; extra == "deploy"
21
21
 
22
22
  # blksprs
23
23
 
24
+ [![GitHub Release](https://img.shields.io/github/v/release/FelixSchoen/blksprs?include_prereleases&label=Latest%20Release)](https://github.com/FelixSchoen/blksprs/releases)
25
+ [![Python Version](https://img.shields.io/badge/Python%20Version-3.11-blue)](https://www.python.org/downloads/release/python-3119/)
26
+
24
27
  ## Overview
25
28
 
26
29
  A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
@@ -8,6 +8,7 @@ blksprs.egg-info/top_level.txt
8
8
  blksprs/layouting/distribution_layout.py
9
9
  blksprs/layouting/sparsity_layout.py
10
10
  blksprs/misc/broadcast_addition.py
11
+ blksprs/misc/repeat_interleave.py
11
12
  blksprs/ops/conversion.py
12
13
  blksprs/ops/distribution.py
13
14
  blksprs/ops/exp.py
@@ -1,10 +1,10 @@
1
1
  [project]
2
2
  name = "blksprs"
3
- version = "1.2"
3
+ version = "1.3"
4
4
  authors = [{ name = "Felix Schön", email = "schoen@kr.tuwien.ac.at" }]
5
5
  description = "A lightweight library for operations on blocksparse matrices in PyTorch."
6
6
  readme = "README.md"
7
- requires-python = ">=3.12"
7
+ requires-python = ">=3.11"
8
8
  license = { file = "LICENSE.md" }
9
9
  dependencies = [
10
10
  "torch"
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes