blksprs 1.1__tar.gz → 1.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,11 +1,11 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: blksprs
3
- Version: 1.1
3
+ Version: 1.2.1
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
7
7
  Project-URL: Bugtracker, https://github.com/FelixSchoen/blksprs/issues
8
- Requires-Python: >=3.12
8
+ Requires-Python: >=3.11
9
9
  Description-Content-Type: text/markdown
10
10
  Requires-Dist: torch
11
11
  Provides-Extra: test
@@ -21,6 +21,9 @@ Requires-Dist: pdoc3; extra == "deploy"
21
21
 
22
22
  # blksprs
23
23
 
24
+ [![GitHub Release](https://img.shields.io/github/v/release/FelixSchoen/blksprs?include_prereleases&label=Latest%20Release)](https://github.com/FelixSchoen/blksprs/releases)
25
+ [![Python Version](https://img.shields.io/badge/Python%20Version-3.11-blue)](https://www.python.org/downloads/release/python-3119/)
26
+
24
27
  ## Overview
25
28
 
26
29
  A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
@@ -33,7 +36,8 @@ Currently supported operations (includes gradient calculation):
33
36
  - Transposition
34
37
  - Gather
35
38
  - Scatter (_supports either no reduction or summation, gradients are only available for summation_)
36
- - Conversion from and to sparse form
39
+ - Conversion to and from sparse form
40
+ - Conversion to different sparsity layouts and different sparsity block sizes
37
41
 
38
42
  As with this library sparse matrices are represented using a tuple of `(matrix, sparsity_layout, sparsity_block_size)`,
39
43
  any element-wise operations can be applied in regular torch-like fashion.
@@ -59,6 +63,11 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
59
63
 
60
64
  ```pip install blksprs```
61
65
 
66
+ ### Dependencies
67
+
68
+ - [PyTorch](https://pytorch.org/) (built with v2.4.0)
69
+ - _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
70
+
62
71
  ## Changelog
63
72
 
64
73
  See [`CHANGELOG.md`](https://github.com/FelixSchoen/blksprs/blob/main/CHANGELOG.md) for a detailed changelog.
@@ -1,5 +1,8 @@
1
1
  # blksprs
2
2
 
3
+ [![GitHub Release](https://img.shields.io/github/v/release/FelixSchoen/blksprs?include_prereleases&label=Latest%20Release)](https://github.com/FelixSchoen/blksprs/releases)
4
+ [![Python Version](https://img.shields.io/badge/Python%20Version-3.11-blue)](https://www.python.org/downloads/release/python-3119/)
5
+
3
6
  ## Overview
4
7
 
5
8
  A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
@@ -12,7 +15,8 @@ Currently supported operations (includes gradient calculation):
12
15
  - Transposition
13
16
  - Gather
14
17
  - Scatter (_supports either no reduction or summation, gradients are only available for summation_)
15
- - Conversion from and to sparse form
18
+ - Conversion to and from sparse form
19
+ - Conversion to different sparsity layouts and different sparsity block sizes
16
20
 
17
21
  As with this library sparse matrices are represented using a tuple of `(matrix, sparsity_layout, sparsity_block_size)`,
18
22
  any element-wise operations can be applied in regular torch-like fashion.
@@ -38,6 +42,11 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
38
42
 
39
43
  ```pip install blksprs```
40
44
 
45
+ ### Dependencies
46
+
47
+ - [PyTorch](https://pytorch.org/) (built with v2.4.0)
48
+ - _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
49
+
41
50
  ## Changelog
42
51
 
43
52
  See [`CHANGELOG.md`](https://github.com/FelixSchoen/blksprs/blob/main/CHANGELOG.md) for a detailed changelog.
@@ -0,0 +1,190 @@
1
+ import math
2
+
3
+ import torch
4
+ import triton
5
+ from torch import Tensor
6
+ from triton import language as tl
7
+
8
+ from blksprs.utils.tools import get_triton_block_size
9
+ from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
10
+ validate_contiguous, validate_sparsity, validate_sparsity_block_size
11
+
12
+
13
+ def build_sparsity_layout(x: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
14
+ """Builds the sparsity layout of a dense tensor in regular form covering its sparse blocks.
15
+
16
+ Args:
17
+ x (Tensor): A block-sparse (or dense) tensor in regular form.
18
+ sparsity_block_size (int): The size of the sparsity blocks.
19
+ triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
20
+
21
+ Returns:
22
+ Tensor: The sparsity layout of the input block-sparse (or dense) tensor.
23
+
24
+ """
25
+ validate_dimensions(x)
26
+ validate_contiguous(x)
27
+ validate_device(x)
28
+
29
+ output = torch.zeros(x.size(0), x.size(1) // sparsity_block_size, x.size(2) // sparsity_block_size,
30
+ device=x.device, dtype=torch.int32)
31
+
32
+ x_b, x_r, x_c = x.size()
33
+ x_b_s, x_r_s, x_c_s = x.stride()
34
+ o_b, o_r, o_c = output.size()
35
+ o_b_s, o_r_s, o_c_s = output.stride()
36
+
37
+ if triton_block_size is None:
38
+ triton_block_size = get_triton_block_size(sparsity_block_size)
39
+
40
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
41
+
42
+ triton_grid = lambda meta: [x_b,
43
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
44
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
45
+
46
+ (kernel_sparsity_layout[triton_grid]
47
+ (x,
48
+ x_b, x_b_s, x_r_s, x_c_s,
49
+ output,
50
+ o_b, o_b_s, o_r_s, o_c_s,
51
+ sparsity_block_size,
52
+ triton_block_size))
53
+
54
+ return output
55
+
56
+
57
+ @triton.jit
58
+ def kernel_sparsity_layout(x,
59
+ x_b, x_b_s, x_r_s, x_c_s,
60
+ o,
61
+ o_b, o_b_s, o_r_s, o_c_s,
62
+ sparsity_block_size,
63
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
64
+ # Get triton block indices
65
+ pid_bat = tl.program_id(axis=0)
66
+ pid_row = tl.program_id(axis=1)
67
+ pid_col = tl.program_id(axis=2)
68
+
69
+ # Load x values
70
+ blk_x_idx = (pid_bat * x_b_s +
71
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
72
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
73
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
74
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
75
+
76
+ # Store sparsity layout value
77
+ if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
78
+ blk_o_idx = (pid_bat * o_b_s +
79
+ (((pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_r_s +
80
+ ((pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_c_s))
81
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
82
+ tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
83
+
84
+
85
+ def build_sparsity_layout_adaption(x: Tensor, sparsity_layout_from: Tensor,
86
+ sparsity_block_size_from: int, sparsity_block_size_to: int,
87
+ triton_block_size: int = None) -> Tensor:
88
+ """Builds the sparsity layout of a block-sparse tensor in compressed form if a different sparsity block size were
89
+ used.
90
+
91
+ Args:
92
+ x (Tensor): A block-sparse tensor in compressed form.
93
+ sparsity_layout_from (Tensor): The sparsity layout of the input block-sparse tensor.
94
+ sparsity_block_size_from (int): The size of the sparsity blocks of the input tensor.
95
+ sparsity_block_size_to (int): The desired size of the sparsity blocks for the resulting layout.
96
+ triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
97
+
98
+ Returns:
99
+ Tensor: The sparsity layout in regular form using the new sparsity block size of the input block-sparse tensor
100
+ in compressed form.
101
+
102
+ """
103
+ validate_dimensions(x)
104
+ validate_contiguous(x, sparsity_layout_from)
105
+ validate_device(x)
106
+ validate_sparsity(sparsity_block_size_from, (x, sparsity_layout_from))
107
+ validate_sparsity_block_size(sparsity_block_size_from, x)
108
+ validate_sparsity_block_size(sparsity_block_size_to)
109
+ min_sparsity_block_size = min(sparsity_block_size_from, sparsity_block_size_to)
110
+ validate_triton_block_size(triton_block_size, min_sparsity_block_size)
111
+
112
+ sparsity_lut = torch.nonzero(sparsity_layout_from).contiguous()
113
+
114
+ validate_contiguous(sparsity_layout_from, sparsity_lut)
115
+
116
+ o_b = sparsity_layout_from.size(0)
117
+ o_r = math.ceil(sparsity_layout_from.size(1) * sparsity_block_size_from // sparsity_block_size_to)
118
+ o_c = math.ceil(sparsity_layout_from.size(2) * sparsity_block_size_from // sparsity_block_size_to)
119
+
120
+ output = torch.zeros(o_b, o_r, o_c, device=x.device, dtype=torch.int32)
121
+
122
+ x_b, x_r, x_c = x.size()
123
+ x_b_s, x_r_s, x_c_s = x.stride()
124
+ s_lut_r, s_lut_c = sparsity_lut.size()
125
+ s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
126
+ o_b_s, o_r_s, o_c_s = output.stride()
127
+
128
+ if triton_block_size is None:
129
+ triton_block_size = get_triton_block_size(sparsity_block_size_from)
130
+
131
+ triton_grid = lambda meta: [x_b,
132
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
133
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
134
+
135
+ (kernel_sparsity_layout_adaption[triton_grid]
136
+ (x,
137
+ x_b, x_b_s, x_r_s, x_c_s,
138
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
139
+ output,
140
+ o_b, o_b_s, o_r_s, o_c_s,
141
+ sparsity_block_size_from,
142
+ sparsity_block_size_to,
143
+ triton_block_size))
144
+
145
+ return output
146
+
147
+
148
+ @triton.jit
149
+ def kernel_sparsity_layout_adaption(x,
150
+ x_b, x_b_s, x_r_s, x_c_s,
151
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
152
+ o,
153
+ o_b, o_b_s, o_r_s, o_c_s,
154
+ sparsity_block_size_from,
155
+ sparsity_block_size_to,
156
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
157
+ # Get triton block indices
158
+ pid_blk = tl.program_id(axis=0)
159
+ pid_row = tl.program_id(axis=1)
160
+ pid_col = tl.program_id(axis=2)
161
+
162
+ # Get sparsity index of current output block consisting of its batch, row, and column index
163
+ spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
164
+ spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
165
+ spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
166
+
167
+ spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
168
+ spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
169
+ spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
170
+
171
+ spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
172
+ spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
173
+ spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
174
+
175
+ # Load x values
176
+ blk_x_idx = ((pid_blk * x_b_s) +
177
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
178
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
179
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
180
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
181
+
182
+ # Store sparsity layout value
183
+ if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
184
+ blk_o_idx = ((spa_bat * o_b_s) +
185
+ (((spa_row * sparsity_block_size_from + pid_row * TRITON_BLOCK_SIZE)
186
+ // sparsity_block_size_to) * o_r_s) +
187
+ (((spa_col * sparsity_block_size_from + pid_col * TRITON_BLOCK_SIZE)
188
+ // sparsity_block_size_to) * o_c_s))
189
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
190
+ tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
@@ -1,8 +1,11 @@
1
+ from typing import Any
2
+
1
3
  import torch
2
4
  import triton
3
5
  from torch import Tensor
4
6
  from triton import language as tl
5
7
 
8
+ from blksprs.layouting.sparsity_layout import build_sparsity_layout_adaption
6
9
  from blksprs.utils.tools import get_triton_block_size
7
10
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
8
11
  validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
@@ -39,6 +42,9 @@ def to_dense(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_
39
42
 
40
43
  validate_contiguous(sparsity_reverse_lut)
41
44
 
45
+ if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
46
+ return x
47
+
42
48
  return _BlocksparseToDense.apply(x,
43
49
  sparsity_layout, sparsity_reverse_lut,
44
50
  sparsity_block_size, fill_value,
@@ -161,6 +167,9 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, trit
161
167
 
162
168
  validate_contiguous(sparsity_layout, sparsity_lut)
163
169
 
170
+ if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
171
+ return x
172
+
164
173
  return _BlocksparseToSparse.apply(x,
165
174
  sparsity_layout, sparsity_lut,
166
175
  sparsity_block_size, n_sparse_blocks,
@@ -178,10 +187,10 @@ class _BlocksparseToSparse(torch.autograd.Function):
178
187
 
179
188
  x_b, x_r, x_c = x.size()
180
189
  x_b_s, x_r_s, x_c_s = x.stride()
181
- o_b, o_r, o_c = output.size()
182
- o_b_s, o_r_s, o_c_s = output.stride()
183
190
  s_lut_r, s_lut_c = sparsity_lut.size()
184
191
  s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
192
+ o_b, o_r, o_c = output.size()
193
+ o_b_s, o_r_s, o_c_s = output.stride()
185
194
 
186
195
  if triton_block_size is None:
187
196
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -254,3 +263,189 @@ class _BlocksparseToSparse(torch.autograd.Function):
254
263
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
255
264
  blk_o_msk = (blk_o_idx < (pid_blk + 1) * o_b_s)
256
265
  tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
266
+
267
+
268
+ def adapt_layout(x: Tensor, sparsity_layout_from: Tensor, sparsity_block_size_from: int, sparsity_block_size_to: int,
269
+ preprocess_data: dict = None, triton_block_size: int = None) -> Tensor:
270
+ """Adapts the sparsity layout of a block-sparse tensor, resulting in a new block-sparse tensor in compressed form
271
+ conforming to the new sparsity layout (and sparsity block size) definition.
272
+
273
+ Args:
274
+ x (Tensor): A block-sparse tensor in compressed form.
275
+ sparsity_layout_from (Tensor): The sparsity layout of the input block-sparse tensor.
276
+ sparsity_block_size_from (int): The size of the sparsity blocks of the input sparsity layout.
277
+ sparsity_block_size_to (int): The size of the sparsity blocks of the output sparsity layout.
278
+ preprocess_data (dict): A dictionary containing data otherwise computed by the function (default ``None``).
279
+ triton_block_size (int): The block size to use for the triton kernel (default ``None``).
280
+
281
+ Returns:
282
+ Tensor: The block-sparse tensor in compressed form with the adapted sparsity layout and sparsity block size.
283
+
284
+ """
285
+ validate_dimensions(x)
286
+ validate_contiguous(x, sparsity_layout_from)
287
+ validate_device(x)
288
+ validate_sparsity(sparsity_block_size_from, (x, sparsity_layout_from))
289
+ validate_sparsity_block_size(sparsity_block_size_from, x)
290
+ validate_sparsity_block_size(sparsity_block_size_to)
291
+ min_sparsity_block_size = min(sparsity_block_size_from, sparsity_block_size_to)
292
+ validate_triton_block_size(triton_block_size, min_sparsity_block_size)
293
+
294
+ if preprocess_data is None:
295
+ preprocess_data = {}
296
+
297
+ if "sparsity_reverse_lut_from" not in preprocess_data:
298
+ sparsity_layout_from_flat = sparsity_layout_from.reshape(-1)
299
+ sparsity_reverse_lut_from = ((torch.cumsum(sparsity_layout_from_flat, dim=-1) - 1) *
300
+ (sparsity_layout_from_flat == 1) -
301
+ (1 * (sparsity_layout_from_flat == 0)))
302
+ else:
303
+ sparsity_reverse_lut_from = preprocess_data["sparsity_reverse_lut_from"]
304
+
305
+ if "sparsity_layout_to" not in preprocess_data:
306
+ sparsity_layout_to = build_sparsity_layout_adaption(x, sparsity_layout_from,
307
+ sparsity_block_size_from, sparsity_block_size_to,
308
+ triton_block_size)
309
+ else:
310
+ sparsity_layout_to = preprocess_data["sparsity_layout_to"]
311
+
312
+ if "sparsity_lut_to" not in preprocess_data:
313
+ sparsity_lut_to = torch.nonzero(sparsity_layout_to).contiguous()
314
+ else:
315
+ sparsity_lut_to = preprocess_data["sparsity_lut_to"]
316
+
317
+ if "n_sparse_blocks_to" not in preprocess_data:
318
+ n_sparse_blocks_to = torch.sum(sparsity_layout_to.to(torch.int)).item()
319
+ else:
320
+ n_sparse_blocks_to = preprocess_data["n_sparse_blocks_to"]
321
+
322
+ validate_contiguous(sparsity_layout_to, sparsity_reverse_lut_from, sparsity_lut_to)
323
+
324
+ if (sparsity_block_size_from == sparsity_block_size_to) and torch.equal(sparsity_layout_from, sparsity_layout_to):
325
+ return x
326
+
327
+ return _BlocksparseAdaptLayout.apply(x,
328
+ sparsity_layout_from, sparsity_reverse_lut_from, sparsity_block_size_from,
329
+ sparsity_layout_to, sparsity_lut_to, sparsity_block_size_to,
330
+ n_sparse_blocks_to, min_sparsity_block_size, triton_block_size)
331
+
332
+
333
+ class _BlocksparseAdaptLayout(torch.autograd.Function):
334
+
335
+ @staticmethod
336
+ def forward(ctx, x: Tensor,
337
+ sparsity_layout_from: Tensor, sparsity_reverse_lut_from: Tensor, sparsity_block_size_from: int,
338
+ sparsity_layout_to: Tensor, sparsity_lut_to: Tensor, sparsity_block_size_to: int,
339
+ n_sparse_blocks_to: int, min_sparsity_block_size: int, triton_block_size: int) -> Tensor:
340
+ output = torch.zeros(size=(n_sparse_blocks_to, sparsity_block_size_to, sparsity_block_size_to),
341
+ dtype=x.dtype, device=x.device)
342
+
343
+ x_b, x_r, x_c = x.size()
344
+ x_b_s, x_r_s, x_c_s = x.stride()
345
+ s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_from.size()
346
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout_from.stride()
347
+ o_b, o_r, o_c = output.size()
348
+ o_b_s, o_r_s, o_c_s = output.stride()
349
+ s_lut_o_r, s_lut_o_c = sparsity_lut_to.size()
350
+ s_lut_o_r_s, s_lut_o_c_s = sparsity_lut_to.stride()
351
+
352
+ if triton_block_size is None:
353
+ triton_block_size = get_triton_block_size(min_sparsity_block_size)
354
+
355
+ triton_grid = lambda meta: [o_b,
356
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
357
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
358
+
359
+ (_BlocksparseAdaptLayout.kernel_adapt_layout[triton_grid]
360
+ (x,
361
+ x_b, x_b_s, x_r_s, x_c_s,
362
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
363
+ sparsity_reverse_lut_from,
364
+ output,
365
+ o_b, o_b_s, o_r_s, o_c_s,
366
+ sparsity_lut_to, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
367
+ sparsity_block_size_from,
368
+ sparsity_block_size_to,
369
+ triton_block_size))
370
+
371
+ ctx.save_for_backward(x, sparsity_layout_from, sparsity_layout_to)
372
+ ctx.sparsity_block_size_from = sparsity_block_size_from
373
+ ctx.sparsity_block_size_to = sparsity_block_size_to
374
+ ctx.triton_block_size = triton_block_size
375
+
376
+ return output
377
+
378
+ @staticmethod
379
+ def backward(ctx, grad_output):
380
+ x, sparsity_layout_from, sparsity_layout_to = ctx.saved_tensors
381
+ sparsity_block_size_from = ctx.sparsity_block_size_from
382
+ sparsity_block_size_to = ctx.sparsity_block_size_to
383
+ triton_block_size = ctx.triton_block_size
384
+
385
+ return adapt_layout(grad_output, sparsity_layout_to, sparsity_block_size_to, sparsity_block_size_from,
386
+ preprocess_data={"sparsity_layout_to": sparsity_layout_from},
387
+ triton_block_size=triton_block_size), None, None, None, None, None, None, None, None, None
388
+
389
+ @staticmethod
390
+ @triton.jit
391
+ def kernel_adapt_layout(x,
392
+ x_b, x_b_s, x_r_s, x_c_s,
393
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
394
+ r_lut_x,
395
+ o,
396
+ o_b, o_b_s, o_r_s, o_c_s,
397
+ s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
398
+ sparsity_block_size_from,
399
+ sparsity_block_size_to,
400
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
401
+ # Get triton block indices
402
+ pid_blk = tl.program_id(axis=0)
403
+ pid_row = tl.program_id(axis=1)
404
+ pid_col = tl.program_id(axis=2)
405
+
406
+ # Get position of current sparsity block consisting of its batch, row, and column index
407
+ spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
408
+ spa_bat_o_msk = (spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
409
+ spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
410
+
411
+ spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
412
+ spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
413
+ spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
414
+
415
+ spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
416
+ spa_col_o_msk = (spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
417
+ spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
418
+
419
+ # Get equivalent sparsity block in from layout
420
+ spa_bat_x = spa_bat_o
421
+ spa_row_x = (spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size_from
422
+ spa_col_x = (spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size_from
423
+
424
+ # # Get reverse sparsity indices for x
425
+ rev_idx_spa_x_idx = (spa_bat_x * s_l_x_b_s +
426
+ spa_row_x * s_l_x_r_s +
427
+ spa_col_x * s_l_x_c_s)
428
+ rev_idx_spa_x_msk = (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
429
+ rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
430
+
431
+ # If block is present commence operations
432
+ if rev_idx_spa_x >= 0:
433
+ # Calculate triton block size shifts
434
+ shift_row_x = ((spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE)
435
+ % sparsity_block_size_from) // TRITON_BLOCK_SIZE
436
+ shift_col_x = ((spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE)
437
+ % sparsity_block_size_from) // TRITON_BLOCK_SIZE
438
+
439
+ # Load x values
440
+ blk_x_idx = ((rev_idx_spa_x * x_b_s) +
441
+ ((shift_row_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
442
+ ((shift_col_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
443
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
444
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
445
+
446
+ # Store output
447
+ blk_o_idx = ((pid_blk * o_b_s) +
448
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
449
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
450
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
451
+ tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
@@ -1,11 +1,11 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: blksprs
3
- Version: 1.1
3
+ Version: 1.2.1
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
7
7
  Project-URL: Bugtracker, https://github.com/FelixSchoen/blksprs/issues
8
- Requires-Python: >=3.12
8
+ Requires-Python: >=3.11
9
9
  Description-Content-Type: text/markdown
10
10
  Requires-Dist: torch
11
11
  Provides-Extra: test
@@ -21,6 +21,9 @@ Requires-Dist: pdoc3; extra == "deploy"
21
21
 
22
22
  # blksprs
23
23
 
24
+ [![GitHub Release](https://img.shields.io/github/v/release/FelixSchoen/blksprs?include_prereleases&label=Latest%20Release)](https://github.com/FelixSchoen/blksprs/releases)
25
+ [![Python Version](https://img.shields.io/badge/Python%20Version-3.11-blue)](https://www.python.org/downloads/release/python-3119/)
26
+
24
27
  ## Overview
25
28
 
26
29
  A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
@@ -33,7 +36,8 @@ Currently supported operations (includes gradient calculation):
33
36
  - Transposition
34
37
  - Gather
35
38
  - Scatter (_supports either no reduction or summation, gradients are only available for summation_)
36
- - Conversion from and to sparse form
39
+ - Conversion to and from sparse form
40
+ - Conversion to different sparsity layouts and different sparsity block sizes
37
41
 
38
42
  As with this library sparse matrices are represented using a tuple of `(matrix, sparsity_layout, sparsity_block_size)`,
39
43
  any element-wise operations can be applied in regular torch-like fashion.
@@ -59,6 +63,11 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
59
63
 
60
64
  ```pip install blksprs```
61
65
 
66
+ ### Dependencies
67
+
68
+ - [PyTorch](https://pytorch.org/) (built with v2.4.0)
69
+ - _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
70
+
62
71
  ## Changelog
63
72
 
64
73
  See [`CHANGELOG.md`](https://github.com/FelixSchoen/blksprs/blob/main/CHANGELOG.md) for a detailed changelog.
@@ -1,10 +1,10 @@
1
1
  [project]
2
2
  name = "blksprs"
3
- version = "1.1"
3
+ version = "1.2.1"
4
4
  authors = [{ name = "Felix Schön", email = "schoen@kr.tuwien.ac.at" }]
5
5
  description = "A lightweight library for operations on blocksparse matrices in PyTorch."
6
6
  readme = "README.md"
7
- requires-python = ">=3.12"
7
+ requires-python = ">=3.11"
8
8
  license = { file = "LICENSE.md" }
9
9
  dependencies = [
10
10
  "torch"
@@ -1,78 +0,0 @@
1
- import torch
2
- import triton
3
- from torch import Tensor
4
- from triton import language as tl
5
-
6
- from blksprs.utils.tools import get_triton_block_size
7
- from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
8
- validate_contiguous
9
-
10
-
11
- def build_sparsity_layout(x: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
12
- """Builds the sparsity layout of a dense tensor covering its sparse blocks.
13
-
14
- Args:
15
- x (Tensor): A block-sparse (or dense) tensor in regular form.
16
- sparsity_block_size (int): The size of the sparsity blocks.
17
- triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
18
-
19
- Returns:
20
- Tensor: The sparsity layout of the input block-sparse (or dense) tensor.
21
-
22
- """
23
- validate_dimensions(x)
24
- validate_contiguous(x)
25
- validate_device(x)
26
-
27
- output = torch.zeros(x.size(0), x.size(1) // sparsity_block_size, x.size(2) // sparsity_block_size,
28
- device=x.device, dtype=torch.int32)
29
-
30
- x_b, x_r, x_c = x.size()
31
- x_b_s, x_r_s, x_c_s = x.stride()
32
- o_b, o_r, o_c = output.size()
33
- o_b_s, o_r_s, o_c_s = output.stride()
34
-
35
- if triton_block_size is None:
36
- triton_block_size = get_triton_block_size(sparsity_block_size)
37
-
38
- validate_triton_block_size(triton_block_size, sparsity_block_size)
39
-
40
- triton_grid = lambda meta: [x_b,
41
- triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
42
- triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
43
-
44
- (kernel_sparsity_layout[triton_grid]
45
- (x,
46
- x_b, x_b_s, x_r_s, x_c_s,
47
- output,
48
- o_b, o_b_s, o_r_s, o_c_s,
49
- sparsity_block_size,
50
- triton_block_size))
51
-
52
- return output
53
-
54
-
55
- @triton.jit
56
- def kernel_sparsity_layout(x,
57
- x_b, x_b_s, x_r_s, x_c_s,
58
- o,
59
- o_b, o_b_s, o_r_s, o_c_s,
60
- sparsity_block_size,
61
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
62
- # Get triton block indices
63
- pid_bat = tl.program_id(axis=0)
64
- pid_row = tl.program_id(axis=1)
65
- pid_col = tl.program_id(axis=2)
66
-
67
- blk_x_idx = (pid_bat * x_b_s +
68
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
69
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
70
- blk_x_msk = (blk_x_idx < x_b * x_b_s)
71
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
72
-
73
- if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
74
- blk_o_idx = (pid_bat * o_b_s +
75
- (((pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_r_s +
76
- ((pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_c_s))
77
- blk_o_msk = (blk_o_idx < o_b * o_b_s)
78
- tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes