blksprs 1.4.1__py3-none-any.whl → 1.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -41,7 +41,7 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
41
41
 
42
42
  validate_contiguous(sparsity_layout_output, sparsity_lut_o)
43
43
 
44
- output = torch.zeros(n_sparse_blocks, sparsity_block_size, sparsity_block_size, device=x.device)
44
+ output = torch.zeros(n_sparse_blocks, sparsity_block_size, sparsity_block_size, dtype=x.dtype, device=x.device)
45
45
 
46
46
  x_b, x_c = x.size()
47
47
  x_b_s, x_c_s = x.stride()
blksprs/misc/row_wise.py CHANGED
@@ -56,6 +56,7 @@ def row_wise_sum(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
56
56
  output = torch.zeros(size=(n_sparse_blocks_output,
57
57
  sparsity_block_size,
58
58
  1 if flag_slice_only else sparsity_block_size),
59
+ dtype=x.dtype,
59
60
  device=x.device)
60
61
 
61
62
  x_b, x_r, x_c = x.size()
blksprs/ops/conversion.py CHANGED
@@ -186,8 +186,8 @@ class _BlocksparseToSparse(torch.autograd.Function):
186
186
  def forward(ctx, x: Tensor,
187
187
  sparsity_layout: Tensor, sparsity_lut: Tensor,
188
188
  sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
189
- output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size), dtype=x.dtype,
190
- device=x.device)
189
+ output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
190
+ dtype=x.dtype, device=x.device)
191
191
 
192
192
  x_b, x_r, x_c = x.size()
193
193
  x_b_s, x_r_s, x_c_s = x.stride()
blksprs/ops/matmul.py CHANGED
@@ -78,7 +78,8 @@ class _BlocksparseMatmulSSS(torch.autograd.Function):
78
78
  sparsity_layout_y: Tensor, sparsity_reverse_lut_y: Tensor,
79
79
  sparsity_layout_o: Tensor, sparsity_lut_o: Tensor,
80
80
  sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
81
- output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size), device=x.device)
81
+ output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
82
+ dtype=x.dtype, device=x.device)
82
83
 
83
84
  x_b, x_r, x_c = x.size()
84
85
  x_b_s, x_r_s, x_c_s = x.stride()
blksprs/ops/softmax.py CHANGED
@@ -127,7 +127,7 @@ class _BlocksparseSoftmax(torch.autograd.Function):
127
127
  s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
128
128
  s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = sparsity_layout_s.stride()
129
129
 
130
- grad_x = torch.empty_like(o)
130
+ grad_x = torch.empty_like(o, dtype=torch.float)
131
131
 
132
132
  triton_grid = lambda meta: [o_b,
133
133
  triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
blksprs/ops/transpose.py CHANGED
@@ -59,7 +59,8 @@ class _BlocksparseTranspose(torch.autograd.Function):
59
59
  def forward(ctx, x: Tensor,
60
60
  sparsity_layout: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor, sparsity_block_size: int,
61
61
  n_sparse_blocks: int, triton_block_size: int) -> (Tensor, Tensor):
62
- output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size), device=x.device)
62
+ output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
63
+ dtype=x.dtype, device=x.device)
63
64
 
64
65
  x_b, x_r, x_c = x.size()
65
66
  x_b_s, x_r_s, x_c_s = x.stride()
@@ -101,7 +102,8 @@ class _BlocksparseTranspose(torch.autograd.Function):
101
102
  sparsity_block_size = ctx.sparsity_block_size
102
103
  triton_block_size = ctx.triton_block_size
103
104
 
104
- return transpose(grad_output, sparsity_layout, sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None
105
+ return transpose(grad_output, sparsity_layout, sparsity_block_size, triton_block_size)[
106
+ 0], None, None, None, None, None, None
105
107
 
106
108
  @staticmethod
107
109
  @triton.jit
blksprs/utils/tools.py CHANGED
@@ -1,4 +1,3 @@
1
- import torch
2
1
  from torch import Tensor, Size
3
2
 
4
3
  from blksprs.utils.validation import _set_skip_validation
@@ -8,7 +7,7 @@ def do_shape_blocksparse(x: Tensor):
8
7
  if x.dim() == 3:
9
8
  return x.contiguous(), x.size()
10
9
 
11
- return x.reshape(-1, x.size(-2), x.size(-1)), x.size()
10
+ return x.reshape(-1, x.size(-2), x.size(-1)).contiguous(), x.size()
12
11
 
13
12
 
14
13
  def undo_shape_blocksparse(x: Tensor, shape: Size):
@@ -3,13 +3,13 @@ from torch import Tensor
3
3
 
4
4
  VALIDATION = True
5
5
 
6
- def validate_dimensions(*tensors: Tensor) -> None:
6
+ def validate_dimensions(*tensors: Tensor, dims=3) -> None:
7
7
  if _check_skip_validation():
8
8
  return
9
9
 
10
10
  for tensor in tensors:
11
- if tensor.dim() != 3:
12
- raise ValueError("Tensor must have 3 dimensions")
11
+ if tensor.dim() != dims:
12
+ raise ValueError(f"Tensor must have {dims} dimensions")
13
13
 
14
14
 
15
15
  def validate_contiguous(*tensors: Tensor) -> None:
@@ -91,6 +91,9 @@ def validate_triton_block_size(triton_block_size: int, sparsity_block_size: int)
91
91
  if triton_block_size is None:
92
92
  return
93
93
 
94
+ if not (triton_block_size & (triton_block_size - 1)) == 0:
95
+ raise ValueError("Triton block size must be a power of 2")
96
+
94
97
  if triton_block_size > sparsity_block_size:
95
98
  raise ValueError("Triton block size cannot be larger than sparsity block size")
96
99
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: blksprs
3
- Version: 1.4.1
3
+ Version: 1.4.2
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -0,0 +1,19 @@
1
+ blksprs/__init__.py,sha256=ORAVhGR91G1wyIOs9Wecv-xfmjju3bJ4Jynq_SGOVY4,833
2
+ blksprs/layouting/distribution_layout.py,sha256=Xd8KjZwI87L9EL1Bw5SGUW9YztFD5q0Ygr99sffvdak,4939
3
+ blksprs/layouting/sparsity_layout.py,sha256=vZL8r5LkMwILYYqTYPZcN_NYFJuVFIB6mmBkdtRyXmI,7893
4
+ blksprs/misc/broadcast_ops.py,sha256=ahm7_lI12bJ6VTKRuSkwEeaEYWRY-BeMIOhtei35zpQ,5323
5
+ blksprs/misc/repeat_interleave.py,sha256=KJeapmxbpA7zGFfa5hUhCGrk4aFmhOhlMw-hbTh9PLI,5668
6
+ blksprs/misc/row_wise.py,sha256=1UtjLplrGx1FkxhzQ2hjSBBY11ToLQs0JiLaXKRAkL4,16893
7
+ blksprs/ops/conversion.py,sha256=vuiNwrwyuGI6H4PKrS_UHI7OKWJwNZd2i3LSjf6RetU,21332
8
+ blksprs/ops/distribution.py,sha256=KhtHRVcv4_woyNlldAjIWF-7021-KX-xyIcN6rE-UgE,16879
9
+ blksprs/ops/exp.py,sha256=CVWVq_emO2CnS_xk6Unx67P7EI7IL26dwtsmBJZOLzQ,3698
10
+ blksprs/ops/matmul.py,sha256=743XeD5M4iUv28sYf7q6mVXDd4jZpV04JAx8bF7hWkw,11254
11
+ blksprs/ops/softmax.py,sha256=cs1utM6UCzHhdJpf-ZysBr6CwbjI-5aQG0ahYY37Zy0,11991
12
+ blksprs/ops/transpose.py,sha256=Ru4YKyg796WT6OnDSTCYG45tMmdgvju3hMFzkwsJnO8,6801
13
+ blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
14
+ blksprs/utils/tools.py,sha256=JAuwsLISr_hcvxIgUVvKz5ZPf9M5ycquplsBU5dVfDc,596
15
+ blksprs/utils/validation.py,sha256=rP6yr-C2ghXfJEERry_pfvVJ0g0VyqV4sL4HkBRlJg8,3345
16
+ blksprs-1.4.2.dist-info/METADATA,sha256=wpv1H29xlts3Muvlg_dtA1KW3TUeBtlD4rr4MHRZm5c,7609
17
+ blksprs-1.4.2.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
18
+ blksprs-1.4.2.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
19
+ blksprs-1.4.2.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.1.0)
2
+ Generator: setuptools (75.2.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,19 +0,0 @@
1
- blksprs/__init__.py,sha256=ORAVhGR91G1wyIOs9Wecv-xfmjju3bJ4Jynq_SGOVY4,833
2
- blksprs/layouting/distribution_layout.py,sha256=Xd8KjZwI87L9EL1Bw5SGUW9YztFD5q0Ygr99sffvdak,4939
3
- blksprs/layouting/sparsity_layout.py,sha256=vZL8r5LkMwILYYqTYPZcN_NYFJuVFIB6mmBkdtRyXmI,7893
4
- blksprs/misc/broadcast_ops.py,sha256=RTcqvx6X_THRBb55jipeEe63YSLIAh27jdpuze0aSek,5308
5
- blksprs/misc/repeat_interleave.py,sha256=KJeapmxbpA7zGFfa5hUhCGrk4aFmhOhlMw-hbTh9PLI,5668
6
- blksprs/misc/row_wise.py,sha256=KCDO5ry5TkjI88LLD_QINZwBkzfmjoQpOOvYLfpUn5I,16853
7
- blksprs/ops/conversion.py,sha256=h1c5T74rQjqYgY9dwWXfPTXRpgzy0dtAhCmtUp8-6uo,21332
8
- blksprs/ops/distribution.py,sha256=KhtHRVcv4_woyNlldAjIWF-7021-KX-xyIcN6rE-UgE,16879
9
- blksprs/ops/exp.py,sha256=CVWVq_emO2CnS_xk6Unx67P7EI7IL26dwtsmBJZOLzQ,3698
10
- blksprs/ops/matmul.py,sha256=6DaYxecJgwiW8L-UISkgyNyzQ31AAkmDL-Oq1EjHt98,11210
11
- blksprs/ops/softmax.py,sha256=cSTxDnNmMRlJGOlCSpdg1U5KUIFpVtHulz8fteJFeh0,11972
12
- blksprs/ops/transpose.py,sha256=et8R124L29TUqihci18ms_hBoYXTtPu5LXgEA8sxk_w,6744
13
- blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
14
- blksprs/utils/tools.py,sha256=RKGWCGd5h1qFOIoShsdJObx4-QsS0RxCyzFie0geNxo,596
15
- blksprs/utils/validation.py,sha256=Gsx3aah6355bWXRPpbFuZ1p0fOrYduIqaM3ON9d5NiI,3197
16
- blksprs-1.4.1.dist-info/METADATA,sha256=3xRmBFHv2U2KnrW3_QX3003SHLkQ1JCaSqh4AUBsJD4,7609
17
- blksprs-1.4.1.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
18
- blksprs-1.4.1.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
19
- blksprs-1.4.1.dist-info/RECORD,,