blksprs 2.0rc4__py3-none-any.whl → 2.0rc6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/layouting/distribution_layout.py +11 -15
- blksprs/layouting/sparsity_layout.py +26 -31
- blksprs/ops/conversion.py +45 -63
- blksprs/ops/distribution.py +37 -56
- blksprs/ops/flow.py +22 -33
- blksprs/ops/matmul.py +19 -20
- blksprs/ops/misc/broadcast_ops.py +14 -19
- blksprs/ops/misc/row_wise.py +35 -54
- blksprs/ops/softmax.py +30 -44
- blksprs/utils/autotuning.py +78 -0
- blksprs/utils/tools.py +6 -25
- blksprs/utils/validation.py +3 -0
- {blksprs-2.0rc4.dist-info → blksprs-2.0rc6.dist-info}/METADATA +14 -5
- blksprs-2.0rc6.dist-info/RECORD +23 -0
- blksprs-2.0rc4.dist-info/RECORD +0 -22
- {blksprs-2.0rc4.dist-info → blksprs-2.0rc6.dist-info}/WHEEL +0 -0
- {blksprs-2.0rc4.dist-info → blksprs-2.0rc6.dist-info}/top_level.txt +0 -0
|
@@ -4,7 +4,8 @@ from torch import Tensor
|
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
6
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
7
|
-
from blksprs.utils.tools import stride
|
|
7
|
+
from blksprs.utils.tools import stride
|
|
8
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
8
9
|
from blksprs.utils.validation import validate_dimensions, validate_device, \
|
|
9
10
|
validate_contiguous
|
|
10
11
|
|
|
@@ -47,6 +48,7 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
|
|
|
47
48
|
triton.cdiv(i_r, meta["TRITON_BLOCK_SIZE"]),
|
|
48
49
|
triton.cdiv(i_c, meta["TRITON_BLOCK_SIZE"])]
|
|
49
50
|
|
|
51
|
+
# TODO wrap
|
|
50
52
|
(build_distribution_layout_kernel[triton_grid]
|
|
51
53
|
(indices,
|
|
52
54
|
i_b, i_b_s, i_r_s, i_c_s,
|
|
@@ -62,7 +64,8 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
|
|
|
62
64
|
|
|
63
65
|
@triton.autotune(
|
|
64
66
|
configs=get_autotune_configs(),
|
|
65
|
-
key=[],
|
|
67
|
+
key=["sparsity_block_size"],
|
|
68
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
66
69
|
reset_to_zero=["o"]
|
|
67
70
|
)
|
|
68
71
|
@triton.jit
|
|
@@ -80,9 +83,6 @@ def build_distribution_layout_kernel(i,
|
|
|
80
83
|
pid_row = tl.program_id(axis=1)
|
|
81
84
|
pid_col = tl.program_id(axis=2)
|
|
82
85
|
|
|
83
|
-
# Get valid triton block size
|
|
84
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
85
|
-
|
|
86
86
|
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
87
87
|
spa_bat_i_idx = (pid_blk * s_lut_i_r_s + 0 * s_lut_i_c_s)
|
|
88
88
|
spa_bat_i_msk = (spa_bat_i_idx >= 0 and spa_bat_i_idx < s_lut_i_r * s_lut_i_r_s)
|
|
@@ -97,12 +97,10 @@ def build_distribution_layout_kernel(i,
|
|
|
97
97
|
spa_col_i = tl.load(s_lut_i + spa_col_i_idx, mask=spa_col_i_msk)
|
|
98
98
|
|
|
99
99
|
blk_i_idx = (pid_blk * i_b_s +
|
|
100
|
-
((pid_row *
|
|
101
|
-
((pid_col *
|
|
102
|
-
blk_i_msk = (
|
|
103
|
-
|
|
104
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
105
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
100
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
101
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
102
|
+
blk_i_msk = (blk_i_idx >= 0 and
|
|
103
|
+
blk_i_idx < i_b * i_b_s)
|
|
106
104
|
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk)
|
|
107
105
|
|
|
108
106
|
dst_bat_idx = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_i, dtype=tl.int32)
|
|
@@ -120,8 +118,6 @@ def build_distribution_layout_kernel(i,
|
|
|
120
118
|
blk_o_idx = ((dst_bat_idx * o_b_s) +
|
|
121
119
|
(dst_row_idx * o_r_s) +
|
|
122
120
|
(dst_col_idx * o_c_s))
|
|
123
|
-
blk_o_msk = (
|
|
124
|
-
|
|
125
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
126
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
121
|
+
blk_o_msk = (blk_o_idx >= 0 and
|
|
122
|
+
blk_o_idx < o_b * o_b_s)
|
|
127
123
|
tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
|
|
@@ -7,7 +7,8 @@ from torch._library.triton import wrap_triton
|
|
|
7
7
|
from triton import language as tl
|
|
8
8
|
|
|
9
9
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
10
|
-
from blksprs.utils.tools import stride
|
|
10
|
+
from blksprs.utils.tools import stride
|
|
11
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs, prune_autotune_configs_conversion
|
|
11
12
|
from blksprs.utils.validation import validate_dimensions, validate_device, \
|
|
12
13
|
validate_contiguous, validate_sparsity, validate_sparsity_block_size
|
|
13
14
|
|
|
@@ -37,10 +38,11 @@ def build_sparsity_layout(x: Tensor, sparsity_block_size: int) -> Tensor:
|
|
|
37
38
|
o_b_s, o_r_s, o_c_s = stride(output)
|
|
38
39
|
|
|
39
40
|
triton_grid = lambda meta: [x_b,
|
|
40
|
-
triton.cdiv(x_r,
|
|
41
|
-
triton.cdiv(x_c,
|
|
41
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
42
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
42
43
|
|
|
43
|
-
|
|
44
|
+
# TODO wrap
|
|
45
|
+
(build_sparsity_layout_kernel[triton_grid]
|
|
44
46
|
(x,
|
|
45
47
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
46
48
|
output,
|
|
@@ -52,7 +54,8 @@ def build_sparsity_layout(x: Tensor, sparsity_block_size: int) -> Tensor:
|
|
|
52
54
|
|
|
53
55
|
@triton.autotune(
|
|
54
56
|
configs=get_autotune_configs(),
|
|
55
|
-
key=[],
|
|
57
|
+
key=["sparsity_block_size"],
|
|
58
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
56
59
|
reset_to_zero=["o"]
|
|
57
60
|
)
|
|
58
61
|
@triton.jit
|
|
@@ -67,24 +70,19 @@ def build_sparsity_layout_kernel(x,
|
|
|
67
70
|
pid_row = tl.program_id(axis=1)
|
|
68
71
|
pid_col = tl.program_id(axis=2)
|
|
69
72
|
|
|
70
|
-
# Get valid triton block size
|
|
71
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
72
|
-
|
|
73
73
|
# Load x values
|
|
74
74
|
blk_x_idx = (pid_bat * x_b_s +
|
|
75
|
-
((pid_row *
|
|
76
|
-
((pid_col *
|
|
77
|
-
blk_x_msk = (
|
|
78
|
-
|
|
79
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
80
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
75
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
76
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
77
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
78
|
+
blk_x_idx < x_b * x_b_s)
|
|
81
79
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
82
80
|
|
|
83
81
|
# Store sparsity layout value
|
|
84
82
|
if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
|
|
85
83
|
blk_o_idx = (pid_bat * o_b_s +
|
|
86
|
-
(((pid_row *
|
|
87
|
-
((pid_col *
|
|
84
|
+
(((pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_r_s +
|
|
85
|
+
((pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_c_s))
|
|
88
86
|
blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
|
|
89
87
|
tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
|
|
90
88
|
|
|
@@ -129,10 +127,11 @@ def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tenso
|
|
|
129
127
|
o_b_s, o_r_s, o_c_s = stride(output)
|
|
130
128
|
|
|
131
129
|
triton_grid = lambda meta: [x_b,
|
|
132
|
-
triton.cdiv(x_r,
|
|
133
|
-
triton.cdiv(x_c,
|
|
130
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
131
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
134
132
|
|
|
135
|
-
|
|
133
|
+
# TODO wrap
|
|
134
|
+
(build_sparsity_layout_adaption_kernel[triton_grid]
|
|
136
135
|
(x,
|
|
137
136
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
138
137
|
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
@@ -146,7 +145,8 @@ def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tenso
|
|
|
146
145
|
|
|
147
146
|
@triton.autotune(
|
|
148
147
|
configs=get_autotune_configs(),
|
|
149
|
-
key=[],
|
|
148
|
+
key=["sparsity_block_size_from", "sparsity_block_size_to"],
|
|
149
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs_conversion},
|
|
150
150
|
reset_to_zero=["o"]
|
|
151
151
|
)
|
|
152
152
|
@triton.jit
|
|
@@ -163,9 +163,6 @@ def build_sparsity_layout_adaption_kernel(x,
|
|
|
163
163
|
pid_row = tl.program_id(axis=1)
|
|
164
164
|
pid_col = tl.program_id(axis=2)
|
|
165
165
|
|
|
166
|
-
# Get valid triton block size
|
|
167
|
-
val_tbs = min(sparsity_block_size_to, TRITON_BLOCK_SIZE)
|
|
168
|
-
|
|
169
166
|
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
170
167
|
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
171
168
|
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
@@ -181,20 +178,18 @@ def build_sparsity_layout_adaption_kernel(x,
|
|
|
181
178
|
|
|
182
179
|
# Load x values
|
|
183
180
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
184
|
-
((pid_row *
|
|
185
|
-
((pid_col *
|
|
186
|
-
blk_x_msk = (
|
|
187
|
-
|
|
188
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < sparsity_block_size_from and
|
|
189
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < sparsity_block_size_from))
|
|
181
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
182
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
183
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
184
|
+
blk_x_idx < x_b * x_b_s)
|
|
190
185
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
191
186
|
|
|
192
187
|
# Store sparsity layout value
|
|
193
188
|
if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
|
|
194
189
|
blk_o_idx = ((spa_bat * o_b_s) +
|
|
195
|
-
(((pid_row *
|
|
190
|
+
(((pid_row * TRITON_BLOCK_SIZE + spa_row * sparsity_block_size_from)
|
|
196
191
|
// sparsity_block_size_to) * o_r_s) +
|
|
197
|
-
(((pid_col *
|
|
192
|
+
(((pid_col * TRITON_BLOCK_SIZE + spa_col * sparsity_block_size_from)
|
|
198
193
|
// sparsity_block_size_to) * o_c_s))
|
|
199
194
|
blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
|
|
200
195
|
tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
|
blksprs/ops/conversion.py
CHANGED
|
@@ -6,7 +6,8 @@ from triton import language as tl
|
|
|
6
6
|
|
|
7
7
|
from blksprs.layouting.sparsity_layout import build_sparsity_layout_adaption
|
|
8
8
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
9
|
-
from blksprs.utils.tools import stride
|
|
9
|
+
from blksprs.utils.tools import stride
|
|
10
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs, prune_autotune_configs_conversion
|
|
10
11
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
11
12
|
validate_sparsity, validate_sparsity_block_size, validate_sparsity_dense
|
|
12
13
|
|
|
@@ -86,7 +87,8 @@ def to_sparse_backward(ctx, grad_output):
|
|
|
86
87
|
|
|
87
88
|
@triton.autotune(
|
|
88
89
|
configs=get_autotune_configs(),
|
|
89
|
-
key=[],
|
|
90
|
+
key=["sparsity_block_size"],
|
|
91
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
90
92
|
reset_to_zero=["o"]
|
|
91
93
|
)
|
|
92
94
|
@triton.jit
|
|
@@ -102,9 +104,6 @@ def to_sparse_kernel(x,
|
|
|
102
104
|
pid_row = tl.program_id(axis=1)
|
|
103
105
|
pid_col = tl.program_id(axis=2)
|
|
104
106
|
|
|
105
|
-
# Get valid triton block size
|
|
106
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
107
|
-
|
|
108
107
|
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
109
108
|
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
110
109
|
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
@@ -120,24 +119,20 @@ def to_sparse_kernel(x,
|
|
|
120
119
|
|
|
121
120
|
# Load block from dense tensor
|
|
122
121
|
blk_d_idx = (spa_bat * x_b_s +
|
|
123
|
-
((pid_row *
|
|
122
|
+
((pid_row * TRITON_BLOCK_SIZE + spa_row * sparsity_block_size +
|
|
124
123
|
tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
125
|
-
((pid_col *
|
|
124
|
+
((pid_col * TRITON_BLOCK_SIZE + spa_col * sparsity_block_size +
|
|
126
125
|
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
127
|
-
blk_d_msk = (
|
|
128
|
-
|
|
129
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
130
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
126
|
+
blk_d_msk = (blk_d_idx >= 0 and
|
|
127
|
+
blk_d_idx < x_b * x_b_s)
|
|
131
128
|
blk_d = tl.load(x + blk_d_idx, mask=blk_d_msk)
|
|
132
129
|
|
|
133
130
|
# Store block in sparse tensor
|
|
134
131
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
135
|
-
((pid_row *
|
|
136
|
-
((pid_col *
|
|
137
|
-
blk_o_msk = (
|
|
138
|
-
|
|
139
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
140
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
132
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
133
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
|
|
134
|
+
blk_o_msk = (blk_o_idx >= 0 and
|
|
135
|
+
blk_o_idx < (pid_blk + 1) * o_b_s)
|
|
141
136
|
tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
|
|
142
137
|
|
|
143
138
|
|
|
@@ -228,8 +223,8 @@ def to_dense_forward(x: Tensor, sparsity_layout: Tensor,
|
|
|
228
223
|
o_b_s, o_r_s, o_c_s = stride(output)
|
|
229
224
|
|
|
230
225
|
triton_grid = lambda meta: [o_b,
|
|
231
|
-
triton.cdiv(o_r,
|
|
232
|
-
triton.cdiv(o_c,
|
|
226
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
227
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
233
228
|
|
|
234
229
|
(wrap_triton(to_dense_kernel)[triton_grid]
|
|
235
230
|
(x,
|
|
@@ -252,7 +247,8 @@ def to_dense_backward(ctx, grad_output):
|
|
|
252
247
|
|
|
253
248
|
@triton.autotune(
|
|
254
249
|
configs=get_autotune_configs(),
|
|
255
|
-
key=[],
|
|
250
|
+
key=["sparsity_block_size"],
|
|
251
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
256
252
|
restore_value=["o"]
|
|
257
253
|
)
|
|
258
254
|
@triton.jit
|
|
@@ -269,12 +265,9 @@ def to_dense_kernel(x,
|
|
|
269
265
|
pid_row = tl.program_id(axis=1)
|
|
270
266
|
pid_col = tl.program_id(axis=2)
|
|
271
267
|
|
|
272
|
-
# Get valid triton block size
|
|
273
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
274
|
-
|
|
275
268
|
# Get sparsity index of current block
|
|
276
|
-
spa_row = (pid_row *
|
|
277
|
-
spa_col = (pid_col *
|
|
269
|
+
spa_row = (pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size
|
|
270
|
+
spa_col = (pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size
|
|
278
271
|
|
|
279
272
|
# Get reverse sparsity index for current block
|
|
280
273
|
rev_idx_spa_idx = (pid_blk * s_l_b_s + spa_row * s_l_r_s + spa_col * s_l_c_s)
|
|
@@ -284,22 +277,18 @@ def to_dense_kernel(x,
|
|
|
284
277
|
# If block is present commence operations
|
|
285
278
|
if rev_idx_spa >= 0:
|
|
286
279
|
blk_idx = (rev_idx_spa * x_b_s +
|
|
287
|
-
(((pid_row % (sparsity_block_size //
|
|
280
|
+
(((pid_row % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
|
|
288
281
|
tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
289
|
-
(((pid_col % (sparsity_block_size //
|
|
282
|
+
(((pid_col % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
|
|
290
283
|
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
291
|
-
blk_msk = (
|
|
292
|
-
|
|
293
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
294
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
284
|
+
blk_msk = (blk_idx >= 0 and
|
|
285
|
+
blk_idx < x_b * x_b_s)
|
|
295
286
|
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
296
287
|
|
|
297
288
|
o_idx = (pid_blk * o_b_s +
|
|
298
|
-
((pid_row *
|
|
299
|
-
((pid_col *
|
|
300
|
-
o_msk = (
|
|
301
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
302
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
289
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
290
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
291
|
+
o_msk = (o_idx >= 0 and o_idx < o_b * o_b_s)
|
|
303
292
|
tl.store(o + o_idx, blk, o_msk)
|
|
304
293
|
|
|
305
294
|
|
|
@@ -403,12 +392,11 @@ def adapt_layout_forward(x: Tensor,
|
|
|
403
392
|
s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_to)
|
|
404
393
|
|
|
405
394
|
triton_grid = lambda meta: [o_b,
|
|
406
|
-
triton.cdiv(o_r,
|
|
407
|
-
|
|
408
|
-
triton.cdiv(o_c, min(meta["sparsity_block_size_from"], meta["sparsity_block_size_to"],
|
|
409
|
-
meta["TRITON_BLOCK_SIZE"]))]
|
|
395
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
396
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
410
397
|
|
|
411
|
-
|
|
398
|
+
# TODO wrap
|
|
399
|
+
(adapt_layout_kernel[triton_grid]
|
|
412
400
|
(x,
|
|
413
401
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
414
402
|
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
@@ -434,7 +422,8 @@ def adapt_layout_backward(ctx, grad_output):
|
|
|
434
422
|
|
|
435
423
|
@triton.autotune(
|
|
436
424
|
configs=get_autotune_configs(),
|
|
437
|
-
key=[],
|
|
425
|
+
key=["sparsity_block_size_from", "sparsity_block_size_to"],
|
|
426
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs_conversion},
|
|
438
427
|
reset_to_zero=["o"]
|
|
439
428
|
)
|
|
440
429
|
@triton.jit
|
|
@@ -453,9 +442,6 @@ def adapt_layout_kernel(x,
|
|
|
453
442
|
pid_row = tl.program_id(axis=1)
|
|
454
443
|
pid_col = tl.program_id(axis=2)
|
|
455
444
|
|
|
456
|
-
# Get valid triton block size (Triton can only handle 2-valued min)
|
|
457
|
-
val_tbs = min(min(sparsity_block_size_from, sparsity_block_size_to), TRITON_BLOCK_SIZE)
|
|
458
|
-
|
|
459
445
|
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
460
446
|
spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
|
|
461
447
|
spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
@@ -471,8 +457,8 @@ def adapt_layout_kernel(x,
|
|
|
471
457
|
|
|
472
458
|
# Get equivalent sparsity block in from layout
|
|
473
459
|
spa_bat_x = spa_bat_o
|
|
474
|
-
spa_row_x = (spa_row_o * sparsity_block_size_to + pid_row *
|
|
475
|
-
spa_col_x = (spa_col_o * sparsity_block_size_to + pid_col *
|
|
460
|
+
spa_row_x = (spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size_from
|
|
461
|
+
spa_col_x = (spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size_from
|
|
476
462
|
|
|
477
463
|
# Get reverse sparsity indices for x
|
|
478
464
|
rev_idx_spa_x_idx = (spa_bat_x * s_l_x_b_s +
|
|
@@ -484,29 +470,25 @@ def adapt_layout_kernel(x,
|
|
|
484
470
|
# If block is present commence operations
|
|
485
471
|
if rev_idx_spa_x >= 0:
|
|
486
472
|
# Calculate triton block size shifts
|
|
487
|
-
shift_row_x = ((spa_row_o * sparsity_block_size_to + pid_row *
|
|
488
|
-
% sparsity_block_size_from) //
|
|
489
|
-
shift_col_x = ((spa_col_o * sparsity_block_size_to + pid_col *
|
|
490
|
-
% sparsity_block_size_from) //
|
|
473
|
+
shift_row_x = ((spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE)
|
|
474
|
+
% sparsity_block_size_from) // TRITON_BLOCK_SIZE
|
|
475
|
+
shift_col_x = ((spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE)
|
|
476
|
+
% sparsity_block_size_from) // TRITON_BLOCK_SIZE
|
|
491
477
|
|
|
492
478
|
# Load x values
|
|
493
479
|
blk_x_idx = ((rev_idx_spa_x * x_b_s) +
|
|
494
|
-
((shift_row_x *
|
|
495
|
-
((shift_col_x *
|
|
496
|
-
blk_x_msk = (
|
|
497
|
-
|
|
498
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
499
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
480
|
+
((shift_row_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
481
|
+
((shift_col_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
482
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
483
|
+
blk_x_idx < x_b * x_b_s)
|
|
500
484
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
501
485
|
|
|
502
486
|
# Store output
|
|
503
487
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
504
|
-
((pid_row *
|
|
505
|
-
((pid_col *
|
|
506
|
-
blk_o_msk = (
|
|
507
|
-
|
|
508
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
509
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
488
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
489
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
490
|
+
blk_o_msk = (blk_o_idx >= 0 and
|
|
491
|
+
blk_o_idx < o_b * o_b_s)
|
|
510
492
|
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
511
493
|
|
|
512
494
|
|
blksprs/ops/distribution.py
CHANGED
|
@@ -6,7 +6,8 @@ from torch._library.triton import wrap_triton
|
|
|
6
6
|
from triton import language as tl
|
|
7
7
|
|
|
8
8
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
9
|
-
from blksprs.utils.tools import stride
|
|
9
|
+
from blksprs.utils.tools import stride
|
|
10
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
10
11
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
11
12
|
validate_sparsity, validate_dtype_int, validate_sparsity_block_size
|
|
12
13
|
|
|
@@ -100,7 +101,8 @@ def gather_backward(ctx, grad_output):
|
|
|
100
101
|
|
|
101
102
|
@triton.autotune(
|
|
102
103
|
configs=get_autotune_configs(),
|
|
103
|
-
key=[],
|
|
104
|
+
key=["sparsity_block_size"],
|
|
105
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
104
106
|
reset_to_zero=["o"]
|
|
105
107
|
)
|
|
106
108
|
@triton.jit
|
|
@@ -121,9 +123,6 @@ def gather_kernel(x,
|
|
|
121
123
|
pid_row = tl.program_id(axis=1)
|
|
122
124
|
pid_col = tl.program_id(axis=2)
|
|
123
125
|
|
|
124
|
-
# Get valid triton block size
|
|
125
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
126
|
-
|
|
127
126
|
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
128
127
|
spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
|
|
129
128
|
spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
@@ -139,12 +138,10 @@ def gather_kernel(x,
|
|
|
139
138
|
|
|
140
139
|
# Load index values
|
|
141
140
|
blk_i_idx = ((pid_blk * i_b_s) +
|
|
142
|
-
((pid_row *
|
|
143
|
-
((pid_col *
|
|
144
|
-
blk_i_msk = (
|
|
145
|
-
|
|
146
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
147
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
141
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
142
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
143
|
+
blk_i_msk = (blk_i_idx >= 0 and
|
|
144
|
+
blk_i_idx < i_b * i_b_s)
|
|
148
145
|
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
|
|
149
146
|
|
|
150
147
|
# Get indices of sparsity blocks and positions within the blocks
|
|
@@ -154,9 +151,9 @@ def gather_kernel(x,
|
|
|
154
151
|
rev_dst_bat_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_o, dtype=tl.int32)
|
|
155
152
|
rev_dst_row_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_o, dtype=tl.int32)
|
|
156
153
|
rev_dst_col_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_o, dtype=tl.int32)
|
|
157
|
-
dst_row_x = (((pid_row *
|
|
154
|
+
dst_row_x = (((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
|
|
158
155
|
.broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
|
|
159
|
-
dst_col_x = (((pid_col *
|
|
156
|
+
dst_col_x = (((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
|
|
160
157
|
.broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
|
|
161
158
|
if dim == 0:
|
|
162
159
|
rev_dst_bat_x = blk_i
|
|
@@ -171,32 +168,26 @@ def gather_kernel(x,
|
|
|
171
168
|
rev_idx_spa_x_idx = ((rev_dst_bat_x * s_l_x_b_s) +
|
|
172
169
|
(rev_dst_row_x * s_l_x_r_s) +
|
|
173
170
|
(rev_dst_col_x * s_l_x_c_s))
|
|
174
|
-
rev_idx_spa_x_msk = (
|
|
175
|
-
|
|
176
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
177
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
171
|
+
rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and
|
|
172
|
+
rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
|
|
178
173
|
rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
|
|
179
174
|
|
|
180
175
|
# Load x values
|
|
181
176
|
blk_x_idx = ((rev_idx_spa_x * x_b_s) +
|
|
182
177
|
dst_row_x +
|
|
183
178
|
dst_col_x)
|
|
184
|
-
blk_x_msk = ((
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
188
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
179
|
+
blk_x_msk = ((blk_x_idx >= 0 and
|
|
180
|
+
blk_x_idx < x_b * x_b_s) and
|
|
181
|
+
rev_idx_spa_x_msk != -1)
|
|
189
182
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
190
183
|
|
|
191
184
|
# Store output
|
|
192
185
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
193
|
-
((pid_row *
|
|
194
|
-
((pid_col *
|
|
195
|
-
blk_o_msk = ((
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
199
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
186
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
187
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
188
|
+
blk_o_msk = ((blk_o_idx >= 0 and
|
|
189
|
+
blk_o_idx < o_b * o_b_s) and
|
|
190
|
+
rev_idx_spa_x_msk != -1)
|
|
200
191
|
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
201
192
|
|
|
202
193
|
|
|
@@ -357,7 +348,8 @@ def scatter_reduce_backward(ctx, grad_output):
|
|
|
357
348
|
|
|
358
349
|
@triton.autotune(
|
|
359
350
|
configs=get_autotune_configs(),
|
|
360
|
-
key=[],
|
|
351
|
+
key=["sparsity_block_size"],
|
|
352
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
361
353
|
reset_to_zero=["o"]
|
|
362
354
|
)
|
|
363
355
|
@triton.jit
|
|
@@ -379,9 +371,6 @@ def scatter_reduce_kernel(x,
|
|
|
379
371
|
pid_row = tl.program_id(axis=1)
|
|
380
372
|
pid_col = tl.program_id(axis=2)
|
|
381
373
|
|
|
382
|
-
# Get valid triton block size
|
|
383
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
384
|
-
|
|
385
374
|
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
386
375
|
spa_bat_x_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
|
|
387
376
|
spa_bat_x_msk = (spa_bat_x_idx >= 0 and spa_bat_x_idx < s_lut_x_r * s_lut_x_r_s)
|
|
@@ -397,22 +386,18 @@ def scatter_reduce_kernel(x,
|
|
|
397
386
|
|
|
398
387
|
# Load x values
|
|
399
388
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
400
|
-
((pid_row *
|
|
401
|
-
((pid_col *
|
|
402
|
-
blk_x_msk = (
|
|
403
|
-
|
|
404
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
405
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
389
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
390
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
391
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
392
|
+
blk_x_idx < x_b * x_b_s)
|
|
406
393
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
407
394
|
|
|
408
395
|
# Load index values
|
|
409
396
|
blk_i_idx = ((pid_blk * i_b_s) +
|
|
410
|
-
((pid_row *
|
|
411
|
-
((pid_col *
|
|
412
|
-
blk_i_msk = (
|
|
413
|
-
|
|
414
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
415
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
397
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
398
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
399
|
+
blk_i_msk = (blk_i_idx >= 0 and
|
|
400
|
+
blk_i_idx < i_b * i_b_s)
|
|
416
401
|
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
|
|
417
402
|
|
|
418
403
|
# Get indices of sparsity blocks and positions within the blocks
|
|
@@ -422,9 +407,9 @@ def scatter_reduce_kernel(x,
|
|
|
422
407
|
rev_dst_bat_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_x, dtype=tl.int32)
|
|
423
408
|
rev_dst_row_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_x, dtype=tl.int32)
|
|
424
409
|
rev_dst_col_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_x, dtype=tl.int32)
|
|
425
|
-
dst_row_o = (((pid_row *
|
|
410
|
+
dst_row_o = (((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
|
|
426
411
|
.broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
|
|
427
|
-
dst_col_o = (((pid_col *
|
|
412
|
+
dst_col_o = (((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
|
|
428
413
|
.broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
|
|
429
414
|
if dim == 0:
|
|
430
415
|
rev_dst_bat_o = blk_i
|
|
@@ -439,21 +424,17 @@ def scatter_reduce_kernel(x,
|
|
|
439
424
|
rev_idx_spa_o_idx = ((rev_dst_bat_o * s_l_o_b_s) +
|
|
440
425
|
(rev_dst_row_o * s_l_o_r_s) +
|
|
441
426
|
(rev_dst_col_o * s_l_o_c_s))
|
|
442
|
-
rev_idx_spa_o_msk = (
|
|
443
|
-
|
|
444
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
445
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
427
|
+
rev_idx_spa_o_msk = (rev_idx_spa_o_idx >= 0 and
|
|
428
|
+
rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
|
|
446
429
|
rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
|
|
447
430
|
|
|
448
431
|
# Store output
|
|
449
432
|
blk_o_idx = ((rev_idx_spa_o * o_b_s) +
|
|
450
433
|
dst_row_o +
|
|
451
434
|
dst_col_o)
|
|
452
|
-
blk_o_msk = ((
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
456
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
435
|
+
blk_o_msk = ((blk_o_idx >= 0 and
|
|
436
|
+
blk_o_idx < o_b * o_b_s) and
|
|
437
|
+
rev_idx_spa_o_msk != -1)
|
|
457
438
|
|
|
458
439
|
if reduce_op_ind == 0:
|
|
459
440
|
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|