blksprs 1.11__py3-none-any.whl → 2.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +2 -5
- blksprs/layouting/distribution_layout.py +32 -25
- blksprs/layouting/sparsity_layout.py +65 -52
- blksprs/ops/conversion.py +421 -399
- blksprs/ops/distribution.py +404 -366
- blksprs/ops/flow.py +125 -106
- blksprs/ops/matmul.py +220 -204
- blksprs/ops/misc/broadcast_ops.py +53 -35
- blksprs/ops/misc/row_wise.py +151 -91
- blksprs/ops/partitioning.py +136 -132
- blksprs/ops/repeat.py +115 -120
- blksprs/ops/softmax.py +274 -246
- blksprs/ops/transpose.py +52 -51
- blksprs/utils/benchmarking.py +3 -3
- blksprs/utils/tools.py +31 -4
- blksprs/utils/validation.py +0 -14
- {blksprs-1.11.dist-info → blksprs-2.0rc1.dist-info}/METADATA +42 -36
- blksprs-2.0rc1.dist-info/RECORD +22 -0
- {blksprs-1.11.dist-info → blksprs-2.0rc1.dist-info}/WHEEL +1 -1
- blksprs/utils/layout_utils.py +0 -17
- blksprs-1.11.dist-info/RECORD +0 -23
- {blksprs-1.11.dist-info → blksprs-2.0rc1.dist-info}/top_level.txt +0 -0
blksprs/ops/flow.py
CHANGED
|
@@ -1,25 +1,68 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import triton
|
|
3
3
|
from torch import Tensor
|
|
4
|
+
from torch._library import triton_op
|
|
5
|
+
from torch._library.triton import wrap_triton
|
|
4
6
|
from triton import language as tl
|
|
5
7
|
|
|
6
|
-
from blksprs.utils.tools import stride,
|
|
8
|
+
from blksprs.utils.tools import stride, get_autotune_configs
|
|
7
9
|
|
|
8
10
|
|
|
11
|
+
@triton_op("blksprs::flow_pull", mutates_args={})
|
|
12
|
+
def flow_pull_forward(x: Tensor, sparsity_layout_o: Tensor,
|
|
13
|
+
sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
14
|
+
sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
15
|
+
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
16
|
+
dtype=x.dtype, device=x.device)
|
|
17
|
+
|
|
18
|
+
x_b, x_r, x_c = x.size()
|
|
19
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
20
|
+
o_b, o_r, o_c = output.size()
|
|
21
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
22
|
+
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
|
|
23
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
|
|
24
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
25
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
26
|
+
|
|
27
|
+
triton_grid = lambda meta: [o_b,
|
|
28
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
29
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
30
|
+
|
|
31
|
+
(wrap_triton(flow_pull_kernel)[triton_grid]
|
|
32
|
+
(x,
|
|
33
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
34
|
+
output,
|
|
35
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
36
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
37
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
38
|
+
sparsity_reverse_lut,
|
|
39
|
+
sparsity_block_size))
|
|
40
|
+
|
|
41
|
+
return output
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@triton.autotune(
|
|
45
|
+
configs=get_autotune_configs(),
|
|
46
|
+
key=[],
|
|
47
|
+
)
|
|
9
48
|
@triton.jit
|
|
10
|
-
def
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
49
|
+
def flow_pull_kernel(x,
|
|
50
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
51
|
+
o,
|
|
52
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
53
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
54
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
55
|
+
r_lut,
|
|
56
|
+
sparsity_block_size,
|
|
57
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
18
58
|
# Get triton block indices
|
|
19
59
|
pid_blk = tl.program_id(axis=0)
|
|
20
60
|
pid_row = tl.program_id(axis=1)
|
|
21
61
|
pid_col = tl.program_id(axis=2)
|
|
22
62
|
|
|
63
|
+
# Get valid triton block size
|
|
64
|
+
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
65
|
+
|
|
23
66
|
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
24
67
|
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
25
68
|
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
@@ -42,32 +85,79 @@ def kernel_blocksparse_flow_pull(x,
|
|
|
42
85
|
|
|
43
86
|
if rev_idx_spa >= 0:
|
|
44
87
|
blk_x_idx = (rev_idx_spa * x_b_s +
|
|
45
|
-
((pid_row *
|
|
46
|
-
((pid_col *
|
|
47
|
-
blk_x_msk = (blk_x_idx >= 0 and
|
|
88
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
89
|
+
((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
90
|
+
blk_x_msk = ((blk_x_idx >= 0 and
|
|
91
|
+
blk_x_idx < x_b * x_b_s) and
|
|
92
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
93
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
48
94
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
49
95
|
|
|
50
96
|
blk_o_idx = (pid_blk * o_b_s +
|
|
51
|
-
((pid_row *
|
|
52
|
-
((pid_col *
|
|
53
|
-
blk_o_msk = (blk_o_idx >= 0 and
|
|
97
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
98
|
+
((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
99
|
+
blk_o_msk = ((blk_o_idx >= 0 and
|
|
100
|
+
blk_o_idx < o_b * o_b_s) and
|
|
101
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
102
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
54
103
|
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
55
104
|
|
|
56
105
|
|
|
106
|
+
@triton_op("blksprs::flow_push", mutates_args={})
|
|
107
|
+
def flow_push_forward(x: Tensor, sparsity_layout_x: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
108
|
+
sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
109
|
+
output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
110
|
+
dtype=x.dtype, device=x.device)
|
|
111
|
+
|
|
112
|
+
x_b, x_r, x_c = x.size()
|
|
113
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
114
|
+
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
|
|
115
|
+
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
|
|
116
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
117
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
118
|
+
o_b, o_r, o_c = output.size()
|
|
119
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
120
|
+
|
|
121
|
+
triton_grid = lambda meta: [x_b,
|
|
122
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
123
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
124
|
+
|
|
125
|
+
(wrap_triton(flow_push_kernel)[triton_grid]
|
|
126
|
+
(x,
|
|
127
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
128
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
129
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
130
|
+
sparsity_reverse_lut,
|
|
131
|
+
output,
|
|
132
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
133
|
+
sparsity_block_size))
|
|
134
|
+
|
|
135
|
+
return output
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
@triton.autotune(
|
|
139
|
+
configs=get_autotune_configs(),
|
|
140
|
+
key=[],
|
|
141
|
+
reset_to_zero=["o"]
|
|
142
|
+
)
|
|
57
143
|
@triton.jit
|
|
58
|
-
def
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
144
|
+
def flow_push_kernel(x,
|
|
145
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
146
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
147
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
148
|
+
r_lut,
|
|
149
|
+
o,
|
|
150
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
151
|
+
sparsity_block_size,
|
|
152
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
66
153
|
# Get triton block indices
|
|
67
154
|
pid_blk = tl.program_id(axis=0)
|
|
68
155
|
pid_row = tl.program_id(axis=1)
|
|
69
156
|
pid_col = tl.program_id(axis=2)
|
|
70
157
|
|
|
158
|
+
# Get valid triton block size
|
|
159
|
+
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
160
|
+
|
|
71
161
|
# Get sparsity index of current input block consisting of its batch, row, and column index
|
|
72
162
|
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
73
163
|
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
@@ -90,90 +180,19 @@ def kernel_blocksparse_flow_push(x,
|
|
|
90
180
|
|
|
91
181
|
if rev_idx_spa >= 0:
|
|
92
182
|
blk_x_idx = (pid_blk * x_b_s +
|
|
93
|
-
((pid_row *
|
|
94
|
-
((pid_col *
|
|
95
|
-
blk_x_msk = (blk_x_idx >= 0 and
|
|
183
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
184
|
+
((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
185
|
+
blk_x_msk = ((blk_x_idx >= 0 and
|
|
186
|
+
blk_x_idx < x_b * x_b_s) and
|
|
187
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
188
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
96
189
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
97
190
|
|
|
98
191
|
blk_o_idx = (rev_idx_spa * o_b_s +
|
|
99
|
-
((pid_row *
|
|
100
|
-
((pid_col *
|
|
101
|
-
blk_o_msk = (blk_o_idx >= 0 and
|
|
192
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
193
|
+
((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
194
|
+
blk_o_msk = ((blk_o_idx >= 0 and
|
|
195
|
+
blk_o_idx < o_b * o_b_s) and
|
|
196
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
197
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
102
198
|
tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
def flow_forward_pull(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
106
|
-
sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
107
|
-
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
108
|
-
dtype=x.dtype, device=x.device)
|
|
109
|
-
|
|
110
|
-
x_b, x_r, x_c = x.size()
|
|
111
|
-
x_b_s, x_r_s, x_c_s = stride(x)
|
|
112
|
-
o_b, o_r, o_c = output.size()
|
|
113
|
-
o_b_s, o_r_s, o_c_s = stride(output)
|
|
114
|
-
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
|
|
115
|
-
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
|
|
116
|
-
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
117
|
-
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
118
|
-
|
|
119
|
-
if triton_block_size is None:
|
|
120
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
121
|
-
|
|
122
|
-
triton_grid = lambda meta: [o_b,
|
|
123
|
-
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
124
|
-
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
125
|
-
|
|
126
|
-
(kernel_blocksparse_flow_pull[triton_grid]
|
|
127
|
-
(x,
|
|
128
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
129
|
-
output,
|
|
130
|
-
o_b, o_b_s, o_r_s, o_c_s,
|
|
131
|
-
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
132
|
-
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
133
|
-
sparsity_reverse_lut,
|
|
134
|
-
triton_block_size))
|
|
135
|
-
|
|
136
|
-
# Save for backward pass
|
|
137
|
-
ctx.sparsity_block_size = sparsity_block_size
|
|
138
|
-
ctx.triton_block_size = triton_block_size
|
|
139
|
-
|
|
140
|
-
return output
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
def flow_forward_push(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
144
|
-
sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
145
|
-
output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
146
|
-
dtype=x.dtype, device=x.device)
|
|
147
|
-
|
|
148
|
-
x_b, x_r, x_c = x.size()
|
|
149
|
-
x_b_s, x_r_s, x_c_s = stride(x)
|
|
150
|
-
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
|
|
151
|
-
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
|
|
152
|
-
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
153
|
-
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
154
|
-
o_b, o_r, o_c = output.size()
|
|
155
|
-
o_b_s, o_r_s, o_c_s = stride(output)
|
|
156
|
-
|
|
157
|
-
if triton_block_size is None:
|
|
158
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
159
|
-
|
|
160
|
-
triton_grid = lambda meta: [x_b,
|
|
161
|
-
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
162
|
-
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
163
|
-
|
|
164
|
-
(kernel_blocksparse_flow_push[triton_grid]
|
|
165
|
-
(x,
|
|
166
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
167
|
-
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
168
|
-
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
169
|
-
sparsity_reverse_lut,
|
|
170
|
-
output,
|
|
171
|
-
o_b, o_b_s, o_r_s, o_c_s,
|
|
172
|
-
triton_block_size))
|
|
173
|
-
|
|
174
|
-
# Save for backward pass
|
|
175
|
-
if ctx is not None:
|
|
176
|
-
ctx.sparsity_block_size = sparsity_block_size
|
|
177
|
-
ctx.triton_block_size = triton_block_size
|
|
178
|
-
|
|
179
|
-
return output
|