blksprs 1.10.2__py3-none-any.whl → 2.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/ops/transpose.py CHANGED
@@ -1,16 +1,15 @@
1
1
  import torch
2
- import triton
3
2
  from torch import Tensor
4
- from triton import language as tl
3
+ from torch._library import triton_op
5
4
 
5
+ from blksprs.ops.flow import flow_pull_forward
6
6
  from blksprs.utils.blksprs_tensor import BlksprsTensor
7
- from blksprs.utils.tools import get_triton_block_size, stride
8
7
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
9
- validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
8
+ validate_sparsity, validate_sparsity_block_size
10
9
 
11
10
 
12
- def transpose(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> (
13
- BlksprsTensor, Tensor):
11
+ def transpose(x: BlksprsTensor, sparsity_layout: Tensor,
12
+ sparsity_block_size: int, lut: dict = None) -> (BlksprsTensor, Tensor):
14
13
  """Transposes a block-sparse tensor in compressed form.
15
14
 
16
15
  Note:
@@ -20,7 +19,7 @@ def transpose(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: in
20
19
  x (BlksprsTensor): A block-sparse tensor in compressed form.
21
20
  sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
22
21
  sparsity_block_size (int): The size of the sparsity blocks.
23
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
22
+ lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
24
23
 
25
24
  Returns:
26
25
  BlksprsTensor: The transposed block-sparse tensor in compressed form.
@@ -28,133 +27,72 @@ def transpose(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: in
28
27
 
29
28
  """
30
29
  x = x.contiguous()
30
+ x_t = x.transpose(-1, -2).contiguous()
31
31
 
32
32
  validate_dimensions(x)
33
33
  validate_contiguous(x)
34
34
  validate_device(x)
35
35
  validate_sparsity(sparsity_block_size, (x, sparsity_layout))
36
36
  validate_sparsity_block_size(sparsity_block_size, x)
37
- validate_triton_block_size(triton_block_size, sparsity_block_size)
38
-
39
- sparsity_layout_t = sparsity_layout.transpose(-1, -2).contiguous()
40
-
41
- sparsity_lut = torch.nonzero(sparsity_layout_t).contiguous()
42
-
43
- sparsity_layout_flat = sparsity_layout.reshape(-1)
44
- sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
45
- (sparsity_layout_flat == 1) -
46
- (1 * (sparsity_layout_flat == 0)))
47
- .reshape(sparsity_layout.size()).transpose(-1, -2).contiguous().reshape(-1))
48
-
49
- n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
50
-
51
- validate_contiguous(sparsity_layout_t, sparsity_lut, sparsity_reverse_lut)
52
-
53
- return BlksprsTensor(
54
- _BlocksparseTranspose.apply(x, sparsity_layout_t, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
55
- n_sparse_blocks, triton_block_size)), sparsity_layout_t
56
-
57
-
58
- class _BlocksparseTranspose(torch.autograd.Function):
59
-
60
- @staticmethod
61
- def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
62
- sparsity_block_size: int,
63
- n_sparse_blocks: int, triton_block_size: int) -> Tensor:
64
- output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
65
- dtype=x.dtype, device=x.device)
66
-
67
- x_b, x_r, x_c = x.size()
68
- x_b_s, x_r_s, x_c_s = stride(x)
69
- s_l_b, s_l_r, s_l_c = sparsity_layout_o.size()
70
- s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout_o)
71
- s_lut_r, s_lut_c = sparsity_lut.shape
72
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
73
- o_b, o_r, o_c = output.size()
74
- o_b_s, o_r_s, o_c_s = stride(output)
75
-
76
- if triton_block_size is None:
77
- triton_block_size = get_triton_block_size(sparsity_block_size)
78
-
79
- triton_grid = lambda meta: [o_b,
80
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
81
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
82
-
83
- (_BlocksparseTranspose.kernel_blocksparse_transpose[triton_grid]
84
- (x,
85
- x_b, x_b_s, x_r_s, x_c_s,
86
- s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
87
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
88
- sparsity_reverse_lut,
89
- output,
90
- o_b, o_b_s,
91
- triton_block_size))
92
-
93
- # Save for backward pass
94
- ctx.save_for_backward(sparsity_layout_o)
95
- ctx.sparsity_block_size = sparsity_block_size
96
- ctx.triton_block_size = triton_block_size
97
-
98
- return output
99
-
100
- @staticmethod
101
- def backward(ctx, grad_output):
102
- sparsity_layout = ctx.saved_tensors[0]
103
- sparsity_block_size = ctx.sparsity_block_size
104
- triton_block_size = ctx.triton_block_size
105
-
106
- return transpose(grad_output, sparsity_layout, sparsity_block_size, triton_block_size)[
107
- 0], None, None, None, None, None, None
108
-
109
- @staticmethod
110
- @triton.jit
111
- def kernel_blocksparse_transpose(x,
112
- x_b, x_b_s, x_r_s, x_c_s,
113
- s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
114
- s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
115
- r_lut,
116
- o,
117
- o_b, o_b_s,
118
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
119
- # Get triton block indices
120
- pid_blk = tl.program_id(axis=0)
121
- pid_row = tl.program_id(axis=1)
122
- pid_col = tl.program_id(axis=2)
123
-
124
- # Get sparsity index of current output block consisting of its batch, row, and column index
125
- spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
126
- spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
127
- spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
128
-
129
- spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
130
- spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
131
- spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
132
-
133
- spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
134
- spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
135
- spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
136
-
137
- # Get reverse sparsity index
138
- rev_idx_spa_idx = (spa_bat * s_l_b_s +
139
- spa_row * s_l_r_s +
140
- spa_col * s_l_c_s)
141
- rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_b * s_l_b_s)
142
- rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
143
-
144
- if rev_idx_spa == -1:
145
- tl.device_assert(False)
146
- return
147
-
148
- blk_x_idx = (rev_idx_spa * x_b_s +
149
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
150
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
151
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
152
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
153
-
154
- blk_x_t = tl.trans(blk_x)
155
-
156
- blk_o_idx = (pid_blk * o_b_s +
157
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
158
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
159
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
160
- tl.store(o + blk_o_idx, blk_x_t, mask=blk_o_msk)
37
+
38
+ lut = transpose_build_lut(lut, sparsity_layout)
39
+
40
+ return BlksprsTensor(transpose_forward(x_t, lut["sparsity_layout_t"],
41
+ lut["sparsity_lut"], lut["sparsity_reverse_lut"],
42
+ sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_t"]
43
+
44
+
45
+ @triton_op("blksprs::transpose", mutates_args={})
46
+ def transpose_forward(x: Tensor, sparsity_layout_o: Tensor,
47
+ sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
48
+ sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
49
+ return flow_pull_forward(x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
50
+ sparsity_block_size, n_sparse_blocks)
51
+
52
+
53
+ def transpose_backward(ctx, grad_output):
54
+ sparsity_layout = ctx.saved_tensors[0]
55
+ sparsity_block_size = ctx.sparsity_block_size
56
+
57
+ return transpose(grad_output, sparsity_layout, sparsity_block_size)[
58
+ 0], None, None, None, None, None
59
+
60
+
61
+ def transpose_build_lut(lut: dict, sparsity_layout: Tensor):
62
+ if lut is None:
63
+ lut = dict()
64
+
65
+ if "sparsity_layout_t" not in lut:
66
+ sparsity_layout_t = sparsity_layout.transpose(-1, -2).contiguous()
67
+ lut["sparsity_layout_t"] = sparsity_layout_t
68
+
69
+ if "sparsity_lut" not in lut:
70
+ sparsity_lut = torch.nonzero(lut["sparsity_layout_t"]).contiguous()
71
+ lut["sparsity_lut"] = sparsity_lut
72
+
73
+ if "sparsity_reverse_lut" not in lut:
74
+ sparsity_layout_flat = sparsity_layout.reshape(-1)
75
+ sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
76
+ (sparsity_layout_flat == 1) -
77
+ (1 * (sparsity_layout_flat == 0)))
78
+ .reshape(sparsity_layout.size()).transpose(-1, -2).contiguous().reshape(-1))
79
+ lut["sparsity_reverse_lut"] = sparsity_reverse_lut
80
+
81
+ if "n_sparse_blocks" not in lut:
82
+ n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
83
+ lut["n_sparse_blocks"] = n_sparse_blocks
84
+
85
+ validate_contiguous(lut["sparsity_layout_t"], lut["sparsity_lut"], lut["sparsity_reverse_lut"])
86
+
87
+ return lut
88
+
89
+
90
+ # noinspection PyUnusedLocal
91
+ def transpose_setup_context(ctx, inputs, output):
92
+ (_, sparsity_layout_o, _, _, sparsity_block_size, _) = inputs
93
+
94
+ ctx.save_for_backward(sparsity_layout_o)
95
+ ctx.sparsity_block_size = sparsity_block_size
96
+
97
+
98
+ transpose_forward.register_autograd(transpose_backward, setup_context=transpose_setup_context)
@@ -5,13 +5,13 @@ from matplotlib import pyplot as plt
5
5
 
6
6
 
7
7
  def benchmark(method_labels: list[str], func_input_generator: Callable,
8
- matrix_sizes: list[int], sparsity_block_sizes: list[int], triton_block_sizes: list[int],
8
+ matrix_sizes: list[int], sparsity_block_sizes: list[int],
9
9
  *funcs_test_subject: Callable, y_lim_top: int = None):
10
10
  quantiles = [0.5, 0.2, 0.8]
11
11
  results = {}
12
12
 
13
- for matrix_size, sparsity_block_size, triton_block_size in zip(matrix_sizes, sparsity_block_sizes, triton_block_sizes):
14
- arguments = func_input_generator(matrix_size, sparsity_block_size, triton_block_size)
13
+ for matrix_size, sparsity_block_size in zip(matrix_sizes, sparsity_block_sizes):
14
+ arguments = func_input_generator(matrix_size, sparsity_block_size)
15
15
 
16
16
  for i, func_test_subject in enumerate(funcs_test_subject):
17
17
  func_ms_avg, func_ms_min, func_ms_max = triton.testing.do_bench(
blksprs/utils/tools.py CHANGED
@@ -1,6 +1,10 @@
1
1
  import torch
2
+ import triton
2
3
  from torch import Tensor, Size
3
4
 
5
+ # Capture scalar outputs for JIT compilation
6
+ torch._dynamo.config.capture_scalar_outputs = True
7
+
4
8
 
5
9
  def do_shape_blocksparse(x: Tensor):
6
10
  if x.dim() == 3:
@@ -16,10 +20,6 @@ def undo_shape_blocksparse(x: Tensor, shape: Size):
16
20
  return x.reshape((*shape[:-2], *x.shape[-2:]))
17
21
 
18
22
 
19
- def get_triton_block_size(sparsity_block_size: int, limit: int = 128):
20
- return min(sparsity_block_size, limit)
21
-
22
-
23
23
  def stride(x: Tensor):
24
24
  if x.dim() == 2:
25
25
  return x.size(1), 1
@@ -27,3 +27,30 @@ def stride(x: Tensor):
27
27
  return x.size(1) * x.size(2), x.size(2), 1
28
28
  else:
29
29
  raise NotImplementedError
30
+
31
+
32
+ @torch.compile
33
+ def get_autotune_configs():
34
+ configs = []
35
+ config_parameters = [
36
+ (16, 3, 8),
37
+ (16, 4, 4),
38
+ (16, 5, 2),
39
+
40
+ (32, 3, 8),
41
+ (32, 4, 4),
42
+ (32, 5, 2),
43
+
44
+ (64, 3, 8),
45
+ (64, 4, 4),
46
+ (64, 5, 2),
47
+
48
+ (128, 3, 8),
49
+ (128, 4, 4),
50
+ (128, 5, 2),
51
+ ]
52
+
53
+ for block_size, num_stages, num_warps in config_parameters:
54
+ configs.append(triton.Config({"TRITON_BLOCK_SIZE": block_size}, num_stages=num_stages, num_warps=num_warps))
55
+
56
+ return configs
@@ -104,20 +104,6 @@ def validate_sparsity_block_size(sparsity_block_size: int, *tensors):
104
104
  raise ValueError("Tensor sizes must be divisible by sparsity block size")
105
105
 
106
106
 
107
- def validate_triton_block_size(triton_block_size: int, sparsity_block_size: int):
108
- if _check_skip_validation():
109
- return
110
-
111
- if triton_block_size is None:
112
- return
113
-
114
- if not (triton_block_size & (triton_block_size - 1)) == 0:
115
- raise ValueError("Triton block size must be a power of 2")
116
-
117
- if triton_block_size > sparsity_block_size:
118
- raise ValueError("Triton block size cannot be larger than sparsity block size")
119
-
120
-
121
107
  def _check_skip_validation():
122
108
  return not VALIDATION
123
109
 
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: blksprs
3
- Version: 1.10.2
3
+ Version: 2.0rc1
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -14,9 +14,8 @@ Requires-Dist: pytest; extra == "test"
14
14
  Requires-Dist: pytest-xdist; extra == "test"
15
15
  Requires-Dist: pytest-cov; extra == "test"
16
16
  Requires-Dist: coverage; extra == "test"
17
+ Requires-Dist: build; extra == "test"
17
18
  Requires-Dist: matplotlib; extra == "test"
18
- Provides-Extra: build
19
- Requires-Dist: build; extra == "build"
20
19
 
21
20
  # blksprs
22
21
 
@@ -25,6 +24,13 @@ Requires-Dist: build; extra == "build"
25
24
 
26
25
  ## Overview
27
26
 
27
+ ### News
28
+
29
+ 🎉 ***Version 2.0 released***. blksprs now supports kernel auto-tuning, JIT compilation, specification of pre-calculated
30
+ LUTs, and makes use of `torch.library.triton_op()`!
31
+
32
+ ---
33
+
28
34
  A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
29
35
 
30
36
  Currently supported operations (includes gradient calculation):
@@ -52,23 +58,25 @@ These include, e.g.,
52
58
  Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to
53
59
  match.
54
60
 
55
- Further helpful operations (included in the ``bs.ops.misc`` module) that do **not** support gradient calculation include:
61
+ Further helpful operations (included in the ``bs.ops.misc`` module) that do **not** support gradient calculation
62
+ include:
56
63
 
57
64
  - Row-wise sum, max, addition, and subtraction
58
65
  - Broadcast addition and subtraction between slices
59
66
 
60
- Furthermore, the library provides a set of utility functions
67
+ Furthermore, the library provides a set of utility functions
61
68
 
62
69
  - for the creation of sparsity layouts based on existing
63
- dense tensors and for the scatter operation (module ``bs.layouting``),
70
+ dense tensors and for the scatter operation (module ``bs.layouting``),
64
71
  - for the application of ``nn.Linear``, ``nn.Dropout``, and ``nn.LayerNorm`` layers to block-sparse tensors,
65
72
  - as well as utility functions to ensure correct input dimensionality, and validate input (module ``bs.utils``).
66
73
 
67
- _* see the [Roadmap](#roadmap) section for more information_
74
+ _* see the [Roadmap](#roadmap) section for more information_
68
75
 
69
76
  ## Installation
70
77
 
71
- Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible with
78
+ Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible
79
+ with
72
80
  the Linux platform**.
73
81
  Keep track of this [issue](https://github.com/triton-lang/triton/issues/1640) for updates.
74
82
 
@@ -78,8 +86,8 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
78
86
 
79
87
  ### Dependencies
80
88
 
81
- - [PyTorch](https://pytorch.org/) (built with v2.5.1)
82
- - _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.2.0)_
89
+ - [PyTorch](https://pytorch.org/) (built with v2.6)
90
+ - _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.2.4)_
83
91
  - _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
84
92
 
85
93
  ## Changelog
@@ -89,12 +97,14 @@ See [`CHANGELOG.md`](https://github.com/FelixSchoen/blksprs/blob/main/CHANGELOG.
89
97
  ## Roadmap
90
98
 
91
99
  Note that since this library covers all our current needs it is in a **bugfix-only** state.
92
- This means that there are no plans to add new features, e.g., support for dimension specification of the ``split`` and ``merge`` operations.
100
+ This means that there are no plans to add new features, e.g., support for dimension specification of the ``split`` and
101
+ ``merge`` operations.
93
102
  We will continue to maintain the library and fix any issues that arise.
94
103
  Should you find any bugs please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
95
104
  We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
96
105
 
97
- It might be that this changes with future projects, but as of December 2024, we are content with the current state of the library.
106
+ It might be that this changes with future projects, but as of March 2025, we are content with the current state of the
107
+ library.
98
108
 
99
109
  ## Usage
100
110
 
@@ -120,10 +130,6 @@ def test_readme():
120
130
  # Must be a power of two, greater than or equal to 16 for matmul, and divide m, n, and k
121
131
  sparsity_block_size = 16
122
132
 
123
- # Must be a power of two and smaller than or equal to sparsity_block_size
124
- # If it is set to ``none`` a value will be chosen automatically
125
- triton_block_size = None
126
-
127
133
  # Initialise random (dense) tensors
128
134
  x = torch.randn(size=(b, h, m, k), device="cuda")
129
135
  y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()
@@ -133,53 +139,53 @@ def test_readme():
133
139
  y_dense, y_shape_original = bs.utils.do_shape_blocksparse(y)
134
140
 
135
141
  # Create sparsity layouts from existing tensors
136
- sparsity_layout_x = bs.layouting.build_sparsity_layout(x_dense, sparsity_block_size,
137
- triton_block_size=triton_block_size)
138
- sparsity_layout_y = bs.layouting.build_sparsity_layout(y_dense, sparsity_block_size,
139
- triton_block_size=triton_block_size)
142
+ sparsity_layout_x = bs.layouting.build_sparsity_layout(x_dense, sparsity_block_size)
143
+ sparsity_layout_y = bs.layouting.build_sparsity_layout(y_dense, sparsity_block_size)
140
144
 
141
145
  # Create random sparsity layout for output tensor
142
146
  sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)
143
147
 
144
148
  # Convert tensors to sparse tensors for matrix multiplication
145
- x_sparse = bs.to_sparse(x_dense, sparsity_layout_x, sparsity_block_size, triton_block_size=triton_block_size)
146
- y_sparse = bs.to_sparse(y_dense, sparsity_layout_y, sparsity_block_size, triton_block_size=triton_block_size)
149
+ x_sparse = bs.ops.to_sparse(x_dense, sparsity_layout_x, sparsity_block_size)
150
+ y_sparse = bs.ops.to_sparse(y_dense, sparsity_layout_y, sparsity_block_size)
151
+
152
+ # As of version 2.0, blksprs supports JIT compilation
153
+ matmul_compiled = torch.compile(bs.ops.matmul)
147
154
 
148
155
  # Perform matrix multiplication
149
- o_sparse = bs.matmul(x_sparse, sparsity_layout_x, y_sparse, sparsity_layout_y, sparsity_layout_o,
150
- sparsity_block_size,
151
- triton_block_size=triton_block_size)
156
+ o_sparse = matmul_compiled(x_sparse, sparsity_layout_x,
157
+ y_sparse, sparsity_layout_y,
158
+ sparsity_layout_o, sparsity_block_size)
152
159
 
153
160
  # Apply element-wise operation
154
161
  o_sparse = torch.add(o_sparse, 1)
155
162
 
156
- o_dense = bs.to_dense(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
163
+ o_dense = bs.ops.to_dense(o_sparse, sparsity_layout_o, sparsity_block_size)
157
164
 
158
165
  # Sanity check
159
166
  o_torch = torch.matmul(x_dense, y_dense)
160
167
  o_torch = torch.add(o_torch, 1)
161
168
 
162
169
  # Perform round trip to set sparse blocks to 0
163
- o_torch_round_trip = bs.to_dense(
164
- bs.to_sparse(o_torch, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size),
165
- sparsity_layout_o, sparsity_block_size, fill_value=0, triton_block_size=triton_block_size)
170
+ o_torch_round_trip = bs.ops.to_dense(
171
+ bs.ops.to_sparse(o_torch, sparsity_layout_o, sparsity_block_size),
172
+ sparsity_layout_o, sparsity_block_size, fill_value=0)
166
173
 
167
174
  # Assert that the output is correct
168
175
  assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2) # Note that small numerical differences are expected
169
176
 
170
177
  # Assert that the output has the correct sparsity layout
171
- actual_sparsity_layout_o = bs.layouting.build_sparsity_layout(o_dense, sparsity_block_size,
172
- triton_block_size=triton_block_size)
178
+ actual_sparsity_layout_o = bs.layouting.build_sparsity_layout(o_dense, sparsity_block_size)
173
179
  assert torch.allclose(actual_sparsity_layout_o.to(torch.int), sparsity_layout_o)
174
180
 
175
181
  # Convert output tensor back to original shape
176
182
  o = bs.utils.undo_shape_blocksparse(o_dense, x_shape_original)
177
183
 
178
184
  # Other available functions
179
- bs.transpose(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
180
- bs.softmax(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
181
- bs.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
182
- bs.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
185
+ bs.ops.transpose(o_sparse, sparsity_layout_o, sparsity_block_size)
186
+ bs.ops.softmax(o_sparse, sparsity_layout_o, sparsity_block_size)
187
+ bs.ops.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size)
188
+ bs.ops.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size)
183
189
 
184
190
 
185
191
  def _get_random_sparsity_layout(b, m, n, sparsity_block_size, sparsity_percentage):
@@ -0,0 +1,22 @@
1
+ blksprs/__init__.py,sha256=OHfpwJCZWGUfpT-DVfC1YSaeZl4aCMNt9CrzMPymywU,1577
2
+ blksprs/layouting/distribution_layout.py,sha256=0glIteoY5oDkiEu5rjLIC-BB_oC4sa3rFWVkohsAG00,5329
3
+ blksprs/layouting/sparsity_layout.py,sha256=UzMcdW7l4zoiLB_LMEbBR1JBdqVSgINDGYvoCYIOulk,10283
4
+ blksprs/ops/conversion.py,sha256=_JKOovDZOmYJLcurJGhgNt5iQB9kOKp3fufFxD8QCZs,22204
5
+ blksprs/ops/distribution.py,sha256=5gE19kPQGQljVbRpDZeqNaOe8ehRhxdQS7PiJp6mMug,21352
6
+ blksprs/ops/flow.py,sha256=G8L_sMAWIM77gv-YLJtyutEzXqyaaofnSX2QKvmDr44,8409
7
+ blksprs/ops/matmul.py,sha256=YAurJcXa_39gRdh2nWUOmbhm8h99arLoO-SN-l134II,11879
8
+ blksprs/ops/partitioning.py,sha256=AooYZOw0oZgA9zXSu09O60hkJcnpWT1OTosr2T2wdQo,9700
9
+ blksprs/ops/repeat.py,sha256=qty0qIFcfiWzROV2A2FB2KiPCC2Pe4q5TwJyGuDBAQE,8839
10
+ blksprs/ops/softmax.py,sha256=eaZ8pfCpNZCX6Gk5Tk-lhNIrBQDhvfHqNNPltqxp91k,12793
11
+ blksprs/ops/transpose.py,sha256=30pGCSjZs42Sg6TEXUdJNCDgmlN1n8aN88uNbV5wOtA,3941
12
+ blksprs/ops/misc/broadcast_ops.py,sha256=lZ5bBIftUKffzeYz77SWB1xmtZTRGMvjF-tG9rqkOXA,6018
13
+ blksprs/ops/misc/row_wise.py,sha256=iwOrHU8HiJGxq2hEmgJGZ60asRm72WLi10-PrpNrdeQ,19532
14
+ blksprs/utils/benchmarking.py,sha256=dLabDscTFn5NkmOI1g7DnKeTneUYW3RIVv9MDF-8BKc,1271
15
+ blksprs/utils/blksprs_tensor.py,sha256=pfoz59aJixj_fIoFx76ySiygwRQUemmgjMKepZ2c4j0,244
16
+ blksprs/utils/processing.py,sha256=WLuMJQ8v-YovXwcDjhlDn3N31WMZXrtyeeyKSgq_zn4,3642
17
+ blksprs/utils/tools.py,sha256=RL18P4NAj7d8gXTTKbMZt4SHCynsw1wPu9yvlrnBQlo,1220
18
+ blksprs/utils/validation.py,sha256=_Ee6bqu7CxdYLFSy4WZOFoXJgd0p_RBMumCwGCk2_Hw,3763
19
+ blksprs-2.0rc1.dist-info/METADATA,sha256=zXzVOvuwgYSyx-lCBycdFvRUmHUD_qYbK8sFkKWZnp8,8601
20
+ blksprs-2.0rc1.dist-info/WHEEL,sha256=1tXe9gY0PYatrMPMDd6jXqjfpz_B-Wqm32CPfRC58XU,91
21
+ blksprs-2.0rc1.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
22
+ blksprs-2.0rc1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.6.0)
2
+ Generator: setuptools (77.0.3)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
blksprs/ops/misc/exp.py DELETED
@@ -1,104 +0,0 @@
1
- import torch
2
- import triton
3
- from torch import Tensor
4
- from triton import language as tl
5
-
6
- from blksprs.utils.blksprs_tensor import BlksprsTensor
7
- from blksprs.utils.tools import get_triton_block_size, stride
8
- from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
9
- validate_sparsity_block_size, validate_triton_block_size
10
-
11
-
12
- def exp(x: BlksprsTensor, sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
13
- """Applies the element-wise exponential function to a block-sparse tensor.
14
-
15
- Note:
16
- This operation does not consider sparse blocks, i.e., these will not be set to ``e^0``.
17
- Consider this when converting back to tensors in regular form.
18
-
19
- Args:
20
- x (BlksprsTensor): A block-sparse tensor in compressed form.
21
- sparsity_block_size (int): The size of the sparsity blocks.
22
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
23
-
24
- Returns:
25
- BlksprsTensor: The exponential function applied to all elements of the input tensor as a block-sparse tensor in
26
- compressed form.
27
-
28
- """
29
- x = x.contiguous()
30
-
31
- validate_dimensions(x)
32
- validate_contiguous(x)
33
- validate_device(x)
34
- validate_sparsity_block_size(sparsity_block_size, x)
35
- validate_triton_block_size(triton_block_size, sparsity_block_size)
36
-
37
- return BlksprsTensor(_BlocksparseExp.apply(x, sparsity_block_size, triton_block_size))
38
-
39
-
40
- class _BlocksparseExp(torch.autograd.Function):
41
-
42
- @staticmethod
43
- def forward(ctx, x: Tensor, sparsity_block_size: int, triton_block_size: int) -> Tensor:
44
- output = torch.empty_like(x)
45
-
46
- x_b, x_r, x_c = x.shape
47
- x_b_s, x_r_s, x_c_s = stride(x)
48
- o_b, o_r, o_c = output.shape
49
- o_b_s, o_r_s, o_c_s = stride(output)
50
-
51
- if triton_block_size is None:
52
- triton_block_size = get_triton_block_size(sparsity_block_size)
53
-
54
- triton_grid = lambda meta: [o_b,
55
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
56
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
57
-
58
- (_BlocksparseExp.kernel_blocksparse_exp[triton_grid]
59
- (x,
60
- x_b, x_b_s, x_r_s, x_c_s,
61
- output,
62
- o_b, o_b_s, o_r_s, o_c_s,
63
- triton_block_size))
64
-
65
- ctx.save_for_backward(output)
66
-
67
- return output
68
-
69
- @staticmethod
70
- def backward(ctx, grad_output):
71
- o = ctx.saved_tensors[0]
72
-
73
- grad_x = torch.mul(grad_output, o)
74
-
75
- return grad_x, None, None
76
-
77
- @staticmethod
78
- @triton.jit
79
- def kernel_blocksparse_exp(x,
80
- x_b, x_b_s, x_r_s, x_c_s,
81
- o,
82
- o_b, o_b_s, o_r_s, o_c_s,
83
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
84
- # Get triton block indices
85
- pid_blk = tl.program_id(axis=0)
86
- pid_row = tl.program_id(axis=1)
87
- pid_col = tl.program_id(axis=2)
88
-
89
- # Load block
90
- blk_x_idx = ((pid_blk * x_b_s) +
91
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
92
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
93
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
94
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
95
-
96
- # Compute exp
97
- buf = tl.exp(blk_x)
98
-
99
- # Store block
100
- blk_o_idx = ((pid_blk * o_b_s) +
101
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
102
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
103
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
104
- tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
@@ -1,17 +0,0 @@
1
- import math
2
-
3
- import torch
4
- import triton
5
- from torch import Tensor
6
- from torch.xpu import device
7
- from triton import language as tl
8
-
9
- from blksprs.utils.blksprs_tensor import BlksprsTensor
10
- from blksprs.utils.tools import get_triton_block_size, stride
11
- from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
12
- validate_contiguous, validate_sparsity, validate_sparsity_block_size
13
-
14
-
15
- def build_full_sparsity_layout(x: Tensor, sparsity_block_size: int) -> Tensor:
16
- return torch.ones(size=(x.size(0), x.size(1) // sparsity_block_size, x.size(2) // sparsity_block_size),
17
- dtype=torch.bool, device=x.device)