blksprs 1.10.2__py3-none-any.whl → 2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +4 -6
- blksprs/layouting/distribution_layout.py +64 -48
- blksprs/layouting/sparsity_layout.py +96 -72
- blksprs/ops/conversion.py +350 -312
- blksprs/ops/distribution.py +320 -266
- blksprs/ops/flow.py +135 -89
- blksprs/ops/matmul.py +184 -151
- blksprs/ops/misc/broadcast_ops.py +68 -53
- blksprs/ops/misc/row_wise.py +223 -176
- blksprs/ops/partitioning.py +140 -89
- blksprs/ops/repeat.py +118 -108
- blksprs/ops/softmax.py +201 -167
- blksprs/ops/transpose.py +71 -131
- blksprs/utils/autotuning.py +78 -0
- blksprs/utils/benchmarking.py +3 -3
- blksprs/utils/processing.py +2 -1
- blksprs/utils/tools.py +5 -6
- blksprs/utils/validation.py +22 -16
- {blksprs-1.10.2.dist-info → blksprs-2.0.dist-info}/METADATA +55 -36
- blksprs-2.0.dist-info/RECORD +23 -0
- {blksprs-1.10.2.dist-info → blksprs-2.0.dist-info}/WHEEL +1 -1
- blksprs/ops/misc/exp.py +0 -104
- blksprs/utils/layout_utils.py +0 -17
- blksprs-1.10.2.dist-info/RECORD +0 -24
- {blksprs-1.10.2.dist-info → blksprs-2.0.dist-info}/top_level.txt +0 -0
|
@@ -1,16 +1,20 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import triton
|
|
3
3
|
from torch import Tensor
|
|
4
|
+
from torch._library import triton_op
|
|
5
|
+
from torch._library.triton import wrap_triton
|
|
4
6
|
from triton import language as tl
|
|
5
7
|
|
|
6
8
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
7
|
-
from blksprs.utils.tools import
|
|
9
|
+
from blksprs.utils.tools import stride
|
|
10
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
8
11
|
from blksprs.utils.validation import validate_contiguous, validate_device, \
|
|
9
|
-
validate_sparsity_block_size
|
|
12
|
+
validate_sparsity_block_size
|
|
10
13
|
|
|
11
14
|
|
|
15
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
12
16
|
def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
13
|
-
sparsity_block_size: int
|
|
17
|
+
sparsity_block_size: int) -> BlksprsTensor:
|
|
14
18
|
"""Performs a broadcast and subsequent addition of two dense tensors x and y. Returns a block-sparse tensor in
|
|
15
19
|
compressed form.
|
|
16
20
|
|
|
@@ -19,7 +23,6 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
|
19
23
|
y (Tensor): A dense input tensor.
|
|
20
24
|
sparsity_layout_output (Tensor): The sparsity layout of the output tensor.
|
|
21
25
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
22
|
-
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
23
26
|
|
|
24
27
|
Returns:
|
|
25
28
|
BlksprsTensor: The result of the operation as a block-sparse tensor in compressed form. Each element o(i, j) of the
|
|
@@ -34,7 +37,6 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
|
34
37
|
if x.size(-1) != y.size(-1):
|
|
35
38
|
raise ValueError("Dimensions of tensors must match")
|
|
36
39
|
validate_sparsity_block_size(sparsity_block_size)
|
|
37
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
38
40
|
|
|
39
41
|
sparsity_lut_o = torch.nonzero(sparsity_layout_output).contiguous()
|
|
40
42
|
|
|
@@ -42,56 +44,66 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
|
42
44
|
|
|
43
45
|
validate_contiguous(sparsity_layout_output, sparsity_lut_o)
|
|
44
46
|
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
x_b, x_c = x.size()
|
|
48
|
-
x_b_s, x_c_s = stride(x)
|
|
49
|
-
y_b, y_c = y.size()
|
|
50
|
-
y_b_s, y_c_s = stride(y)
|
|
51
|
-
o_b, o_r, o_c = output.size()
|
|
52
|
-
o_b_s, o_r_s, o_c_s = stride(output)
|
|
53
|
-
s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
|
|
54
|
-
s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_o)
|
|
55
|
-
|
|
56
|
-
if triton_block_size is None:
|
|
57
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
58
|
-
|
|
59
|
-
triton_grid = lambda meta: [o_b,
|
|
60
|
-
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
61
|
-
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
62
|
-
|
|
63
|
-
(kernel_broadcast_addition[triton_grid]
|
|
64
|
-
(x,
|
|
65
|
-
x_b, x_b_s, x_c_s,
|
|
66
|
-
y,
|
|
67
|
-
y_b, y_b_s, y_c_s,
|
|
68
|
-
output,
|
|
69
|
-
o_b, o_b_s, o_r_s, o_c_s,
|
|
70
|
-
sparsity_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
71
|
-
sparsity_block_size,
|
|
72
|
-
triton_block_size))
|
|
73
|
-
|
|
74
|
-
return BlksprsTensor(output)
|
|
47
|
+
return BlksprsTensor(broadcast_add_forward(x, y, sparsity_lut_o, sparsity_block_size, n_sparse_blocks))
|
|
75
48
|
|
|
76
49
|
|
|
77
50
|
def broadcast_sub(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
78
|
-
sparsity_block_size: int
|
|
51
|
+
sparsity_block_size: int) -> BlksprsTensor:
|
|
79
52
|
"""Wrapper for ``broadcast_add`` with negated y.
|
|
80
53
|
|
|
81
54
|
"""
|
|
82
|
-
return broadcast_add(x, torch.neg(y), sparsity_layout_output, sparsity_block_size
|
|
83
|
-
|
|
84
|
-
|
|
55
|
+
return broadcast_add(x, torch.neg(y), sparsity_layout_output, sparsity_block_size)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@triton_op("blksprs::broadcast_add_forward", mutates_args={})
|
|
59
|
+
def broadcast_add_forward(x: Tensor, y: Tensor,
|
|
60
|
+
sparsity_lut_o: Tensor,
|
|
61
|
+
sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
62
|
+
with torch.no_grad():
|
|
63
|
+
output = torch.zeros(n_sparse_blocks, sparsity_block_size, sparsity_block_size, dtype=x.dtype, device=x.device)
|
|
64
|
+
|
|
65
|
+
x_b, x_c = x.size()
|
|
66
|
+
x_b_s, x_c_s = stride(x)
|
|
67
|
+
y_b, y_c = y.size()
|
|
68
|
+
y_b_s, y_c_s = stride(y)
|
|
69
|
+
o_b, o_r, o_c = output.size()
|
|
70
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
71
|
+
s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
|
|
72
|
+
s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_o)
|
|
73
|
+
|
|
74
|
+
triton_grid = lambda meta: [o_b,
|
|
75
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
76
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
77
|
+
|
|
78
|
+
(wrap_triton(broadcast_add_kernel)[triton_grid]
|
|
79
|
+
(x,
|
|
80
|
+
x_b, x_b_s, x_c_s,
|
|
81
|
+
y,
|
|
82
|
+
y_b, y_b_s, y_c_s,
|
|
83
|
+
output,
|
|
84
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
85
|
+
sparsity_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
86
|
+
sparsity_block_size))
|
|
87
|
+
|
|
88
|
+
return output
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@triton.autotune(
|
|
92
|
+
configs=get_autotune_configs(),
|
|
93
|
+
key=["sparsity_block_size"],
|
|
94
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
95
|
+
reset_to_zero=["o"]
|
|
96
|
+
)
|
|
85
97
|
@triton.jit
|
|
86
|
-
def
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
98
|
+
def broadcast_add_kernel(x,
|
|
99
|
+
x_b, x_b_s, x_c_s,
|
|
100
|
+
y,
|
|
101
|
+
y_b, y_b_s, y_c_s,
|
|
102
|
+
o,
|
|
103
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
104
|
+
s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
105
|
+
sparsity_block_size,
|
|
106
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
95
107
|
# Get triton block indices
|
|
96
108
|
pid_blk = tl.program_id(axis=0)
|
|
97
109
|
pid_row = tl.program_id(axis=1)
|
|
@@ -112,16 +124,18 @@ def kernel_broadcast_addition(x,
|
|
|
112
124
|
|
|
113
125
|
# Load x block
|
|
114
126
|
blk_x_idx = (spa_bat_o * x_b_s +
|
|
115
|
-
((
|
|
127
|
+
((pid_row * TRITON_BLOCK_SIZE + spa_row_o * sparsity_block_size +
|
|
116
128
|
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
117
|
-
blk_x_msk = (blk_x_idx >= 0 and
|
|
129
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
130
|
+
blk_x_idx < x_b * x_b_s)
|
|
118
131
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
119
132
|
|
|
120
133
|
# Load y block
|
|
121
134
|
blk_y_idx = (spa_bat_o * y_b_s +
|
|
122
|
-
((
|
|
135
|
+
((pid_col * TRITON_BLOCK_SIZE + spa_col_o * sparsity_block_size +
|
|
123
136
|
tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
|
|
124
|
-
blk_y_msk = (blk_y_idx >= 0 and
|
|
137
|
+
blk_y_msk = (blk_y_idx >= 0 and
|
|
138
|
+
blk_y_idx < y_b * y_b_s)
|
|
125
139
|
blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
|
|
126
140
|
|
|
127
141
|
# Compute sum
|
|
@@ -132,5 +146,6 @@ def kernel_broadcast_addition(x,
|
|
|
132
146
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
133
147
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
134
148
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
135
|
-
blk_o_msk = (blk_o_idx >= 0 and
|
|
149
|
+
blk_o_msk = (blk_o_idx >= 0 and
|
|
150
|
+
blk_o_idx < o_b * o_b_s)
|
|
136
151
|
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|