blksprs 1.10.2__py3-none-any.whl → 2.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +2 -6
- blksprs/layouting/distribution_layout.py +32 -25
- blksprs/layouting/sparsity_layout.py +65 -52
- blksprs/ops/conversion.py +423 -374
- blksprs/ops/distribution.py +403 -335
- blksprs/ops/flow.py +135 -83
- blksprs/ops/matmul.py +221 -187
- blksprs/ops/misc/broadcast_ops.py +53 -35
- blksprs/ops/misc/row_wise.py +151 -91
- blksprs/ops/partitioning.py +136 -89
- blksprs/ops/repeat.py +115 -108
- blksprs/ops/softmax.py +244 -208
- blksprs/ops/transpose.py +69 -131
- blksprs/utils/benchmarking.py +3 -3
- blksprs/utils/tools.py +31 -4
- blksprs/utils/validation.py +0 -14
- {blksprs-1.10.2.dist-info → blksprs-2.0rc1.dist-info}/METADATA +42 -36
- blksprs-2.0rc1.dist-info/RECORD +22 -0
- {blksprs-1.10.2.dist-info → blksprs-2.0rc1.dist-info}/WHEEL +1 -1
- blksprs/ops/misc/exp.py +0 -104
- blksprs/utils/layout_utils.py +0 -17
- blksprs-1.10.2.dist-info/RECORD +0 -24
- {blksprs-1.10.2.dist-info → blksprs-2.0rc1.dist-info}/top_level.txt +0 -0
blksprs/ops/flow.py
CHANGED
|
@@ -1,25 +1,68 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import triton
|
|
3
3
|
from torch import Tensor
|
|
4
|
+
from torch._library import triton_op
|
|
5
|
+
from torch._library.triton import wrap_triton
|
|
4
6
|
from triton import language as tl
|
|
5
7
|
|
|
6
|
-
from blksprs.utils.tools import stride,
|
|
8
|
+
from blksprs.utils.tools import stride, get_autotune_configs
|
|
7
9
|
|
|
8
10
|
|
|
11
|
+
@triton_op("blksprs::flow_pull", mutates_args={})
|
|
12
|
+
def flow_pull_forward(x: Tensor, sparsity_layout_o: Tensor,
|
|
13
|
+
sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
14
|
+
sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
15
|
+
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
16
|
+
dtype=x.dtype, device=x.device)
|
|
17
|
+
|
|
18
|
+
x_b, x_r, x_c = x.size()
|
|
19
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
20
|
+
o_b, o_r, o_c = output.size()
|
|
21
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
22
|
+
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
|
|
23
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
|
|
24
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
25
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
26
|
+
|
|
27
|
+
triton_grid = lambda meta: [o_b,
|
|
28
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
29
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
30
|
+
|
|
31
|
+
(wrap_triton(flow_pull_kernel)[triton_grid]
|
|
32
|
+
(x,
|
|
33
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
34
|
+
output,
|
|
35
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
36
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
37
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
38
|
+
sparsity_reverse_lut,
|
|
39
|
+
sparsity_block_size))
|
|
40
|
+
|
|
41
|
+
return output
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@triton.autotune(
|
|
45
|
+
configs=get_autotune_configs(),
|
|
46
|
+
key=[],
|
|
47
|
+
)
|
|
9
48
|
@triton.jit
|
|
10
|
-
def
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
49
|
+
def flow_pull_kernel(x,
|
|
50
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
51
|
+
o,
|
|
52
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
53
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
54
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
55
|
+
r_lut,
|
|
56
|
+
sparsity_block_size,
|
|
57
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
18
58
|
# Get triton block indices
|
|
19
59
|
pid_blk = tl.program_id(axis=0)
|
|
20
60
|
pid_row = tl.program_id(axis=1)
|
|
21
61
|
pid_col = tl.program_id(axis=2)
|
|
22
62
|
|
|
63
|
+
# Get valid triton block size
|
|
64
|
+
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
65
|
+
|
|
23
66
|
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
24
67
|
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
25
68
|
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
@@ -40,37 +83,81 @@ def kernel_blocksparse_flow_pull(x,
|
|
|
40
83
|
rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
|
|
41
84
|
rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
42
85
|
|
|
43
|
-
if rev_idx_spa
|
|
44
|
-
|
|
45
|
-
|
|
86
|
+
if rev_idx_spa >= 0:
|
|
87
|
+
blk_x_idx = (rev_idx_spa * x_b_s +
|
|
88
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
89
|
+
((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
90
|
+
blk_x_msk = ((blk_x_idx >= 0 and
|
|
91
|
+
blk_x_idx < x_b * x_b_s) and
|
|
92
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
93
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
94
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
95
|
+
|
|
96
|
+
blk_o_idx = (pid_blk * o_b_s +
|
|
97
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
98
|
+
((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
99
|
+
blk_o_msk = ((blk_o_idx >= 0 and
|
|
100
|
+
blk_o_idx < o_b * o_b_s) and
|
|
101
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
102
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
103
|
+
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@triton_op("blksprs::flow_push", mutates_args={})
|
|
107
|
+
def flow_push_forward(x: Tensor, sparsity_layout_x: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
108
|
+
sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
109
|
+
output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
110
|
+
dtype=x.dtype, device=x.device)
|
|
46
111
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
112
|
+
x_b, x_r, x_c = x.size()
|
|
113
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
114
|
+
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
|
|
115
|
+
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
|
|
116
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
117
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
118
|
+
o_b, o_r, o_c = output.size()
|
|
119
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
120
|
+
|
|
121
|
+
triton_grid = lambda meta: [x_b,
|
|
122
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
123
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
124
|
+
|
|
125
|
+
(wrap_triton(flow_push_kernel)[triton_grid]
|
|
126
|
+
(x,
|
|
127
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
128
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
129
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
130
|
+
sparsity_reverse_lut,
|
|
131
|
+
output,
|
|
132
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
133
|
+
sparsity_block_size))
|
|
52
134
|
|
|
53
|
-
|
|
54
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
55
|
-
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
56
|
-
blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
|
|
57
|
-
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
135
|
+
return output
|
|
58
136
|
|
|
59
137
|
|
|
138
|
+
@triton.autotune(
|
|
139
|
+
configs=get_autotune_configs(),
|
|
140
|
+
key=[],
|
|
141
|
+
reset_to_zero=["o"]
|
|
142
|
+
)
|
|
60
143
|
@triton.jit
|
|
61
|
-
def
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
144
|
+
def flow_push_kernel(x,
|
|
145
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
146
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
147
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
148
|
+
r_lut,
|
|
149
|
+
o,
|
|
150
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
151
|
+
sparsity_block_size,
|
|
152
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
69
153
|
# Get triton block indices
|
|
70
154
|
pid_blk = tl.program_id(axis=0)
|
|
71
155
|
pid_row = tl.program_id(axis=1)
|
|
72
156
|
pid_col = tl.program_id(axis=2)
|
|
73
157
|
|
|
158
|
+
# Get valid triton block size
|
|
159
|
+
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
160
|
+
|
|
74
161
|
# Get sparsity index of current input block consisting of its batch, row, and column index
|
|
75
162
|
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
76
163
|
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
@@ -91,56 +178,21 @@ def kernel_blocksparse_flow_push(x,
|
|
|
91
178
|
rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_x_b * s_l_x_b_s)
|
|
92
179
|
rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
93
180
|
|
|
94
|
-
if rev_idx_spa
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
113
|
-
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
114
|
-
dtype=x.dtype, device=x.device)
|
|
115
|
-
|
|
116
|
-
x_b, x_r, x_c = x.size()
|
|
117
|
-
x_b_s, x_r_s, x_c_s = stride(x)
|
|
118
|
-
o_b, o_r, o_c = output.size()
|
|
119
|
-
o_b_s, o_r_s, o_c_s = stride(output)
|
|
120
|
-
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
|
|
121
|
-
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
|
|
122
|
-
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
123
|
-
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
124
|
-
|
|
125
|
-
if triton_block_size is None:
|
|
126
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
127
|
-
|
|
128
|
-
triton_grid = lambda meta: [o_b,
|
|
129
|
-
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
130
|
-
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
131
|
-
|
|
132
|
-
(kernel_blocksparse_flow_pull[triton_grid]
|
|
133
|
-
(x,
|
|
134
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
135
|
-
output,
|
|
136
|
-
o_b, o_b_s, o_r_s, o_c_s,
|
|
137
|
-
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
138
|
-
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
139
|
-
sparsity_reverse_lut,
|
|
140
|
-
triton_block_size))
|
|
141
|
-
|
|
142
|
-
# Save for backward pass
|
|
143
|
-
ctx.sparsity_block_size = sparsity_block_size
|
|
144
|
-
ctx.triton_block_size = triton_block_size
|
|
145
|
-
|
|
146
|
-
return output
|
|
181
|
+
if rev_idx_spa >= 0:
|
|
182
|
+
blk_x_idx = (pid_blk * x_b_s +
|
|
183
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
184
|
+
((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
185
|
+
blk_x_msk = ((blk_x_idx >= 0 and
|
|
186
|
+
blk_x_idx < x_b * x_b_s) and
|
|
187
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
188
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
189
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
190
|
+
|
|
191
|
+
blk_o_idx = (rev_idx_spa * o_b_s +
|
|
192
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
193
|
+
((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
194
|
+
blk_o_msk = ((blk_o_idx >= 0 and
|
|
195
|
+
blk_o_idx < o_b * o_b_s) and
|
|
196
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
197
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
198
|
+
tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
|