blksprs 1.4__py3-none-any.whl → 1.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/__init__.py CHANGED
@@ -15,4 +15,4 @@ class misc:
15
15
  from blksprs.misc.row_wise import row_wise_sum, row_wise_max, row_wise_add, row_wise_sub
16
16
 
17
17
  class util:
18
- from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse
18
+ from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse, disable_validation
@@ -31,7 +31,7 @@ def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
31
31
  sparsity_lut_i = torch.nonzero(sparsity_layout_indices).contiguous()
32
32
 
33
33
  output = torch.zeros(size_target[0], size_target[1] // sparsity_block_size, size_target[2] // sparsity_block_size,
34
- device=indices.device, dtype=torch.int32)
34
+ dtype=torch.bool, device=indices.device)
35
35
 
36
36
  i_b, i_r, i_c = indices.size()
37
37
  i_b_s, i_r_s, i_c_s = indices.stride()
@@ -27,7 +27,7 @@ def build_sparsity_layout(x: Tensor, sparsity_block_size: int, triton_block_size
27
27
  validate_device(x)
28
28
 
29
29
  output = torch.zeros(x.size(0), x.size(1) // sparsity_block_size, x.size(2) // sparsity_block_size,
30
- device=x.device, dtype=torch.int32)
30
+ dtype=torch.bool, device=x.device)
31
31
 
32
32
  x_b, x_r, x_c = x.size()
33
33
  x_b_s, x_r_s, x_c_s = x.stride()
@@ -117,7 +117,7 @@ def build_sparsity_layout_adaption(x: Tensor, sparsity_layout_from: Tensor,
117
117
  o_r = math.ceil(sparsity_layout_from.size(1) * sparsity_block_size_from // sparsity_block_size_to)
118
118
  o_c = math.ceil(sparsity_layout_from.size(2) * sparsity_block_size_from // sparsity_block_size_to)
119
119
 
120
- output = torch.zeros(o_b, o_r, o_c, device=x.device, dtype=torch.int32)
120
+ output = torch.zeros(o_b, o_r, o_c, dtype=torch.bool, device=x.device)
121
121
 
122
122
  x_b, x_r, x_c = x.size()
123
123
  x_b_s, x_r_s, x_c_s = x.stride()
@@ -25,6 +25,9 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
25
25
  output tensor corresponds to x(i) + y(j).
26
26
 
27
27
  """
28
+ x = x.contiguous()
29
+ y = y.contiguous()
30
+
28
31
  validate_device(x, y)
29
32
  validate_contiguous(x, y)
30
33
  if x.size(-1) != y.size(-1):
@@ -27,6 +27,8 @@ def repeat_interleave(x: Tensor, sparsity_layout: Tensor, repeats: int,
27
27
  Tensor: The sparsity layout of the resulting output tensor.
28
28
 
29
29
  """
30
+ x = x.contiguous()
31
+
30
32
  validate_dimensions(x)
31
33
  validate_contiguous(x)
32
34
  validate_device(x)
blksprs/misc/row_wise.py CHANGED
@@ -31,6 +31,8 @@ def row_wise_sum(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
31
31
  of the input and the sparsity layout of the output tensor.
32
32
 
33
33
  """
34
+ x = x.contiguous()
35
+
34
36
  validate_dimensions(x)
35
37
  validate_contiguous(x)
36
38
  validate_device(x)
@@ -151,6 +153,8 @@ def row_wise_max(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
151
153
  of the input and the sparsity layout of the output tensor.
152
154
 
153
155
  """
156
+ x = x.contiguous()
157
+
154
158
  validate_dimensions(x)
155
159
  validate_contiguous(x)
156
160
  validate_device(x)
blksprs/ops/conversion.py CHANGED
@@ -28,6 +28,8 @@ def to_dense(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_
28
28
  Tensor: The block-sparse tensor converted to regular form.
29
29
 
30
30
  """
31
+ x = x.contiguous()
32
+
31
33
  validate_dimensions(x)
32
34
  validate_contiguous(x, sparsity_layout)
33
35
  validate_device(x)
@@ -156,6 +158,8 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, trit
156
158
  Tensor: The block-sparse tensor converted to compressed form.
157
159
 
158
160
  """
161
+ x = x.contiguous()
162
+
159
163
  validate_dimensions(x)
160
164
  validate_contiguous(x)
161
165
  validate_device(x)
@@ -282,6 +286,8 @@ def adapt_layout(x: Tensor, sparsity_layout_from: Tensor, sparsity_block_size_fr
282
286
  Tensor: The block-sparse tensor in compressed form with the adapted sparsity layout and sparsity block size.
283
287
 
284
288
  """
289
+ x = x.contiguous()
290
+
285
291
  validate_dimensions(x)
286
292
  validate_contiguous(x, sparsity_layout_from)
287
293
  validate_device(x)
@@ -24,6 +24,9 @@ def gather(src: Tensor, sparsity_layout_src: Tensor, idx: Tensor, sparsity_layou
24
24
  Tensor: The result of the gather operation as a block-sparse tensor in compressed form.
25
25
 
26
26
  """
27
+ src = src.contiguous()
28
+ idx = idx.contiguous()
29
+
27
30
  validate_dimensions(src, idx)
28
31
  validate_contiguous(src, idx)
29
32
  validate_dtype_int(idx)
@@ -200,6 +203,9 @@ def scatter_reduce(src: Tensor, sparsity_layout_src: Tensor,
200
203
  Tensor: The result of the scatter operation as a block-sparse tensor in compressed form.
201
204
 
202
205
  """
206
+ src = src.contiguous()
207
+ idx = idx.contiguous()
208
+
203
209
  validate_dimensions(src, idx)
204
210
  validate_contiguous(src, idx)
205
211
  validate_dtype_int(idx)
blksprs/ops/exp.py CHANGED
@@ -25,6 +25,8 @@ def exp(x: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> T
25
25
  compressed form.
26
26
 
27
27
  """
28
+ x = x.contiguous()
29
+
28
30
  validate_dimensions(x)
29
31
  validate_contiguous(x)
30
32
  validate_device(x)
blksprs/ops/matmul.py CHANGED
@@ -6,7 +6,7 @@ from triton import language as tl
6
6
  from blksprs.ops.transpose import transpose
7
7
  from blksprs.utils.tools import get_triton_block_size
8
8
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
9
- validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
9
+ validate_sparsity, validate_sparsity_block_size, validate_triton_block_size, validate_dtype_float
10
10
 
11
11
 
12
12
  def matmul(x: Tensor, sparsity_layout_x: Tensor,
@@ -30,8 +30,12 @@ def matmul(x: Tensor, sparsity_layout_x: Tensor,
30
30
  Tensor: The result of the matrix multiplication as a block-sparse tensor in compressed form.
31
31
 
32
32
  """
33
+ x = x.contiguous()
34
+ y = y.contiguous()
35
+
33
36
  validate_dimensions(x, y)
34
37
  validate_contiguous(x, y)
38
+ validate_dtype_float(x, y)
35
39
  validate_device(x, y)
36
40
  validate_sparsity(sparsity_block_size, (x, sparsity_layout_x), (y, sparsity_layout_y))
37
41
  if sparsity_layout_x.size(-1) != sparsity_layout_y.size(-2):
@@ -211,7 +215,7 @@ class _BlocksparseMatmulSSS(torch.autograd.Function):
211
215
  blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
212
216
 
213
217
  # Perform matrix multiplication
214
- buf += tl.dot(blk_x, blk_y)
218
+ buf += tl.dot(blk_x, blk_y, input_precision="tf32")
215
219
 
216
220
  # Store output
217
221
  blk_o_idx = ((pid_blk * o_b_s) +
blksprs/ops/softmax.py CHANGED
@@ -26,6 +26,8 @@ def softmax(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton
26
26
  Tensor: The result of the softmax operation as a block-sparse tensor in compressed form.
27
27
 
28
28
  """
29
+ x = x.contiguous()
30
+
29
31
  validate_dimensions(x)
30
32
  validate_contiguous(x)
31
33
  validate_device(x)
blksprs/ops/transpose.py CHANGED
@@ -26,6 +26,8 @@ def transpose(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, trit
26
26
  Tensor: The sparsity layout of the transposed tensor.
27
27
 
28
28
  """
29
+ x = x.contiguous()
30
+
29
31
  validate_dimensions(x)
30
32
  validate_contiguous(x)
31
33
  validate_device(x)
blksprs/utils/tools.py CHANGED
@@ -1,10 +1,12 @@
1
1
  import torch
2
2
  from torch import Tensor, Size
3
3
 
4
+ from blksprs.utils.validation import _set_skip_validation
5
+
4
6
 
5
7
  def do_shape_blocksparse(x: Tensor):
6
8
  if x.dim() == 3:
7
- return x, x.size()
9
+ return x.contiguous(), x.size()
8
10
 
9
11
  return x.reshape(-1, x.size(-2), x.size(-1)), x.size()
10
12
 
@@ -18,3 +20,7 @@ def undo_shape_blocksparse(x: Tensor, shape: Size):
18
20
 
19
21
  def get_triton_block_size(sparsity_block_size: int, limit: int = 128):
20
22
  return min(sparsity_block_size, limit)
23
+
24
+
25
+ def disable_validation():
26
+ _set_skip_validation(True)
@@ -1,9 +1,10 @@
1
1
  import torch
2
2
  from torch import Tensor
3
3
 
4
+ VALIDATION = True
4
5
 
5
6
  def validate_dimensions(*tensors: Tensor) -> None:
6
- if _skip_validation():
7
+ if _check_skip_validation():
7
8
  return
8
9
 
9
10
  for tensor in tensors:
@@ -12,7 +13,7 @@ def validate_dimensions(*tensors: Tensor) -> None:
12
13
 
13
14
 
14
15
  def validate_contiguous(*tensors: Tensor) -> None:
15
- if _skip_validation():
16
+ if _check_skip_validation():
16
17
  return
17
18
 
18
19
  for tensor in tensors:
@@ -21,7 +22,7 @@ def validate_contiguous(*tensors: Tensor) -> None:
21
22
 
22
23
 
23
24
  def validate_dtype_float(*tensors: Tensor) -> None:
24
- if _skip_validation():
25
+ if _check_skip_validation():
25
26
  return
26
27
 
27
28
  for tensor in tensors:
@@ -30,7 +31,7 @@ def validate_dtype_float(*tensors: Tensor) -> None:
30
31
 
31
32
 
32
33
  def validate_dtype_int(*tensors: Tensor) -> None:
33
- if _skip_validation():
34
+ if _check_skip_validation():
34
35
  return
35
36
 
36
37
  for tensor in tensors:
@@ -39,7 +40,7 @@ def validate_dtype_int(*tensors: Tensor) -> None:
39
40
 
40
41
 
41
42
  def validate_device(*tensors: Tensor) -> None:
42
- if _skip_validation():
43
+ if _check_skip_validation():
43
44
  return
44
45
 
45
46
  device = None
@@ -56,7 +57,7 @@ def validate_device(*tensors: Tensor) -> None:
56
57
 
57
58
 
58
59
  def validate_sparsity(sparsity_block_size: int, *tensor_sparsity_layout_tuples: tuple[Tensor, Tensor]) -> None:
59
- if _skip_validation():
60
+ if _check_skip_validation():
60
61
  return
61
62
 
62
63
  for (tensor, sparsity_layout) in tensor_sparsity_layout_tuples:
@@ -73,7 +74,7 @@ def _validate_sparsity_layout_values(sparsity_layout: Tensor):
73
74
  raise ValueError("Sparsity layout values must be either 0 or 1")
74
75
 
75
76
  def validate_sparsity_block_size(sparsity_block_size: int, *tensors):
76
- if _skip_validation():
77
+ if _check_skip_validation():
77
78
  return
78
79
 
79
80
  if not (sparsity_block_size & (sparsity_block_size - 1)) == 0:
@@ -84,7 +85,7 @@ def validate_sparsity_block_size(sparsity_block_size: int, *tensors):
84
85
  raise ValueError("Tensor sizes must be divisible by sparsity block size")
85
86
 
86
87
  def validate_triton_block_size(triton_block_size: int, sparsity_block_size: int):
87
- if _skip_validation():
88
+ if _check_skip_validation():
88
89
  return
89
90
 
90
91
  if triton_block_size is None:
@@ -93,5 +94,9 @@ def validate_triton_block_size(triton_block_size: int, sparsity_block_size: int)
93
94
  if triton_block_size > sparsity_block_size:
94
95
  raise ValueError("Triton block size cannot be larger than sparsity block size")
95
96
 
96
- def _skip_validation():
97
- return False
97
+ def _check_skip_validation():
98
+ return not VALIDATION
99
+
100
+ def _set_skip_validation(skip_validation: bool):
101
+ global VALIDATION
102
+ VALIDATION = not skip_validation
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: blksprs
3
- Version: 1.4
3
+ Version: 1.4.1
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -8,10 +8,8 @@ Project-URL: Bugtracker, https://github.com/FelixSchoen/blksprs/issues
8
8
  Requires-Python: >=3.11
9
9
  Description-Content-Type: text/markdown
10
10
  Requires-Dist: torch
11
- Provides-Extra: deploy
12
- Requires-Dist: build; extra == "deploy"
13
- Requires-Dist: twine; extra == "deploy"
14
- Requires-Dist: pdoc3; extra == "deploy"
11
+ Provides-Extra: build
12
+ Requires-Dist: build; extra == "build"
15
13
  Provides-Extra: test
16
14
  Requires-Dist: pytest; extra == "test"
17
15
  Requires-Dist: pytest-xdist; extra == "test"
@@ -146,7 +144,7 @@ def test_readme():
146
144
  # Assert that the output has the correct sparsity layout
147
145
  actual_sparsity_layout_o = bs.layout.build_sparsity_layout(o_dense, sparsity_block_size,
148
146
  triton_block_size=triton_block_size)
149
- assert torch.allclose(actual_sparsity_layout_o, sparsity_layout_o)
147
+ assert torch.allclose(actual_sparsity_layout_o.to(torch.int), sparsity_layout_o)
150
148
 
151
149
  # Convert output tensor back to original shape
152
150
  o = bs.util.undo_shape_blocksparse(o_dense, x_shape_original)
@@ -0,0 +1,19 @@
1
+ blksprs/__init__.py,sha256=ORAVhGR91G1wyIOs9Wecv-xfmjju3bJ4Jynq_SGOVY4,833
2
+ blksprs/layouting/distribution_layout.py,sha256=Xd8KjZwI87L9EL1Bw5SGUW9YztFD5q0Ygr99sffvdak,4939
3
+ blksprs/layouting/sparsity_layout.py,sha256=vZL8r5LkMwILYYqTYPZcN_NYFJuVFIB6mmBkdtRyXmI,7893
4
+ blksprs/misc/broadcast_ops.py,sha256=RTcqvx6X_THRBb55jipeEe63YSLIAh27jdpuze0aSek,5308
5
+ blksprs/misc/repeat_interleave.py,sha256=KJeapmxbpA7zGFfa5hUhCGrk4aFmhOhlMw-hbTh9PLI,5668
6
+ blksprs/misc/row_wise.py,sha256=KCDO5ry5TkjI88LLD_QINZwBkzfmjoQpOOvYLfpUn5I,16853
7
+ blksprs/ops/conversion.py,sha256=h1c5T74rQjqYgY9dwWXfPTXRpgzy0dtAhCmtUp8-6uo,21332
8
+ blksprs/ops/distribution.py,sha256=KhtHRVcv4_woyNlldAjIWF-7021-KX-xyIcN6rE-UgE,16879
9
+ blksprs/ops/exp.py,sha256=CVWVq_emO2CnS_xk6Unx67P7EI7IL26dwtsmBJZOLzQ,3698
10
+ blksprs/ops/matmul.py,sha256=6DaYxecJgwiW8L-UISkgyNyzQ31AAkmDL-Oq1EjHt98,11210
11
+ blksprs/ops/softmax.py,sha256=cSTxDnNmMRlJGOlCSpdg1U5KUIFpVtHulz8fteJFeh0,11972
12
+ blksprs/ops/transpose.py,sha256=et8R124L29TUqihci18ms_hBoYXTtPu5LXgEA8sxk_w,6744
13
+ blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
14
+ blksprs/utils/tools.py,sha256=RKGWCGd5h1qFOIoShsdJObx4-QsS0RxCyzFie0geNxo,596
15
+ blksprs/utils/validation.py,sha256=Gsx3aah6355bWXRPpbFuZ1p0fOrYduIqaM3ON9d5NiI,3197
16
+ blksprs-1.4.1.dist-info/METADATA,sha256=3xRmBFHv2U2KnrW3_QX3003SHLkQ1JCaSqh4AUBsJD4,7609
17
+ blksprs-1.4.1.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
18
+ blksprs-1.4.1.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
19
+ blksprs-1.4.1.dist-info/RECORD,,
@@ -1,19 +0,0 @@
1
- blksprs/__init__.py,sha256=vUthykoYgmHqo2rNYgfrKTNMq7IDalRpCa1nVdFEOqA,813
2
- blksprs/layouting/distribution_layout.py,sha256=GQ-ZRXbeImiLcbaqnL2FuUZ6DoFwmB0naT_YrOpD84Q,4940
3
- blksprs/layouting/sparsity_layout.py,sha256=TtADT_WWcZpW3zyGy6KAgkAo44gDryXZqdJLZGEX2V8,7895
4
- blksprs/misc/broadcast_ops.py,sha256=xLj7CH5yEBihI5gT8SRFqQta1DXvl3iSskhbHsOX_EM,5261
5
- blksprs/misc/repeat_interleave.py,sha256=WrIp7uJsnvjIhFeLYPfkL2j5vXyKmDQGrJ69b3Y0lQ8,5644
6
- blksprs/misc/row_wise.py,sha256=Fa57BVfmneXT_8Ms-Vao8H8fh89sT3Z0b_gtN-7gano,16805
7
- blksprs/ops/conversion.py,sha256=-AOzj_j3WrBLGIgd2oVPvYS8XKfzlvGtSIWzW_qP1lk,21260
8
- blksprs/ops/distribution.py,sha256=_fQb6fWpLxocAh86D74ATahChi0EK0eBb4eUOUEBVps,16769
9
- blksprs/ops/exp.py,sha256=qs8fVtCzxl4CKT4GepaqurjEL62jyi8VjMY12JFrFAU,3674
10
- blksprs/ops/matmul.py,sha256=x3lrYg4g8fIf5PeMtZY_SEpi11kP9RFcRoemCIxcSDE,11086
11
- blksprs/ops/softmax.py,sha256=1lxgS12oJ5UcRkDxq13OOjp9AHwhgzSfBosEO1GzKvs,11948
12
- blksprs/ops/transpose.py,sha256=cX_E3b-QMhsUDNn9D8HVkYesc2JBc-EcVBUZfCWExM8,6720
13
- blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
14
- blksprs/utils/tools.py,sha256=DwophH01AeNTZAo0B1uWbKFSGBQjI5z0WmFnYKh-BBk,465
15
- blksprs/utils/validation.py,sha256=gJYZO5C48YUrXV3Fy_Z_lCaOpiFj951FT-Od7sKfprg,3007
16
- blksprs-1.4.dist-info/METADATA,sha256=mC9Vql8wtF_gLYwnGXx8_p9aKL7PnxrQSoZNkQegxic,7675
17
- blksprs-1.4.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
18
- blksprs-1.4.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
19
- blksprs-1.4.dist-info/RECORD,,
File without changes