blksprs 2.0rc2__tar.gz → 2.0rc4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {blksprs-2.0rc2 → blksprs-2.0rc4}/PKG-INFO +2 -2
- {blksprs-2.0rc2 → blksprs-2.0rc4}/README.md +1 -1
- {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs/layouting/sparsity_layout.py +3 -0
- {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs/ops/conversion.py +6 -1
- {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs/ops/distribution.py +4 -1
- {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs/ops/flow.py +2 -1
- {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs/ops/matmul.py +3 -1
- {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs/ops/misc/row_wise.py +6 -2
- {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs/ops/partitioning.py +2 -0
- {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs/ops/repeat.py +2 -0
- {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs/ops/softmax.py +9 -5
- {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs/ops/transpose.py +1 -0
- {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs/utils/processing.py +3 -1
- {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs/utils/validation.py +18 -1
- {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs.egg-info/PKG-INFO +2 -2
- {blksprs-2.0rc2 → blksprs-2.0rc4}/pyproject.toml +1 -1
- {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs/__init__.py +0 -0
- {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs/layouting/distribution_layout.py +0 -0
- {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs/ops/misc/broadcast_ops.py +0 -0
- {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs/utils/benchmarking.py +0 -0
- {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs/utils/blksprs_tensor.py +0 -0
- {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs/utils/tools.py +0 -0
- {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs.egg-info/SOURCES.txt +0 -0
- {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs.egg-info/dependency_links.txt +0 -0
- {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs.egg-info/requires.txt +0 -0
- {blksprs-2.0rc2 → blksprs-2.0rc4}/blksprs.egg-info/top_level.txt +0 -0
- {blksprs-2.0rc2 → blksprs-2.0rc4}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.0rc4
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -27,7 +27,7 @@ Requires-Dist: matplotlib; extra == "test"
|
|
|
27
27
|
### News
|
|
28
28
|
|
|
29
29
|
🎉 ***Version 2.0 released***. blksprs now supports kernel auto-tuning, JIT compilation, specification of pre-calculated
|
|
30
|
-
LUTs, and makes use of `torch.library.triton_op()`!
|
|
30
|
+
LUTs, autocasting, and makes use of `torch.library.triton_op()`!
|
|
31
31
|
|
|
32
32
|
---
|
|
33
33
|
|
|
@@ -8,7 +8,7 @@
|
|
|
8
8
|
### News
|
|
9
9
|
|
|
10
10
|
🎉 ***Version 2.0 released***. blksprs now supports kernel auto-tuning, JIT compilation, specification of pre-calculated
|
|
11
|
-
LUTs, and makes use of `torch.library.triton_op()`!
|
|
11
|
+
LUTs, autocasting, and makes use of `torch.library.triton_op()`!
|
|
12
12
|
|
|
13
13
|
---
|
|
14
14
|
|
|
@@ -12,6 +12,7 @@ from blksprs.utils.validation import validate_dimensions, validate_device, \
|
|
|
12
12
|
validate_contiguous, validate_sparsity, validate_sparsity_block_size
|
|
13
13
|
|
|
14
14
|
|
|
15
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
15
16
|
def build_sparsity_layout(x: Tensor, sparsity_block_size: int) -> Tensor:
|
|
16
17
|
"""Builds the sparsity layout of a dense tensor in regular form covering its sparse blocks.
|
|
17
18
|
|
|
@@ -199,6 +200,7 @@ def build_sparsity_layout_adaption_kernel(x,
|
|
|
199
200
|
tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
|
|
200
201
|
|
|
201
202
|
|
|
203
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
202
204
|
def build_sparsity_layout_matmul(sparsity_layout_x: Tensor, sparsity_layout_y: Tensor) -> Tensor:
|
|
203
205
|
"""Builds the precise sparsity layout of the result of a matrix multiplication between the two input tensors.
|
|
204
206
|
|
|
@@ -213,6 +215,7 @@ def build_sparsity_layout_matmul(sparsity_layout_x: Tensor, sparsity_layout_y: T
|
|
|
213
215
|
return torch.matmul(sparsity_layout_x.to(torch.float), sparsity_layout_y.to(torch.float)).to(torch.bool)
|
|
214
216
|
|
|
215
217
|
|
|
218
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
216
219
|
def build_sparsity_layout_matmul_fast(sparsity_layout_x: Tensor, sparsity_layout_y: Tensor):
|
|
217
220
|
"""Builds the approximate sparsity layout of the result of a matrix multiplication between the two input tensors.
|
|
218
221
|
|
|
@@ -18,6 +18,7 @@ def to_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int) ->
|
|
|
18
18
|
return to_sparse(x, sparsity_layout, sparsity_block_size)
|
|
19
19
|
|
|
20
20
|
|
|
21
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
21
22
|
def to_sparse(x: Tensor, sparsity_layout: Tensor,
|
|
22
23
|
sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
|
|
23
24
|
"""Converts a block-sparse tensor in regular form to a block-sparse tensor in compressed form based on the given
|
|
@@ -53,7 +54,7 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor,
|
|
|
53
54
|
@triton_op("blksprs::to_sparse", mutates_args={})
|
|
54
55
|
def to_sparse_forward(x: Tensor, _: Tensor,
|
|
55
56
|
sparsity_lut: Tensor, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
56
|
-
output = torch.
|
|
57
|
+
output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
57
58
|
dtype=x.dtype, device=x.device)
|
|
58
59
|
|
|
59
60
|
x_b, x_r, x_c = x.size()
|
|
@@ -86,6 +87,7 @@ def to_sparse_backward(ctx, grad_output):
|
|
|
86
87
|
@triton.autotune(
|
|
87
88
|
configs=get_autotune_configs(),
|
|
88
89
|
key=[],
|
|
90
|
+
reset_to_zero=["o"]
|
|
89
91
|
)
|
|
90
92
|
@triton.jit
|
|
91
93
|
def to_sparse_kernel(x,
|
|
@@ -175,6 +177,7 @@ def from_blksprs(x: BlksprsTensor, sparsity_layout: Tensor,
|
|
|
175
177
|
return to_dense(x, sparsity_layout, sparsity_block_size, fill_value=fill_value, lut=lut)
|
|
176
178
|
|
|
177
179
|
|
|
180
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
178
181
|
def to_dense(x: BlksprsTensor, sparsity_layout: Tensor,
|
|
179
182
|
sparsity_block_size: int, fill_value: float = 0, lut: dict = None) -> Tensor:
|
|
180
183
|
"""Converts a block-sparse tensor in compressed form to a block-sparse tensor in regular form based on the given
|
|
@@ -250,6 +253,7 @@ def to_dense_backward(ctx, grad_output):
|
|
|
250
253
|
@triton.autotune(
|
|
251
254
|
configs=get_autotune_configs(),
|
|
252
255
|
key=[],
|
|
256
|
+
restore_value=["o"]
|
|
253
257
|
)
|
|
254
258
|
@triton.jit
|
|
255
259
|
def to_dense_kernel(x,
|
|
@@ -326,6 +330,7 @@ def to_dense_setup_context(ctx, inputs, output):
|
|
|
326
330
|
to_dense_forward.register_autograd(to_dense_backward, setup_context=to_dense_setup_context)
|
|
327
331
|
|
|
328
332
|
|
|
333
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
329
334
|
def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_size_from: int,
|
|
330
335
|
sparsity_block_size_to: int, sparsity_layout_to: Tensor = None) -> (BlksprsTensor, Tensor):
|
|
331
336
|
"""Adapts the sparsity layout of a block-sparse tensor, resulting in a new block-sparse tensor in compressed form
|
|
@@ -11,6 +11,7 @@ from blksprs.utils.validation import validate_contiguous, validate_dimensions, v
|
|
|
11
11
|
validate_sparsity, validate_dtype_int, validate_sparsity_block_size
|
|
12
12
|
|
|
13
13
|
|
|
14
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
14
15
|
def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
15
16
|
dim: int,
|
|
16
17
|
idx: BlksprsTensor, sparsity_layout_idx: Tensor,
|
|
@@ -53,7 +54,7 @@ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
53
54
|
def gather_forward(x: Tensor, sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
|
|
54
55
|
dim: int, i: Tensor, _: Tensor, sparsity_lut_i: Tensor,
|
|
55
56
|
sparsity_block_size: int) -> Tensor:
|
|
56
|
-
output = torch.
|
|
57
|
+
output = torch.zeros_like(i, dtype=x.dtype)
|
|
57
58
|
|
|
58
59
|
x_b, x_r, x_c = x.size()
|
|
59
60
|
x_b_s, x_r_s, x_c_s = stride(x)
|
|
@@ -100,6 +101,7 @@ def gather_backward(ctx, grad_output):
|
|
|
100
101
|
@triton.autotune(
|
|
101
102
|
configs=get_autotune_configs(),
|
|
102
103
|
key=[],
|
|
104
|
+
reset_to_zero=["o"]
|
|
103
105
|
)
|
|
104
106
|
@triton.jit
|
|
105
107
|
def gather_kernel(x,
|
|
@@ -247,6 +249,7 @@ def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
247
249
|
reduce_op="none", lut=lut)
|
|
248
250
|
|
|
249
251
|
|
|
252
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
250
253
|
def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
251
254
|
dim: int,
|
|
252
255
|
idx: BlksprsTensor,
|
|
@@ -12,7 +12,7 @@ from blksprs.utils.tools import stride, get_autotune_configs
|
|
|
12
12
|
def flow_pull_forward(x: Tensor, sparsity_layout_o: Tensor,
|
|
13
13
|
sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
14
14
|
sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
15
|
-
output = torch.
|
|
15
|
+
output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
16
16
|
dtype=x.dtype, device=x.device)
|
|
17
17
|
|
|
18
18
|
x_b, x_r, x_c = x.size()
|
|
@@ -44,6 +44,7 @@ def flow_pull_forward(x: Tensor, sparsity_layout_o: Tensor,
|
|
|
44
44
|
@triton.autotune(
|
|
45
45
|
configs=get_autotune_configs(),
|
|
46
46
|
key=[],
|
|
47
|
+
reset_to_zero=["o"]
|
|
47
48
|
)
|
|
48
49
|
@triton.jit
|
|
49
50
|
def flow_pull_kernel(x,
|
|
@@ -11,6 +11,7 @@ from blksprs.utils.validation import validate_contiguous, validate_dimensions, v
|
|
|
11
11
|
validate_sparsity, validate_sparsity_block_size, validate_dtype_float
|
|
12
12
|
|
|
13
13
|
|
|
14
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
14
15
|
def matmul(x: BlksprsTensor, sparsity_layout_x: Tensor,
|
|
15
16
|
y: BlksprsTensor, sparsity_layout_y: Tensor,
|
|
16
17
|
sparsity_layout_output: Tensor,
|
|
@@ -59,7 +60,7 @@ def matmul_forward(x: Tensor, y: Tensor,
|
|
|
59
60
|
sparsity_layout_y: Tensor, sparsity_reverse_lut_y: Tensor,
|
|
60
61
|
_: Tensor, sparsity_lut_o: Tensor,
|
|
61
62
|
sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
62
|
-
output = torch.
|
|
63
|
+
output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
63
64
|
dtype=x.dtype, device=x.device)
|
|
64
65
|
|
|
65
66
|
x_b, x_r, x_c = x.size()
|
|
@@ -117,6 +118,7 @@ def matmul_backward(ctx, grad_output):
|
|
|
117
118
|
@triton.autotune(
|
|
118
119
|
configs=get_autotune_configs(),
|
|
119
120
|
key=[],
|
|
121
|
+
reset_to_zero=["o"]
|
|
120
122
|
)
|
|
121
123
|
@triton.jit
|
|
122
124
|
def matmul_kernel(x,
|
|
@@ -10,6 +10,7 @@ from blksprs.utils.validation import validate_dimensions, validate_contiguous, v
|
|
|
10
10
|
validate_sparsity_block_size
|
|
11
11
|
|
|
12
12
|
|
|
13
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
|
|
13
14
|
def row_wise_sum(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
14
15
|
flag_slice_only: bool = False) -> (BlksprsTensor, Tensor):
|
|
15
16
|
"""Computes the row-wise sum of a block-sparse tensor.
|
|
@@ -156,6 +157,7 @@ def row_wise_sum_kernel(x,
|
|
|
156
157
|
tl.atomic_add(o + o_idx, buf, o_msk)
|
|
157
158
|
|
|
158
159
|
|
|
160
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
159
161
|
def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
160
162
|
flag_slice_only: bool = False) -> (BlksprsTensor, Tensor):
|
|
161
163
|
"""Computes the row-wise max of a block-sparse tensor.
|
|
@@ -304,6 +306,7 @@ def row_wise_max_kernel(x,
|
|
|
304
306
|
tl.atomic_max(o + o_idx, buf, o_msk)
|
|
305
307
|
|
|
306
308
|
|
|
309
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
307
310
|
def row_wise_add(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
|
|
308
311
|
sparsity_block_size: int) -> BlksprsTensor:
|
|
309
312
|
"""For each row in ``y`` adds the value to each value in the corresponding row of the block-sparse tensor ``x``.
|
|
@@ -351,7 +354,7 @@ def row_wise_sub(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
|
|
|
351
354
|
def row_wise_add_forward(x: Tensor, sparsity_lut_x: Tensor,
|
|
352
355
|
sparsity_layout_x_rwm: Tensor, sparsity_reverse_x_lut_rwm: Tensor,
|
|
353
356
|
y: Tensor, sparsity_block_size: int) -> Tensor:
|
|
354
|
-
output = torch.
|
|
357
|
+
output = torch.zeros_like(x)
|
|
355
358
|
|
|
356
359
|
x_b, x_r, x_c = x.size()
|
|
357
360
|
x_b_s, x_r_s, x_c_s = stride(x)
|
|
@@ -384,7 +387,8 @@ def row_wise_add_forward(x: Tensor, sparsity_lut_x: Tensor,
|
|
|
384
387
|
|
|
385
388
|
@triton.autotune(
|
|
386
389
|
configs=get_autotune_configs(),
|
|
387
|
-
key=[]
|
|
390
|
+
key=[],
|
|
391
|
+
reset_to_zero=["o"]
|
|
388
392
|
)
|
|
389
393
|
@triton.jit
|
|
390
394
|
def kernel_blocksparse_row_wise_add(x,
|
|
@@ -8,6 +8,7 @@ from blksprs.utils.validation import validate_dimensions, validate_contiguous, v
|
|
|
8
8
|
validate_sparsity, validate_sparsity_block_size
|
|
9
9
|
|
|
10
10
|
|
|
11
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
11
12
|
def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
12
13
|
dim: int, sparsity_block_size: int, lut: dict = None) -> (
|
|
13
14
|
BlksprsTensor, Tensor):
|
|
@@ -111,6 +112,7 @@ def split_setup_context(ctx, inputs, output):
|
|
|
111
112
|
split_forward.register_autograd(split_backward, setup_context=split_setup_context)
|
|
112
113
|
|
|
113
114
|
|
|
115
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
114
116
|
def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
115
117
|
dim: int, sparsity_block_size: int, lut: dict = None) -> (
|
|
116
118
|
BlksprsTensor, Tensor):
|
|
@@ -8,6 +8,7 @@ from blksprs.utils.validation import validate_dimensions, validate_contiguous, v
|
|
|
8
8
|
validate_sparsity, validate_sparsity_block_size
|
|
9
9
|
|
|
10
10
|
|
|
11
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
11
12
|
def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
|
|
12
13
|
sparsity_block_size: int, sparsity_layout_output: Tensor = None, lut: dict = None) -> (
|
|
13
14
|
BlksprsTensor, Tensor):
|
|
@@ -50,6 +51,7 @@ def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int,
|
|
|
50
51
|
lut["sparsity_reverse_lut"], sparsity_block_size, lut["n_sparse_blocks"])), lut["sparsity_layout_o"]
|
|
51
52
|
|
|
52
53
|
|
|
54
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
53
55
|
def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
|
|
54
56
|
sparsity_block_size: int, sparsity_layout_output: Tensor = None, lut: dict = None) -> (
|
|
55
57
|
BlksprsTensor, Tensor):
|
|
@@ -9,9 +9,10 @@ from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
|
|
|
9
9
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
10
10
|
from blksprs.utils.tools import stride, get_autotune_configs
|
|
11
11
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
12
|
-
validate_sparsity, validate_sparsity_block_size
|
|
12
|
+
validate_sparsity, validate_sparsity_block_size, validate_dtype_float_32
|
|
13
13
|
|
|
14
14
|
|
|
15
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
|
|
15
16
|
def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
|
|
16
17
|
"""Computes the softmax of a block-sparse tensor in compressed form.
|
|
17
18
|
|
|
@@ -32,6 +33,7 @@ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
|
32
33
|
|
|
33
34
|
validate_dimensions(x)
|
|
34
35
|
validate_contiguous(x)
|
|
36
|
+
validate_dtype_float_32(x)
|
|
35
37
|
validate_device(x)
|
|
36
38
|
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
37
39
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
@@ -49,7 +51,7 @@ def softmax_forward(x: Tensor, sparsity_layout: Tensor,
|
|
|
49
51
|
sparsity_lut: Tensor,
|
|
50
52
|
sparsity_reverse_lut_rws: Tensor,
|
|
51
53
|
sparsity_block_size: int) -> Tensor:
|
|
52
|
-
output = torch.
|
|
54
|
+
output = torch.zeros_like(x)
|
|
53
55
|
|
|
54
56
|
x_b, x_r, x_c = x.size()
|
|
55
57
|
x_b_s, x_r_s, x_c_s = stride(x)
|
|
@@ -106,7 +108,7 @@ def softmax_backward(ctx, grad_output):
|
|
|
106
108
|
s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
|
|
107
109
|
s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_s)
|
|
108
110
|
|
|
109
|
-
grad_x = torch.
|
|
111
|
+
grad_x = torch.zeros_like(o, dtype=torch.float)
|
|
110
112
|
|
|
111
113
|
triton_grid = lambda meta: [o_b,
|
|
112
114
|
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
@@ -131,7 +133,8 @@ def softmax_backward(ctx, grad_output):
|
|
|
131
133
|
|
|
132
134
|
@triton.autotune(
|
|
133
135
|
configs=get_autotune_configs(),
|
|
134
|
-
key=[]
|
|
136
|
+
key=[],
|
|
137
|
+
reset_to_zero=["o"]
|
|
135
138
|
)
|
|
136
139
|
@triton.jit
|
|
137
140
|
def softmax_kernel(x,
|
|
@@ -196,7 +199,8 @@ def softmax_kernel(x,
|
|
|
196
199
|
|
|
197
200
|
@triton.autotune(
|
|
198
201
|
configs=get_autotune_configs(),
|
|
199
|
-
key=[]
|
|
202
|
+
key=[],
|
|
203
|
+
reset_to_zero=["o"]
|
|
200
204
|
)
|
|
201
205
|
@triton.jit
|
|
202
206
|
def softmax_kernel_grad(g,
|
|
@@ -8,6 +8,7 @@ from blksprs.utils.validation import validate_dimensions, validate_contiguous, v
|
|
|
8
8
|
validate_sparsity, validate_sparsity_block_size
|
|
9
9
|
|
|
10
10
|
|
|
11
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
11
12
|
def transpose(x: BlksprsTensor, sparsity_layout: Tensor,
|
|
12
13
|
sparsity_block_size: int, lut: dict = None) -> (BlksprsTensor, Tensor):
|
|
13
14
|
"""Transposes a block-sparse tensor in compressed form.
|
|
@@ -11,6 +11,7 @@ from blksprs.ops.repeat import repeat
|
|
|
11
11
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
12
12
|
|
|
13
13
|
|
|
14
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
14
15
|
def apply_torch_linear(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
15
16
|
linear: nn.Linear, bias: nn.Parameter = None) -> (BlksprsTensor, Tensor):
|
|
16
17
|
# Extract weight and bias
|
|
@@ -25,7 +26,8 @@ def apply_torch_linear(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block
|
|
|
25
26
|
|
|
26
27
|
# Apply weights
|
|
27
28
|
sparsity_layout_xw = build_sparsity_layout_matmul_fast(sparsity_layout, sparsity_layout_w_t)
|
|
28
|
-
|
|
29
|
+
# TODO At the moment, manual cast is needed. Bug with custom_fwd?
|
|
30
|
+
xw = matmul(x, sparsity_layout, BlksprsTensor(w_t_bs.to(x.dtype)), sparsity_layout_w_t, sparsity_layout_xw, sparsity_block_size)
|
|
29
31
|
interim = xw
|
|
30
32
|
|
|
31
33
|
# Apply bias
|
|
@@ -26,10 +26,27 @@ def validate_dtype_float(*tensors: Tensor) -> None:
|
|
|
26
26
|
if _check_skip_validation():
|
|
27
27
|
return
|
|
28
28
|
|
|
29
|
-
|
|
29
|
+
dtype = None
|
|
30
|
+
|
|
31
|
+
for i, tensor in enumerate(tensors):
|
|
32
|
+
if i == 0:
|
|
33
|
+
dtype = tensor.dtype
|
|
34
|
+
|
|
30
35
|
if tensor.dtype != torch.float16 and tensor.dtype != torch.float32:
|
|
31
36
|
raise ValueError("Tensor must have either float16 or float32 dtype")
|
|
32
37
|
|
|
38
|
+
if tensor.dtype != dtype:
|
|
39
|
+
raise ValueError("Tensors must have same dtype")
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def validate_dtype_float_32(*tensors: Tensor) -> None:
|
|
43
|
+
if _check_skip_validation():
|
|
44
|
+
return
|
|
45
|
+
|
|
46
|
+
for tensor in tensors:
|
|
47
|
+
if tensor.dtype != torch.float32:
|
|
48
|
+
raise ValueError("Tensor must have float32 dtype")
|
|
49
|
+
|
|
33
50
|
|
|
34
51
|
def validate_dtype_int(*tensors: Tensor) -> None:
|
|
35
52
|
if _check_skip_validation():
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.0rc4
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -27,7 +27,7 @@ Requires-Dist: matplotlib; extra == "test"
|
|
|
27
27
|
### News
|
|
28
28
|
|
|
29
29
|
🎉 ***Version 2.0 released***. blksprs now supports kernel auto-tuning, JIT compilation, specification of pre-calculated
|
|
30
|
-
LUTs, and makes use of `torch.library.triton_op()`!
|
|
30
|
+
LUTs, autocasting, and makes use of `torch.library.triton_op()`!
|
|
31
31
|
|
|
32
32
|
---
|
|
33
33
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|