blksprs 2.0rc3__py3-none-any.whl → 2.0rc6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/layouting/distribution_layout.py +11 -15
- blksprs/layouting/sparsity_layout.py +26 -31
- blksprs/ops/conversion.py +48 -64
- blksprs/ops/distribution.py +39 -57
- blksprs/ops/flow.py +24 -34
- blksprs/ops/matmul.py +21 -21
- blksprs/ops/misc/broadcast_ops.py +14 -19
- blksprs/ops/misc/row_wise.py +37 -55
- blksprs/ops/softmax.py +34 -46
- blksprs/utils/autotuning.py +78 -0
- blksprs/utils/tools.py +6 -25
- blksprs/utils/validation.py +3 -0
- {blksprs-2.0rc3.dist-info → blksprs-2.0rc6.dist-info}/METADATA +14 -5
- blksprs-2.0rc6.dist-info/RECORD +23 -0
- blksprs-2.0rc3.dist-info/RECORD +0 -22
- {blksprs-2.0rc3.dist-info → blksprs-2.0rc6.dist-info}/WHEEL +0 -0
- {blksprs-2.0rc3.dist-info → blksprs-2.0rc6.dist-info}/top_level.txt +0 -0
blksprs/ops/flow.py
CHANGED
|
@@ -5,14 +5,15 @@ from torch._library import triton_op
|
|
|
5
5
|
from torch._library.triton import wrap_triton
|
|
6
6
|
from triton import language as tl
|
|
7
7
|
|
|
8
|
-
from blksprs.utils.tools import stride
|
|
8
|
+
from blksprs.utils.tools import stride
|
|
9
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
@triton_op("blksprs::flow_pull", mutates_args={})
|
|
12
13
|
def flow_pull_forward(x: Tensor, sparsity_layout_o: Tensor,
|
|
13
14
|
sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
14
15
|
sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
15
|
-
output = torch.
|
|
16
|
+
output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
16
17
|
dtype=x.dtype, device=x.device)
|
|
17
18
|
|
|
18
19
|
x_b, x_r, x_c = x.size()
|
|
@@ -43,7 +44,9 @@ def flow_pull_forward(x: Tensor, sparsity_layout_o: Tensor,
|
|
|
43
44
|
|
|
44
45
|
@triton.autotune(
|
|
45
46
|
configs=get_autotune_configs(),
|
|
46
|
-
key=[],
|
|
47
|
+
key=["sparsity_block_size"],
|
|
48
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
49
|
+
reset_to_zero=["o"]
|
|
47
50
|
)
|
|
48
51
|
@triton.jit
|
|
49
52
|
def flow_pull_kernel(x,
|
|
@@ -60,9 +63,6 @@ def flow_pull_kernel(x,
|
|
|
60
63
|
pid_row = tl.program_id(axis=1)
|
|
61
64
|
pid_col = tl.program_id(axis=2)
|
|
62
65
|
|
|
63
|
-
# Get valid triton block size
|
|
64
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
65
|
-
|
|
66
66
|
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
67
67
|
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
68
68
|
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
@@ -85,21 +85,17 @@ def flow_pull_kernel(x,
|
|
|
85
85
|
|
|
86
86
|
if rev_idx_spa >= 0:
|
|
87
87
|
blk_x_idx = (rev_idx_spa * x_b_s +
|
|
88
|
-
((pid_row *
|
|
89
|
-
((pid_col *
|
|
90
|
-
blk_x_msk = (
|
|
91
|
-
|
|
92
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
93
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
88
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
89
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
90
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
91
|
+
blk_x_idx < x_b * x_b_s)
|
|
94
92
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
95
93
|
|
|
96
94
|
blk_o_idx = (pid_blk * o_b_s +
|
|
97
|
-
((pid_row *
|
|
98
|
-
((pid_col *
|
|
99
|
-
blk_o_msk = (
|
|
100
|
-
|
|
101
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
102
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
95
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
96
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
97
|
+
blk_o_msk = (blk_o_idx >= 0 and
|
|
98
|
+
blk_o_idx < o_b * o_b_s)
|
|
103
99
|
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
104
100
|
|
|
105
101
|
|
|
@@ -137,7 +133,8 @@ def flow_push_forward(x: Tensor, sparsity_layout_x: Tensor, sparsity_lut: Tensor
|
|
|
137
133
|
|
|
138
134
|
@triton.autotune(
|
|
139
135
|
configs=get_autotune_configs(),
|
|
140
|
-
key=[],
|
|
136
|
+
key=["sparsity_block_size"],
|
|
137
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
141
138
|
reset_to_zero=["o"]
|
|
142
139
|
)
|
|
143
140
|
@triton.jit
|
|
@@ -155,9 +152,6 @@ def flow_push_kernel(x,
|
|
|
155
152
|
pid_row = tl.program_id(axis=1)
|
|
156
153
|
pid_col = tl.program_id(axis=2)
|
|
157
154
|
|
|
158
|
-
# Get valid triton block size
|
|
159
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
160
|
-
|
|
161
155
|
# Get sparsity index of current input block consisting of its batch, row, and column index
|
|
162
156
|
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
163
157
|
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
@@ -180,19 +174,15 @@ def flow_push_kernel(x,
|
|
|
180
174
|
|
|
181
175
|
if rev_idx_spa >= 0:
|
|
182
176
|
blk_x_idx = (pid_blk * x_b_s +
|
|
183
|
-
((pid_row *
|
|
184
|
-
((pid_col *
|
|
185
|
-
blk_x_msk = (
|
|
186
|
-
|
|
187
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
188
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
177
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
178
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
179
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
180
|
+
blk_x_idx < x_b * x_b_s)
|
|
189
181
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
190
182
|
|
|
191
183
|
blk_o_idx = (rev_idx_spa * o_b_s +
|
|
192
|
-
((pid_row *
|
|
193
|
-
((pid_col *
|
|
194
|
-
blk_o_msk = (
|
|
195
|
-
|
|
196
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
197
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
184
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
185
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
186
|
+
blk_o_msk = (blk_o_idx >= 0 and
|
|
187
|
+
blk_o_idx < o_b * o_b_s)
|
|
198
188
|
tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
blksprs/ops/matmul.py
CHANGED
|
@@ -6,7 +6,8 @@ from triton import language as tl
|
|
|
6
6
|
|
|
7
7
|
from blksprs.ops.transpose import transpose
|
|
8
8
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
9
|
-
from blksprs.utils.tools import stride
|
|
9
|
+
from blksprs.utils.tools import stride
|
|
10
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
10
11
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
11
12
|
validate_sparsity, validate_sparsity_block_size, validate_dtype_float
|
|
12
13
|
|
|
@@ -60,7 +61,7 @@ def matmul_forward(x: Tensor, y: Tensor,
|
|
|
60
61
|
sparsity_layout_y: Tensor, sparsity_reverse_lut_y: Tensor,
|
|
61
62
|
_: Tensor, sparsity_lut_o: Tensor,
|
|
62
63
|
sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
63
|
-
output = torch.
|
|
64
|
+
output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
64
65
|
dtype=x.dtype, device=x.device)
|
|
65
66
|
|
|
66
67
|
x_b, x_r, x_c = x.size()
|
|
@@ -117,7 +118,9 @@ def matmul_backward(ctx, grad_output):
|
|
|
117
118
|
|
|
118
119
|
@triton.autotune(
|
|
119
120
|
configs=get_autotune_configs(),
|
|
120
|
-
key=[],
|
|
121
|
+
key=["sparsity_block_size"],
|
|
122
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
123
|
+
reset_to_zero=["o"]
|
|
121
124
|
)
|
|
122
125
|
@triton.jit
|
|
123
126
|
def matmul_kernel(x,
|
|
@@ -140,9 +143,6 @@ def matmul_kernel(x,
|
|
|
140
143
|
pid_row = tl.program_id(axis=1)
|
|
141
144
|
pid_col = tl.program_id(axis=2)
|
|
142
145
|
|
|
143
|
-
# Get valid triton block size
|
|
144
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
145
|
-
|
|
146
146
|
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
147
147
|
spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
|
|
148
148
|
spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
@@ -160,11 +160,11 @@ def matmul_kernel(x,
|
|
|
160
160
|
buf = tl.zeros(shape=(TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), dtype=tl.float32)
|
|
161
161
|
|
|
162
162
|
# Slide over triton block sized segments of input tensors
|
|
163
|
-
for i_seg_tri in range(0, tl.cdiv(s_l_x_c * sparsity_block_size,
|
|
163
|
+
for i_seg_tri in range(0, tl.cdiv(s_l_x_c * sparsity_block_size, TRITON_BLOCK_SIZE)):
|
|
164
164
|
# Convert to segment index of sparsity layout
|
|
165
|
-
i_seg_spa = (i_seg_tri *
|
|
165
|
+
i_seg_spa = (i_seg_tri * TRITON_BLOCK_SIZE) // sparsity_block_size
|
|
166
166
|
# Calculate the triton segment index within a block
|
|
167
|
-
i_seg_tri_mod = i_seg_tri % (sparsity_block_size //
|
|
167
|
+
i_seg_tri_mod = i_seg_tri % (sparsity_block_size // TRITON_BLOCK_SIZE)
|
|
168
168
|
|
|
169
169
|
# Get reverse sparsity indices for input tensors x and y
|
|
170
170
|
# These are either -1 if the block is empty or equal to the index of the block in the sparse tensor
|
|
@@ -184,23 +184,23 @@ def matmul_kernel(x,
|
|
|
184
184
|
# If both blocks are present commence calculation
|
|
185
185
|
if rev_idx_spa_x >= 0 and rev_idx_spa_y >= 0:
|
|
186
186
|
blk_x_idx = ((rev_idx_spa_x * x_b_s) +
|
|
187
|
-
((pid_row *
|
|
188
|
-
((i_seg_tri_mod *
|
|
187
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
188
|
+
((i_seg_tri_mod * TRITON_BLOCK_SIZE +
|
|
189
189
|
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
190
190
|
blk_x_msk = ((blk_x_idx >= 0 and
|
|
191
191
|
blk_x_idx < x_b * x_b_s) and
|
|
192
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] <
|
|
193
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] <
|
|
192
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < TRITON_BLOCK_SIZE and
|
|
193
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < TRITON_BLOCK_SIZE))
|
|
194
194
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
195
195
|
|
|
196
196
|
blk_y_idx = ((rev_idx_spa_y * y_b_s) +
|
|
197
|
-
((i_seg_tri_mod *
|
|
197
|
+
((i_seg_tri_mod * TRITON_BLOCK_SIZE +
|
|
198
198
|
tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
|
|
199
|
-
((pid_col *
|
|
199
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
|
|
200
200
|
blk_y_msk = ((blk_y_idx >= 0 and
|
|
201
201
|
blk_y_idx < y_b * y_b_s) and
|
|
202
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] <
|
|
203
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] <
|
|
202
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < TRITON_BLOCK_SIZE and
|
|
203
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < TRITON_BLOCK_SIZE))
|
|
204
204
|
blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
|
|
205
205
|
|
|
206
206
|
# Perform matrix multiplication
|
|
@@ -211,12 +211,12 @@ def matmul_kernel(x,
|
|
|
211
211
|
|
|
212
212
|
# Store output
|
|
213
213
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
214
|
-
((pid_row *
|
|
215
|
-
((pid_col *
|
|
214
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
215
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
216
216
|
blk_o_msk = ((blk_o_idx >= 0 and
|
|
217
217
|
blk_o_idx < o_b * o_b_s) and
|
|
218
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] <
|
|
219
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] <
|
|
218
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < TRITON_BLOCK_SIZE and
|
|
219
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < TRITON_BLOCK_SIZE))
|
|
220
220
|
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
|
221
221
|
|
|
222
222
|
|
|
@@ -6,7 +6,8 @@ from torch._library.triton import wrap_triton
|
|
|
6
6
|
from triton import language as tl
|
|
7
7
|
|
|
8
8
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
9
|
-
from blksprs.utils.tools import stride
|
|
9
|
+
from blksprs.utils.tools import stride
|
|
10
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
10
11
|
from blksprs.utils.validation import validate_contiguous, validate_device, \
|
|
11
12
|
validate_sparsity_block_size
|
|
12
13
|
|
|
@@ -87,7 +88,8 @@ def broadcast_add_forward(x: Tensor, y: Tensor,
|
|
|
87
88
|
|
|
88
89
|
@triton.autotune(
|
|
89
90
|
configs=get_autotune_configs(),
|
|
90
|
-
key=[],
|
|
91
|
+
key=["sparsity_block_size"],
|
|
92
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
91
93
|
reset_to_zero=["o"]
|
|
92
94
|
)
|
|
93
95
|
@triton.jit
|
|
@@ -105,9 +107,6 @@ def broadcast_add_kernel(x,
|
|
|
105
107
|
pid_row = tl.program_id(axis=1)
|
|
106
108
|
pid_col = tl.program_id(axis=2)
|
|
107
109
|
|
|
108
|
-
# Get valid triton block size
|
|
109
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
110
|
-
|
|
111
110
|
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
112
111
|
spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
|
|
113
112
|
spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
@@ -123,20 +122,18 @@ def broadcast_add_kernel(x,
|
|
|
123
122
|
|
|
124
123
|
# Load x block
|
|
125
124
|
blk_x_idx = (spa_bat_o * x_b_s +
|
|
126
|
-
((pid_row *
|
|
125
|
+
((pid_row * TRITON_BLOCK_SIZE + spa_row_o * sparsity_block_size +
|
|
127
126
|
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
128
|
-
blk_x_msk = (
|
|
129
|
-
|
|
130
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
127
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
128
|
+
blk_x_idx < x_b * x_b_s)
|
|
131
129
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
132
130
|
|
|
133
131
|
# Load y block
|
|
134
132
|
blk_y_idx = (spa_bat_o * y_b_s +
|
|
135
|
-
((pid_col *
|
|
133
|
+
((pid_col * TRITON_BLOCK_SIZE + spa_col_o * sparsity_block_size +
|
|
136
134
|
tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
|
|
137
|
-
blk_y_msk = (
|
|
138
|
-
|
|
139
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
135
|
+
blk_y_msk = (blk_y_idx >= 0 and
|
|
136
|
+
blk_y_idx < y_b * y_b_s)
|
|
140
137
|
blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
|
|
141
138
|
|
|
142
139
|
# Compute sum
|
|
@@ -145,10 +142,8 @@ def broadcast_add_kernel(x,
|
|
|
145
142
|
|
|
146
143
|
# Store result
|
|
147
144
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
148
|
-
((pid_row *
|
|
149
|
-
((pid_col *
|
|
150
|
-
blk_o_msk = (
|
|
151
|
-
|
|
152
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
153
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
145
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
146
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
147
|
+
blk_o_msk = (blk_o_idx >= 0 and
|
|
148
|
+
blk_o_idx < o_b * o_b_s)
|
|
154
149
|
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
blksprs/ops/misc/row_wise.py
CHANGED
|
@@ -5,7 +5,8 @@ from torch._library.triton import wrap_triton, triton_op
|
|
|
5
5
|
from triton import language as tl
|
|
6
6
|
|
|
7
7
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
8
|
-
from blksprs.utils.tools import stride,
|
|
8
|
+
from blksprs.utils.tools import stride, get_autocast_min_val
|
|
9
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
9
10
|
from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, validate_sparsity, \
|
|
10
11
|
validate_sparsity_block_size
|
|
11
12
|
|
|
@@ -96,7 +97,8 @@ def row_wise_sum_forward(x: Tensor, sparsity_lut: Tensor,
|
|
|
96
97
|
|
|
97
98
|
@triton.autotune(
|
|
98
99
|
configs=get_autotune_configs(),
|
|
99
|
-
key=[],
|
|
100
|
+
key=["sparsity_block_size"],
|
|
101
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
100
102
|
reset_to_zero=["o"]
|
|
101
103
|
)
|
|
102
104
|
@triton.jit
|
|
@@ -114,9 +116,6 @@ def row_wise_sum_kernel(x,
|
|
|
114
116
|
pid_row = tl.program_id(axis=1)
|
|
115
117
|
pid_col = tl.program_id(axis=2)
|
|
116
118
|
|
|
117
|
-
# Get valid triton block size
|
|
118
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
119
|
-
|
|
120
119
|
# Get position of current sparsity block consisting of its batch and row index
|
|
121
120
|
spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
|
|
122
121
|
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
|
|
@@ -137,23 +136,19 @@ def row_wise_sum_kernel(x,
|
|
|
137
136
|
return
|
|
138
137
|
|
|
139
138
|
blk_idx = ((pid_blk * x_b_s) +
|
|
140
|
-
((pid_row *
|
|
141
|
-
((pid_col *
|
|
142
|
-
blk_msk = (
|
|
143
|
-
|
|
144
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
145
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
139
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
140
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
141
|
+
blk_msk = (blk_idx >= 0 and
|
|
142
|
+
blk_idx < x_b * x_b_s)
|
|
146
143
|
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
147
144
|
|
|
148
145
|
buf = tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
|
|
149
146
|
|
|
150
147
|
o_idx = (rev_idx_spa * o_b_s +
|
|
151
|
-
((pid_row *
|
|
148
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
152
149
|
(tl.arange(0, 1))[None, :])
|
|
153
|
-
o_msk = (
|
|
154
|
-
|
|
155
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
156
|
-
tl.arange(0, 1)[None, :] < val_tbs))
|
|
150
|
+
o_msk = (o_idx >= 0 and
|
|
151
|
+
o_idx < o_b * o_b_s)
|
|
157
152
|
tl.atomic_add(o + o_idx, buf, o_msk)
|
|
158
153
|
|
|
159
154
|
|
|
@@ -214,7 +209,7 @@ def row_wise_max_forward(x: Tensor, sparsity_lut: Tensor,
|
|
|
214
209
|
output = torch.full(size=(n_sparse_blocks_output,
|
|
215
210
|
sparsity_block_size,
|
|
216
211
|
1 if flag_slice_only else sparsity_block_size),
|
|
217
|
-
fill_value=
|
|
212
|
+
fill_value=get_autocast_min_val(),
|
|
218
213
|
device=x.device)
|
|
219
214
|
|
|
220
215
|
x_b, x_r, x_c = x.size()
|
|
@@ -245,7 +240,8 @@ def row_wise_max_forward(x: Tensor, sparsity_lut: Tensor,
|
|
|
245
240
|
|
|
246
241
|
@triton.autotune(
|
|
247
242
|
configs=get_autotune_configs(),
|
|
248
|
-
key=[],
|
|
243
|
+
key=["sparsity_block_size"],
|
|
244
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
249
245
|
restore_value=["o"]
|
|
250
246
|
)
|
|
251
247
|
@triton.jit
|
|
@@ -263,9 +259,6 @@ def row_wise_max_kernel(x,
|
|
|
263
259
|
pid_row = tl.program_id(axis=1)
|
|
264
260
|
pid_col = tl.program_id(axis=2)
|
|
265
261
|
|
|
266
|
-
# Get valid triton block size
|
|
267
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
268
|
-
|
|
269
262
|
# Get position of current sparsity block consisting of its batch and row index
|
|
270
263
|
spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
|
|
271
264
|
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
|
|
@@ -286,23 +279,19 @@ def row_wise_max_kernel(x,
|
|
|
286
279
|
return
|
|
287
280
|
|
|
288
281
|
blk_idx = ((pid_blk * x_b_s) +
|
|
289
|
-
((pid_row *
|
|
290
|
-
((pid_col *
|
|
291
|
-
blk_msk = (
|
|
292
|
-
|
|
293
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
294
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
282
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
283
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
284
|
+
blk_msk = (blk_idx >= 0 and
|
|
285
|
+
blk_idx < x_b * x_b_s)
|
|
295
286
|
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
296
287
|
|
|
297
288
|
buf = tl.reshape(tl.max(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
|
|
298
289
|
|
|
299
290
|
o_idx = (rev_idx_spa * o_b_s +
|
|
300
|
-
((pid_row *
|
|
291
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
301
292
|
(tl.arange(0, 1))[None, :])
|
|
302
|
-
o_msk = (
|
|
303
|
-
|
|
304
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
305
|
-
tl.arange(0, 1)[None, :] < val_tbs))
|
|
293
|
+
o_msk = (o_idx >= 0 and
|
|
294
|
+
o_idx < o_b * o_b_s)
|
|
306
295
|
tl.atomic_max(o + o_idx, buf, o_msk)
|
|
307
296
|
|
|
308
297
|
|
|
@@ -354,7 +343,7 @@ def row_wise_sub(x: BlksprsTensor, sparsity_layout_x: Tensor, y: Tensor,
|
|
|
354
343
|
def row_wise_add_forward(x: Tensor, sparsity_lut_x: Tensor,
|
|
355
344
|
sparsity_layout_x_rwm: Tensor, sparsity_reverse_x_lut_rwm: Tensor,
|
|
356
345
|
y: Tensor, sparsity_block_size: int) -> Tensor:
|
|
357
|
-
output = torch.
|
|
346
|
+
output = torch.zeros_like(x)
|
|
358
347
|
|
|
359
348
|
x_b, x_r, x_c = x.size()
|
|
360
349
|
x_b_s, x_r_s, x_c_s = stride(x)
|
|
@@ -371,7 +360,7 @@ def row_wise_add_forward(x: Tensor, sparsity_lut_x: Tensor,
|
|
|
371
360
|
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
372
361
|
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
373
362
|
|
|
374
|
-
(kernel_blocksparse_row_wise_add[triton_grid]
|
|
363
|
+
(wrap_triton(kernel_blocksparse_row_wise_add)[triton_grid]
|
|
375
364
|
(x,
|
|
376
365
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
377
366
|
sparsity_lut_x, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
@@ -387,7 +376,9 @@ def row_wise_add_forward(x: Tensor, sparsity_lut_x: Tensor,
|
|
|
387
376
|
|
|
388
377
|
@triton.autotune(
|
|
389
378
|
configs=get_autotune_configs(),
|
|
390
|
-
key=[]
|
|
379
|
+
key=["sparsity_block_size"],
|
|
380
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
381
|
+
reset_to_zero=["o"]
|
|
391
382
|
)
|
|
392
383
|
@triton.jit
|
|
393
384
|
def kernel_blocksparse_row_wise_add(x,
|
|
@@ -405,9 +396,6 @@ def kernel_blocksparse_row_wise_add(x,
|
|
|
405
396
|
pid_row = tl.program_id(axis=1)
|
|
406
397
|
pid_col = tl.program_id(axis=2)
|
|
407
398
|
|
|
408
|
-
# Get valid triton block size
|
|
409
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
410
|
-
|
|
411
399
|
# Get position of current sparsity block consisting of its batch and row index
|
|
412
400
|
spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
|
|
413
401
|
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
|
|
@@ -429,22 +417,18 @@ def kernel_blocksparse_row_wise_add(x,
|
|
|
429
417
|
|
|
430
418
|
# Load x block
|
|
431
419
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
432
|
-
((pid_row *
|
|
433
|
-
((pid_col *
|
|
434
|
-
blk_x_msk = (
|
|
435
|
-
|
|
436
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
437
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
420
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
421
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
422
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
423
|
+
blk_x_idx < x_b * x_b_s)
|
|
438
424
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
439
425
|
|
|
440
426
|
# Load sum block
|
|
441
427
|
blk_s_idx = (rev_idx_spa_s * y_b_s +
|
|
442
|
-
((pid_row *
|
|
428
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
|
|
443
429
|
(tl.arange(0, 1) * y_c_s)[None, :])
|
|
444
|
-
blk_s_msk = (
|
|
445
|
-
|
|
446
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
447
|
-
tl.arange(0, 1)[None, :] < val_tbs))
|
|
430
|
+
blk_s_msk = (blk_s_idx >= 0 and
|
|
431
|
+
blk_s_idx < y_b * y_b_s)
|
|
448
432
|
blk_s = tl.load(y + blk_s_idx, mask=blk_s_msk)
|
|
449
433
|
|
|
450
434
|
# Compute exp
|
|
@@ -452,10 +436,8 @@ def kernel_blocksparse_row_wise_add(x,
|
|
|
452
436
|
|
|
453
437
|
# Store block
|
|
454
438
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
455
|
-
((pid_row *
|
|
456
|
-
((pid_col *
|
|
457
|
-
blk_o_msk = (
|
|
458
|
-
|
|
459
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
460
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
439
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
440
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
441
|
+
blk_o_msk = (blk_o_idx >= 0 and
|
|
442
|
+
blk_o_idx < o_b * o_b_s)
|
|
461
443
|
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
blksprs/ops/softmax.py
CHANGED
|
@@ -7,7 +7,8 @@ from triton import language as tl
|
|
|
7
7
|
|
|
8
8
|
from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
|
|
9
9
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
10
|
-
from blksprs.utils.tools import stride
|
|
10
|
+
from blksprs.utils.tools import stride
|
|
11
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
11
12
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
12
13
|
validate_sparsity, validate_sparsity_block_size, validate_dtype_float_32
|
|
13
14
|
|
|
@@ -51,7 +52,7 @@ def softmax_forward(x: Tensor, sparsity_layout: Tensor,
|
|
|
51
52
|
sparsity_lut: Tensor,
|
|
52
53
|
sparsity_reverse_lut_rws: Tensor,
|
|
53
54
|
sparsity_block_size: int) -> Tensor:
|
|
54
|
-
output = torch.
|
|
55
|
+
output = torch.zeros_like(x)
|
|
55
56
|
|
|
56
57
|
x_b, x_r, x_c = x.size()
|
|
57
58
|
x_b_s, x_r_s, x_c_s = stride(x)
|
|
@@ -108,13 +109,14 @@ def softmax_backward(ctx, grad_output):
|
|
|
108
109
|
s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
|
|
109
110
|
s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_s)
|
|
110
111
|
|
|
111
|
-
grad_x = torch.
|
|
112
|
+
grad_x = torch.zeros_like(o, dtype=torch.float)
|
|
112
113
|
|
|
113
114
|
triton_grid = lambda meta: [o_b,
|
|
114
115
|
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
115
116
|
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
116
117
|
|
|
117
|
-
|
|
118
|
+
# TODO wrap
|
|
119
|
+
(softmax_kernel_grad[triton_grid]
|
|
118
120
|
(grad_output,
|
|
119
121
|
o_b, o_b_s, o_r_s, o_c_s,
|
|
120
122
|
o,
|
|
@@ -133,7 +135,9 @@ def softmax_backward(ctx, grad_output):
|
|
|
133
135
|
|
|
134
136
|
@triton.autotune(
|
|
135
137
|
configs=get_autotune_configs(),
|
|
136
|
-
key=[]
|
|
138
|
+
key=["sparsity_block_size"],
|
|
139
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
140
|
+
reset_to_zero=["o"]
|
|
137
141
|
)
|
|
138
142
|
@triton.jit
|
|
139
143
|
def softmax_kernel(x,
|
|
@@ -150,9 +154,6 @@ def softmax_kernel(x,
|
|
|
150
154
|
pid_row = tl.program_id(axis=1)
|
|
151
155
|
pid_col = tl.program_id(axis=2)
|
|
152
156
|
|
|
153
|
-
# Get valid triton block size
|
|
154
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
155
|
-
|
|
156
157
|
# Get position of current sparsity block consisting of its batch and row index
|
|
157
158
|
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
158
159
|
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
@@ -171,22 +172,18 @@ def softmax_kernel(x,
|
|
|
171
172
|
if rev_idx_spa_s >= 0:
|
|
172
173
|
# Load x block
|
|
173
174
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
174
|
-
((pid_row *
|
|
175
|
-
((pid_col *
|
|
176
|
-
blk_x_msk = (
|
|
177
|
-
|
|
178
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
179
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
175
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
176
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
177
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
178
|
+
blk_x_idx < x_b * x_b_s)
|
|
180
179
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
181
180
|
|
|
182
181
|
# Load sum block
|
|
183
182
|
blk_s_idx = (rev_idx_spa_s * s_b_s +
|
|
184
|
-
((pid_row *
|
|
183
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
|
|
185
184
|
(tl.arange(0, 1) * s_c_s)[None, :])
|
|
186
|
-
blk_s_msk = (
|
|
187
|
-
|
|
188
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
189
|
-
tl.arange(0, 1)[None, :] < val_tbs))
|
|
185
|
+
blk_s_msk = (blk_s_idx >= 0 and
|
|
186
|
+
blk_s_idx < s_b * s_b_s)
|
|
190
187
|
blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
|
|
191
188
|
|
|
192
189
|
# Compute softmax
|
|
@@ -198,7 +195,9 @@ def softmax_kernel(x,
|
|
|
198
195
|
|
|
199
196
|
@triton.autotune(
|
|
200
197
|
configs=get_autotune_configs(),
|
|
201
|
-
key=[]
|
|
198
|
+
key=["sparsity_block_size"],
|
|
199
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
200
|
+
reset_to_zero=["o"]
|
|
202
201
|
)
|
|
203
202
|
@triton.jit
|
|
204
203
|
def softmax_kernel_grad(g,
|
|
@@ -219,9 +218,6 @@ def softmax_kernel_grad(g,
|
|
|
219
218
|
pid_row = tl.program_id(axis=1)
|
|
220
219
|
pid_col = tl.program_id(axis=2)
|
|
221
220
|
|
|
222
|
-
# Get valid triton block size
|
|
223
|
-
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
224
|
-
|
|
225
221
|
# Get position of current sparsity block consisting of its batch and row index
|
|
226
222
|
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
227
223
|
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
@@ -238,41 +234,33 @@ def softmax_kernel_grad(g,
|
|
|
238
234
|
|
|
239
235
|
if rev_idx_spa_s >= 0:
|
|
240
236
|
blk_s_idx = (rev_idx_spa_s * s_b_s +
|
|
241
|
-
((pid_row *
|
|
237
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
|
|
242
238
|
(tl.arange(0, 1) * s_c_s)[None, :])
|
|
243
|
-
blk_s_msk = (
|
|
244
|
-
|
|
245
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
246
|
-
tl.arange(0, 1)[None, :] < val_tbs))
|
|
239
|
+
blk_s_msk = (blk_s_idx >= 0 and
|
|
240
|
+
blk_s_idx < s_b * s_b_s)
|
|
247
241
|
blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
|
|
248
242
|
|
|
249
243
|
blk_g_idx = ((pid_blk * g_b_s) +
|
|
250
|
-
((pid_row *
|
|
251
|
-
((pid_col *
|
|
252
|
-
blk_g_msk = (
|
|
253
|
-
|
|
254
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
255
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
244
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_r_s)[:, None] +
|
|
245
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_c_s)[None, :])
|
|
246
|
+
blk_g_msk = (blk_g_idx >= 0 and
|
|
247
|
+
blk_g_idx < g_b * g_b_s)
|
|
256
248
|
blk_g = tl.load(g + blk_g_idx, mask=blk_g_msk)
|
|
257
249
|
|
|
258
250
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
259
|
-
((pid_row *
|
|
260
|
-
((pid_col *
|
|
261
|
-
blk_x_msk = (
|
|
262
|
-
|
|
263
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
264
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
251
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
252
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
253
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
254
|
+
blk_x_idx < x_b * x_b_s)
|
|
265
255
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
266
256
|
|
|
267
257
|
buf = blk_x * (blk_g - blk_s)
|
|
268
258
|
|
|
269
259
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
270
|
-
((pid_row *
|
|
271
|
-
((pid_col *
|
|
272
|
-
blk_o_msk = (
|
|
273
|
-
|
|
274
|
-
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
275
|
-
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
260
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
261
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
262
|
+
blk_o_msk = (blk_o_idx >= 0 and
|
|
263
|
+
blk_o_idx < o_b * o_b_s)
|
|
276
264
|
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
|
277
265
|
|
|
278
266
|
|