blksprs 1.11__py3-none-any.whl → 2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/ops/transpose.py CHANGED
@@ -1,17 +1,16 @@
1
1
  import torch
2
- import triton
3
2
  from torch import Tensor
4
- from triton import language as tl
3
+ from torch._library import triton_op
5
4
 
6
- from blksprs.ops.flow import flow_forward_pull
5
+ from blksprs.ops.flow import flow_pull_forward
7
6
  from blksprs.utils.blksprs_tensor import BlksprsTensor
8
- from blksprs.utils.tools import get_triton_block_size, stride
9
7
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
10
- validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
8
+ validate_sparsity, validate_sparsity_block_size
11
9
 
12
10
 
13
- def transpose(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, triton_block_size: int = None,
14
- lut: dict = None) -> (BlksprsTensor, Tensor):
11
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
12
+ def transpose(x: BlksprsTensor, sparsity_layout: Tensor,
13
+ sparsity_block_size: int, lut: dict = None) -> (BlksprsTensor, Tensor):
15
14
  """Transposes a block-sparse tensor in compressed form.
16
15
 
17
16
  Note:
@@ -21,7 +20,6 @@ def transpose(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: in
21
20
  x (BlksprsTensor): A block-sparse tensor in compressed form.
22
21
  sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
23
22
  sparsity_block_size (int): The size of the sparsity blocks.
24
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
25
23
  lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
26
24
 
27
25
  Returns:
@@ -30,68 +28,73 @@ def transpose(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: in
30
28
 
31
29
  """
32
30
  x = x.contiguous()
33
- x_t = x.transpose(-1, -2).contiguous()
34
31
 
35
32
  validate_dimensions(x)
36
33
  validate_contiguous(x)
37
34
  validate_device(x)
38
35
  validate_sparsity(sparsity_block_size, (x, sparsity_layout))
39
36
  validate_sparsity_block_size(sparsity_block_size, x)
40
- validate_triton_block_size(triton_block_size, sparsity_block_size)
41
37
 
42
- lut = _BlocksparseTranspose.build_lut(lut, sparsity_layout)
38
+ lut = transpose_build_lut(lut, sparsity_layout)
43
39
 
44
- return BlksprsTensor(
45
- _BlocksparseTranspose.apply(x_t, lut["sparsity_layout_t"], lut["sparsity_lut"], lut["sparsity_reverse_lut"],
46
- sparsity_block_size,
47
- lut["n_sparse_blocks"], triton_block_size)), lut["sparsity_layout_t"]
40
+ return BlksprsTensor(transpose_forward(x, lut["sparsity_layout_t"],
41
+ lut["sparsity_lut"], lut["sparsity_reverse_lut"],
42
+ sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_t"]
48
43
 
49
44
 
50
- class _BlocksparseTranspose(torch.autograd.Function):
45
+ @triton_op("blksprs::transpose_forward", mutates_args={})
46
+ def transpose_forward(x: Tensor, sparsity_layout_o: Tensor,
47
+ sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
48
+ sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
49
+ with torch.no_grad():
50
+ x_t = x.transpose(-1, -2).contiguous()
51
+ return flow_pull_forward(x_t, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
52
+ sparsity_block_size, n_sparse_blocks)
51
53
 
52
- @staticmethod
53
- def build_lut(lut: dict, sparsity_layout: Tensor):
54
- if lut is None:
55
- lut = dict()
56
54
 
57
- if "sparsity_layout_t" not in lut:
58
- sparsity_layout_t = sparsity_layout.transpose(-1, -2).contiguous()
59
- lut["sparsity_layout_t"] = sparsity_layout_t
55
+ def transpose_wrapper_backward(ctx, grad_output):
56
+ sparsity_layout = ctx.saved_tensors[0]
57
+ sparsity_block_size = ctx.sparsity_block_size
60
58
 
61
- if "sparsity_lut" not in lut:
62
- sparsity_lut = torch.nonzero(lut["sparsity_layout_t"]).contiguous()
63
- lut["sparsity_lut"] = sparsity_lut
59
+ return transpose(grad_output, sparsity_layout, sparsity_block_size)[
60
+ 0], None, None, None, None, None
64
61
 
65
- if "sparsity_reverse_lut" not in lut:
66
- sparsity_layout_flat = sparsity_layout.reshape(-1)
67
- sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
68
- (sparsity_layout_flat == 1) -
69
- (1 * (sparsity_layout_flat == 0)))
70
- .reshape(sparsity_layout.size()).transpose(-1, -2).contiguous().reshape(-1))
71
- lut["sparsity_reverse_lut"] = sparsity_reverse_lut
72
62
 
73
- if "n_sparse_blocks" not in lut:
74
- n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
75
- lut["n_sparse_blocks"] = n_sparse_blocks
63
+ def transpose_build_lut(lut: dict, sparsity_layout: Tensor):
64
+ if lut is None:
65
+ lut = dict()
76
66
 
77
- validate_contiguous(lut["sparsity_layout_t"], lut["sparsity_lut"], lut["sparsity_reverse_lut"])
67
+ if "sparsity_layout_t" not in lut:
68
+ sparsity_layout_t = sparsity_layout.transpose(-1, -2).contiguous()
69
+ lut["sparsity_layout_t"] = sparsity_layout_t
78
70
 
79
- return lut
71
+ if "sparsity_lut" not in lut:
72
+ sparsity_lut = torch.nonzero(lut["sparsity_layout_t"]).contiguous()
73
+ lut["sparsity_lut"] = sparsity_lut
80
74
 
81
- @staticmethod
82
- def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
83
- sparsity_block_size: int,
84
- n_sparse_blocks: int, triton_block_size: int) -> Tensor:
85
- ctx.save_for_backward(sparsity_layout_o)
75
+ if "sparsity_reverse_lut" not in lut:
76
+ sparsity_layout_flat = sparsity_layout.reshape(-1)
77
+ sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
78
+ (sparsity_layout_flat == 1) -
79
+ (1 * (sparsity_layout_flat == 0)))
80
+ .reshape(sparsity_layout.size()).transpose(-1, -2).contiguous().reshape(-1))
81
+ lut["sparsity_reverse_lut"] = sparsity_reverse_lut
86
82
 
87
- return flow_forward_pull(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
88
- sparsity_block_size, n_sparse_blocks, triton_block_size)
83
+ if "n_sparse_blocks" not in lut:
84
+ n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
85
+ lut["n_sparse_blocks"] = n_sparse_blocks
89
86
 
90
- @staticmethod
91
- def backward(ctx, grad_output):
92
- sparsity_layout = ctx.saved_tensors[0]
93
- sparsity_block_size = ctx.sparsity_block_size
94
- triton_block_size = ctx.triton_block_size
87
+ validate_contiguous(lut["sparsity_layout_t"], lut["sparsity_lut"], lut["sparsity_reverse_lut"])
95
88
 
96
- return transpose(grad_output, sparsity_layout, sparsity_block_size, triton_block_size)[
97
- 0], None, None, None, None, None, None
89
+ return lut
90
+
91
+
92
+ # noinspection PyUnusedLocal
93
+ def transpose_setup_context(ctx, inputs, output):
94
+ (_, sparsity_layout_o, _, _, sparsity_block_size, _) = inputs
95
+
96
+ ctx.save_for_backward(sparsity_layout_o)
97
+ ctx.sparsity_block_size = sparsity_block_size
98
+
99
+
100
+ transpose_forward.register_autograd(transpose_wrapper_backward, setup_context=transpose_setup_context)
@@ -0,0 +1,78 @@
1
+ import os
2
+
3
+ blksprs_autotune_mode = os.getenv("BLKSPRS_AUTOTUNE", "DEFAULT")
4
+
5
+ if blksprs_autotune_mode == "DEFAULT":
6
+ autotune_parameters = [
7
+ (16, 3, 8),
8
+ (16, 4, 4),
9
+ (16, 5, 2),
10
+
11
+ (32, 3, 8),
12
+ (32, 4, 4),
13
+ (32, 5, 2),
14
+
15
+ (64, 3, 8),
16
+ (64, 4, 4),
17
+ (64, 5, 2),
18
+
19
+ (128, 3, 8),
20
+ (128, 4, 4),
21
+ (128, 5, 2),
22
+ ]
23
+ elif blksprs_autotune_mode == "TEST":
24
+ autotune_parameters = [
25
+ (16, 3, 8),
26
+
27
+ (32, 3, 8),
28
+
29
+ (64, 3, 8),
30
+ ]
31
+ else:
32
+ raise NotImplementedError(f"Unknown autotune mode: {blksprs_autotune_mode}")
33
+
34
+ import torch
35
+ import triton
36
+
37
+
38
+ def prune_autotune_configs(autotune_configs, kernel_args, **kwargs):
39
+ sparsity_block_size = kernel_args["sparsity_block_size"]
40
+
41
+ pruned_configs = []
42
+
43
+ for config in autotune_configs:
44
+ if config.kwargs["TRITON_BLOCK_SIZE"] <= sparsity_block_size:
45
+ pruned_configs.append(config)
46
+
47
+ assert len(pruned_configs) > 0, f"No valid autotune configs found for sparsity block size {sparsity_block_size}"
48
+
49
+ return pruned_configs
50
+
51
+
52
+ def prune_autotune_configs_conversion(autotune_configs, kernel_args, **kwargs):
53
+ sparsity_block_size_from = kernel_args["sparsity_block_size_from"]
54
+ sparsity_block_size_to = kernel_args["sparsity_block_size_to"]
55
+ sparsity_block_size = min(sparsity_block_size_from, sparsity_block_size_to)
56
+
57
+ pruned_configs = []
58
+
59
+ for config in autotune_configs:
60
+ if config.kwargs["TRITON_BLOCK_SIZE"] <= sparsity_block_size:
61
+ pruned_configs.append(config)
62
+
63
+ assert len(pruned_configs) > 0, f"No valid autotune configs found for sparsity block size {sparsity_block_size}"
64
+
65
+ return pruned_configs
66
+
67
+
68
+ @torch.compile
69
+ def get_autotune_configs():
70
+ global autotune_parameters
71
+
72
+ autotune_configs = []
73
+
74
+ for block_size, num_stages, num_warps in autotune_parameters:
75
+ autotune_configs.append(
76
+ triton.Config({"TRITON_BLOCK_SIZE": block_size}, num_stages=num_stages, num_warps=num_warps))
77
+
78
+ return autotune_configs
@@ -5,13 +5,13 @@ from matplotlib import pyplot as plt
5
5
 
6
6
 
7
7
  def benchmark(method_labels: list[str], func_input_generator: Callable,
8
- matrix_sizes: list[int], sparsity_block_sizes: list[int], triton_block_sizes: list[int],
8
+ matrix_sizes: list[int], sparsity_block_sizes: list[int],
9
9
  *funcs_test_subject: Callable, y_lim_top: int = None):
10
10
  quantiles = [0.5, 0.2, 0.8]
11
11
  results = {}
12
12
 
13
- for matrix_size, sparsity_block_size, triton_block_size in zip(matrix_sizes, sparsity_block_sizes, triton_block_sizes):
14
- arguments = func_input_generator(matrix_size, sparsity_block_size, triton_block_size)
13
+ for matrix_size, sparsity_block_size in zip(matrix_sizes, sparsity_block_sizes):
14
+ arguments = func_input_generator(matrix_size, sparsity_block_size)
15
15
 
16
16
  for i, func_test_subject in enumerate(funcs_test_subject):
17
17
  func_ms_avg, func_ms_min, func_ms_max = triton.testing.do_bench(
@@ -11,6 +11,7 @@ from blksprs.ops.repeat import repeat
11
11
  from blksprs.utils.blksprs_tensor import BlksprsTensor
12
12
 
13
13
 
14
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
14
15
  def apply_torch_linear(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
15
16
  linear: nn.Linear, bias: nn.Parameter = None) -> (BlksprsTensor, Tensor):
16
17
  # Extract weight and bias
@@ -25,7 +26,7 @@ def apply_torch_linear(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block
25
26
 
26
27
  # Apply weights
27
28
  sparsity_layout_xw = build_sparsity_layout_matmul_fast(sparsity_layout, sparsity_layout_w_t)
28
- xw = matmul(x, sparsity_layout, w_t_bs, sparsity_layout_w_t, sparsity_layout_xw, sparsity_block_size)
29
+ xw = matmul(x, sparsity_layout, BlksprsTensor(w_t_bs.to(x.dtype)), sparsity_layout_w_t, sparsity_layout_xw, sparsity_block_size)
29
30
  interim = xw
30
31
 
31
32
  # Apply bias
blksprs/utils/tools.py CHANGED
@@ -1,25 +1,24 @@
1
1
  import torch
2
2
  from torch import Tensor, Size
3
3
 
4
+ # Capture scalar outputs for JIT compilation
5
+ torch._dynamo.config.capture_scalar_outputs = True
4
6
 
5
- def do_shape_blocksparse(x: Tensor):
7
+
8
+ def do_shape_blocksparse(x: Tensor) -> tuple[Tensor, Size]:
6
9
  if x.dim() == 3:
7
10
  return x.contiguous(), x.size()
8
11
 
9
12
  return x.reshape(-1, x.size(-2), x.size(-1)).contiguous(), x.size()
10
13
 
11
14
 
12
- def undo_shape_blocksparse(x: Tensor, shape: Size):
15
+ def undo_shape_blocksparse(x: Tensor, shape: Size | tuple[int, ...]) -> Tensor:
13
16
  if x.shape[:-2] == shape[:-2]:
14
17
  return x
15
18
 
16
19
  return x.reshape((*shape[:-2], *x.shape[-2:]))
17
20
 
18
21
 
19
- def get_triton_block_size(sparsity_block_size: int, limit: int = 128):
20
- return min(sparsity_block_size, limit)
21
-
22
-
23
22
  def stride(x: Tensor):
24
23
  if x.dim() == 2:
25
24
  return x.size(1), 1
@@ -26,6 +26,23 @@ def validate_dtype_float(*tensors: Tensor) -> None:
26
26
  if _check_skip_validation():
27
27
  return
28
28
 
29
+ dtype = None
30
+
31
+ for i, tensor in enumerate(tensors):
32
+ if i == 0:
33
+ dtype = tensor.dtype
34
+
35
+ if tensor.dtype != torch.float16 and tensor.dtype != torch.float32:
36
+ raise ValueError("Tensor must have either float16 or float32 dtype")
37
+
38
+ if tensor.dtype != dtype:
39
+ raise ValueError("Tensors must have same dtype")
40
+
41
+
42
+ def validate_dtype_float_32(*tensors: Tensor) -> None:
43
+ if _check_skip_validation():
44
+ return
45
+
29
46
  for tensor in tensors:
30
47
  if tensor.dtype != torch.float32:
31
48
  raise ValueError("Tensor must have float32 dtype")
@@ -38,7 +55,7 @@ def validate_dtype_int(*tensors: Tensor) -> None:
38
55
  for tensor in tensors:
39
56
  if (tensor.dtype !=
40
57
  torch.int32 and tensor.dtype != torch.int64):
41
- raise ValueError("Tensor must have int32 or int64 dtype")
58
+ raise ValueError("Tensor must have either int32 or int64 dtype")
42
59
 
43
60
 
44
61
  def validate_device(*tensors: Tensor) -> None:
@@ -51,7 +68,7 @@ def validate_device(*tensors: Tensor) -> None:
51
68
  if i == 0:
52
69
  device = tensor.device
53
70
 
54
- if not device.type == 'cuda':
71
+ if not device.type == "cuda":
55
72
  raise ValueError("Tensors must be on GPU")
56
73
 
57
74
  if tensor.device != device:
@@ -96,6 +113,9 @@ def validate_sparsity_block_size(sparsity_block_size: int, *tensors):
96
113
  if _check_skip_validation():
97
114
  return
98
115
 
116
+ if not sparsity_block_size >= 16:
117
+ raise ValueError("Sparsity block size must be at least 16")
118
+
99
119
  if not (sparsity_block_size & (sparsity_block_size - 1)) == 0:
100
120
  raise ValueError("Sparsity block size must be a power of 2")
101
121
 
@@ -104,20 +124,6 @@ def validate_sparsity_block_size(sparsity_block_size: int, *tensors):
104
124
  raise ValueError("Tensor sizes must be divisible by sparsity block size")
105
125
 
106
126
 
107
- def validate_triton_block_size(triton_block_size: int, sparsity_block_size: int):
108
- if _check_skip_validation():
109
- return
110
-
111
- if triton_block_size is None:
112
- return
113
-
114
- if not (triton_block_size & (triton_block_size - 1)) == 0:
115
- raise ValueError("Triton block size must be a power of 2")
116
-
117
- if triton_block_size > sparsity_block_size:
118
- raise ValueError("Triton block size cannot be larger than sparsity block size")
119
-
120
-
121
127
  def _check_skip_validation():
122
128
  return not VALIDATION
123
129
 
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: blksprs
3
- Version: 1.11
3
+ Version: 2.0
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -14,9 +14,8 @@ Requires-Dist: pytest; extra == "test"
14
14
  Requires-Dist: pytest-xdist; extra == "test"
15
15
  Requires-Dist: pytest-cov; extra == "test"
16
16
  Requires-Dist: coverage; extra == "test"
17
+ Requires-Dist: build; extra == "test"
17
18
  Requires-Dist: matplotlib; extra == "test"
18
- Provides-Extra: build
19
- Requires-Dist: build; extra == "build"
20
19
 
21
20
  # blksprs
22
21
 
@@ -25,6 +24,13 @@ Requires-Dist: build; extra == "build"
25
24
 
26
25
  ## Overview
27
26
 
27
+ ### News
28
+
29
+ 🎉 ***Version 2.0 released***. blksprs now supports kernel auto-tuning, JIT compilation, specification of pre-calculated
30
+ LUTs, autocasting, and makes use of `torch.library.triton_op()`!
31
+
32
+ ---
33
+
28
34
  A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
29
35
 
30
36
  Currently supported operations (includes gradient calculation):
@@ -52,23 +58,25 @@ These include, e.g.,
52
58
  Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to
53
59
  match.
54
60
 
55
- Further helpful operations (included in the ``bs.ops.misc`` module) that do **not** support gradient calculation include:
61
+ Further helpful operations (included in the ``bs.ops.misc`` module) that do **not** support gradient calculation
62
+ include:
56
63
 
57
64
  - Row-wise sum, max, addition, and subtraction
58
65
  - Broadcast addition and subtraction between slices
59
66
 
60
- Furthermore, the library provides a set of utility functions
67
+ Furthermore, the library provides a set of utility functions
61
68
 
62
69
  - for the creation of sparsity layouts based on existing
63
- dense tensors and for the scatter operation (module ``bs.layouting``),
70
+ dense tensors and for the scatter operation (module ``bs.layouting``),
64
71
  - for the application of ``nn.Linear``, ``nn.Dropout``, and ``nn.LayerNorm`` layers to block-sparse tensors,
65
72
  - as well as utility functions to ensure correct input dimensionality, and validate input (module ``bs.utils``).
66
73
 
67
- _* see the [Roadmap](#roadmap) section for more information_
74
+ _* see the [Roadmap](#roadmap) section for more information_
68
75
 
69
76
  ## Installation
70
77
 
71
- Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible with
78
+ Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible
79
+ with
72
80
  the Linux platform**.
73
81
  Keep track of this [issue](https://github.com/triton-lang/triton/issues/1640) for updates.
74
82
 
@@ -78,8 +86,8 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
78
86
 
79
87
  ### Dependencies
80
88
 
81
- - [PyTorch](https://pytorch.org/) (built with v2.5.1)
82
- - _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.2.0)_
89
+ - [PyTorch](https://pytorch.org/) (built with v2.6)
90
+ - _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.2.4)_
83
91
  - _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
84
92
 
85
93
  ## Changelog
@@ -89,12 +97,27 @@ See [`CHANGELOG.md`](https://github.com/FelixSchoen/blksprs/blob/main/CHANGELOG.
89
97
  ## Roadmap
90
98
 
91
99
  Note that since this library covers all our current needs it is in a **bugfix-only** state.
92
- This means that there are no plans to add new features, e.g., support for dimension specification of the ``split`` and ``merge`` operations.
100
+ This means that there are no plans to add new features, e.g., support for dimension specification of the ``split`` and
101
+ ``merge`` operations.
93
102
  We will continue to maintain the library and fix any issues that arise.
94
103
  Should you find any bugs please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
95
104
  We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
96
105
 
97
- It might be that this changes with future projects, but as of December 2024, we are content with the current state of the library.
106
+ It might be that this changes with future projects, but as of March 2025, we are content with the current state of the
107
+ library.
108
+
109
+ ## Known Limitations and Issues
110
+
111
+ - Triton has a bug with `tl.atomix_max()` used for the row-wise max operation.
112
+ In order to work around this bug a manual conversion of some values is needed, (slightly) negatively impacting
113
+ performance.
114
+ Watch the [issue](https://github.com/triton-lang/triton/issues/6376) on Triton's issue tracker for more information.
115
+ - PyTorch's `wrap_triton()` currently does not support config pruning. It thus cannot be used for some of the kernels,
116
+ which could impact graph compilation.
117
+ - There seem to be some issues with autocasting, forcing some operations to manually cast.
118
+ - There will be some slight numerical differences between vanilla and blksprs operations.
119
+ These instabilities are due to Triton and thus cannot be fixed by this library alone.
120
+ However, for all intents and purposes, these very minor differences should not matter and can safely be ignored.
98
121
 
99
122
  ## Usage
100
123
 
@@ -120,10 +143,6 @@ def test_readme():
120
143
  # Must be a power of two, greater than or equal to 16 for matmul, and divide m, n, and k
121
144
  sparsity_block_size = 16
122
145
 
123
- # Must be a power of two and smaller than or equal to sparsity_block_size
124
- # If it is set to ``none`` a value will be chosen automatically
125
- triton_block_size = None
126
-
127
146
  # Initialise random (dense) tensors
128
147
  x = torch.randn(size=(b, h, m, k), device="cuda")
129
148
  y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()
@@ -133,53 +152,53 @@ def test_readme():
133
152
  y_dense, y_shape_original = bs.utils.do_shape_blocksparse(y)
134
153
 
135
154
  # Create sparsity layouts from existing tensors
136
- sparsity_layout_x = bs.layouting.build_sparsity_layout(x_dense, sparsity_block_size,
137
- triton_block_size=triton_block_size)
138
- sparsity_layout_y = bs.layouting.build_sparsity_layout(y_dense, sparsity_block_size,
139
- triton_block_size=triton_block_size)
155
+ sparsity_layout_x = bs.layouting.build_sparsity_layout(x_dense, sparsity_block_size)
156
+ sparsity_layout_y = bs.layouting.build_sparsity_layout(y_dense, sparsity_block_size)
140
157
 
141
158
  # Create random sparsity layout for output tensor
142
159
  sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)
143
160
 
144
161
  # Convert tensors to sparse tensors for matrix multiplication
145
- x_sparse = bs.to_sparse(x_dense, sparsity_layout_x, sparsity_block_size, triton_block_size=triton_block_size)
146
- y_sparse = bs.to_sparse(y_dense, sparsity_layout_y, sparsity_block_size, triton_block_size=triton_block_size)
162
+ x_sparse = bs.ops.to_sparse(x_dense, sparsity_layout_x, sparsity_block_size)
163
+ y_sparse = bs.ops.to_sparse(y_dense, sparsity_layout_y, sparsity_block_size)
164
+
165
+ # As of version 2.0, blksprs supports JIT compilation
166
+ matmul_compiled = torch.compile(bs.ops.matmul)
147
167
 
148
168
  # Perform matrix multiplication
149
- o_sparse = bs.matmul(x_sparse, sparsity_layout_x, y_sparse, sparsity_layout_y, sparsity_layout_o,
150
- sparsity_block_size,
151
- triton_block_size=triton_block_size)
169
+ o_sparse = matmul_compiled(x_sparse, sparsity_layout_x,
170
+ y_sparse, sparsity_layout_y,
171
+ sparsity_layout_o, sparsity_block_size)
152
172
 
153
173
  # Apply element-wise operation
154
174
  o_sparse = torch.add(o_sparse, 1)
155
175
 
156
- o_dense = bs.to_dense(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
176
+ o_dense = bs.ops.to_dense(o_sparse, sparsity_layout_o, sparsity_block_size)
157
177
 
158
178
  # Sanity check
159
179
  o_torch = torch.matmul(x_dense, y_dense)
160
180
  o_torch = torch.add(o_torch, 1)
161
181
 
162
182
  # Perform round trip to set sparse blocks to 0
163
- o_torch_round_trip = bs.to_dense(
164
- bs.to_sparse(o_torch, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size),
165
- sparsity_layout_o, sparsity_block_size, fill_value=0, triton_block_size=triton_block_size)
183
+ o_torch_round_trip = bs.ops.to_dense(
184
+ bs.ops.to_sparse(o_torch, sparsity_layout_o, sparsity_block_size),
185
+ sparsity_layout_o, sparsity_block_size, fill_value=0)
166
186
 
167
187
  # Assert that the output is correct
168
188
  assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2) # Note that small numerical differences are expected
169
189
 
170
190
  # Assert that the output has the correct sparsity layout
171
- actual_sparsity_layout_o = bs.layouting.build_sparsity_layout(o_dense, sparsity_block_size,
172
- triton_block_size=triton_block_size)
191
+ actual_sparsity_layout_o = bs.layouting.build_sparsity_layout(o_dense, sparsity_block_size)
173
192
  assert torch.allclose(actual_sparsity_layout_o.to(torch.int), sparsity_layout_o)
174
193
 
175
194
  # Convert output tensor back to original shape
176
195
  o = bs.utils.undo_shape_blocksparse(o_dense, x_shape_original)
177
196
 
178
197
  # Other available functions
179
- bs.transpose(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
180
- bs.softmax(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
181
- bs.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
182
- bs.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
198
+ bs.ops.transpose(o_sparse, sparsity_layout_o, sparsity_block_size)
199
+ bs.ops.softmax(o_sparse, sparsity_layout_o, sparsity_block_size)
200
+ bs.ops.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size)
201
+ bs.ops.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size)
183
202
 
184
203
 
185
204
  def _get_random_sparsity_layout(b, m, n, sparsity_block_size, sparsity_percentage):
@@ -0,0 +1,23 @@
1
+ blksprs/__init__.py,sha256=QXYiDairqn5LzvpoNyBXHdQfyac-aX3IDjTQUMkaH-A,1598
2
+ blksprs/layouting/distribution_layout.py,sha256=ur1ty_2U-Hfj78hMWsLZvu7ZuGhzW3qGLKMc72DfTZM,5861
3
+ blksprs/layouting/sparsity_layout.py,sha256=eXHmu2h7K5Q-YUpfOxocJoeP_5ZoQFZf_eHLxRZQbYU,11207
4
+ blksprs/ops/conversion.py,sha256=kf5HKofZ4nVeHCIqQoYKiIlgsAhq33Tnmnr1c17Fkqs,21906
5
+ blksprs/ops/distribution.py,sha256=0tPldv0ARzmCV1CU2jvfqpHBgOuHPrDFiCtqsLs7CZc,20789
6
+ blksprs/ops/flow.py,sha256=PDZAD8u4y9qW1IXERki6ItKbEKnm_ChG8SKWM3_P9Oc,8245
7
+ blksprs/ops/matmul.py,sha256=5tVBKU_lglUjaLDi6J_dscdqlmzRz38OGxqAxZxZXDs,11879
8
+ blksprs/ops/partitioning.py,sha256=cfQmY9BZqGTvvJorIhtb-EyuGRJGPraWR-wTKdb47aI,9954
9
+ blksprs/ops/repeat.py,sha256=TLYNxwPuT9y5K9xyM41WK5gnggAJF3lI61Q2K7zWjns,9035
10
+ blksprs/ops/softmax.py,sha256=BwrRQdtRdkiSvl2mf5bpsTmyIxWiJOpa1HFg0st5yGU,12778
11
+ blksprs/ops/transpose.py,sha256=U-VAyLRT6_NDv9qYSFzBqfVlDeIpTqAMEXkqto0VF6w,4072
12
+ blksprs/ops/misc/broadcast_ops.py,sha256=-PrHiSJikZh8nXUmXxSCtFEP27TTxFr4wcrNxBjnimk,5987
13
+ blksprs/ops/misc/row_wise.py,sha256=n5FJjAuOd8BHBJQx4bsQwr-HmXkR9PYVAqfk77wjOFU,19653
14
+ blksprs/utils/autotuning.py,sha256=a-kmWRjJ3eED2XbjkQeOJSyW8bdIs27HgKMPvAKqWeU,2052
15
+ blksprs/utils/benchmarking.py,sha256=dLabDscTFn5NkmOI1g7DnKeTneUYW3RIVv9MDF-8BKc,1271
16
+ blksprs/utils/blksprs_tensor.py,sha256=pfoz59aJixj_fIoFx76ySiygwRQUemmgjMKepZ2c4j0,244
17
+ blksprs/utils/processing.py,sha256=RNkEDc0g-sNHRuMPkRzNWU13d3_lIkXMJdoqES4yQTM,3738
18
+ blksprs/utils/tools.py,sha256=CPf7viQ2OTcZFrB1aSL8_us4VE9M6YEfDz2dE30jr9I,715
19
+ blksprs/utils/validation.py,sha256=G8eQlvJVMKfEX3k2AwBD0A6Ck-gFoRLpLNY6HXsB3fA,4348
20
+ blksprs-2.0.dist-info/METADATA,sha256=7YVN_akf-ewrAW5thDZzhT3hogn0dVxV73lvlPMk59c,9506
21
+ blksprs-2.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
22
+ blksprs-2.0.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
23
+ blksprs-2.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (76.0.0)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,17 +0,0 @@
1
- import math
2
-
3
- import torch
4
- import triton
5
- from torch import Tensor
6
- from torch.xpu import device
7
- from triton import language as tl
8
-
9
- from blksprs.utils.blksprs_tensor import BlksprsTensor
10
- from blksprs.utils.tools import get_triton_block_size, stride
11
- from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
12
- validate_contiguous, validate_sparsity, validate_sparsity_block_size
13
-
14
-
15
- def build_full_sparsity_layout(x: Tensor, sparsity_block_size: int) -> Tensor:
16
- return torch.ones(size=(x.size(0), x.size(1) // sparsity_block_size, x.size(2) // sparsity_block_size),
17
- dtype=torch.bool, device=x.device)
@@ -1,23 +0,0 @@
1
- blksprs/__init__.py,sha256=AJYVfR40nOfE5F3waHPVSuajwYDcoGkiEQc8HhQbUBU,1721
2
- blksprs/layouting/distribution_layout.py,sha256=xDGY5-J7uSD8oenlf8bEJ2amMiQG3NBf2klTTydbTJE,5140
3
- blksprs/layouting/sparsity_layout.py,sha256=IVtHc_nN3ZM2y4GFcys70PqDWmWc7tkHlVGlToErANk,9894
4
- blksprs/ops/conversion.py,sha256=QFtZ-nmY2JAWutheiO07vatXqz3eSZBP5Ym_U2Q1oWk,23299
5
- blksprs/ops/distribution.py,sha256=nHTuE7Tq0Q404VN8bWNC2sEwmmdAtgZI6I7auRICdps,21749
6
- blksprs/ops/flow.py,sha256=7tOXfTBKOAixYmDa_VXg7TwviLV5ZQMHQjtbyOjqA00,7879
7
- blksprs/ops/matmul.py,sha256=eVj_BGj78bJkXYuvw4KctMfcfveQBt5OdYmeXzdpO88,12631
8
- blksprs/ops/partitioning.py,sha256=qMv9w3yFWXwXIhIppdcJ_JMsoZ25HCH38vb6GRneoLM,10416
9
- blksprs/ops/repeat.py,sha256=i824ijprfYpCaEjiSG5FTUZz7wMS5ksVy_-vY7ZX8Fg,9729
10
- blksprs/ops/softmax.py,sha256=_mGkA2jHN8cXwtWXYswobEPyM7UC0JyzRszoE4ZYs7w,13063
11
- blksprs/ops/transpose.py,sha256=O1XhGIGiVkhOSKcBD0HrYaeK6HmpvEEzLb7zJl7xsIM,4246
12
- blksprs/ops/misc/broadcast_ops.py,sha256=pv0nssSDOdDbQFttpqUIs2ZXShqfm2RYCfJH-C5x3H0,5544
13
- blksprs/ops/misc/row_wise.py,sha256=DnV5-xEJUbqZlK2fETwHiPQDUMwT-lkc0VUhBlnJ5Y0,17458
14
- blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
15
- blksprs/utils/blksprs_tensor.py,sha256=pfoz59aJixj_fIoFx76ySiygwRQUemmgjMKepZ2c4j0,244
16
- blksprs/utils/layout_utils.py,sha256=49ZdPS_gMn_IrWty3FARbi2rda5a8g5DmAEL8LOrC30,670
17
- blksprs/utils/processing.py,sha256=WLuMJQ8v-YovXwcDjhlDn3N31WMZXrtyeeyKSgq_zn4,3642
18
- blksprs/utils/tools.py,sha256=k2OfEplbQiAwVjP84zZf7SNB8FzvMtOFBL9sC98OCbI,683
19
- blksprs/utils/validation.py,sha256=CbxBbeQWJo8wox5eMoVzaTlP9FVBwt3-gxUOmi3EUgw,4213
20
- blksprs-1.11.dist-info/METADATA,sha256=NUEiHexWiFNbMxQI2TUEzMw9iGBhxqflhWr2xCgOw28,9105
21
- blksprs-1.11.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
22
- blksprs-1.11.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
23
- blksprs-1.11.dist-info/RECORD,,