blksprs 2.0rc1__py3-none-any.whl → 2.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/ops/matmul.py CHANGED
@@ -205,6 +205,9 @@ def matmul_kernel(x,
205
205
  # Perform matrix multiplication
206
206
  buf += tl.dot(blk_x, blk_y)
207
207
 
208
+ # Cast buffer
209
+ buf = buf.to(o.dtype.element_ty)
210
+
208
211
  # Store output
209
212
  blk_o_idx = ((pid_blk * o_b_s) +
210
213
  ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
@@ -27,8 +27,8 @@ def validate_dtype_float(*tensors: Tensor) -> None:
27
27
  return
28
28
 
29
29
  for tensor in tensors:
30
- if tensor.dtype != torch.float32:
31
- raise ValueError("Tensor must have float32 dtype")
30
+ if tensor.dtype != torch.float16 and tensor.dtype != torch.float32:
31
+ raise ValueError("Tensor must have either float16 or float32 dtype")
32
32
 
33
33
 
34
34
  def validate_dtype_int(*tensors: Tensor) -> None:
@@ -38,7 +38,7 @@ def validate_dtype_int(*tensors: Tensor) -> None:
38
38
  for tensor in tensors:
39
39
  if (tensor.dtype !=
40
40
  torch.int32 and tensor.dtype != torch.int64):
41
- raise ValueError("Tensor must have int32 or int64 dtype")
41
+ raise ValueError("Tensor must have either int32 or int64 dtype")
42
42
 
43
43
 
44
44
  def validate_device(*tensors: Tensor) -> None:
@@ -51,7 +51,7 @@ def validate_device(*tensors: Tensor) -> None:
51
51
  if i == 0:
52
52
  device = tensor.device
53
53
 
54
- if not device.type == 'cuda':
54
+ if not device.type == "cuda":
55
55
  raise ValueError("Tensors must be on GPU")
56
56
 
57
57
  if tensor.device != device:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: blksprs
3
- Version: 2.0rc1
3
+ Version: 2.0rc2
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -4,7 +4,7 @@ blksprs/layouting/sparsity_layout.py,sha256=UzMcdW7l4zoiLB_LMEbBR1JBdqVSgINDGYvo
4
4
  blksprs/ops/conversion.py,sha256=_JKOovDZOmYJLcurJGhgNt5iQB9kOKp3fufFxD8QCZs,22204
5
5
  blksprs/ops/distribution.py,sha256=5gE19kPQGQljVbRpDZeqNaOe8ehRhxdQS7PiJp6mMug,21352
6
6
  blksprs/ops/flow.py,sha256=G8L_sMAWIM77gv-YLJtyutEzXqyaaofnSX2QKvmDr44,8409
7
- blksprs/ops/matmul.py,sha256=YAurJcXa_39gRdh2nWUOmbhm8h99arLoO-SN-l134II,11879
7
+ blksprs/ops/matmul.py,sha256=b4Bic8xjKt7P52nUsQn7vgvH4huuEEMf6ntXtiebRNg,11935
8
8
  blksprs/ops/partitioning.py,sha256=AooYZOw0oZgA9zXSu09O60hkJcnpWT1OTosr2T2wdQo,9700
9
9
  blksprs/ops/repeat.py,sha256=qty0qIFcfiWzROV2A2FB2KiPCC2Pe4q5TwJyGuDBAQE,8839
10
10
  blksprs/ops/softmax.py,sha256=eaZ8pfCpNZCX6Gk5Tk-lhNIrBQDhvfHqNNPltqxp91k,12793
@@ -15,8 +15,8 @@ blksprs/utils/benchmarking.py,sha256=dLabDscTFn5NkmOI1g7DnKeTneUYW3RIVv9MDF-8BKc
15
15
  blksprs/utils/blksprs_tensor.py,sha256=pfoz59aJixj_fIoFx76ySiygwRQUemmgjMKepZ2c4j0,244
16
16
  blksprs/utils/processing.py,sha256=WLuMJQ8v-YovXwcDjhlDn3N31WMZXrtyeeyKSgq_zn4,3642
17
17
  blksprs/utils/tools.py,sha256=RL18P4NAj7d8gXTTKbMZt4SHCynsw1wPu9yvlrnBQlo,1220
18
- blksprs/utils/validation.py,sha256=_Ee6bqu7CxdYLFSy4WZOFoXJgd0p_RBMumCwGCk2_Hw,3763
19
- blksprs-2.0rc1.dist-info/METADATA,sha256=zXzVOvuwgYSyx-lCBycdFvRUmHUD_qYbK8sFkKWZnp8,8601
20
- blksprs-2.0rc1.dist-info/WHEEL,sha256=1tXe9gY0PYatrMPMDd6jXqjfpz_B-Wqm32CPfRC58XU,91
21
- blksprs-2.0rc1.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
22
- blksprs-2.0rc1.dist-info/RECORD,,
18
+ blksprs/utils/validation.py,sha256=kYRERD4DbQ9lKs0Kd7BQbTx6LW9BBWzf4NlVvkYCyGw,3822
19
+ blksprs-2.0rc2.dist-info/METADATA,sha256=UJ439QdVHceVCaTvz1Qd44C5IhG9QQz60yU0xVGxjR0,8601
20
+ blksprs-2.0rc2.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
21
+ blksprs-2.0rc2.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
22
+ blksprs-2.0rc2.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (77.0.3)
2
+ Generator: setuptools (78.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5