blksprs 1.11__py3-none-any.whl → 2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +4 -5
- blksprs/layouting/distribution_layout.py +64 -48
- blksprs/layouting/sparsity_layout.py +96 -72
- blksprs/ops/conversion.py +349 -338
- blksprs/ops/distribution.py +318 -294
- blksprs/ops/flow.py +113 -100
- blksprs/ops/matmul.py +187 -172
- blksprs/ops/misc/broadcast_ops.py +68 -53
- blksprs/ops/misc/row_wise.py +223 -176
- blksprs/ops/partitioning.py +140 -132
- blksprs/ops/repeat.py +118 -120
- blksprs/ops/softmax.py +240 -214
- blksprs/ops/transpose.py +55 -52
- blksprs/utils/autotuning.py +78 -0
- blksprs/utils/benchmarking.py +3 -3
- blksprs/utils/processing.py +2 -1
- blksprs/utils/tools.py +5 -6
- blksprs/utils/validation.py +22 -16
- {blksprs-1.11.dist-info → blksprs-2.0.dist-info}/METADATA +55 -36
- blksprs-2.0.dist-info/RECORD +23 -0
- {blksprs-1.11.dist-info → blksprs-2.0.dist-info}/WHEEL +1 -1
- blksprs/utils/layout_utils.py +0 -17
- blksprs-1.11.dist-info/RECORD +0 -23
- {blksprs-1.11.dist-info → blksprs-2.0.dist-info}/top_level.txt +0 -0
blksprs/ops/flow.py
CHANGED
|
@@ -1,20 +1,65 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import triton
|
|
3
3
|
from torch import Tensor
|
|
4
|
+
from torch._library import triton_op
|
|
5
|
+
from torch._library.triton import wrap_triton
|
|
4
6
|
from triton import language as tl
|
|
5
7
|
|
|
6
|
-
from blksprs.utils.tools import stride
|
|
7
|
-
|
|
8
|
-
|
|
8
|
+
from blksprs.utils.tools import stride
|
|
9
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@triton_op("blksprs::flow_pull_forward", mutates_args={})
|
|
13
|
+
def flow_pull_forward(x: Tensor, sparsity_layout_o: Tensor,
|
|
14
|
+
sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
15
|
+
sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
16
|
+
with torch.no_grad():
|
|
17
|
+
output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
18
|
+
dtype=x.dtype, device=x.device)
|
|
19
|
+
|
|
20
|
+
x_b, x_r, x_c = x.size()
|
|
21
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
22
|
+
o_b, o_r, o_c = output.size()
|
|
23
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
24
|
+
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
|
|
25
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
|
|
26
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
27
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
28
|
+
|
|
29
|
+
triton_grid = lambda meta: [o_b,
|
|
30
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
31
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
32
|
+
|
|
33
|
+
(wrap_triton(flow_pull_kernel)[triton_grid]
|
|
34
|
+
(x,
|
|
35
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
36
|
+
output,
|
|
37
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
38
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
39
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
40
|
+
sparsity_reverse_lut,
|
|
41
|
+
sparsity_block_size))
|
|
42
|
+
|
|
43
|
+
return output
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# noinspection PyUnusedLocal
|
|
47
|
+
@triton.autotune(
|
|
48
|
+
configs=get_autotune_configs(),
|
|
49
|
+
key=["sparsity_block_size"],
|
|
50
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
51
|
+
reset_to_zero=["o"]
|
|
52
|
+
)
|
|
9
53
|
@triton.jit
|
|
10
|
-
def
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
54
|
+
def flow_pull_kernel(x,
|
|
55
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
56
|
+
o,
|
|
57
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
58
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
59
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
60
|
+
r_lut,
|
|
61
|
+
sparsity_block_size,
|
|
62
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
18
63
|
# Get triton block indices
|
|
19
64
|
pid_blk = tl.program_id(axis=0)
|
|
20
65
|
pid_row = tl.program_id(axis=1)
|
|
@@ -44,25 +89,68 @@ def kernel_blocksparse_flow_pull(x,
|
|
|
44
89
|
blk_x_idx = (rev_idx_spa * x_b_s +
|
|
45
90
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
46
91
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
47
|
-
blk_x_msk = (blk_x_idx >= 0 and
|
|
92
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
93
|
+
blk_x_idx < x_b * x_b_s)
|
|
48
94
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
49
95
|
|
|
50
96
|
blk_o_idx = (pid_blk * o_b_s +
|
|
51
97
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
52
98
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
53
|
-
blk_o_msk = (blk_o_idx >= 0 and
|
|
99
|
+
blk_o_msk = (blk_o_idx >= 0 and
|
|
100
|
+
blk_o_idx < o_b * o_b_s)
|
|
54
101
|
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
55
102
|
|
|
56
103
|
|
|
104
|
+
@triton_op("blksprs::flow_push_forward", mutates_args={})
|
|
105
|
+
def flow_push_forward(x: Tensor, sparsity_layout_x: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
106
|
+
sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
107
|
+
with torch.no_grad():
|
|
108
|
+
output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
109
|
+
dtype=x.dtype, device=x.device)
|
|
110
|
+
|
|
111
|
+
x_b, x_r, x_c = x.size()
|
|
112
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
113
|
+
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
|
|
114
|
+
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
|
|
115
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
116
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
117
|
+
o_b, o_r, o_c = output.size()
|
|
118
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
119
|
+
|
|
120
|
+
triton_grid = lambda meta: [x_b,
|
|
121
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
122
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
123
|
+
|
|
124
|
+
(wrap_triton(flow_push_kernel)[triton_grid]
|
|
125
|
+
(x,
|
|
126
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
127
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
128
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
129
|
+
sparsity_reverse_lut,
|
|
130
|
+
output,
|
|
131
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
132
|
+
sparsity_block_size))
|
|
133
|
+
|
|
134
|
+
return output
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
# noinspection PyUnusedLocal
|
|
138
|
+
@triton.autotune(
|
|
139
|
+
configs=get_autotune_configs(),
|
|
140
|
+
key=["sparsity_block_size"],
|
|
141
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
142
|
+
reset_to_zero=["o"]
|
|
143
|
+
)
|
|
57
144
|
@triton.jit
|
|
58
|
-
def
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
145
|
+
def flow_push_kernel(x,
|
|
146
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
147
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
148
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
149
|
+
r_lut,
|
|
150
|
+
o,
|
|
151
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
152
|
+
sparsity_block_size,
|
|
153
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
66
154
|
# Get triton block indices
|
|
67
155
|
pid_blk = tl.program_id(axis=0)
|
|
68
156
|
pid_row = tl.program_id(axis=1)
|
|
@@ -92,88 +180,13 @@ def kernel_blocksparse_flow_push(x,
|
|
|
92
180
|
blk_x_idx = (pid_blk * x_b_s +
|
|
93
181
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
94
182
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
95
|
-
blk_x_msk = (blk_x_idx >= 0 and
|
|
183
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
184
|
+
blk_x_idx < x_b * x_b_s)
|
|
96
185
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
97
186
|
|
|
98
187
|
blk_o_idx = (rev_idx_spa * o_b_s +
|
|
99
188
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
100
189
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
101
|
-
blk_o_msk = (blk_o_idx >= 0 and
|
|
190
|
+
blk_o_msk = (blk_o_idx >= 0 and
|
|
191
|
+
blk_o_idx < o_b * o_b_s)
|
|
102
192
|
tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
def flow_forward_pull(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
106
|
-
sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
107
|
-
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
108
|
-
dtype=x.dtype, device=x.device)
|
|
109
|
-
|
|
110
|
-
x_b, x_r, x_c = x.size()
|
|
111
|
-
x_b_s, x_r_s, x_c_s = stride(x)
|
|
112
|
-
o_b, o_r, o_c = output.size()
|
|
113
|
-
o_b_s, o_r_s, o_c_s = stride(output)
|
|
114
|
-
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
|
|
115
|
-
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
|
|
116
|
-
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
117
|
-
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
118
|
-
|
|
119
|
-
if triton_block_size is None:
|
|
120
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
121
|
-
|
|
122
|
-
triton_grid = lambda meta: [o_b,
|
|
123
|
-
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
124
|
-
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
125
|
-
|
|
126
|
-
(kernel_blocksparse_flow_pull[triton_grid]
|
|
127
|
-
(x,
|
|
128
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
129
|
-
output,
|
|
130
|
-
o_b, o_b_s, o_r_s, o_c_s,
|
|
131
|
-
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
132
|
-
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
133
|
-
sparsity_reverse_lut,
|
|
134
|
-
triton_block_size))
|
|
135
|
-
|
|
136
|
-
# Save for backward pass
|
|
137
|
-
ctx.sparsity_block_size = sparsity_block_size
|
|
138
|
-
ctx.triton_block_size = triton_block_size
|
|
139
|
-
|
|
140
|
-
return output
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
def flow_forward_push(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
144
|
-
sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
145
|
-
output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
146
|
-
dtype=x.dtype, device=x.device)
|
|
147
|
-
|
|
148
|
-
x_b, x_r, x_c = x.size()
|
|
149
|
-
x_b_s, x_r_s, x_c_s = stride(x)
|
|
150
|
-
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
|
|
151
|
-
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
|
|
152
|
-
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
153
|
-
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
154
|
-
o_b, o_r, o_c = output.size()
|
|
155
|
-
o_b_s, o_r_s, o_c_s = stride(output)
|
|
156
|
-
|
|
157
|
-
if triton_block_size is None:
|
|
158
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
159
|
-
|
|
160
|
-
triton_grid = lambda meta: [x_b,
|
|
161
|
-
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
162
|
-
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
163
|
-
|
|
164
|
-
(kernel_blocksparse_flow_push[triton_grid]
|
|
165
|
-
(x,
|
|
166
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
167
|
-
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
168
|
-
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
169
|
-
sparsity_reverse_lut,
|
|
170
|
-
output,
|
|
171
|
-
o_b, o_b_s, o_r_s, o_c_s,
|
|
172
|
-
triton_block_size))
|
|
173
|
-
|
|
174
|
-
# Save for backward pass
|
|
175
|
-
if ctx is not None:
|
|
176
|
-
ctx.sparsity_block_size = sparsity_block_size
|
|
177
|
-
ctx.triton_block_size = triton_block_size
|
|
178
|
-
|
|
179
|
-
return output
|