blksprs 2.0rc4__py3-none-any.whl → 2.0rc6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/layouting/distribution_layout.py +11 -15
- blksprs/layouting/sparsity_layout.py +26 -31
- blksprs/ops/conversion.py +45 -63
- blksprs/ops/distribution.py +37 -56
- blksprs/ops/flow.py +22 -33
- blksprs/ops/matmul.py +19 -20
- blksprs/ops/misc/broadcast_ops.py +14 -19
- blksprs/ops/misc/row_wise.py +35 -54
- blksprs/ops/softmax.py +30 -44
- blksprs/utils/autotuning.py +78 -0
- blksprs/utils/tools.py +6 -25
- blksprs/utils/validation.py +3 -0
- {blksprs-2.0rc4.dist-info → blksprs-2.0rc6.dist-info}/METADATA +14 -5
- blksprs-2.0rc6.dist-info/RECORD +23 -0
- blksprs-2.0rc4.dist-info/RECORD +0 -22
- {blksprs-2.0rc4.dist-info → blksprs-2.0rc6.dist-info}/WHEEL +0 -0
- {blksprs-2.0rc4.dist-info → blksprs-2.0rc6.dist-info}/top_level.txt +0 -0
blksprs/ops/flow.py
CHANGED
|
@@ -5,7 +5,8 @@ from torch._library import triton_op
|
|
|
5
5
|
from torch._library.triton import wrap_triton
|
|
6
6
|
from triton import language as tl
|
|
7
7
|
|
|
8
|
-
from blksprs.utils.tools import stride
|
|
8
|
+
from blksprs.utils.tools import stride
|
|
9
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
@triton_op("blksprs::flow_pull", mutates_args={})
|
|
@@ -43,7 +44,8 @@ def flow_pull_forward(x: Tensor, sparsity_layout_o: Tensor,
|
|
|
43
44
|
|
|
44
45
|
@triton.autotune(
|
|
45
46
|
configs=get_autotune_configs(),
|
|
46
|
-
key=[],
|
|
47
|
+
key=["sparsity_block_size"],
|
|
48
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
47
49
|
reset_to_zero=["o"]
|
|
48
50
|
)
|
|
49
51
|
@triton.jit
|
|
@@ -61,9 +63,6 @@ def flow_pull_kernel(x,
|
|
|
61
63
|
pid_row = tl.program_id(axis=1)
|
|
62
64
|
pid_col = tl.program_id(axis=2)
|
|
63
65
|
|
|
64
|
-
# Get valid triton block size
|
|
65
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
66
|
-
|
|
67
66
|
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
68
67
|
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
69
68
|
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
@@ -86,21 +85,17 @@ def flow_pull_kernel(x,
|
|
|
86
85
|
|
|
87
86
|
if rev_idx_spa >= 0:
|
|
88
87
|
blk_x_idx = (rev_idx_spa * x_b_s +
|
|
89
|
-
((pid_row *
|
|
90
|
-
((pid_col *
|
|
91
|
-
blk_x_msk = (
|
|
92
|
-
|
|
93
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
94
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
88
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
89
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
90
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
91
|
+
blk_x_idx < x_b * x_b_s)
|
|
95
92
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
96
93
|
|
|
97
94
|
blk_o_idx = (pid_blk * o_b_s +
|
|
98
|
-
((pid_row *
|
|
99
|
-
((pid_col *
|
|
100
|
-
blk_o_msk = (
|
|
101
|
-
|
|
102
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
103
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
95
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
96
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
97
|
+
blk_o_msk = (blk_o_idx >= 0 and
|
|
98
|
+
blk_o_idx < o_b * o_b_s)
|
|
104
99
|
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
105
100
|
|
|
106
101
|
|
|
@@ -138,7 +133,8 @@ def flow_push_forward(x: Tensor, sparsity_layout_x: Tensor, sparsity_lut: Tensor
|
|
|
138
133
|
|
|
139
134
|
@triton.autotune(
|
|
140
135
|
configs=get_autotune_configs(),
|
|
141
|
-
key=[],
|
|
136
|
+
key=["sparsity_block_size"],
|
|
137
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
142
138
|
reset_to_zero=["o"]
|
|
143
139
|
)
|
|
144
140
|
@triton.jit
|
|
@@ -156,9 +152,6 @@ def flow_push_kernel(x,
|
|
|
156
152
|
pid_row = tl.program_id(axis=1)
|
|
157
153
|
pid_col = tl.program_id(axis=2)
|
|
158
154
|
|
|
159
|
-
# Get valid triton block size
|
|
160
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
161
|
-
|
|
162
155
|
# Get sparsity index of current input block consisting of its batch, row, and column index
|
|
163
156
|
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
164
157
|
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
@@ -181,19 +174,15 @@ def flow_push_kernel(x,
|
|
|
181
174
|
|
|
182
175
|
if rev_idx_spa >= 0:
|
|
183
176
|
blk_x_idx = (pid_blk * x_b_s +
|
|
184
|
-
((pid_row *
|
|
185
|
-
((pid_col *
|
|
186
|
-
blk_x_msk = (
|
|
187
|
-
|
|
188
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
189
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
177
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
178
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
179
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
180
|
+
blk_x_idx < x_b * x_b_s)
|
|
190
181
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
191
182
|
|
|
192
183
|
blk_o_idx = (rev_idx_spa * o_b_s +
|
|
193
|
-
((pid_row *
|
|
194
|
-
((pid_col *
|
|
195
|
-
blk_o_msk = (
|
|
196
|
-
|
|
197
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
198
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
184
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
185
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
186
|
+
blk_o_msk = (blk_o_idx >= 0 and
|
|
187
|
+
blk_o_idx < o_b * o_b_s)
|
|
199
188
|
tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
blksprs/ops/matmul.py
CHANGED
|
@@ -6,7 +6,8 @@ from triton import language as tl
|
|
|
6
6
|
|
|
7
7
|
from blksprs.ops.transpose import transpose
|
|
8
8
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
9
|
-
from blksprs.utils.tools import stride
|
|
9
|
+
from blksprs.utils.tools import stride
|
|
10
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
10
11
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
11
12
|
validate_sparsity, validate_sparsity_block_size, validate_dtype_float
|
|
12
13
|
|
|
@@ -117,7 +118,8 @@ def matmul_backward(ctx, grad_output):
|
|
|
117
118
|
|
|
118
119
|
@triton.autotune(
|
|
119
120
|
configs=get_autotune_configs(),
|
|
120
|
-
key=[],
|
|
121
|
+
key=["sparsity_block_size"],
|
|
122
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
121
123
|
reset_to_zero=["o"]
|
|
122
124
|
)
|
|
123
125
|
@triton.jit
|
|
@@ -141,9 +143,6 @@ def matmul_kernel(x,
|
|
|
141
143
|
pid_row = tl.program_id(axis=1)
|
|
142
144
|
pid_col = tl.program_id(axis=2)
|
|
143
145
|
|
|
144
|
-
# Get valid triton block size
|
|
145
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
146
|
-
|
|
147
146
|
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
148
147
|
spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
|
|
149
148
|
spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
@@ -161,11 +160,11 @@ def matmul_kernel(x,
|
|
|
161
160
|
buf = tl.zeros(shape=(TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), dtype=tl.float32)
|
|
162
161
|
|
|
163
162
|
# Slide over triton block sized segments of input tensors
|
|
164
|
-
for i_seg_tri in range(0, tl.cdiv(s_l_x_c * sparsity_block_size,
|
|
163
|
+
for i_seg_tri in range(0, tl.cdiv(s_l_x_c * sparsity_block_size, TRITON_BLOCK_SIZE)):
|
|
165
164
|
# Convert to segment index of sparsity layout
|
|
166
|
-
i_seg_spa = (i_seg_tri *
|
|
165
|
+
i_seg_spa = (i_seg_tri * TRITON_BLOCK_SIZE) // sparsity_block_size
|
|
167
166
|
# Calculate the triton segment index within a block
|
|
168
|
-
i_seg_tri_mod = i_seg_tri % (sparsity_block_size //
|
|
167
|
+
i_seg_tri_mod = i_seg_tri % (sparsity_block_size // TRITON_BLOCK_SIZE)
|
|
169
168
|
|
|
170
169
|
# Get reverse sparsity indices for input tensors x and y
|
|
171
170
|
# These are either -1 if the block is empty or equal to the index of the block in the sparse tensor
|
|
@@ -185,23 +184,23 @@ def matmul_kernel(x,
|
|
|
185
184
|
# If both blocks are present commence calculation
|
|
186
185
|
if rev_idx_spa_x >= 0 and rev_idx_spa_y >= 0:
|
|
187
186
|
blk_x_idx = ((rev_idx_spa_x * x_b_s) +
|
|
188
|
-
((pid_row *
|
|
189
|
-
((i_seg_tri_mod *
|
|
187
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
188
|
+
((i_seg_tri_mod * TRITON_BLOCK_SIZE +
|
|
190
189
|
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
191
190
|
blk_x_msk = ((blk_x_idx >= 0 and
|
|
192
191
|
blk_x_idx < x_b * x_b_s) and
|
|
193
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] <
|
|
194
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] <
|
|
192
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < TRITON_BLOCK_SIZE and
|
|
193
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < TRITON_BLOCK_SIZE))
|
|
195
194
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
196
195
|
|
|
197
196
|
blk_y_idx = ((rev_idx_spa_y * y_b_s) +
|
|
198
|
-
((i_seg_tri_mod *
|
|
197
|
+
((i_seg_tri_mod * TRITON_BLOCK_SIZE +
|
|
199
198
|
tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
|
|
200
|
-
((pid_col *
|
|
199
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
|
|
201
200
|
blk_y_msk = ((blk_y_idx >= 0 and
|
|
202
201
|
blk_y_idx < y_b * y_b_s) and
|
|
203
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] <
|
|
204
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] <
|
|
202
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < TRITON_BLOCK_SIZE and
|
|
203
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < TRITON_BLOCK_SIZE))
|
|
205
204
|
blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
|
|
206
205
|
|
|
207
206
|
# Perform matrix multiplication
|
|
@@ -212,12 +211,12 @@ def matmul_kernel(x,
|
|
|
212
211
|
|
|
213
212
|
# Store output
|
|
214
213
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
215
|
-
((pid_row *
|
|
216
|
-
((pid_col *
|
|
214
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
215
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
217
216
|
blk_o_msk = ((blk_o_idx >= 0 and
|
|
218
217
|
blk_o_idx < o_b * o_b_s) and
|
|
219
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] <
|
|
220
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] <
|
|
218
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < TRITON_BLOCK_SIZE and
|
|
219
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < TRITON_BLOCK_SIZE))
|
|
221
220
|
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
|
222
221
|
|
|
223
222
|
|
|
@@ -6,7 +6,8 @@ from torch._library.triton import wrap_triton
|
|
|
6
6
|
from triton import language as tl
|
|
7
7
|
|
|
8
8
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
9
|
-
from blksprs.utils.tools import stride
|
|
9
|
+
from blksprs.utils.tools import stride
|
|
10
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
10
11
|
from blksprs.utils.validation import validate_contiguous, validate_device, \
|
|
11
12
|
validate_sparsity_block_size
|
|
12
13
|
|
|
@@ -87,7 +88,8 @@ def broadcast_add_forward(x: Tensor, y: Tensor,
|
|
|
87
88
|
|
|
88
89
|
@triton.autotune(
|
|
89
90
|
configs=get_autotune_configs(),
|
|
90
|
-
key=[],
|
|
91
|
+
key=["sparsity_block_size"],
|
|
92
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
91
93
|
reset_to_zero=["o"]
|
|
92
94
|
)
|
|
93
95
|
@triton.jit
|
|
@@ -105,9 +107,6 @@ def broadcast_add_kernel(x,
|
|
|
105
107
|
pid_row = tl.program_id(axis=1)
|
|
106
108
|
pid_col = tl.program_id(axis=2)
|
|
107
109
|
|
|
108
|
-
# Get valid triton block size
|
|
109
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
110
|
-
|
|
111
110
|
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
112
111
|
spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
|
|
113
112
|
spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
@@ -123,20 +122,18 @@ def broadcast_add_kernel(x,
|
|
|
123
122
|
|
|
124
123
|
# Load x block
|
|
125
124
|
blk_x_idx = (spa_bat_o * x_b_s +
|
|
126
|
-
((pid_row *
|
|
125
|
+
((pid_row * TRITON_BLOCK_SIZE + spa_row_o * sparsity_block_size +
|
|
127
126
|
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
128
|
-
blk_x_msk = (
|
|
129
|
-
|
|
130
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
127
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
128
|
+
blk_x_idx < x_b * x_b_s)
|
|
131
129
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
132
130
|
|
|
133
131
|
# Load y block
|
|
134
132
|
blk_y_idx = (spa_bat_o * y_b_s +
|
|
135
|
-
((pid_col *
|
|
133
|
+
((pid_col * TRITON_BLOCK_SIZE + spa_col_o * sparsity_block_size +
|
|
136
134
|
tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
|
|
137
|
-
blk_y_msk = (
|
|
138
|
-
|
|
139
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
135
|
+
blk_y_msk = (blk_y_idx >= 0 and
|
|
136
|
+
blk_y_idx < y_b * y_b_s)
|
|
140
137
|
blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
|
|
141
138
|
|
|
142
139
|
# Compute sum
|
|
@@ -145,10 +142,8 @@ def broadcast_add_kernel(x,
|
|
|
145
142
|
|
|
146
143
|
# Store result
|
|
147
144
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
148
|
-
((pid_row *
|
|
149
|
-
((pid_col *
|
|
150
|
-
blk_o_msk = (
|
|
151
|
-
|
|
152
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
153
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
145
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
146
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
147
|
+
blk_o_msk = (blk_o_idx >= 0 and
|
|
148
|
+
blk_o_idx < o_b * o_b_s)
|
|
154
149
|
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
blksprs/ops/misc/row_wise.py
CHANGED
|
@@ -5,7 +5,8 @@ from torch._library.triton import wrap_triton, triton_op
|
|
|
5
5
|
from triton import language as tl
|
|
6
6
|
|
|
7
7
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
8
|
-
from blksprs.utils.tools import stride,
|
|
8
|
+
from blksprs.utils.tools import stride, get_autocast_min_val
|
|
9
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
9
10
|
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, validate_sparsity, \
|
|
10
11
|
validate_sparsity_block_size
|
|
11
12
|
|
|
@@ -96,7 +97,8 @@ def row_wise_sum_forward(x: Tensor, sparsity_lut: Tensor,
|
|
|
96
97
|
|
|
97
98
|
@triton.autotune(
|
|
98
99
|
configs=get_autotune_configs(),
|
|
99
|
-
key=[],
|
|
100
|
+
key=["sparsity_block_size"],
|
|
101
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
100
102
|
reset_to_zero=["o"]
|
|
101
103
|
)
|
|
102
104
|
@triton.jit
|
|
@@ -114,9 +116,6 @@ def row_wise_sum_kernel(x,
|
|
|
114
116
|
pid_row = tl.program_id(axis=1)
|
|
115
117
|
pid_col = tl.program_id(axis=2)
|
|
116
118
|
|
|
117
|
-
# Get valid triton block size
|
|
118
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
119
|
-
|
|
120
119
|
# Get position of current sparsity block consisting of its batch and row index
|
|
121
120
|
spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
|
|
122
121
|
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
|
|
@@ -137,23 +136,19 @@ def row_wise_sum_kernel(x,
|
|
|
137
136
|
return
|
|
138
137
|
|
|
139
138
|
blk_idx = ((pid_blk * x_b_s) +
|
|
140
|
-
((pid_row *
|
|
141
|
-
((pid_col *
|
|
142
|
-
blk_msk = (
|
|
143
|
-
|
|
144
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
145
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
139
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
140
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
141
|
+
blk_msk = (blk_idx >= 0 and
|
|
142
|
+
blk_idx < x_b * x_b_s)
|
|
146
143
|
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
147
144
|
|
|
148
145
|
buf = tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
|
|
149
146
|
|
|
150
147
|
o_idx = (rev_idx_spa * o_b_s +
|
|
151
|
-
((pid_row *
|
|
148
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
152
149
|
(tl.arange(0, 1))[None, :])
|
|
153
|
-
o_msk = (
|
|
154
|
-
|
|
155
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
156
|
-
tl.arange(0, 1)[None, :] < val_tbs))
|
|
150
|
+
o_msk = (o_idx >= 0 and
|
|
151
|
+
o_idx < o_b * o_b_s)
|
|
157
152
|
tl.atomic_add(o + o_idx, buf, o_msk)
|
|
158
153
|
|
|
159
154
|
|
|
@@ -214,7 +209,7 @@ def row_wise_max_forward(x: Tensor, sparsity_lut: Tensor,
|
|
|
214
209
|
output = torch.full(size=(n_sparse_blocks_output,
|
|
215
210
|
sparsity_block_size,
|
|
216
211
|
1 if flag_slice_only else sparsity_block_size),
|
|
217
|
-
fill_value=
|
|
212
|
+
fill_value=get_autocast_min_val(),
|
|
218
213
|
device=x.device)
|
|
219
214
|
|
|
220
215
|
x_b, x_r, x_c = x.size()
|
|
@@ -245,7 +240,8 @@ def row_wise_max_forward(x: Tensor, sparsity_lut: Tensor,
|
|
|
245
240
|
|
|
246
241
|
@triton.autotune(
|
|
247
242
|
configs=get_autotune_configs(),
|
|
248
|
-
key=[],
|
|
243
|
+
key=["sparsity_block_size"],
|
|
244
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
249
245
|
restore_value=["o"]
|
|
250
246
|
)
|
|
251
247
|
@triton.jit
|
|
@@ -263,9 +259,6 @@ def row_wise_max_kernel(x,
|
|
|
263
259
|
pid_row = tl.program_id(axis=1)
|
|
264
260
|
pid_col = tl.program_id(axis=2)
|
|
265
261
|
|
|
266
|
-
# Get valid triton block size
|
|
267
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
268
|
-
|
|
269
262
|
# Get position of current sparsity block consisting of its batch and row index
|
|
270
263
|
spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
|
|
271
264
|
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
|
|
@@ -286,23 +279,19 @@ def row_wise_max_kernel(x,
|
|
|
286
279
|
return
|
|
287
280
|
|
|
288
281
|
blk_idx = ((pid_blk * x_b_s) +
|
|
289
|
-
((pid_row *
|
|
290
|
-
((pid_col *
|
|
291
|
-
blk_msk = (
|
|
292
|
-
|
|
293
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
294
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
282
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
283
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
284
|
+
blk_msk = (blk_idx >= 0 and
|
|
285
|
+
blk_idx < x_b * x_b_s)
|
|
295
286
|
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
296
287
|
|
|
297
288
|
buf = tl.reshape(tl.max(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
|
|
298
289
|
|
|
299
290
|
o_idx = (rev_idx_spa * o_b_s +
|
|
300
|
-
((pid_row *
|
|
291
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
301
292
|
(tl.arange(0, 1))[None, :])
|
|
302
|
-
o_msk = (
|
|
303
|
-
|
|
304
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
305
|
-
tl.arange(0, 1)[None, :] < val_tbs))
|
|
293
|
+
o_msk = (o_idx >= 0 and
|
|
294
|
+
o_idx < o_b * o_b_s)
|
|
306
295
|
tl.atomic_max(o + o_idx, buf, o_msk)
|
|
307
296
|
|
|
308
297
|
|
|
@@ -371,7 +360,7 @@ def row_wise_add_forward(x: Tensor, sparsity_lut_x: Tensor,
|
|
|
371
360
|
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
372
361
|
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
373
362
|
|
|
374
|
-
(kernel_blocksparse_row_wise_add[triton_grid]
|
|
363
|
+
(wrap_triton(kernel_blocksparse_row_wise_add)[triton_grid]
|
|
375
364
|
(x,
|
|
376
365
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
377
366
|
sparsity_lut_x, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
@@ -387,7 +376,8 @@ def row_wise_add_forward(x: Tensor, sparsity_lut_x: Tensor,
|
|
|
387
376
|
|
|
388
377
|
@triton.autotune(
|
|
389
378
|
configs=get_autotune_configs(),
|
|
390
|
-
key=[],
|
|
379
|
+
key=["sparsity_block_size"],
|
|
380
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
391
381
|
reset_to_zero=["o"]
|
|
392
382
|
)
|
|
393
383
|
@triton.jit
|
|
@@ -406,9 +396,6 @@ def kernel_blocksparse_row_wise_add(x,
|
|
|
406
396
|
pid_row = tl.program_id(axis=1)
|
|
407
397
|
pid_col = tl.program_id(axis=2)
|
|
408
398
|
|
|
409
|
-
# Get valid triton block size
|
|
410
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
411
|
-
|
|
412
399
|
# Get position of current sparsity block consisting of its batch and row index
|
|
413
400
|
spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
|
|
414
401
|
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
|
|
@@ -430,22 +417,18 @@ def kernel_blocksparse_row_wise_add(x,
|
|
|
430
417
|
|
|
431
418
|
# Load x block
|
|
432
419
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
433
|
-
((pid_row *
|
|
434
|
-
((pid_col *
|
|
435
|
-
blk_x_msk = (
|
|
436
|
-
|
|
437
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
438
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
420
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
421
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
422
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
423
|
+
blk_x_idx < x_b * x_b_s)
|
|
439
424
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
440
425
|
|
|
441
426
|
# Load sum block
|
|
442
427
|
blk_s_idx = (rev_idx_spa_s * y_b_s +
|
|
443
|
-
((pid_row *
|
|
428
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
|
|
444
429
|
(tl.arange(0, 1) * y_c_s)[None, :])
|
|
445
|
-
blk_s_msk = (
|
|
446
|
-
|
|
447
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
448
|
-
tl.arange(0, 1)[None, :] < val_tbs))
|
|
430
|
+
blk_s_msk = (blk_s_idx >= 0 and
|
|
431
|
+
blk_s_idx < y_b * y_b_s)
|
|
449
432
|
blk_s = tl.load(y + blk_s_idx, mask=blk_s_msk)
|
|
450
433
|
|
|
451
434
|
# Compute exp
|
|
@@ -453,10 +436,8 @@ def kernel_blocksparse_row_wise_add(x,
|
|
|
453
436
|
|
|
454
437
|
# Store block
|
|
455
438
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
456
|
-
((pid_row *
|
|
457
|
-
((pid_col *
|
|
458
|
-
blk_o_msk = (
|
|
459
|
-
|
|
460
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
461
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
439
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
440
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
441
|
+
blk_o_msk = (blk_o_idx >= 0 and
|
|
442
|
+
blk_o_idx < o_b * o_b_s)
|
|
462
443
|
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
blksprs/ops/softmax.py
CHANGED
|
@@ -7,7 +7,8 @@ from triton import language as tl
|
|
|
7
7
|
|
|
8
8
|
from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
|
|
9
9
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
10
|
-
from blksprs.utils.tools import stride
|
|
10
|
+
from blksprs.utils.tools import stride
|
|
11
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
11
12
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
12
13
|
validate_sparsity, validate_sparsity_block_size, validate_dtype_float_32
|
|
13
14
|
|
|
@@ -114,7 +115,8 @@ def softmax_backward(ctx, grad_output):
|
|
|
114
115
|
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
115
116
|
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
116
117
|
|
|
117
|
-
|
|
118
|
+
# TODO wrap
|
|
119
|
+
(softmax_kernel_grad[triton_grid]
|
|
118
120
|
(grad_output,
|
|
119
121
|
o_b, o_b_s, o_r_s, o_c_s,
|
|
120
122
|
o,
|
|
@@ -133,7 +135,8 @@ def softmax_backward(ctx, grad_output):
|
|
|
133
135
|
|
|
134
136
|
@triton.autotune(
|
|
135
137
|
configs=get_autotune_configs(),
|
|
136
|
-
key=[],
|
|
138
|
+
key=["sparsity_block_size"],
|
|
139
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
137
140
|
reset_to_zero=["o"]
|
|
138
141
|
)
|
|
139
142
|
@triton.jit
|
|
@@ -151,9 +154,6 @@ def softmax_kernel(x,
|
|
|
151
154
|
pid_row = tl.program_id(axis=1)
|
|
152
155
|
pid_col = tl.program_id(axis=2)
|
|
153
156
|
|
|
154
|
-
# Get valid triton block size
|
|
155
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
156
|
-
|
|
157
157
|
# Get position of current sparsity block consisting of its batch and row index
|
|
158
158
|
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
159
159
|
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
@@ -172,22 +172,18 @@ def softmax_kernel(x,
|
|
|
172
172
|
if rev_idx_spa_s >= 0:
|
|
173
173
|
# Load x block
|
|
174
174
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
175
|
-
((pid_row *
|
|
176
|
-
((pid_col *
|
|
177
|
-
blk_x_msk = (
|
|
178
|
-
|
|
179
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
180
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
175
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
176
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
177
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
178
|
+
blk_x_idx < x_b * x_b_s)
|
|
181
179
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
182
180
|
|
|
183
181
|
# Load sum block
|
|
184
182
|
blk_s_idx = (rev_idx_spa_s * s_b_s +
|
|
185
|
-
((pid_row *
|
|
183
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
|
|
186
184
|
(tl.arange(0, 1) * s_c_s)[None, :])
|
|
187
|
-
blk_s_msk = (
|
|
188
|
-
|
|
189
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
190
|
-
tl.arange(0, 1)[None, :] < val_tbs))
|
|
185
|
+
blk_s_msk = (blk_s_idx >= 0 and
|
|
186
|
+
blk_s_idx < s_b * s_b_s)
|
|
191
187
|
blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
|
|
192
188
|
|
|
193
189
|
# Compute softmax
|
|
@@ -199,7 +195,8 @@ def softmax_kernel(x,
|
|
|
199
195
|
|
|
200
196
|
@triton.autotune(
|
|
201
197
|
configs=get_autotune_configs(),
|
|
202
|
-
key=[],
|
|
198
|
+
key=["sparsity_block_size"],
|
|
199
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
203
200
|
reset_to_zero=["o"]
|
|
204
201
|
)
|
|
205
202
|
@triton.jit
|
|
@@ -221,9 +218,6 @@ def softmax_kernel_grad(g,
|
|
|
221
218
|
pid_row = tl.program_id(axis=1)
|
|
222
219
|
pid_col = tl.program_id(axis=2)
|
|
223
220
|
|
|
224
|
-
# Get valid triton block size
|
|
225
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
226
|
-
|
|
227
221
|
# Get position of current sparsity block consisting of its batch and row index
|
|
228
222
|
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
229
223
|
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
@@ -240,41 +234,33 @@ def softmax_kernel_grad(g,
|
|
|
240
234
|
|
|
241
235
|
if rev_idx_spa_s >= 0:
|
|
242
236
|
blk_s_idx = (rev_idx_spa_s * s_b_s +
|
|
243
|
-
((pid_row *
|
|
237
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
|
|
244
238
|
(tl.arange(0, 1) * s_c_s)[None, :])
|
|
245
|
-
blk_s_msk = (
|
|
246
|
-
|
|
247
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
248
|
-
tl.arange(0, 1)[None, :] < val_tbs))
|
|
239
|
+
blk_s_msk = (blk_s_idx >= 0 and
|
|
240
|
+
blk_s_idx < s_b * s_b_s)
|
|
249
241
|
blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
|
|
250
242
|
|
|
251
243
|
blk_g_idx = ((pid_blk * g_b_s) +
|
|
252
|
-
((pid_row *
|
|
253
|
-
((pid_col *
|
|
254
|
-
blk_g_msk = (
|
|
255
|
-
|
|
256
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
257
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
244
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_r_s)[:, None] +
|
|
245
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_c_s)[None, :])
|
|
246
|
+
blk_g_msk = (blk_g_idx >= 0 and
|
|
247
|
+
blk_g_idx < g_b * g_b_s)
|
|
258
248
|
blk_g = tl.load(g + blk_g_idx, mask=blk_g_msk)
|
|
259
249
|
|
|
260
250
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
261
|
-
((pid_row *
|
|
262
|
-
((pid_col *
|
|
263
|
-
blk_x_msk = (
|
|
264
|
-
|
|
265
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
266
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
251
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
252
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
253
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
254
|
+
blk_x_idx < x_b * x_b_s)
|
|
267
255
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
268
256
|
|
|
269
257
|
buf = blk_x * (blk_g - blk_s)
|
|
270
258
|
|
|
271
259
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
272
|
-
((pid_row *
|
|
273
|
-
((pid_col *
|
|
274
|
-
blk_o_msk = (
|
|
275
|
-
|
|
276
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
277
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
260
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
261
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
262
|
+
blk_o_msk = (blk_o_idx >= 0 and
|
|
263
|
+
blk_o_idx < o_b * o_b_s)
|
|
278
264
|
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
|
279
265
|
|
|
280
266
|
|