blksprs 2.0rc1__tar.gz → 2.0rc2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. {blksprs-2.0rc1 → blksprs-2.0rc2}/PKG-INFO +1 -1
  2. {blksprs-2.0rc1 → blksprs-2.0rc2}/blksprs/ops/matmul.py +3 -0
  3. {blksprs-2.0rc1 → blksprs-2.0rc2}/blksprs/utils/validation.py +4 -4
  4. {blksprs-2.0rc1 → blksprs-2.0rc2}/blksprs.egg-info/PKG-INFO +1 -1
  5. {blksprs-2.0rc1 → blksprs-2.0rc2}/pyproject.toml +1 -1
  6. {blksprs-2.0rc1 → blksprs-2.0rc2}/README.md +0 -0
  7. {blksprs-2.0rc1 → blksprs-2.0rc2}/blksprs/__init__.py +0 -0
  8. {blksprs-2.0rc1 → blksprs-2.0rc2}/blksprs/layouting/distribution_layout.py +0 -0
  9. {blksprs-2.0rc1 → blksprs-2.0rc2}/blksprs/layouting/sparsity_layout.py +0 -0
  10. {blksprs-2.0rc1 → blksprs-2.0rc2}/blksprs/ops/conversion.py +0 -0
  11. {blksprs-2.0rc1 → blksprs-2.0rc2}/blksprs/ops/distribution.py +0 -0
  12. {blksprs-2.0rc1 → blksprs-2.0rc2}/blksprs/ops/flow.py +0 -0
  13. {blksprs-2.0rc1 → blksprs-2.0rc2}/blksprs/ops/misc/broadcast_ops.py +0 -0
  14. {blksprs-2.0rc1 → blksprs-2.0rc2}/blksprs/ops/misc/row_wise.py +0 -0
  15. {blksprs-2.0rc1 → blksprs-2.0rc2}/blksprs/ops/partitioning.py +0 -0
  16. {blksprs-2.0rc1 → blksprs-2.0rc2}/blksprs/ops/repeat.py +0 -0
  17. {blksprs-2.0rc1 → blksprs-2.0rc2}/blksprs/ops/softmax.py +0 -0
  18. {blksprs-2.0rc1 → blksprs-2.0rc2}/blksprs/ops/transpose.py +0 -0
  19. {blksprs-2.0rc1 → blksprs-2.0rc2}/blksprs/utils/benchmarking.py +0 -0
  20. {blksprs-2.0rc1 → blksprs-2.0rc2}/blksprs/utils/blksprs_tensor.py +0 -0
  21. {blksprs-2.0rc1 → blksprs-2.0rc2}/blksprs/utils/processing.py +0 -0
  22. {blksprs-2.0rc1 → blksprs-2.0rc2}/blksprs/utils/tools.py +0 -0
  23. {blksprs-2.0rc1 → blksprs-2.0rc2}/blksprs.egg-info/SOURCES.txt +0 -0
  24. {blksprs-2.0rc1 → blksprs-2.0rc2}/blksprs.egg-info/dependency_links.txt +0 -0
  25. {blksprs-2.0rc1 → blksprs-2.0rc2}/blksprs.egg-info/requires.txt +0 -0
  26. {blksprs-2.0rc1 → blksprs-2.0rc2}/blksprs.egg-info/top_level.txt +0 -0
  27. {blksprs-2.0rc1 → blksprs-2.0rc2}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: blksprs
3
- Version: 2.0rc1
3
+ Version: 2.0rc2
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -205,6 +205,9 @@ def matmul_kernel(x,
205
205
  # Perform matrix multiplication
206
206
  buf += tl.dot(blk_x, blk_y)
207
207
 
208
+ # Cast buffer
209
+ buf = buf.to(o.dtype.element_ty)
210
+
208
211
  # Store output
209
212
  blk_o_idx = ((pid_blk * o_b_s) +
210
213
  ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
@@ -27,8 +27,8 @@ def validate_dtype_float(*tensors: Tensor) -> None:
27
27
  return
28
28
 
29
29
  for tensor in tensors:
30
- if tensor.dtype != torch.float32:
31
- raise ValueError("Tensor must have float32 dtype")
30
+ if tensor.dtype != torch.float16 and tensor.dtype != torch.float32:
31
+ raise ValueError("Tensor must have either float16 or float32 dtype")
32
32
 
33
33
 
34
34
  def validate_dtype_int(*tensors: Tensor) -> None:
@@ -38,7 +38,7 @@ def validate_dtype_int(*tensors: Tensor) -> None:
38
38
  for tensor in tensors:
39
39
  if (tensor.dtype !=
40
40
  torch.int32 and tensor.dtype != torch.int64):
41
- raise ValueError("Tensor must have int32 or int64 dtype")
41
+ raise ValueError("Tensor must have either int32 or int64 dtype")
42
42
 
43
43
 
44
44
  def validate_device(*tensors: Tensor) -> None:
@@ -51,7 +51,7 @@ def validate_device(*tensors: Tensor) -> None:
51
51
  if i == 0:
52
52
  device = tensor.device
53
53
 
54
- if not device.type == 'cuda':
54
+ if not device.type == "cuda":
55
55
  raise ValueError("Tensors must be on GPU")
56
56
 
57
57
  if tensor.device != device:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: blksprs
3
- Version: 2.0rc1
3
+ Version: 2.0rc2
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "blksprs"
3
- version = "2.0-rc.1"
3
+ version = "2.0-rc.2"
4
4
  authors = [{ name = "Felix Schön", email = "schoen@kr.tuwien.ac.at" }]
5
5
  description = "A lightweight library for operations on blocksparse matrices in PyTorch."
6
6
  readme = "README.md"
File without changes
File without changes
File without changes
File without changes
File without changes