blksprs 2.0rc2__tar.gz → 2.0rc4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. {blksprs-2.0rc2 → blksprs-2.0rc4}/PKG-INFO +2 -2
  2. {blksprs-2.0rc2 → blksprs-2.0rc4}/README.md +1 -1
  3. {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs/layouting/sparsity_layout.py +3 -0
  4. {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs/ops/conversion.py +6 -1
  5. {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs/ops/distribution.py +4 -1
  6. {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs/ops/flow.py +2 -1
  7. {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs/ops/matmul.py +3 -1
  8. {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs/ops/misc/row_wise.py +6 -2
  9. {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs/ops/partitioning.py +2 -0
  10. {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs/ops/repeat.py +2 -0
  11. {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs/ops/softmax.py +9 -5
  12. {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs/ops/transpose.py +1 -0
  13. {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs/utils/processing.py +3 -1
  14. {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs/utils/validation.py +18 -1
  15. {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs.egg-info/PKG-INFO +2 -2
  16. {blksprs-2.0rc2 → blksprs-2.0rc4}/pyproject.toml +1 -1
  17. {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs/__init__.py +0 -0
  18. {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs/layouting/distribution_layout.py +0 -0
  19. {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs/ops/misc/broadcast_ops.py +0 -0
  20. {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs/utils/benchmarking.py +0 -0
  21. {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs/utils/blksprs_tensor.py +0 -0
  22. {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs/utils/tools.py +0 -0
  23. {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs.egg-info/SOURCES.txt +0 -0
  24. {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs.egg-info/dependency_links.txt +0 -0
  25. {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs.egg-info/requires.txt +0 -0
  26. {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs.egg-info/top_level.txt +0 -0
  27. {blksprs-2.0rc2 → blksprs-2.0rc4}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: blksprs
3
- Version: 2.0rc2
3
+ Version: 2.0rc4
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -27,7 +27,7 @@ Requires-Dist: matplotlib; extra == "test"
27
27
  ### News
28
28
 
29
29
  🎉 ***Version 2.0 released***. blksprs now supports kernel auto-tuning, JIT compilation, specification of pre-calculated
30
- LUTs, and makes use of `torch.library.triton_op()`!
30
+ LUTs, autocasting, and makes use of `torch.library.triton_op()`!
31
31
 
32
32
  ---
33
33
 
@@ -8,7 +8,7 @@
8
8
  ### News
9
9
 
10
10
  🎉 ***Version 2.0 released***. blksprs now supports kernel auto-tuning, JIT compilation, specification of pre-calculated
11
- LUTs, and makes use of `torch.library.triton_op()`!
11
+ LUTs, autocasting, and makes use of `torch.library.triton_op()`!
12
12
 
13
13
  ---
14
14
 
@@ -12,6 +12,7 @@ from blksprs.utils.validation import validate_dimensions, validate_device, \
12
12
  validate_contiguous, validate_sparsity, validate_sparsity_block_size
13
13
 
14
14
 
15
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
15
16
  def build_sparsity_layout(x: Tensor, sparsity_block_size: int) -> Tensor:
16
17
  """Builds the sparsity layout of a dense tensor in regular form covering its sparse blocks.
17
18
 
@@ -199,6 +200,7 @@ def build_sparsity_layout_adaption_kernel(x,
199
200
  tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
200
201
 
201
202
 
203
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
202
204
  def build_sparsity_layout_matmul(sparsity_layout_x: Tensor, sparsity_layout_y: Tensor) -> Tensor:
203
205
  """Builds the precise sparsity layout of the result of a matrix multiplication between the two input tensors.
204
206
 
@@ -213,6 +215,7 @@ def build_sparsity_layout_matmul(sparsity_layout_x: Tensor, sparsity_layout_y: T
213
215
  return torch.matmul(sparsity_layout_x.to(torch.float), sparsity_layout_y.to(torch.float)).to(torch.bool)
214
216
 
215
217
 
218
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
216
219
  def build_sparsity_layout_matmul_fast(sparsity_layout_x: Tensor, sparsity_layout_y: Tensor):
217
220
  """Builds the approximate sparsity layout of the result of a matrix multiplication between the two input tensors.
218
221
 
@@ -18,6 +18,7 @@ def to_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int) ->
18
18
  return to_sparse(x, sparsity_layout, sparsity_block_size)
19
19
 
20
20
 
21
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
21
22
  def to_sparse(x: Tensor, sparsity_layout: Tensor,
22
23
  sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
23
24
  """Converts a block-sparse tensor in regular form to a block-sparse tensor in compressed form based on the given
@@ -53,7 +54,7 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor,
53
54
  @triton_op("blksprs::to_sparse", mutates_args={})
54
55
  def to_sparse_forward(x: Tensor, _: Tensor,
55
56
  sparsity_lut: Tensor, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
56
- output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
57
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
57
58
  dtype=x.dtype, device=x.device)
58
59
 
59
60
  x_b, x_r, x_c = x.size()
@@ -86,6 +87,7 @@ def to_sparse_backward(ctx, grad_output):
86
87
  @triton.autotune(
87
88
  configs=get_autotune_configs(),
88
89
  key=[],
90
+ reset_to_zero=["o"]
89
91
  )
90
92
  @triton.jit
91
93
  def to_sparse_kernel(x,
@@ -175,6 +177,7 @@ def from_blksprs(x: BlksprsTensor, sparsity_layout: Tensor,
175
177
  return to_dense(x, sparsity_layout, sparsity_block_size, fill_value=fill_value, lut=lut)
176
178
 
177
179
 
180
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
178
181
  def to_dense(x: BlksprsTensor, sparsity_layout: Tensor,
179
182
  sparsity_block_size: int, fill_value: float = 0, lut: dict = None) -> Tensor:
180
183
  """Converts a block-sparse tensor in compressed form to a block-sparse tensor in regular form based on the given
@@ -250,6 +253,7 @@ def to_dense_backward(ctx, grad_output):
250
253
  @triton.autotune(
251
254
  configs=get_autotune_configs(),
252
255
  key=[],
256
+ restore_value=["o"]
253
257
  )
254
258
  @triton.jit
255
259
  def to_dense_kernel(x,
@@ -326,6 +330,7 @@ def to_dense_setup_context(ctx, inputs, output):
326
330
  to_dense_forward.register_autograd(to_dense_backward, setup_context=to_dense_setup_context)
327
331
 
328
332
 
333
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
329
334
  def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_size_from: int,
330
335
  sparsity_block_size_to: int, sparsity_layout_to: Tensor = None) -> (BlksprsTensor, Tensor):
331
336
  """Adapts the sparsity layout of a block-sparse tensor, resulting in a new block-sparse tensor in compressed form
@@ -11,6 +11,7 @@ from blksprs.utils.validation import validate_contiguous, validate_dimensions, v
11
11
  validate_sparsity, validate_dtype_int, validate_sparsity_block_size
12
12
 
13
13
 
14
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
14
15
  def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
15
16
  dim: int,
16
17
  idx: BlksprsTensor, sparsity_layout_idx: Tensor,
@@ -53,7 +54,7 @@ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
53
54
  def gather_forward(x: Tensor, sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
54
55
  dim: int, i: Tensor, _: Tensor, sparsity_lut_i: Tensor,
55
56
  sparsity_block_size: int) -> Tensor:
56
- output = torch.empty_like(i, dtype=x.dtype)
57
+ output = torch.zeros_like(i, dtype=x.dtype)
57
58
 
58
59
  x_b, x_r, x_c = x.size()
59
60
  x_b_s, x_r_s, x_c_s = stride(x)
@@ -100,6 +101,7 @@ def gather_backward(ctx, grad_output):
100
101
  @triton.autotune(
101
102
  configs=get_autotune_configs(),
102
103
  key=[],
104
+ reset_to_zero=["o"]
103
105
  )
104
106
  @triton.jit
105
107
  def gather_kernel(x,
@@ -247,6 +249,7 @@ def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
247
249
  reduce_op="none", lut=lut)
248
250
 
249
251
 
252
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
250
253
  def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
251
254
  dim: int,
252
255
  idx: BlksprsTensor,
@@ -12,7 +12,7 @@ from blksprs.utils.tools import stride, get_autotune_configs
12
12
  def flow_pull_forward(x: Tensor, sparsity_layout_o: Tensor,
13
13
  sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
14
14
  sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
15
- output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
15
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
16
16
  dtype=x.dtype, device=x.device)
17
17
 
18
18
  x_b, x_r, x_c = x.size()
@@ -44,6 +44,7 @@ def flow_pull_forward(x: Tensor, sparsity_layout_o: Tensor,
44
44
  @triton.autotune(
45
45
  configs=get_autotune_configs(),
46
46
  key=[],
47
+ reset_to_zero=["o"]
47
48
  )
48
49
  @triton.jit
49
50
  def flow_pull_kernel(x,
@@ -11,6 +11,7 @@ from blksprs.utils.validation import validate_contiguous, validate_dimensions, v
11
11
  validate_sparsity, validate_sparsity_block_size, validate_dtype_float
12
12
 
13
13
 
14
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
14
15
  def matmul(x: BlksprsTensor, sparsity_layout_x: Tensor,
15
16
  y: BlksprsTensor, sparsity_layout_y: Tensor,
16
17
  sparsity_layout_output: Tensor,
@@ -59,7 +60,7 @@ def matmul_forward(x: Tensor, y: Tensor,
59
60
  sparsity_layout_y: Tensor, sparsity_reverse_lut_y: Tensor,
60
61
  _: Tensor, sparsity_lut_o: Tensor,
61
62
  sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
62
- output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
63
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
63
64
  dtype=x.dtype, device=x.device)
64
65
 
65
66
  x_b, x_r, x_c = x.size()
@@ -117,6 +118,7 @@ def matmul_backward(ctx, grad_output):
117
118
  @triton.autotune(
118
119
  configs=get_autotune_configs(),
119
120
  key=[],
121
+ reset_to_zero=["o"]
120
122
  )
121
123
  @triton.jit
122
124
  def matmul_kernel(x,
@@ -10,6 +10,7 @@ from blksprs.utils.validation import validate_dimensions, validate_contiguous, v
10
10
  validate_sparsity_block_size
11
11
 
12
12
 
13
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
13
14
  def row_wise_sum(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
14
15
  flag_slice_only: bool = False) -> (BlksprsTensor, Tensor):
15
16
  """Computes the row-wise sum of a block-sparse tensor.
@@ -156,6 +157,7 @@ def row_wise_sum_kernel(x,
156
157
  tl.atomic_add(o + o_idx, buf, o_msk)
157
158
 
158
159
 
160
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
159
161
  def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
160
162
  flag_slice_only: bool = False) -> (BlksprsTensor, Tensor):
161
163
  """Computes the row-wise max of a block-sparse tensor.
@@ -304,6 +306,7 @@ def row_wise_max_kernel(x,
304
306
  tl.atomic_max(o + o_idx, buf, o_msk)
305
307
 
306
308
 
309
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
307
310
  def row_wise_add(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
308
311
  sparsity_block_size: int) -> BlksprsTensor:
309
312
  """For each row in ``y`` adds the value to each value in the corresponding row of the block-sparse tensor ``x``.
@@ -351,7 +354,7 @@ def row_wise_sub(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
351
354
  def row_wise_add_forward(x: Tensor, sparsity_lut_x: Tensor,
352
355
  sparsity_layout_x_rwm: Tensor, sparsity_reverse_x_lut_rwm: Tensor,
353
356
  y: Tensor, sparsity_block_size: int) -> Tensor:
354
- output = torch.empty_like(x)
357
+ output = torch.zeros_like(x)
355
358
 
356
359
  x_b, x_r, x_c = x.size()
357
360
  x_b_s, x_r_s, x_c_s = stride(x)
@@ -384,7 +387,8 @@ def row_wise_add_forward(x: Tensor, sparsity_lut_x: Tensor,
384
387
 
385
388
  @triton.autotune(
386
389
  configs=get_autotune_configs(),
387
- key=[]
390
+ key=[],
391
+ reset_to_zero=["o"]
388
392
  )
389
393
  @triton.jit
390
394
  def kernel_blocksparse_row_wise_add(x,
@@ -8,6 +8,7 @@ from blksprs.utils.validation import validate_dimensions, validate_contiguous, v
8
8
  validate_sparsity, validate_sparsity_block_size
9
9
 
10
10
 
11
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
11
12
  def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
12
13
  dim: int, sparsity_block_size: int, lut: dict = None) -> (
13
14
  BlksprsTensor, Tensor):
@@ -111,6 +112,7 @@ def split_setup_context(ctx, inputs, output):
111
112
  split_forward.register_autograd(split_backward, setup_context=split_setup_context)
112
113
 
113
114
 
115
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
114
116
  def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
115
117
  dim: int, sparsity_block_size: int, lut: dict = None) -> (
116
118
  BlksprsTensor, Tensor):
@@ -8,6 +8,7 @@ from blksprs.utils.validation import validate_dimensions, validate_contiguous, v
8
8
  validate_sparsity, validate_sparsity_block_size
9
9
 
10
10
 
11
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
11
12
  def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
12
13
  sparsity_block_size: int, sparsity_layout_output: Tensor = None, lut: dict = None) -> (
13
14
  BlksprsTensor, Tensor):
@@ -50,6 +51,7 @@ def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int,
50
51
  lut["sparsity_reverse_lut"], sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_o"]
51
52
 
52
53
 
54
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
53
55
  def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
54
56
  sparsity_block_size: int, sparsity_layout_output: Tensor = None, lut: dict = None) -> (
55
57
  BlksprsTensor, Tensor):
@@ -9,9 +9,10 @@ from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
9
9
  from blksprs.utils.blksprs_tensor import BlksprsTensor
10
10
  from blksprs.utils.tools import stride, get_autotune_configs
11
11
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
12
- validate_sparsity, validate_sparsity_block_size
12
+ validate_sparsity, validate_sparsity_block_size, validate_dtype_float_32
13
13
 
14
14
 
15
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
15
16
  def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
16
17
  """Computes the softmax of a block-sparse tensor in compressed form.
17
18
 
@@ -32,6 +33,7 @@ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
32
33
 
33
34
  validate_dimensions(x)
34
35
  validate_contiguous(x)
36
+ validate_dtype_float_32(x)
35
37
  validate_device(x)
36
38
  validate_sparsity(sparsity_block_size, (x, sparsity_layout))
37
39
  validate_sparsity_block_size(sparsity_block_size, x)
@@ -49,7 +51,7 @@ def softmax_forward(x: Tensor, sparsity_layout: Tensor,
49
51
  sparsity_lut: Tensor,
50
52
  sparsity_reverse_lut_rws: Tensor,
51
53
  sparsity_block_size: int) -> Tensor:
52
- output = torch.empty_like(x)
54
+ output = torch.zeros_like(x)
53
55
 
54
56
  x_b, x_r, x_c = x.size()
55
57
  x_b_s, x_r_s, x_c_s = stride(x)
@@ -106,7 +108,7 @@ def softmax_backward(ctx, grad_output):
106
108
  s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
107
109
  s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_s)
108
110
 
109
- grad_x = torch.empty_like(o, dtype=torch.float)
111
+ grad_x = torch.zeros_like(o, dtype=torch.float)
110
112
 
111
113
  triton_grid = lambda meta: [o_b,
112
114
  triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
@@ -131,7 +133,8 @@ def softmax_backward(ctx, grad_output):
131
133
 
132
134
  @triton.autotune(
133
135
  configs=get_autotune_configs(),
134
- key=[]
136
+ key=[],
137
+ reset_to_zero=["o"]
135
138
  )
136
139
  @triton.jit
137
140
  def softmax_kernel(x,
@@ -196,7 +199,8 @@ def softmax_kernel(x,
196
199
 
197
200
  @triton.autotune(
198
201
  configs=get_autotune_configs(),
199
- key=[]
202
+ key=[],
203
+ reset_to_zero=["o"]
200
204
  )
201
205
  @triton.jit
202
206
  def softmax_kernel_grad(g,
@@ -8,6 +8,7 @@ from blksprs.utils.validation import validate_dimensions, validate_contiguous, v
8
8
  validate_sparsity, validate_sparsity_block_size
9
9
 
10
10
 
11
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
11
12
  def transpose(x: BlksprsTensor, sparsity_layout: Tensor,
12
13
  sparsity_block_size: int, lut: dict = None) -> (BlksprsTensor, Tensor):
13
14
  """Transposes a block-sparse tensor in compressed form.
@@ -11,6 +11,7 @@ from blksprs.ops.repeat import repeat
11
11
  from blksprs.utils.blksprs_tensor import BlksprsTensor
12
12
 
13
13
 
14
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
14
15
  def apply_torch_linear(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
15
16
  linear: nn.Linear, bias: nn.Parameter = None) -> (BlksprsTensor, Tensor):
16
17
  # Extract weight and bias
@@ -25,7 +26,8 @@ def apply_torch_linear(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block
25
26
 
26
27
  # Apply weights
27
28
  sparsity_layout_xw = build_sparsity_layout_matmul_fast(sparsity_layout, sparsity_layout_w_t)
28
- xw = matmul(x, sparsity_layout, w_t_bs, sparsity_layout_w_t, sparsity_layout_xw, sparsity_block_size)
29
+ # TODO At the moment, manual cast is needed. Bug with custom_fwd?
30
+ xw = matmul(x, sparsity_layout, BlksprsTensor(w_t_bs.to(x.dtype)), sparsity_layout_w_t, sparsity_layout_xw, sparsity_block_size)
29
31
  interim = xw
30
32
 
31
33
  # Apply bias
@@ -26,10 +26,27 @@ def validate_dtype_float(*tensors: Tensor) -> None:
26
26
  if _check_skip_validation():
27
27
  return
28
28
 
29
- for tensor in tensors:
29
+ dtype = None
30
+
31
+ for i, tensor in enumerate(tensors):
32
+ if i == 0:
33
+ dtype = tensor.dtype
34
+
30
35
  if tensor.dtype != torch.float16 and tensor.dtype != torch.float32:
31
36
  raise ValueError("Tensor must have either float16 or float32 dtype")
32
37
 
38
+ if tensor.dtype != dtype:
39
+ raise ValueError("Tensors must have same dtype")
40
+
41
+
42
+ def validate_dtype_float_32(*tensors: Tensor) -> None:
43
+ if _check_skip_validation():
44
+ return
45
+
46
+ for tensor in tensors:
47
+ if tensor.dtype != torch.float32:
48
+ raise ValueError("Tensor must have float32 dtype")
49
+
33
50
 
34
51
  def validate_dtype_int(*tensors: Tensor) -> None:
35
52
  if _check_skip_validation():
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: blksprs
3
- Version: 2.0rc2
3
+ Version: 2.0rc4
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -27,7 +27,7 @@ Requires-Dist: matplotlib; extra == "test"
27
27
  ### News
28
28
 
29
29
  🎉 ***Version 2.0 released***. blksprs now supports kernel auto-tuning, JIT compilation, specification of pre-calculated
30
- LUTs, and makes use of `torch.library.triton_op()`!
30
+ LUTs, autocasting, and makes use of `torch.library.triton_op()`!
31
31
 
32
32
  ---
33
33
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "blksprs"
3
- version = "2.0-rc.2"
3
+ version = "2.0-rc.4"
4
4
  authors = [{ name = "Felix Schön", email = "schoen@kr.tuwien.ac.at" }]
5
5
  description = "A lightweight library for operations on blocksparse matrices in PyTorch."
6
6
  readme = "README.md"
File without changes
File without changes