blksprs 2.0rc4__py3-none-any.whl → 2.0rc7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/layouting/distribution_layout.py +11 -15
- blksprs/layouting/sparsity_layout.py +26 -31
- blksprs/ops/conversion.py +45 -63
- blksprs/ops/distribution.py +38 -57
- blksprs/ops/flow.py +22 -33
- blksprs/ops/matmul.py +19 -20
- blksprs/ops/misc/broadcast_ops.py +15 -19
- blksprs/ops/misc/row_wise.py +39 -54
- blksprs/ops/softmax.py +30 -44
- blksprs/utils/autotuning.py +78 -0
- blksprs/utils/tools.py +0 -28
- blksprs/utils/validation.py +3 -0
- {blksprs-2.0rc4.dist-info → blksprs-2.0rc7.dist-info}/METADATA +18 -5
- blksprs-2.0rc7.dist-info/RECORD +23 -0
- blksprs-2.0rc4.dist-info/RECORD +0 -22
- {blksprs-2.0rc4.dist-info → blksprs-2.0rc7.dist-info}/WHEEL +0 -0
- {blksprs-2.0rc4.dist-info → blksprs-2.0rc7.dist-info}/top_level.txt +0 -0
blksprs/ops/flow.py
CHANGED
|
@@ -5,7 +5,8 @@ from torch._library import triton_op
|
|
|
5
5
|
from torch._library.triton import wrap_triton
|
|
6
6
|
from triton import language as tl
|
|
7
7
|
|
|
8
|
-
from blksprs.utils.tools import stride
|
|
8
|
+
from blksprs.utils.tools import stride
|
|
9
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
@triton_op("blksprs::flow_pull", mutates_args={})
|
|
@@ -43,7 +44,8 @@ def flow_pull_forward(x: Tensor, sparsity_layout_o: Tensor,
|
|
|
43
44
|
|
|
44
45
|
@triton.autotune(
|
|
45
46
|
configs=get_autotune_configs(),
|
|
46
|
-
key=[],
|
|
47
|
+
key=["sparsity_block_size"],
|
|
48
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
47
49
|
reset_to_zero=["o"]
|
|
48
50
|
)
|
|
49
51
|
@triton.jit
|
|
@@ -61,9 +63,6 @@ def flow_pull_kernel(x,
|
|
|
61
63
|
pid_row = tl.program_id(axis=1)
|
|
62
64
|
pid_col = tl.program_id(axis=2)
|
|
63
65
|
|
|
64
|
-
# Get valid triton block size
|
|
65
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
66
|
-
|
|
67
66
|
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
68
67
|
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
69
68
|
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
@@ -86,21 +85,17 @@ def flow_pull_kernel(x,
|
|
|
86
85
|
|
|
87
86
|
if rev_idx_spa >= 0:
|
|
88
87
|
blk_x_idx = (rev_idx_spa * x_b_s +
|
|
89
|
-
((pid_row *
|
|
90
|
-
((pid_col *
|
|
91
|
-
blk_x_msk = (
|
|
92
|
-
|
|
93
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
94
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
88
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
89
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
90
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
91
|
+
blk_x_idx < x_b * x_b_s)
|
|
95
92
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
96
93
|
|
|
97
94
|
blk_o_idx = (pid_blk * o_b_s +
|
|
98
|
-
((pid_row *
|
|
99
|
-
((pid_col *
|
|
100
|
-
blk_o_msk = (
|
|
101
|
-
|
|
102
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
103
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
95
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
96
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
97
|
+
blk_o_msk = (blk_o_idx >= 0 and
|
|
98
|
+
blk_o_idx < o_b * o_b_s)
|
|
104
99
|
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
105
100
|
|
|
106
101
|
|
|
@@ -138,7 +133,8 @@ def flow_push_forward(x: Tensor, sparsity_layout_x: Tensor, sparsity_lut: Tensor
|
|
|
138
133
|
|
|
139
134
|
@triton.autotune(
|
|
140
135
|
configs=get_autotune_configs(),
|
|
141
|
-
key=[],
|
|
136
|
+
key=["sparsity_block_size"],
|
|
137
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
142
138
|
reset_to_zero=["o"]
|
|
143
139
|
)
|
|
144
140
|
@triton.jit
|
|
@@ -156,9 +152,6 @@ def flow_push_kernel(x,
|
|
|
156
152
|
pid_row = tl.program_id(axis=1)
|
|
157
153
|
pid_col = tl.program_id(axis=2)
|
|
158
154
|
|
|
159
|
-
# Get valid triton block size
|
|
160
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
161
|
-
|
|
162
155
|
# Get sparsity index of current input block consisting of its batch, row, and column index
|
|
163
156
|
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
164
157
|
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
@@ -181,19 +174,15 @@ def flow_push_kernel(x,
|
|
|
181
174
|
|
|
182
175
|
if rev_idx_spa >= 0:
|
|
183
176
|
blk_x_idx = (pid_blk * x_b_s +
|
|
184
|
-
((pid_row *
|
|
185
|
-
((pid_col *
|
|
186
|
-
blk_x_msk = (
|
|
187
|
-
|
|
188
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
189
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
177
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
178
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
179
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
180
|
+
blk_x_idx < x_b * x_b_s)
|
|
190
181
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
191
182
|
|
|
192
183
|
blk_o_idx = (rev_idx_spa * o_b_s +
|
|
193
|
-
((pid_row *
|
|
194
|
-
((pid_col *
|
|
195
|
-
blk_o_msk = (
|
|
196
|
-
|
|
197
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
198
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
184
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
185
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
186
|
+
blk_o_msk = (blk_o_idx >= 0 and
|
|
187
|
+
blk_o_idx < o_b * o_b_s)
|
|
199
188
|
tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
blksprs/ops/matmul.py
CHANGED
|
@@ -6,7 +6,8 @@ from triton import language as tl
|
|
|
6
6
|
|
|
7
7
|
from blksprs.ops.transpose import transpose
|
|
8
8
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
9
|
-
from blksprs.utils.tools import stride
|
|
9
|
+
from blksprs.utils.tools import stride
|
|
10
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
10
11
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
11
12
|
validate_sparsity, validate_sparsity_block_size, validate_dtype_float
|
|
12
13
|
|
|
@@ -117,7 +118,8 @@ def matmul_backward(ctx, grad_output):
|
|
|
117
118
|
|
|
118
119
|
@triton.autotune(
|
|
119
120
|
configs=get_autotune_configs(),
|
|
120
|
-
key=[],
|
|
121
|
+
key=["sparsity_block_size"],
|
|
122
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
121
123
|
reset_to_zero=["o"]
|
|
122
124
|
)
|
|
123
125
|
@triton.jit
|
|
@@ -141,9 +143,6 @@ def matmul_kernel(x,
|
|
|
141
143
|
pid_row = tl.program_id(axis=1)
|
|
142
144
|
pid_col = tl.program_id(axis=2)
|
|
143
145
|
|
|
144
|
-
# Get valid triton block size
|
|
145
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
146
|
-
|
|
147
146
|
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
148
147
|
spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
|
|
149
148
|
spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
@@ -161,11 +160,11 @@ def matmul_kernel(x,
|
|
|
161
160
|
buf = tl.zeros(shape=(TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), dtype=tl.float32)
|
|
162
161
|
|
|
163
162
|
# Slide over triton block sized segments of input tensors
|
|
164
|
-
for i_seg_tri in range(0, tl.cdiv(s_l_x_c * sparsity_block_size,
|
|
163
|
+
for i_seg_tri in range(0, tl.cdiv(s_l_x_c * sparsity_block_size, TRITON_BLOCK_SIZE)):
|
|
165
164
|
# Convert to segment index of sparsity layout
|
|
166
|
-
i_seg_spa = (i_seg_tri *
|
|
165
|
+
i_seg_spa = (i_seg_tri * TRITON_BLOCK_SIZE) // sparsity_block_size
|
|
167
166
|
# Calculate the triton segment index within a block
|
|
168
|
-
i_seg_tri_mod = i_seg_tri % (sparsity_block_size //
|
|
167
|
+
i_seg_tri_mod = i_seg_tri % (sparsity_block_size // TRITON_BLOCK_SIZE)
|
|
169
168
|
|
|
170
169
|
# Get reverse sparsity indices for input tensors x and y
|
|
171
170
|
# These are either -1 if the block is empty or equal to the index of the block in the sparse tensor
|
|
@@ -185,23 +184,23 @@ def matmul_kernel(x,
|
|
|
185
184
|
# If both blocks are present commence calculation
|
|
186
185
|
if rev_idx_spa_x >= 0 and rev_idx_spa_y >= 0:
|
|
187
186
|
blk_x_idx = ((rev_idx_spa_x * x_b_s) +
|
|
188
|
-
((pid_row *
|
|
189
|
-
((i_seg_tri_mod *
|
|
187
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
188
|
+
((i_seg_tri_mod * TRITON_BLOCK_SIZE +
|
|
190
189
|
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
191
190
|
blk_x_msk = ((blk_x_idx >= 0 and
|
|
192
191
|
blk_x_idx < x_b * x_b_s) and
|
|
193
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] <
|
|
194
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] <
|
|
192
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < TRITON_BLOCK_SIZE and
|
|
193
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < TRITON_BLOCK_SIZE))
|
|
195
194
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
196
195
|
|
|
197
196
|
blk_y_idx = ((rev_idx_spa_y * y_b_s) +
|
|
198
|
-
((i_seg_tri_mod *
|
|
197
|
+
((i_seg_tri_mod * TRITON_BLOCK_SIZE +
|
|
199
198
|
tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
|
|
200
|
-
((pid_col *
|
|
199
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
|
|
201
200
|
blk_y_msk = ((blk_y_idx >= 0 and
|
|
202
201
|
blk_y_idx < y_b * y_b_s) and
|
|
203
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] <
|
|
204
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] <
|
|
202
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < TRITON_BLOCK_SIZE and
|
|
203
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < TRITON_BLOCK_SIZE))
|
|
205
204
|
blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
|
|
206
205
|
|
|
207
206
|
# Perform matrix multiplication
|
|
@@ -212,12 +211,12 @@ def matmul_kernel(x,
|
|
|
212
211
|
|
|
213
212
|
# Store output
|
|
214
213
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
215
|
-
((pid_row *
|
|
216
|
-
((pid_col *
|
|
214
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
215
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
217
216
|
blk_o_msk = ((blk_o_idx >= 0 and
|
|
218
217
|
blk_o_idx < o_b * o_b_s) and
|
|
219
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] <
|
|
220
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] <
|
|
218
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < TRITON_BLOCK_SIZE and
|
|
219
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < TRITON_BLOCK_SIZE))
|
|
221
220
|
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
|
222
221
|
|
|
223
222
|
|
|
@@ -6,11 +6,13 @@ from torch._library.triton import wrap_triton
|
|
|
6
6
|
from triton import language as tl
|
|
7
7
|
|
|
8
8
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
9
|
-
from blksprs.utils.tools import stride
|
|
9
|
+
from blksprs.utils.tools import stride
|
|
10
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
10
11
|
from blksprs.utils.validation import validate_contiguous, validate_device, \
|
|
11
12
|
validate_sparsity_block_size
|
|
12
13
|
|
|
13
14
|
|
|
15
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
14
16
|
def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
15
17
|
sparsity_block_size: int) -> BlksprsTensor:
|
|
16
18
|
"""Performs a broadcast and subsequent addition of two dense tensors x and y. Returns a block-sparse tensor in
|
|
@@ -87,7 +89,8 @@ def broadcast_add_forward(x: Tensor, y: Tensor,
|
|
|
87
89
|
|
|
88
90
|
@triton.autotune(
|
|
89
91
|
configs=get_autotune_configs(),
|
|
90
|
-
key=[],
|
|
92
|
+
key=["sparsity_block_size"],
|
|
93
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
91
94
|
reset_to_zero=["o"]
|
|
92
95
|
)
|
|
93
96
|
@triton.jit
|
|
@@ -105,9 +108,6 @@ def broadcast_add_kernel(x,
|
|
|
105
108
|
pid_row = tl.program_id(axis=1)
|
|
106
109
|
pid_col = tl.program_id(axis=2)
|
|
107
110
|
|
|
108
|
-
# Get valid triton block size
|
|
109
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
110
|
-
|
|
111
111
|
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
112
112
|
spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
|
|
113
113
|
spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
@@ -123,20 +123,18 @@ def broadcast_add_kernel(x,
|
|
|
123
123
|
|
|
124
124
|
# Load x block
|
|
125
125
|
blk_x_idx = (spa_bat_o * x_b_s +
|
|
126
|
-
((pid_row *
|
|
126
|
+
((pid_row * TRITON_BLOCK_SIZE + spa_row_o * sparsity_block_size +
|
|
127
127
|
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
128
|
-
blk_x_msk = (
|
|
129
|
-
|
|
130
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
128
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
129
|
+
blk_x_idx < x_b * x_b_s)
|
|
131
130
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
132
131
|
|
|
133
132
|
# Load y block
|
|
134
133
|
blk_y_idx = (spa_bat_o * y_b_s +
|
|
135
|
-
((pid_col *
|
|
134
|
+
((pid_col * TRITON_BLOCK_SIZE + spa_col_o * sparsity_block_size +
|
|
136
135
|
tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
|
|
137
|
-
blk_y_msk = (
|
|
138
|
-
|
|
139
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
136
|
+
blk_y_msk = (blk_y_idx >= 0 and
|
|
137
|
+
blk_y_idx < y_b * y_b_s)
|
|
140
138
|
blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
|
|
141
139
|
|
|
142
140
|
# Compute sum
|
|
@@ -145,10 +143,8 @@ def broadcast_add_kernel(x,
|
|
|
145
143
|
|
|
146
144
|
# Store result
|
|
147
145
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
148
|
-
((pid_row *
|
|
149
|
-
((pid_col *
|
|
150
|
-
blk_o_msk = (
|
|
151
|
-
|
|
152
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
153
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
146
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
147
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
148
|
+
blk_o_msk = (blk_o_idx >= 0 and
|
|
149
|
+
blk_o_idx < o_b * o_b_s)
|
|
154
150
|
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
blksprs/ops/misc/row_wise.py
CHANGED
|
@@ -4,8 +4,9 @@ from torch import Tensor
|
|
|
4
4
|
from torch._library.triton import wrap_triton, triton_op
|
|
5
5
|
from triton import language as tl
|
|
6
6
|
|
|
7
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
7
8
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
8
|
-
from blksprs.utils.tools import stride
|
|
9
|
+
from blksprs.utils.tools import stride
|
|
9
10
|
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, validate_sparsity, \
|
|
10
11
|
validate_sparsity_block_size
|
|
11
12
|
|
|
@@ -94,9 +95,11 @@ def row_wise_sum_forward(x: Tensor, sparsity_lut: Tensor,
|
|
|
94
95
|
return output
|
|
95
96
|
|
|
96
97
|
|
|
98
|
+
# noinspection PyUnusedLocal
|
|
97
99
|
@triton.autotune(
|
|
98
100
|
configs=get_autotune_configs(),
|
|
99
|
-
key=[],
|
|
101
|
+
key=["sparsity_block_size"],
|
|
102
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
100
103
|
reset_to_zero=["o"]
|
|
101
104
|
)
|
|
102
105
|
@triton.jit
|
|
@@ -114,9 +117,6 @@ def row_wise_sum_kernel(x,
|
|
|
114
117
|
pid_row = tl.program_id(axis=1)
|
|
115
118
|
pid_col = tl.program_id(axis=2)
|
|
116
119
|
|
|
117
|
-
# Get valid triton block size
|
|
118
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
119
|
-
|
|
120
120
|
# Get position of current sparsity block consisting of its batch and row index
|
|
121
121
|
spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
|
|
122
122
|
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
|
|
@@ -137,23 +137,19 @@ def row_wise_sum_kernel(x,
|
|
|
137
137
|
return
|
|
138
138
|
|
|
139
139
|
blk_idx = ((pid_blk * x_b_s) +
|
|
140
|
-
((pid_row *
|
|
141
|
-
((pid_col *
|
|
142
|
-
blk_msk = (
|
|
143
|
-
|
|
144
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
145
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
140
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
141
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
142
|
+
blk_msk = (blk_idx >= 0 and
|
|
143
|
+
blk_idx < x_b * x_b_s)
|
|
146
144
|
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
147
145
|
|
|
148
146
|
buf = tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
|
|
149
147
|
|
|
150
148
|
o_idx = (rev_idx_spa * o_b_s +
|
|
151
|
-
((pid_row *
|
|
149
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
152
150
|
(tl.arange(0, 1))[None, :])
|
|
153
|
-
o_msk = (
|
|
154
|
-
|
|
155
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
156
|
-
tl.arange(0, 1)[None, :] < val_tbs))
|
|
151
|
+
o_msk = (o_idx >= 0 and
|
|
152
|
+
o_idx < o_b * o_b_s)
|
|
157
153
|
tl.atomic_add(o + o_idx, buf, o_msk)
|
|
158
154
|
|
|
159
155
|
|
|
@@ -180,6 +176,8 @@ def row_wise_max(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
|
|
|
180
176
|
of the input and the sparsity layout of the output tensor.
|
|
181
177
|
|
|
182
178
|
"""
|
|
179
|
+
# TODO Fix for triton bug, see https://github.com/triton-lang/triton/issues/6376
|
|
180
|
+
x = torch.where(x == -0.0, torch.tensor(0.0), x)
|
|
183
181
|
x = x.contiguous()
|
|
184
182
|
|
|
185
183
|
validate_dimensions(x)
|
|
@@ -214,7 +212,7 @@ def row_wise_max_forward(x: Tensor, sparsity_lut: Tensor,
|
|
|
214
212
|
output = torch.full(size=(n_sparse_blocks_output,
|
|
215
213
|
sparsity_block_size,
|
|
216
214
|
1 if flag_slice_only else sparsity_block_size),
|
|
217
|
-
fill_value=
|
|
215
|
+
fill_value=torch.finfo(x.dtype).min,
|
|
218
216
|
device=x.device)
|
|
219
217
|
|
|
220
218
|
x_b, x_r, x_c = x.size()
|
|
@@ -243,9 +241,11 @@ def row_wise_max_forward(x: Tensor, sparsity_lut: Tensor,
|
|
|
243
241
|
return output
|
|
244
242
|
|
|
245
243
|
|
|
244
|
+
# noinspection PyUnusedLocal
|
|
246
245
|
@triton.autotune(
|
|
247
246
|
configs=get_autotune_configs(),
|
|
248
|
-
key=[],
|
|
247
|
+
key=["sparsity_block_size"],
|
|
248
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
249
249
|
restore_value=["o"]
|
|
250
250
|
)
|
|
251
251
|
@triton.jit
|
|
@@ -263,9 +263,6 @@ def row_wise_max_kernel(x,
|
|
|
263
263
|
pid_row = tl.program_id(axis=1)
|
|
264
264
|
pid_col = tl.program_id(axis=2)
|
|
265
265
|
|
|
266
|
-
# Get valid triton block size
|
|
267
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
268
|
-
|
|
269
266
|
# Get position of current sparsity block consisting of its batch and row index
|
|
270
267
|
spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
|
|
271
268
|
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
|
|
@@ -286,23 +283,19 @@ def row_wise_max_kernel(x,
|
|
|
286
283
|
return
|
|
287
284
|
|
|
288
285
|
blk_idx = ((pid_blk * x_b_s) +
|
|
289
|
-
((pid_row *
|
|
290
|
-
((pid_col *
|
|
291
|
-
blk_msk = (
|
|
292
|
-
|
|
293
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
294
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
286
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
287
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
288
|
+
blk_msk = (blk_idx >= 0 and
|
|
289
|
+
blk_idx < x_b * x_b_s)
|
|
295
290
|
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
296
291
|
|
|
297
292
|
buf = tl.reshape(tl.max(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
|
|
298
293
|
|
|
299
294
|
o_idx = (rev_idx_spa * o_b_s +
|
|
300
|
-
((pid_row *
|
|
295
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
301
296
|
(tl.arange(0, 1))[None, :])
|
|
302
|
-
o_msk = (
|
|
303
|
-
|
|
304
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
305
|
-
tl.arange(0, 1)[None, :] < val_tbs))
|
|
297
|
+
o_msk = (o_idx >= 0 and
|
|
298
|
+
o_idx < o_b * o_b_s)
|
|
306
299
|
tl.atomic_max(o + o_idx, buf, o_msk)
|
|
307
300
|
|
|
308
301
|
|
|
@@ -371,7 +364,7 @@ def row_wise_add_forward(x: Tensor, sparsity_lut_x: Tensor,
|
|
|
371
364
|
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
372
365
|
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
373
366
|
|
|
374
|
-
(kernel_blocksparse_row_wise_add[triton_grid]
|
|
367
|
+
(wrap_triton(kernel_blocksparse_row_wise_add)[triton_grid]
|
|
375
368
|
(x,
|
|
376
369
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
377
370
|
sparsity_lut_x, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
@@ -387,7 +380,8 @@ def row_wise_add_forward(x: Tensor, sparsity_lut_x: Tensor,
|
|
|
387
380
|
|
|
388
381
|
@triton.autotune(
|
|
389
382
|
configs=get_autotune_configs(),
|
|
390
|
-
key=[],
|
|
383
|
+
key=["sparsity_block_size"],
|
|
384
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
391
385
|
reset_to_zero=["o"]
|
|
392
386
|
)
|
|
393
387
|
@triton.jit
|
|
@@ -406,9 +400,6 @@ def kernel_blocksparse_row_wise_add(x,
|
|
|
406
400
|
pid_row = tl.program_id(axis=1)
|
|
407
401
|
pid_col = tl.program_id(axis=2)
|
|
408
402
|
|
|
409
|
-
# Get valid triton block size
|
|
410
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
411
|
-
|
|
412
403
|
# Get position of current sparsity block consisting of its batch and row index
|
|
413
404
|
spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
|
|
414
405
|
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
|
|
@@ -430,22 +421,18 @@ def kernel_blocksparse_row_wise_add(x,
|
|
|
430
421
|
|
|
431
422
|
# Load x block
|
|
432
423
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
433
|
-
((pid_row *
|
|
434
|
-
((pid_col *
|
|
435
|
-
blk_x_msk = (
|
|
436
|
-
|
|
437
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
438
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
424
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
425
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
426
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
427
|
+
blk_x_idx < x_b * x_b_s)
|
|
439
428
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
440
429
|
|
|
441
430
|
# Load sum block
|
|
442
431
|
blk_s_idx = (rev_idx_spa_s * y_b_s +
|
|
443
|
-
((pid_row *
|
|
432
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
|
|
444
433
|
(tl.arange(0, 1) * y_c_s)[None, :])
|
|
445
|
-
blk_s_msk = (
|
|
446
|
-
|
|
447
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
448
|
-
tl.arange(0, 1)[None, :] < val_tbs))
|
|
434
|
+
blk_s_msk = (blk_s_idx >= 0 and
|
|
435
|
+
blk_s_idx < y_b * y_b_s)
|
|
449
436
|
blk_s = tl.load(y + blk_s_idx, mask=blk_s_msk)
|
|
450
437
|
|
|
451
438
|
# Compute exp
|
|
@@ -453,10 +440,8 @@ def kernel_blocksparse_row_wise_add(x,
|
|
|
453
440
|
|
|
454
441
|
# Store block
|
|
455
442
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
456
|
-
((pid_row *
|
|
457
|
-
((pid_col *
|
|
458
|
-
blk_o_msk = (
|
|
459
|
-
|
|
460
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
461
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
443
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
444
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
445
|
+
blk_o_msk = (blk_o_idx >= 0 and
|
|
446
|
+
blk_o_idx < o_b * o_b_s)
|
|
462
447
|
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
blksprs/ops/softmax.py
CHANGED
|
@@ -7,7 +7,8 @@ from triton import language as tl
|
|
|
7
7
|
|
|
8
8
|
from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
|
|
9
9
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
10
|
-
from blksprs.utils.tools import stride
|
|
10
|
+
from blksprs.utils.tools import stride
|
|
11
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
11
12
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
12
13
|
validate_sparsity, validate_sparsity_block_size, validate_dtype_float_32
|
|
13
14
|
|
|
@@ -114,7 +115,8 @@ def softmax_backward(ctx, grad_output):
|
|
|
114
115
|
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
115
116
|
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
116
117
|
|
|
117
|
-
|
|
118
|
+
# TODO wrap
|
|
119
|
+
(softmax_kernel_grad[triton_grid]
|
|
118
120
|
(grad_output,
|
|
119
121
|
o_b, o_b_s, o_r_s, o_c_s,
|
|
120
122
|
o,
|
|
@@ -133,7 +135,8 @@ def softmax_backward(ctx, grad_output):
|
|
|
133
135
|
|
|
134
136
|
@triton.autotune(
|
|
135
137
|
configs=get_autotune_configs(),
|
|
136
|
-
key=[],
|
|
138
|
+
key=["sparsity_block_size"],
|
|
139
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
137
140
|
reset_to_zero=["o"]
|
|
138
141
|
)
|
|
139
142
|
@triton.jit
|
|
@@ -151,9 +154,6 @@ def softmax_kernel(x,
|
|
|
151
154
|
pid_row = tl.program_id(axis=1)
|
|
152
155
|
pid_col = tl.program_id(axis=2)
|
|
153
156
|
|
|
154
|
-
# Get valid triton block size
|
|
155
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
156
|
-
|
|
157
157
|
# Get position of current sparsity block consisting of its batch and row index
|
|
158
158
|
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
159
159
|
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
@@ -172,22 +172,18 @@ def softmax_kernel(x,
|
|
|
172
172
|
if rev_idx_spa_s >= 0:
|
|
173
173
|
# Load x block
|
|
174
174
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
175
|
-
((pid_row *
|
|
176
|
-
((pid_col *
|
|
177
|
-
blk_x_msk = (
|
|
178
|
-
|
|
179
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
180
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
175
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
176
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
177
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
178
|
+
blk_x_idx < x_b * x_b_s)
|
|
181
179
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
182
180
|
|
|
183
181
|
# Load sum block
|
|
184
182
|
blk_s_idx = (rev_idx_spa_s * s_b_s +
|
|
185
|
-
((pid_row *
|
|
183
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
|
|
186
184
|
(tl.arange(0, 1) * s_c_s)[None, :])
|
|
187
|
-
blk_s_msk = (
|
|
188
|
-
|
|
189
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
190
|
-
tl.arange(0, 1)[None, :] < val_tbs))
|
|
185
|
+
blk_s_msk = (blk_s_idx >= 0 and
|
|
186
|
+
blk_s_idx < s_b * s_b_s)
|
|
191
187
|
blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
|
|
192
188
|
|
|
193
189
|
# Compute softmax
|
|
@@ -199,7 +195,8 @@ def softmax_kernel(x,
|
|
|
199
195
|
|
|
200
196
|
@triton.autotune(
|
|
201
197
|
configs=get_autotune_configs(),
|
|
202
|
-
key=[],
|
|
198
|
+
key=["sparsity_block_size"],
|
|
199
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
203
200
|
reset_to_zero=["o"]
|
|
204
201
|
)
|
|
205
202
|
@triton.jit
|
|
@@ -221,9 +218,6 @@ def softmax_kernel_grad(g,
|
|
|
221
218
|
pid_row = tl.program_id(axis=1)
|
|
222
219
|
pid_col = tl.program_id(axis=2)
|
|
223
220
|
|
|
224
|
-
# Get valid triton block size
|
|
225
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
226
|
-
|
|
227
221
|
# Get position of current sparsity block consisting of its batch and row index
|
|
228
222
|
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
229
223
|
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
@@ -240,41 +234,33 @@ def softmax_kernel_grad(g,
|
|
|
240
234
|
|
|
241
235
|
if rev_idx_spa_s >= 0:
|
|
242
236
|
blk_s_idx = (rev_idx_spa_s * s_b_s +
|
|
243
|
-
((pid_row *
|
|
237
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
|
|
244
238
|
(tl.arange(0, 1) * s_c_s)[None, :])
|
|
245
|
-
blk_s_msk = (
|
|
246
|
-
|
|
247
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
248
|
-
tl.arange(0, 1)[None, :] < val_tbs))
|
|
239
|
+
blk_s_msk = (blk_s_idx >= 0 and
|
|
240
|
+
blk_s_idx < s_b * s_b_s)
|
|
249
241
|
blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
|
|
250
242
|
|
|
251
243
|
blk_g_idx = ((pid_blk * g_b_s) +
|
|
252
|
-
((pid_row *
|
|
253
|
-
((pid_col *
|
|
254
|
-
blk_g_msk = (
|
|
255
|
-
|
|
256
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
257
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
244
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_r_s)[:, None] +
|
|
245
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_c_s)[None, :])
|
|
246
|
+
blk_g_msk = (blk_g_idx >= 0 and
|
|
247
|
+
blk_g_idx < g_b * g_b_s)
|
|
258
248
|
blk_g = tl.load(g + blk_g_idx, mask=blk_g_msk)
|
|
259
249
|
|
|
260
250
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
261
|
-
((pid_row *
|
|
262
|
-
((pid_col *
|
|
263
|
-
blk_x_msk = (
|
|
264
|
-
|
|
265
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
266
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
251
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
252
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
253
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
254
|
+
blk_x_idx < x_b * x_b_s)
|
|
267
255
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
268
256
|
|
|
269
257
|
buf = blk_x * (blk_g - blk_s)
|
|
270
258
|
|
|
271
259
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
272
|
-
((pid_row *
|
|
273
|
-
((pid_col *
|
|
274
|
-
blk_o_msk = (
|
|
275
|
-
|
|
276
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
277
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
260
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
261
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
262
|
+
blk_o_msk = (blk_o_idx >= 0 and
|
|
263
|
+
blk_o_idx < o_b * o_b_s)
|
|
278
264
|
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
|
279
265
|
|
|
280
266
|
|