blksprs 2.0rc2__py3-none-any.whl → 2.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,6 +12,7 @@ from blksprs.utils.validation import validate_dimensions, validate_device, \
12
12
  validate_contiguous, validate_sparsity, validate_sparsity_block_size
13
13
 
14
14
 
15
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
15
16
  def build_sparsity_layout(x: Tensor, sparsity_block_size: int) -> Tensor:
16
17
  """Builds the sparsity layout of a dense tensor in regular form covering its sparse blocks.
17
18
 
@@ -199,6 +200,7 @@ def build_sparsity_layout_adaption_kernel(x,
199
200
  tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
200
201
 
201
202
 
203
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
202
204
  def build_sparsity_layout_matmul(sparsity_layout_x: Tensor, sparsity_layout_y: Tensor) -> Tensor:
203
205
  """Builds the precise sparsity layout of the result of a matrix multiplication between the two input tensors.
204
206
 
@@ -213,6 +215,7 @@ def build_sparsity_layout_matmul(sparsity_layout_x: Tensor, sparsity_layout_y: T
213
215
  return torch.matmul(sparsity_layout_x.to(torch.float), sparsity_layout_y.to(torch.float)).to(torch.bool)
214
216
 
215
217
 
218
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
216
219
  def build_sparsity_layout_matmul_fast(sparsity_layout_x: Tensor, sparsity_layout_y: Tensor):
217
220
  """Builds the approximate sparsity layout of the result of a matrix multiplication between the two input tensors.
218
221
 
blksprs/ops/conversion.py CHANGED
@@ -18,6 +18,7 @@ def to_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int) ->
18
18
  return to_sparse(x, sparsity_layout, sparsity_block_size)
19
19
 
20
20
 
21
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
21
22
  def to_sparse(x: Tensor, sparsity_layout: Tensor,
22
23
  sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
23
24
  """Converts a block-sparse tensor in regular form to a block-sparse tensor in compressed form based on the given
@@ -175,6 +176,7 @@ def from_blksprs(x: BlksprsTensor, sparsity_layout: Tensor,
175
176
  return to_dense(x, sparsity_layout, sparsity_block_size, fill_value=fill_value, lut=lut)
176
177
 
177
178
 
179
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
178
180
  def to_dense(x: BlksprsTensor, sparsity_layout: Tensor,
179
181
  sparsity_block_size: int, fill_value: float = 0, lut: dict = None) -> Tensor:
180
182
  """Converts a block-sparse tensor in compressed form to a block-sparse tensor in regular form based on the given
@@ -326,6 +328,7 @@ def to_dense_setup_context(ctx, inputs, output):
326
328
  to_dense_forward.register_autograd(to_dense_backward, setup_context=to_dense_setup_context)
327
329
 
328
330
 
331
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
329
332
  def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_size_from: int,
330
333
  sparsity_block_size_to: int, sparsity_layout_to: Tensor = None) -> (BlksprsTensor, Tensor):
331
334
  """Adapts the sparsity layout of a block-sparse tensor, resulting in a new block-sparse tensor in compressed form
@@ -11,6 +11,7 @@ from blksprs.utils.validation import validate_contiguous, validate_dimensions, v
11
11
  validate_sparsity, validate_dtype_int, validate_sparsity_block_size
12
12
 
13
13
 
14
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
14
15
  def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
15
16
  dim: int,
16
17
  idx: BlksprsTensor, sparsity_layout_idx: Tensor,
@@ -247,6 +248,7 @@ def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
247
248
  reduce_op="none", lut=lut)
248
249
 
249
250
 
251
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
250
252
  def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
251
253
  dim: int,
252
254
  idx: BlksprsTensor,
blksprs/ops/matmul.py CHANGED
@@ -11,6 +11,7 @@ from blksprs.utils.validation import validate_contiguous, validate_dimensions, v
11
11
  validate_sparsity, validate_sparsity_block_size, validate_dtype_float
12
12
 
13
13
 
14
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
14
15
  def matmul(x: BlksprsTensor, sparsity_layout_x: Tensor,
15
16
  y: BlksprsTensor, sparsity_layout_y: Tensor,
16
17
  sparsity_layout_output: Tensor,
@@ -10,6 +10,7 @@ from blksprs.utils.validation import validate_dimensions, validate_contiguous, v
10
10
  validate_sparsity_block_size
11
11
 
12
12
 
13
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
13
14
  def row_wise_sum(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
14
15
  flag_slice_only: bool = False) -> (BlksprsTensor, Tensor):
15
16
  """Computes the row-wise sum of a block-sparse tensor.
@@ -156,6 +157,7 @@ def row_wise_sum_kernel(x,
156
157
  tl.atomic_add(o + o_idx, buf, o_msk)
157
158
 
158
159
 
160
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
159
161
  def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
160
162
  flag_slice_only: bool = False) -> (BlksprsTensor, Tensor):
161
163
  """Computes the row-wise max of a block-sparse tensor.
@@ -304,6 +306,7 @@ def row_wise_max_kernel(x,
304
306
  tl.atomic_max(o + o_idx, buf, o_msk)
305
307
 
306
308
 
309
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
307
310
  def row_wise_add(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
308
311
  sparsity_block_size: int) -> BlksprsTensor:
309
312
  """For each row in ``y`` adds the value to each value in the corresponding row of the block-sparse tensor ``x``.
@@ -8,6 +8,7 @@ from blksprs.utils.validation import validate_dimensions, validate_contiguous, v
8
8
  validate_sparsity, validate_sparsity_block_size
9
9
 
10
10
 
11
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
11
12
  def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
12
13
  dim: int, sparsity_block_size: int, lut: dict = None) -> (
13
14
  BlksprsTensor, Tensor):
@@ -111,6 +112,7 @@ def split_setup_context(ctx, inputs, output):
111
112
  split_forward.register_autograd(split_backward, setup_context=split_setup_context)
112
113
 
113
114
 
115
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
114
116
  def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
115
117
  dim: int, sparsity_block_size: int, lut: dict = None) -> (
116
118
  BlksprsTensor, Tensor):
blksprs/ops/repeat.py CHANGED
@@ -8,6 +8,7 @@ from blksprs.utils.validation import validate_dimensions, validate_contiguous, v
8
8
  validate_sparsity, validate_sparsity_block_size
9
9
 
10
10
 
11
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
11
12
  def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
12
13
  sparsity_block_size: int, sparsity_layout_output: Tensor = None, lut: dict = None) -> (
13
14
  BlksprsTensor, Tensor):
@@ -50,6 +51,7 @@ def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int,
50
51
  lut["sparsity_reverse_lut"], sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_o"]
51
52
 
52
53
 
54
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
53
55
  def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
54
56
  sparsity_block_size: int, sparsity_layout_output: Tensor = None, lut: dict = None) -> (
55
57
  BlksprsTensor, Tensor):
blksprs/ops/softmax.py CHANGED
@@ -9,9 +9,10 @@ from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
9
9
  from blksprs.utils.blksprs_tensor import BlksprsTensor
10
10
  from blksprs.utils.tools import stride, get_autotune_configs
11
11
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
12
- validate_sparsity, validate_sparsity_block_size
12
+ validate_sparsity, validate_sparsity_block_size, validate_dtype_float_32
13
13
 
14
14
 
15
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
15
16
  def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
16
17
  """Computes the softmax of a block-sparse tensor in compressed form.
17
18
 
@@ -32,6 +33,7 @@ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
32
33
 
33
34
  validate_dimensions(x)
34
35
  validate_contiguous(x)
36
+ validate_dtype_float_32(x)
35
37
  validate_device(x)
36
38
  validate_sparsity(sparsity_block_size, (x, sparsity_layout))
37
39
  validate_sparsity_block_size(sparsity_block_size, x)
blksprs/ops/transpose.py CHANGED
@@ -8,6 +8,7 @@ from blksprs.utils.validation import validate_dimensions, validate_contiguous, v
8
8
  validate_sparsity, validate_sparsity_block_size
9
9
 
10
10
 
11
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
11
12
  def transpose(x: BlksprsTensor, sparsity_layout: Tensor,
12
13
  sparsity_block_size: int, lut: dict = None) -> (BlksprsTensor, Tensor):
13
14
  """Transposes a block-sparse tensor in compressed form.
@@ -11,6 +11,7 @@ from blksprs.ops.repeat import repeat
11
11
  from blksprs.utils.blksprs_tensor import BlksprsTensor
12
12
 
13
13
 
14
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
14
15
  def apply_torch_linear(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
15
16
  linear: nn.Linear, bias: nn.Parameter = None) -> (BlksprsTensor, Tensor):
16
17
  # Extract weight and bias
@@ -25,7 +26,8 @@ def apply_torch_linear(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block
25
26
 
26
27
  # Apply weights
27
28
  sparsity_layout_xw = build_sparsity_layout_matmul_fast(sparsity_layout, sparsity_layout_w_t)
28
- xw = matmul(x, sparsity_layout, w_t_bs, sparsity_layout_w_t, sparsity_layout_xw, sparsity_block_size)
29
+ # TODO At the moment, manual cast is needed. Bug with custom_fwd?
30
+ xw = matmul(x, sparsity_layout, BlksprsTensor(w_t_bs.to(x.dtype)), sparsity_layout_w_t, sparsity_layout_xw, sparsity_block_size)
29
31
  interim = xw
30
32
 
31
33
  # Apply bias
@@ -26,10 +26,27 @@ def validate_dtype_float(*tensors: Tensor) -> None:
26
26
  if _check_skip_validation():
27
27
  return
28
28
 
29
- for tensor in tensors:
29
+ dtype = None
30
+
31
+ for i, tensor in enumerate(tensors):
32
+ if i == 0:
33
+ dtype = tensor.dtype
34
+
30
35
  if tensor.dtype != torch.float16 and tensor.dtype != torch.float32:
31
36
  raise ValueError("Tensor must have either float16 or float32 dtype")
32
37
 
38
+ if tensor.dtype != dtype:
39
+ raise ValueError("Tensors must have same dtype")
40
+
41
+
42
+ def validate_dtype_float_32(*tensors: Tensor) -> None:
43
+ if _check_skip_validation():
44
+ return
45
+
46
+ for tensor in tensors:
47
+ if tensor.dtype != torch.float32:
48
+ raise ValueError("Tensor must have float32 dtype")
49
+
33
50
 
34
51
  def validate_dtype_int(*tensors: Tensor) -> None:
35
52
  if _check_skip_validation():
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: blksprs
3
- Version: 2.0rc2
3
+ Version: 2.0rc3
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -27,7 +27,7 @@ Requires-Dist: matplotlib; extra == "test"
27
27
  ### News
28
28
 
29
29
  🎉 ***Version 2.0 released***. blksprs now supports kernel auto-tuning, JIT compilation, specification of pre-calculated
30
- LUTs, and makes use of `torch.library.triton_op()`!
30
+ LUTs, autocasting, and makes use of `torch.library.triton_op()`!
31
31
 
32
32
  ---
33
33
 
@@ -0,0 +1,22 @@
1
+ blksprs/__init__.py,sha256=OHfpwJCZWGUfpT-DVfC1YSaeZl4aCMNt9CrzMPymywU,1577
2
+ blksprs/layouting/distribution_layout.py,sha256=0glIteoY5oDkiEu5rjLIC-BB_oC4sa3rFWVkohsAG00,5329
3
+ blksprs/layouting/sparsity_layout.py,sha256=ZUhJm1jJn-npiJWFjsVyzjXDQOp8z-Wjjv0MPQOXRvg,10490
4
+ blksprs/ops/conversion.py,sha256=pdoWhqEbgsB4STr_NjDcuLUlzSGdYCMaGrW7IOSfxiA,22411
5
+ blksprs/ops/distribution.py,sha256=hLpKUoS553jM_F13WyLNNf73PM1yLqgDTkZUdW_pleo,21490
6
+ blksprs/ops/flow.py,sha256=G8L_sMAWIM77gv-YLJtyutEzXqyaaofnSX2QKvmDr44,8409
7
+ blksprs/ops/matmul.py,sha256=t9JUujkG-sGu4iyM4bjgrZJeNtMk3l8tk7rzYvWBCR8,12004
8
+ blksprs/ops/partitioning.py,sha256=nAV28f3NtvT4OFvDtnE0A-VxpDQmMXS0pZw4CJwzqGA,9838
9
+ blksprs/ops/repeat.py,sha256=bQpJuwtt8aRdSzxT78lJ8f8fLDhPkYK5UvMfJ-PQrkc,8977
10
+ blksprs/ops/softmax.py,sha256=-9wFmQpnnCGK-xOZe-5L_cCxl5Cn_GNc9QGvhSQbRe4,12918
11
+ blksprs/ops/transpose.py,sha256=PQKteFnzNAOEC7voO7wh_dq9c54UjCboJz889aBCwKc,4010
12
+ blksprs/ops/misc/broadcast_ops.py,sha256=lZ5bBIftUKffzeYz77SWB1xmtZTRGMvjF-tG9rqkOXA,6018
13
+ blksprs/ops/misc/row_wise.py,sha256=NcnLaXlPM7aQSoKXHYInao8F0xSQHixbVz-xebF5Bx0,19739
14
+ blksprs/utils/benchmarking.py,sha256=dLabDscTFn5NkmOI1g7DnKeTneUYW3RIVv9MDF-8BKc,1271
15
+ blksprs/utils/blksprs_tensor.py,sha256=pfoz59aJixj_fIoFx76ySiygwRQUemmgjMKepZ2c4j0,244
16
+ blksprs/utils/processing.py,sha256=xuu9iDpwTvsqI_WKMSD8QCNuvPnfcKMRcuF2L4Zs6Ts,3808
17
+ blksprs/utils/tools.py,sha256=RL18P4NAj7d8gXTTKbMZt4SHCynsw1wPu9yvlrnBQlo,1220
18
+ blksprs/utils/validation.py,sha256=7ks9hdNKbov1JE9y1bpnIfjWCVhqINTZOIZPi6d7k8E,4241
19
+ blksprs-2.0rc3.dist-info/METADATA,sha256=58xKs5zAesWFMPGu4d0jLPth4yUNS95MGPqqMpn-syM,8614
20
+ blksprs-2.0rc3.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
21
+ blksprs-2.0rc3.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
22
+ blksprs-2.0rc3.dist-info/RECORD,,
@@ -1,22 +0,0 @@
1
- blksprs/__init__.py,sha256=OHfpwJCZWGUfpT-DVfC1YSaeZl4aCMNt9CrzMPymywU,1577
2
- blksprs/layouting/distribution_layout.py,sha256=0glIteoY5oDkiEu5rjLIC-BB_oC4sa3rFWVkohsAG00,5329
3
- blksprs/layouting/sparsity_layout.py,sha256=UzMcdW7l4zoiLB_LMEbBR1JBdqVSgINDGYvoCYIOulk,10283
4
- blksprs/ops/conversion.py,sha256=_JKOovDZOmYJLcurJGhgNt5iQB9kOKp3fufFxD8QCZs,22204
5
- blksprs/ops/distribution.py,sha256=5gE19kPQGQljVbRpDZeqNaOe8ehRhxdQS7PiJp6mMug,21352
6
- blksprs/ops/flow.py,sha256=G8L_sMAWIM77gv-YLJtyutEzXqyaaofnSX2QKvmDr44,8409
7
- blksprs/ops/matmul.py,sha256=b4Bic8xjKt7P52nUsQn7vgvH4huuEEMf6ntXtiebRNg,11935
8
- blksprs/ops/partitioning.py,sha256=AooYZOw0oZgA9zXSu09O60hkJcnpWT1OTosr2T2wdQo,9700
9
- blksprs/ops/repeat.py,sha256=qty0qIFcfiWzROV2A2FB2KiPCC2Pe4q5TwJyGuDBAQE,8839
10
- blksprs/ops/softmax.py,sha256=eaZ8pfCpNZCX6Gk5Tk-lhNIrBQDhvfHqNNPltqxp91k,12793
11
- blksprs/ops/transpose.py,sha256=30pGCSjZs42Sg6TEXUdJNCDgmlN1n8aN88uNbV5wOtA,3941
12
- blksprs/ops/misc/broadcast_ops.py,sha256=lZ5bBIftUKffzeYz77SWB1xmtZTRGMvjF-tG9rqkOXA,6018
13
- blksprs/ops/misc/row_wise.py,sha256=iwOrHU8HiJGxq2hEmgJGZ60asRm72WLi10-PrpNrdeQ,19532
14
- blksprs/utils/benchmarking.py,sha256=dLabDscTFn5NkmOI1g7DnKeTneUYW3RIVv9MDF-8BKc,1271
15
- blksprs/utils/blksprs_tensor.py,sha256=pfoz59aJixj_fIoFx76ySiygwRQUemmgjMKepZ2c4j0,244
16
- blksprs/utils/processing.py,sha256=WLuMJQ8v-YovXwcDjhlDn3N31WMZXrtyeeyKSgq_zn4,3642
17
- blksprs/utils/tools.py,sha256=RL18P4NAj7d8gXTTKbMZt4SHCynsw1wPu9yvlrnBQlo,1220
18
- blksprs/utils/validation.py,sha256=kYRERD4DbQ9lKs0Kd7BQbTx6LW9BBWzf4NlVvkYCyGw,3822
19
- blksprs-2.0rc2.dist-info/METADATA,sha256=UJ439QdVHceVCaTvz1Qd44C5IhG9QQz60yU0xVGxjR0,8601
20
- blksprs-2.0rc2.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
21
- blksprs-2.0rc2.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
22
- blksprs-2.0rc2.dist-info/RECORD,,