blksprs 2.0rc3__py3-none-any.whl → 2.0rc6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/layouting/distribution_layout.py +11 -15
- blksprs/layouting/sparsity_layout.py +26 -31
- blksprs/ops/conversion.py +48 -64
- blksprs/ops/distribution.py +39 -57
- blksprs/ops/flow.py +24 -34
- blksprs/ops/matmul.py +21 -21
- blksprs/ops/misc/broadcast_ops.py +14 -19
- blksprs/ops/misc/row_wise.py +37 -55
- blksprs/ops/softmax.py +34 -46
- blksprs/utils/autotuning.py +78 -0
- blksprs/utils/tools.py +6 -25
- blksprs/utils/validation.py +3 -0
- {blksprs-2.0rc3.dist-info → blksprs-2.0rc6.dist-info}/METADATA +14 -5
- blksprs-2.0rc6.dist-info/RECORD +23 -0
- blksprs-2.0rc3.dist-info/RECORD +0 -22
- {blksprs-2.0rc3.dist-info → blksprs-2.0rc6.dist-info}/WHEEL +0 -0
- {blksprs-2.0rc3.dist-info → blksprs-2.0rc6.dist-info}/top_level.txt +0 -0
|
@@ -4,7 +4,8 @@ from torch import Tensor
|
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
6
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
7
|
-
from blksprs.utils.tools import stride
|
|
7
|
+
from blksprs.utils.tools import stride
|
|
8
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
8
9
|
from blksprs.utils.validation import validate_dimensions, validate_device, \
|
|
9
10
|
validate_contiguous
|
|
10
11
|
|
|
@@ -47,6 +48,7 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
|
|
|
47
48
|
triton.cdiv(i_r, meta["TRITON_BLOCK_SIZE"]),
|
|
48
49
|
triton.cdiv(i_c, meta["TRITON_BLOCK_SIZE"])]
|
|
49
50
|
|
|
51
|
+
# TODO wrap
|
|
50
52
|
(build_distribution_layout_kernel[triton_grid]
|
|
51
53
|
(indices,
|
|
52
54
|
i_b, i_b_s, i_r_s, i_c_s,
|
|
@@ -62,7 +64,8 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
|
|
|
62
64
|
|
|
63
65
|
@triton.autotune(
|
|
64
66
|
configs=get_autotune_configs(),
|
|
65
|
-
key=[],
|
|
67
|
+
key=["sparsity_block_size"],
|
|
68
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
66
69
|
reset_to_zero=["o"]
|
|
67
70
|
)
|
|
68
71
|
@triton.jit
|
|
@@ -80,9 +83,6 @@ def build_distribution_layout_kernel(i,
|
|
|
80
83
|
pid_row = tl.program_id(axis=1)
|
|
81
84
|
pid_col = tl.program_id(axis=2)
|
|
82
85
|
|
|
83
|
-
# Get valid triton block size
|
|
84
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
85
|
-
|
|
86
86
|
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
87
87
|
spa_bat_i_idx = (pid_blk * s_lut_i_r_s + 0 * s_lut_i_c_s)
|
|
88
88
|
spa_bat_i_msk = (spa_bat_i_idx >= 0 and spa_bat_i_idx < s_lut_i_r * s_lut_i_r_s)
|
|
@@ -97,12 +97,10 @@ def build_distribution_layout_kernel(i,
|
|
|
97
97
|
spa_col_i = tl.load(s_lut_i + spa_col_i_idx, mask=spa_col_i_msk)
|
|
98
98
|
|
|
99
99
|
blk_i_idx = (pid_blk * i_b_s +
|
|
100
|
-
((pid_row *
|
|
101
|
-
((pid_col *
|
|
102
|
-
blk_i_msk = (
|
|
103
|
-
|
|
104
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
105
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
100
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
101
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
102
|
+
blk_i_msk = (blk_i_idx >= 0 and
|
|
103
|
+
blk_i_idx < i_b * i_b_s)
|
|
106
104
|
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk)
|
|
107
105
|
|
|
108
106
|
dst_bat_idx = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_i, dtype=tl.int32)
|
|
@@ -120,8 +118,6 @@ def build_distribution_layout_kernel(i,
|
|
|
120
118
|
blk_o_idx = ((dst_bat_idx * o_b_s) +
|
|
121
119
|
(dst_row_idx * o_r_s) +
|
|
122
120
|
(dst_col_idx * o_c_s))
|
|
123
|
-
blk_o_msk = (
|
|
124
|
-
|
|
125
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
126
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
121
|
+
blk_o_msk = (blk_o_idx >= 0 and
|
|
122
|
+
blk_o_idx < o_b * o_b_s)
|
|
127
123
|
tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
|
|
@@ -7,7 +7,8 @@ from torch._library.triton import wrap_triton
|
|
|
7
7
|
from triton import language as tl
|
|
8
8
|
|
|
9
9
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
10
|
-
from blksprs.utils.tools import stride
|
|
10
|
+
from blksprs.utils.tools import stride
|
|
11
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs, prune_autotune_configs_conversion
|
|
11
12
|
from blksprs.utils.validation import validate_dimensions, validate_device, \
|
|
12
13
|
validate_contiguous, validate_sparsity, validate_sparsity_block_size
|
|
13
14
|
|
|
@@ -37,10 +38,11 @@ def build_sparsity_layout(x: Tensor, sparsity_block_size: int) -> Tensor:
|
|
|
37
38
|
o_b_s, o_r_s, o_c_s = stride(output)
|
|
38
39
|
|
|
39
40
|
triton_grid = lambda meta: [x_b,
|
|
40
|
-
triton.cdiv(x_r,
|
|
41
|
-
triton.cdiv(x_c,
|
|
41
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
42
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
42
43
|
|
|
43
|
-
|
|
44
|
+
# TODO wrap
|
|
45
|
+
(build_sparsity_layout_kernel[triton_grid]
|
|
44
46
|
(x,
|
|
45
47
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
46
48
|
output,
|
|
@@ -52,7 +54,8 @@ def build_sparsity_layout(x: Tensor, sparsity_block_size: int) -> Tensor:
|
|
|
52
54
|
|
|
53
55
|
@triton.autotune(
|
|
54
56
|
configs=get_autotune_configs(),
|
|
55
|
-
key=[],
|
|
57
|
+
key=["sparsity_block_size"],
|
|
58
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
56
59
|
reset_to_zero=["o"]
|
|
57
60
|
)
|
|
58
61
|
@triton.jit
|
|
@@ -67,24 +70,19 @@ def build_sparsity_layout_kernel(x,
|
|
|
67
70
|
pid_row = tl.program_id(axis=1)
|
|
68
71
|
pid_col = tl.program_id(axis=2)
|
|
69
72
|
|
|
70
|
-
# Get valid triton block size
|
|
71
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
72
|
-
|
|
73
73
|
# Load x values
|
|
74
74
|
blk_x_idx = (pid_bat * x_b_s +
|
|
75
|
-
((pid_row *
|
|
76
|
-
((pid_col *
|
|
77
|
-
blk_x_msk = (
|
|
78
|
-
|
|
79
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
80
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
75
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
76
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
77
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
78
|
+
blk_x_idx < x_b * x_b_s)
|
|
81
79
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
82
80
|
|
|
83
81
|
# Store sparsity layout value
|
|
84
82
|
if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
|
|
85
83
|
blk_o_idx = (pid_bat * o_b_s +
|
|
86
|
-
(((pid_row *
|
|
87
|
-
((pid_col *
|
|
84
|
+
(((pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_r_s +
|
|
85
|
+
((pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_c_s))
|
|
88
86
|
blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
|
|
89
87
|
tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
|
|
90
88
|
|
|
@@ -129,10 +127,11 @@ def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tenso
|
|
|
129
127
|
o_b_s, o_r_s, o_c_s = stride(output)
|
|
130
128
|
|
|
131
129
|
triton_grid = lambda meta: [x_b,
|
|
132
|
-
triton.cdiv(x_r,
|
|
133
|
-
triton.cdiv(x_c,
|
|
130
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
131
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
134
132
|
|
|
135
|
-
|
|
133
|
+
# TODO wrap
|
|
134
|
+
(build_sparsity_layout_adaption_kernel[triton_grid]
|
|
136
135
|
(x,
|
|
137
136
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
138
137
|
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
@@ -146,7 +145,8 @@ def build_sparsity_layout_adaption(x: BlksprsTensor, sparsity_layout_from: Tenso
|
|
|
146
145
|
|
|
147
146
|
@triton.autotune(
|
|
148
147
|
configs=get_autotune_configs(),
|
|
149
|
-
key=[],
|
|
148
|
+
key=["sparsity_block_size_from", "sparsity_block_size_to"],
|
|
149
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs_conversion},
|
|
150
150
|
reset_to_zero=["o"]
|
|
151
151
|
)
|
|
152
152
|
@triton.jit
|
|
@@ -163,9 +163,6 @@ def build_sparsity_layout_adaption_kernel(x,
|
|
|
163
163
|
pid_row = tl.program_id(axis=1)
|
|
164
164
|
pid_col = tl.program_id(axis=2)
|
|
165
165
|
|
|
166
|
-
# Get valid triton block size
|
|
167
|
-
val_tbs = min(sparsity_block_size_to, TRITON_BLOCK_SIZE)
|
|
168
|
-
|
|
169
166
|
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
170
167
|
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
171
168
|
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
@@ -181,20 +178,18 @@ def build_sparsity_layout_adaption_kernel(x,
|
|
|
181
178
|
|
|
182
179
|
# Load x values
|
|
183
180
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
184
|
-
((pid_row *
|
|
185
|
-
((pid_col *
|
|
186
|
-
blk_x_msk = (
|
|
187
|
-
|
|
188
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < sparsity_block_size_from and
|
|
189
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < sparsity_block_size_from))
|
|
181
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
182
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
183
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
184
|
+
blk_x_idx < x_b * x_b_s)
|
|
190
185
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
191
186
|
|
|
192
187
|
# Store sparsity layout value
|
|
193
188
|
if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
|
|
194
189
|
blk_o_idx = ((spa_bat * o_b_s) +
|
|
195
|
-
(((pid_row *
|
|
190
|
+
(((pid_row * TRITON_BLOCK_SIZE + spa_row * sparsity_block_size_from)
|
|
196
191
|
// sparsity_block_size_to) * o_r_s) +
|
|
197
|
-
(((pid_col *
|
|
192
|
+
(((pid_col * TRITON_BLOCK_SIZE + spa_col * sparsity_block_size_from)
|
|
198
193
|
// sparsity_block_size_to) * o_c_s))
|
|
199
194
|
blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
|
|
200
195
|
tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
|
blksprs/ops/conversion.py
CHANGED
|
@@ -6,7 +6,8 @@ from triton import language as tl
|
|
|
6
6
|
|
|
7
7
|
from blksprs.layouting.sparsity_layout import build_sparsity_layout_adaption
|
|
8
8
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
9
|
-
from blksprs.utils.tools import stride
|
|
9
|
+
from blksprs.utils.tools import stride
|
|
10
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs, prune_autotune_configs_conversion
|
|
10
11
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
11
12
|
validate_sparsity, validate_sparsity_block_size, validate_sparsity_dense
|
|
12
13
|
|
|
@@ -54,7 +55,7 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor,
|
|
|
54
55
|
@triton_op("blksprs::to_sparse", mutates_args={})
|
|
55
56
|
def to_sparse_forward(x: Tensor, _: Tensor,
|
|
56
57
|
sparsity_lut: Tensor, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
57
|
-
output = torch.
|
|
58
|
+
output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
58
59
|
dtype=x.dtype, device=x.device)
|
|
59
60
|
|
|
60
61
|
x_b, x_r, x_c = x.size()
|
|
@@ -86,7 +87,9 @@ def to_sparse_backward(ctx, grad_output):
|
|
|
86
87
|
|
|
87
88
|
@triton.autotune(
|
|
88
89
|
configs=get_autotune_configs(),
|
|
89
|
-
key=[],
|
|
90
|
+
key=["sparsity_block_size"],
|
|
91
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
92
|
+
reset_to_zero=["o"]
|
|
90
93
|
)
|
|
91
94
|
@triton.jit
|
|
92
95
|
def to_sparse_kernel(x,
|
|
@@ -101,9 +104,6 @@ def to_sparse_kernel(x,
|
|
|
101
104
|
pid_row = tl.program_id(axis=1)
|
|
102
105
|
pid_col = tl.program_id(axis=2)
|
|
103
106
|
|
|
104
|
-
# Get valid triton block size
|
|
105
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
106
|
-
|
|
107
107
|
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
108
108
|
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
109
109
|
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
@@ -119,24 +119,20 @@ def to_sparse_kernel(x,
|
|
|
119
119
|
|
|
120
120
|
# Load block from dense tensor
|
|
121
121
|
blk_d_idx = (spa_bat * x_b_s +
|
|
122
|
-
((pid_row *
|
|
122
|
+
((pid_row * TRITON_BLOCK_SIZE + spa_row * sparsity_block_size +
|
|
123
123
|
tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
124
|
-
((pid_col *
|
|
124
|
+
((pid_col * TRITON_BLOCK_SIZE + spa_col * sparsity_block_size +
|
|
125
125
|
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
126
|
-
blk_d_msk = (
|
|
127
|
-
|
|
128
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
129
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
126
|
+
blk_d_msk = (blk_d_idx >= 0 and
|
|
127
|
+
blk_d_idx < x_b * x_b_s)
|
|
130
128
|
blk_d = tl.load(x + blk_d_idx, mask=blk_d_msk)
|
|
131
129
|
|
|
132
130
|
# Store block in sparse tensor
|
|
133
131
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
134
|
-
((pid_row *
|
|
135
|
-
((pid_col *
|
|
136
|
-
blk_o_msk = (
|
|
137
|
-
|
|
138
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
139
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
132
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
133
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
|
|
134
|
+
blk_o_msk = (blk_o_idx >= 0 and
|
|
135
|
+
blk_o_idx < (pid_blk + 1) * o_b_s)
|
|
140
136
|
tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
|
|
141
137
|
|
|
142
138
|
|
|
@@ -227,8 +223,8 @@ def to_dense_forward(x: Tensor, sparsity_layout: Tensor,
|
|
|
227
223
|
o_b_s, o_r_s, o_c_s = stride(output)
|
|
228
224
|
|
|
229
225
|
triton_grid = lambda meta: [o_b,
|
|
230
|
-
triton.cdiv(o_r,
|
|
231
|
-
triton.cdiv(o_c,
|
|
226
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
227
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
232
228
|
|
|
233
229
|
(wrap_triton(to_dense_kernel)[triton_grid]
|
|
234
230
|
(x,
|
|
@@ -251,7 +247,9 @@ def to_dense_backward(ctx, grad_output):
|
|
|
251
247
|
|
|
252
248
|
@triton.autotune(
|
|
253
249
|
configs=get_autotune_configs(),
|
|
254
|
-
key=[],
|
|
250
|
+
key=["sparsity_block_size"],
|
|
251
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
252
|
+
restore_value=["o"]
|
|
255
253
|
)
|
|
256
254
|
@triton.jit
|
|
257
255
|
def to_dense_kernel(x,
|
|
@@ -267,12 +265,9 @@ def to_dense_kernel(x,
|
|
|
267
265
|
pid_row = tl.program_id(axis=1)
|
|
268
266
|
pid_col = tl.program_id(axis=2)
|
|
269
267
|
|
|
270
|
-
# Get valid triton block size
|
|
271
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
272
|
-
|
|
273
268
|
# Get sparsity index of current block
|
|
274
|
-
spa_row = (pid_row *
|
|
275
|
-
spa_col = (pid_col *
|
|
269
|
+
spa_row = (pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size
|
|
270
|
+
spa_col = (pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size
|
|
276
271
|
|
|
277
272
|
# Get reverse sparsity index for current block
|
|
278
273
|
rev_idx_spa_idx = (pid_blk * s_l_b_s + spa_row * s_l_r_s + spa_col * s_l_c_s)
|
|
@@ -282,22 +277,18 @@ def to_dense_kernel(x,
|
|
|
282
277
|
# If block is present commence operations
|
|
283
278
|
if rev_idx_spa >= 0:
|
|
284
279
|
blk_idx = (rev_idx_spa * x_b_s +
|
|
285
|
-
(((pid_row % (sparsity_block_size //
|
|
280
|
+
(((pid_row % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
|
|
286
281
|
tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
287
|
-
(((pid_col % (sparsity_block_size //
|
|
282
|
+
(((pid_col % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
|
|
288
283
|
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
289
|
-
blk_msk = (
|
|
290
|
-
|
|
291
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
292
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
284
|
+
blk_msk = (blk_idx >= 0 and
|
|
285
|
+
blk_idx < x_b * x_b_s)
|
|
293
286
|
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
294
287
|
|
|
295
288
|
o_idx = (pid_blk * o_b_s +
|
|
296
|
-
((pid_row *
|
|
297
|
-
((pid_col *
|
|
298
|
-
o_msk = (
|
|
299
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
300
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
289
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
290
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
291
|
+
o_msk = (o_idx >= 0 and o_idx < o_b * o_b_s)
|
|
301
292
|
tl.store(o + o_idx, blk, o_msk)
|
|
302
293
|
|
|
303
294
|
|
|
@@ -401,12 +392,11 @@ def adapt_layout_forward(x: Tensor,
|
|
|
401
392
|
s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_to)
|
|
402
393
|
|
|
403
394
|
triton_grid = lambda meta: [o_b,
|
|
404
|
-
triton.cdiv(o_r,
|
|
405
|
-
|
|
406
|
-
triton.cdiv(o_c, min(meta["sparsity_block_size_from"], meta["sparsity_block_size_to"],
|
|
407
|
-
meta["TRITON_BLOCK_SIZE"]))]
|
|
395
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
396
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
408
397
|
|
|
409
|
-
|
|
398
|
+
# TODO wrap
|
|
399
|
+
(adapt_layout_kernel[triton_grid]
|
|
410
400
|
(x,
|
|
411
401
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
412
402
|
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
@@ -432,7 +422,8 @@ def adapt_layout_backward(ctx, grad_output):
|
|
|
432
422
|
|
|
433
423
|
@triton.autotune(
|
|
434
424
|
configs=get_autotune_configs(),
|
|
435
|
-
key=[],
|
|
425
|
+
key=["sparsity_block_size_from", "sparsity_block_size_to"],
|
|
426
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs_conversion},
|
|
436
427
|
reset_to_zero=["o"]
|
|
437
428
|
)
|
|
438
429
|
@triton.jit
|
|
@@ -451,9 +442,6 @@ def adapt_layout_kernel(x,
|
|
|
451
442
|
pid_row = tl.program_id(axis=1)
|
|
452
443
|
pid_col = tl.program_id(axis=2)
|
|
453
444
|
|
|
454
|
-
# Get valid triton block size (Triton can only handle 2-valued min)
|
|
455
|
-
val_tbs = min(min(sparsity_block_size_from, sparsity_block_size_to), TRITON_BLOCK_SIZE)
|
|
456
|
-
|
|
457
445
|
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
458
446
|
spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
|
|
459
447
|
spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
@@ -469,8 +457,8 @@ def adapt_layout_kernel(x,
|
|
|
469
457
|
|
|
470
458
|
# Get equivalent sparsity block in from layout
|
|
471
459
|
spa_bat_x = spa_bat_o
|
|
472
|
-
spa_row_x = (spa_row_o * sparsity_block_size_to + pid_row *
|
|
473
|
-
spa_col_x = (spa_col_o * sparsity_block_size_to + pid_col *
|
|
460
|
+
spa_row_x = (spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size_from
|
|
461
|
+
spa_col_x = (spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size_from
|
|
474
462
|
|
|
475
463
|
# Get reverse sparsity indices for x
|
|
476
464
|
rev_idx_spa_x_idx = (spa_bat_x * s_l_x_b_s +
|
|
@@ -482,29 +470,25 @@ def adapt_layout_kernel(x,
|
|
|
482
470
|
# If block is present commence operations
|
|
483
471
|
if rev_idx_spa_x >= 0:
|
|
484
472
|
# Calculate triton block size shifts
|
|
485
|
-
shift_row_x = ((spa_row_o * sparsity_block_size_to + pid_row *
|
|
486
|
-
% sparsity_block_size_from) //
|
|
487
|
-
shift_col_x = ((spa_col_o * sparsity_block_size_to + pid_col *
|
|
488
|
-
% sparsity_block_size_from) //
|
|
473
|
+
shift_row_x = ((spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE)
|
|
474
|
+
% sparsity_block_size_from) // TRITON_BLOCK_SIZE
|
|
475
|
+
shift_col_x = ((spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE)
|
|
476
|
+
% sparsity_block_size_from) // TRITON_BLOCK_SIZE
|
|
489
477
|
|
|
490
478
|
# Load x values
|
|
491
479
|
blk_x_idx = ((rev_idx_spa_x * x_b_s) +
|
|
492
|
-
((shift_row_x *
|
|
493
|
-
((shift_col_x *
|
|
494
|
-
blk_x_msk = (
|
|
495
|
-
|
|
496
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
497
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
480
|
+
((shift_row_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
481
|
+
((shift_col_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
482
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
483
|
+
blk_x_idx < x_b * x_b_s)
|
|
498
484
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
499
485
|
|
|
500
486
|
# Store output
|
|
501
487
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
502
|
-
((pid_row *
|
|
503
|
-
((pid_col *
|
|
504
|
-
blk_o_msk = (
|
|
505
|
-
|
|
506
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
507
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
488
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
489
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
490
|
+
blk_o_msk = (blk_o_idx >= 0 and
|
|
491
|
+
blk_o_idx < o_b * o_b_s)
|
|
508
492
|
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
509
493
|
|
|
510
494
|
|
blksprs/ops/distribution.py
CHANGED
|
@@ -6,7 +6,8 @@ from torch._library.triton import wrap_triton
|
|
|
6
6
|
from triton import language as tl
|
|
7
7
|
|
|
8
8
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
9
|
-
from blksprs.utils.tools import stride
|
|
9
|
+
from blksprs.utils.tools import stride
|
|
10
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
10
11
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
11
12
|
validate_sparsity, validate_dtype_int, validate_sparsity_block_size
|
|
12
13
|
|
|
@@ -54,7 +55,7 @@ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
54
55
|
def gather_forward(x: Tensor, sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
|
|
55
56
|
dim: int, i: Tensor, _: Tensor, sparsity_lut_i: Tensor,
|
|
56
57
|
sparsity_block_size: int) -> Tensor:
|
|
57
|
-
output = torch.
|
|
58
|
+
output = torch.zeros_like(i, dtype=x.dtype)
|
|
58
59
|
|
|
59
60
|
x_b, x_r, x_c = x.size()
|
|
60
61
|
x_b_s, x_r_s, x_c_s = stride(x)
|
|
@@ -100,7 +101,9 @@ def gather_backward(ctx, grad_output):
|
|
|
100
101
|
|
|
101
102
|
@triton.autotune(
|
|
102
103
|
configs=get_autotune_configs(),
|
|
103
|
-
key=[],
|
|
104
|
+
key=["sparsity_block_size"],
|
|
105
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
106
|
+
reset_to_zero=["o"]
|
|
104
107
|
)
|
|
105
108
|
@triton.jit
|
|
106
109
|
def gather_kernel(x,
|
|
@@ -120,9 +123,6 @@ def gather_kernel(x,
|
|
|
120
123
|
pid_row = tl.program_id(axis=1)
|
|
121
124
|
pid_col = tl.program_id(axis=2)
|
|
122
125
|
|
|
123
|
-
# Get valid triton block size
|
|
124
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
125
|
-
|
|
126
126
|
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
127
127
|
spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
|
|
128
128
|
spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
@@ -138,12 +138,10 @@ def gather_kernel(x,
|
|
|
138
138
|
|
|
139
139
|
# Load index values
|
|
140
140
|
blk_i_idx = ((pid_blk * i_b_s) +
|
|
141
|
-
((pid_row *
|
|
142
|
-
((pid_col *
|
|
143
|
-
blk_i_msk = (
|
|
144
|
-
|
|
145
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
146
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
141
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
142
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
143
|
+
blk_i_msk = (blk_i_idx >= 0 and
|
|
144
|
+
blk_i_idx < i_b * i_b_s)
|
|
147
145
|
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
|
|
148
146
|
|
|
149
147
|
# Get indices of sparsity blocks and positions within the blocks
|
|
@@ -153,9 +151,9 @@ def gather_kernel(x,
|
|
|
153
151
|
rev_dst_bat_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_o, dtype=tl.int32)
|
|
154
152
|
rev_dst_row_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_o, dtype=tl.int32)
|
|
155
153
|
rev_dst_col_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_o, dtype=tl.int32)
|
|
156
|
-
dst_row_x = (((pid_row *
|
|
154
|
+
dst_row_x = (((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
|
|
157
155
|
.broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
|
|
158
|
-
dst_col_x = (((pid_col *
|
|
156
|
+
dst_col_x = (((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
|
|
159
157
|
.broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
|
|
160
158
|
if dim == 0:
|
|
161
159
|
rev_dst_bat_x = blk_i
|
|
@@ -170,32 +168,26 @@ def gather_kernel(x,
|
|
|
170
168
|
rev_idx_spa_x_idx = ((rev_dst_bat_x * s_l_x_b_s) +
|
|
171
169
|
(rev_dst_row_x * s_l_x_r_s) +
|
|
172
170
|
(rev_dst_col_x * s_l_x_c_s))
|
|
173
|
-
rev_idx_spa_x_msk = (
|
|
174
|
-
|
|
175
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
176
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
171
|
+
rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and
|
|
172
|
+
rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
|
|
177
173
|
rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
|
|
178
174
|
|
|
179
175
|
# Load x values
|
|
180
176
|
blk_x_idx = ((rev_idx_spa_x * x_b_s) +
|
|
181
177
|
dst_row_x +
|
|
182
178
|
dst_col_x)
|
|
183
|
-
blk_x_msk = ((
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
187
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
179
|
+
blk_x_msk = ((blk_x_idx >= 0 and
|
|
180
|
+
blk_x_idx < x_b * x_b_s) and
|
|
181
|
+
rev_idx_spa_x_msk != -1)
|
|
188
182
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
189
183
|
|
|
190
184
|
# Store output
|
|
191
185
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
192
|
-
((pid_row *
|
|
193
|
-
((pid_col *
|
|
194
|
-
blk_o_msk = ((
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
198
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
186
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
187
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
188
|
+
blk_o_msk = ((blk_o_idx >= 0 and
|
|
189
|
+
blk_o_idx < o_b * o_b_s) and
|
|
190
|
+
rev_idx_spa_x_msk != -1)
|
|
199
191
|
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
200
192
|
|
|
201
193
|
|
|
@@ -356,7 +348,8 @@ def scatter_reduce_backward(ctx, grad_output):
|
|
|
356
348
|
|
|
357
349
|
@triton.autotune(
|
|
358
350
|
configs=get_autotune_configs(),
|
|
359
|
-
key=[],
|
|
351
|
+
key=["sparsity_block_size"],
|
|
352
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
360
353
|
reset_to_zero=["o"]
|
|
361
354
|
)
|
|
362
355
|
@triton.jit
|
|
@@ -378,9 +371,6 @@ def scatter_reduce_kernel(x,
|
|
|
378
371
|
pid_row = tl.program_id(axis=1)
|
|
379
372
|
pid_col = tl.program_id(axis=2)
|
|
380
373
|
|
|
381
|
-
# Get valid triton block size
|
|
382
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
383
|
-
|
|
384
374
|
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
385
375
|
spa_bat_x_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
|
|
386
376
|
spa_bat_x_msk = (spa_bat_x_idx >= 0 and spa_bat_x_idx < s_lut_x_r * s_lut_x_r_s)
|
|
@@ -396,22 +386,18 @@ def scatter_reduce_kernel(x,
|
|
|
396
386
|
|
|
397
387
|
# Load x values
|
|
398
388
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
399
|
-
((pid_row *
|
|
400
|
-
((pid_col *
|
|
401
|
-
blk_x_msk = (
|
|
402
|
-
|
|
403
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
404
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
389
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
390
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
391
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
392
|
+
blk_x_idx < x_b * x_b_s)
|
|
405
393
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
406
394
|
|
|
407
395
|
# Load index values
|
|
408
396
|
blk_i_idx = ((pid_blk * i_b_s) +
|
|
409
|
-
((pid_row *
|
|
410
|
-
((pid_col *
|
|
411
|
-
blk_i_msk = (
|
|
412
|
-
|
|
413
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
414
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
397
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
398
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
399
|
+
blk_i_msk = (blk_i_idx >= 0 and
|
|
400
|
+
blk_i_idx < i_b * i_b_s)
|
|
415
401
|
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
|
|
416
402
|
|
|
417
403
|
# Get indices of sparsity blocks and positions within the blocks
|
|
@@ -421,9 +407,9 @@ def scatter_reduce_kernel(x,
|
|
|
421
407
|
rev_dst_bat_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_x, dtype=tl.int32)
|
|
422
408
|
rev_dst_row_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_x, dtype=tl.int32)
|
|
423
409
|
rev_dst_col_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_x, dtype=tl.int32)
|
|
424
|
-
dst_row_o = (((pid_row *
|
|
410
|
+
dst_row_o = (((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
|
|
425
411
|
.broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
|
|
426
|
-
dst_col_o = (((pid_col *
|
|
412
|
+
dst_col_o = (((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
|
|
427
413
|
.broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
|
|
428
414
|
if dim == 0:
|
|
429
415
|
rev_dst_bat_o = blk_i
|
|
@@ -438,21 +424,17 @@ def scatter_reduce_kernel(x,
|
|
|
438
424
|
rev_idx_spa_o_idx = ((rev_dst_bat_o * s_l_o_b_s) +
|
|
439
425
|
(rev_dst_row_o * s_l_o_r_s) +
|
|
440
426
|
(rev_dst_col_o * s_l_o_c_s))
|
|
441
|
-
rev_idx_spa_o_msk = (
|
|
442
|
-
|
|
443
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
444
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
427
|
+
rev_idx_spa_o_msk = (rev_idx_spa_o_idx >= 0 and
|
|
428
|
+
rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
|
|
445
429
|
rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
|
|
446
430
|
|
|
447
431
|
# Store output
|
|
448
432
|
blk_o_idx = ((rev_idx_spa_o * o_b_s) +
|
|
449
433
|
dst_row_o +
|
|
450
434
|
dst_col_o)
|
|
451
|
-
blk_o_msk = ((
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
455
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
435
|
+
blk_o_msk = ((blk_o_idx >= 0 and
|
|
436
|
+
blk_o_idx < o_b * o_b_s) and
|
|
437
|
+
rev_idx_spa_o_msk != -1)
|
|
456
438
|
|
|
457
439
|
if reduce_op_ind == 0:
|
|
458
440
|
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|