blksprs 1.10.2__py3-none-any.whl → 2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +4 -6
- blksprs/layouting/distribution_layout.py +64 -48
- blksprs/layouting/sparsity_layout.py +96 -72
- blksprs/ops/conversion.py +350 -312
- blksprs/ops/distribution.py +320 -266
- blksprs/ops/flow.py +135 -89
- blksprs/ops/matmul.py +184 -151
- blksprs/ops/misc/broadcast_ops.py +68 -53
- blksprs/ops/misc/row_wise.py +223 -176
- blksprs/ops/partitioning.py +140 -89
- blksprs/ops/repeat.py +118 -108
- blksprs/ops/softmax.py +201 -167
- blksprs/ops/transpose.py +71 -131
- blksprs/utils/autotuning.py +78 -0
- blksprs/utils/benchmarking.py +3 -3
- blksprs/utils/processing.py +2 -1
- blksprs/utils/tools.py +5 -6
- blksprs/utils/validation.py +22 -16
- {blksprs-1.10.2.dist-info → blksprs-2.0.dist-info}/METADATA +55 -36
- blksprs-2.0.dist-info/RECORD +23 -0
- {blksprs-1.10.2.dist-info → blksprs-2.0.dist-info}/WHEEL +1 -1
- blksprs/ops/misc/exp.py +0 -104
- blksprs/utils/layout_utils.py +0 -17
- blksprs-1.10.2.dist-info/RECORD +0 -24
- {blksprs-1.10.2.dist-info → blksprs-2.0.dist-info}/top_level.txt +0 -0
blksprs/ops/softmax.py
CHANGED
|
@@ -1,18 +1,20 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import triton
|
|
3
3
|
from torch import Tensor
|
|
4
|
+
from torch._library import triton_op
|
|
5
|
+
from torch._library.triton import wrap_triton
|
|
4
6
|
from triton import language as tl
|
|
5
7
|
|
|
6
|
-
from blksprs.ops.misc.exp import exp
|
|
7
8
|
from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
|
|
8
9
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
9
|
-
from blksprs.utils.tools import
|
|
10
|
+
from blksprs.utils.tools import stride
|
|
11
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
10
12
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
11
|
-
validate_sparsity, validate_sparsity_block_size,
|
|
13
|
+
validate_sparsity, validate_sparsity_block_size, validate_dtype_float_32
|
|
12
14
|
|
|
13
15
|
|
|
14
|
-
|
|
15
|
-
|
|
16
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
|
|
17
|
+
def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
|
|
16
18
|
"""Computes the softmax of a block-sparse tensor in compressed form.
|
|
17
19
|
|
|
18
20
|
Note:
|
|
@@ -22,7 +24,7 @@ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
|
22
24
|
x (BlksprsTensor): A block-sparse tensor in compressed form.
|
|
23
25
|
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
24
26
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
25
|
-
|
|
27
|
+
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
26
28
|
|
|
27
29
|
Returns:
|
|
28
30
|
BlksprsTensor: The result of the softmax operation as a block-sparse tensor in compressed form.
|
|
@@ -32,88 +34,73 @@ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
|
32
34
|
|
|
33
35
|
validate_dimensions(x)
|
|
34
36
|
validate_contiguous(x)
|
|
37
|
+
validate_dtype_float_32(x)
|
|
35
38
|
validate_device(x)
|
|
36
39
|
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
37
40
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
38
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
39
41
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
# Save for backward pass
|
|
103
|
-
ctx.save_for_backward(output, sparsity_layout, sparsity_lut)
|
|
104
|
-
ctx.sparsity_block_size = sparsity_block_size
|
|
105
|
-
ctx.triton_block_size = triton_block_size
|
|
106
|
-
|
|
107
|
-
return output
|
|
108
|
-
|
|
109
|
-
@staticmethod
|
|
110
|
-
def backward(ctx, grad_output):
|
|
111
|
-
o, sparsity_layout, sparsity_lut = ctx.saved_tensors
|
|
112
|
-
sparsity_block_size = ctx.sparsity_block_size
|
|
113
|
-
triton_block_size = ctx.triton_block_size
|
|
114
|
-
|
|
115
|
-
s, sparsity_layout_s = row_wise_sum(grad_output * o, sparsity_layout, sparsity_block_size, flag_slice_only=True,
|
|
116
|
-
triton_block_size=triton_block_size)
|
|
42
|
+
lut = softmax_build_lut(lut, sparsity_layout)
|
|
43
|
+
|
|
44
|
+
return BlksprsTensor(softmax_forward(x, sparsity_layout,
|
|
45
|
+
lut["sparsity_lut"],
|
|
46
|
+
lut["sparsity_reverse_lut_rws"],
|
|
47
|
+
sparsity_block_size))
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@triton_op("blksprs::softmax_forward", mutates_args={})
|
|
51
|
+
def softmax_forward(x: Tensor, sparsity_layout: Tensor,
|
|
52
|
+
sparsity_lut: Tensor,
|
|
53
|
+
sparsity_reverse_lut_rws: Tensor,
|
|
54
|
+
sparsity_block_size: int) -> Tensor:
|
|
55
|
+
output = torch.zeros_like(x)
|
|
56
|
+
|
|
57
|
+
x_row_wise_max, sparsity_layout_rwm = row_wise_max(x, sparsity_layout, sparsity_block_size,
|
|
58
|
+
flag_slice_only=True)
|
|
59
|
+
x_scaled = row_wise_sub(x, sparsity_layout, x_row_wise_max, sparsity_block_size)
|
|
60
|
+
x_exp = torch.exp(x_scaled)
|
|
61
|
+
x_exp_row_wise_sum, sparsity_layout_rws = row_wise_sum(x_exp, sparsity_layout, sparsity_block_size,
|
|
62
|
+
flag_slice_only=True)
|
|
63
|
+
|
|
64
|
+
x_b, x_r, x_c = x.size()
|
|
65
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
66
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
67
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
68
|
+
o_b, o_r, o_c = output.size()
|
|
69
|
+
s_b, s_r, s_c = x_exp_row_wise_sum.shape
|
|
70
|
+
s_b_s, s_r_s, s_c_s = stride(x_exp_row_wise_sum)
|
|
71
|
+
s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_rws.shape
|
|
72
|
+
s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_rws)
|
|
73
|
+
|
|
74
|
+
triton_grid = lambda meta: [o_b,
|
|
75
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
76
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
77
|
+
|
|
78
|
+
(wrap_triton(softmax_kernel)[triton_grid]
|
|
79
|
+
(x_exp,
|
|
80
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
81
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
82
|
+
x_exp_row_wise_sum, s_b, s_b_s, s_r_s, s_c_s,
|
|
83
|
+
s_l_s_b, s_l_s_b_s, s_l_s_r_s,
|
|
84
|
+
sparsity_reverse_lut_rws,
|
|
85
|
+
output,
|
|
86
|
+
sparsity_block_size))
|
|
87
|
+
|
|
88
|
+
return output
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def softmax_backward_wrapper(ctx, grad_output):
|
|
92
|
+
o, sparsity_layout, sparsity_lut = ctx.saved_tensors
|
|
93
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
94
|
+
|
|
95
|
+
return softmax_backward(grad_output, o, sparsity_lut, sparsity_layout,
|
|
96
|
+
sparsity_block_size), None, None, None, None, None
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@triton_op("blksprs::softmax_backward", mutates_args={})
|
|
100
|
+
def softmax_backward(grad_output: Tensor, o: Tensor, sparsity_lut: Tensor, sparsity_layout: Tensor,
|
|
101
|
+
sparsity_block_size: int) -> Tensor:
|
|
102
|
+
with torch.no_grad():
|
|
103
|
+
s, sparsity_layout_s = row_wise_sum(grad_output * o, sparsity_layout, sparsity_block_size, flag_slice_only=True)
|
|
117
104
|
|
|
118
105
|
sparsity_layout_s_flat = sparsity_layout_s.reshape(-1)
|
|
119
106
|
sparsity_reverse_lut_s = ((torch.cumsum(sparsity_layout_s_flat, dim=-1) - 1) *
|
|
@@ -129,13 +116,13 @@ class _BlocksparseSoftmax(torch.autograd.Function):
|
|
|
129
116
|
s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
|
|
130
117
|
s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_s)
|
|
131
118
|
|
|
132
|
-
grad_x = torch.
|
|
119
|
+
grad_x = torch.zeros_like(o, dtype=torch.float)
|
|
133
120
|
|
|
134
121
|
triton_grid = lambda meta: [o_b,
|
|
135
122
|
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
136
123
|
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
137
124
|
|
|
138
|
-
(
|
|
125
|
+
(wrap_triton(softmax_kernel_grad)[triton_grid]
|
|
139
126
|
(grad_output,
|
|
140
127
|
o_b, o_b_s, o_r_s, o_c_s,
|
|
141
128
|
o,
|
|
@@ -147,57 +134,63 @@ class _BlocksparseSoftmax(torch.autograd.Function):
|
|
|
147
134
|
sparsity_reverse_lut_s,
|
|
148
135
|
grad_x,
|
|
149
136
|
o_b, o_b_s, o_r_s, o_c_s,
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
137
|
+
sparsity_block_size))
|
|
138
|
+
|
|
139
|
+
return grad_x
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
# noinspection PyUnusedLocal
|
|
143
|
+
@triton.autotune(
|
|
144
|
+
configs=get_autotune_configs(),
|
|
145
|
+
key=["sparsity_block_size"],
|
|
146
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
147
|
+
reset_to_zero=["o"]
|
|
148
|
+
)
|
|
149
|
+
@triton.jit
|
|
150
|
+
def softmax_kernel(x,
|
|
151
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
152
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
153
|
+
s, s_b, s_b_s, s_r_s, s_c_s,
|
|
154
|
+
s_l_s_b, s_l_s_b_s, s_l_s_r_s,
|
|
155
|
+
r_lut_s,
|
|
156
|
+
o,
|
|
157
|
+
sparsity_block_size,
|
|
158
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
159
|
+
# Get triton block indices
|
|
160
|
+
pid_blk = tl.program_id(axis=0)
|
|
161
|
+
pid_row = tl.program_id(axis=1)
|
|
162
|
+
pid_col = tl.program_id(axis=2)
|
|
163
|
+
|
|
164
|
+
# Get position of current sparsity block consisting of its batch and row index
|
|
165
|
+
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
166
|
+
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
167
|
+
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
168
|
+
|
|
169
|
+
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
170
|
+
spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
|
|
171
|
+
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
172
|
+
|
|
173
|
+
# Get reverse sparsity indices for s
|
|
174
|
+
rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
|
|
175
|
+
spa_row * s_l_s_r_s)
|
|
176
|
+
rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
|
|
177
|
+
rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
|
|
178
|
+
|
|
179
|
+
if rev_idx_spa_s >= 0:
|
|
189
180
|
# Load x block
|
|
190
181
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
191
182
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
192
183
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
193
|
-
blk_x_msk = (blk_x_idx >= 0 and
|
|
184
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
185
|
+
blk_x_idx < x_b * x_b_s)
|
|
194
186
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
195
187
|
|
|
196
188
|
# Load sum block
|
|
197
189
|
blk_s_idx = (rev_idx_spa_s * s_b_s +
|
|
198
190
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
|
|
199
191
|
(tl.arange(0, 1) * s_c_s)[None, :])
|
|
200
|
-
blk_s_msk = (blk_s_idx >= 0 and
|
|
192
|
+
blk_s_msk = (blk_s_idx >= 0 and
|
|
193
|
+
blk_s_idx < s_b * s_b_s)
|
|
201
194
|
blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
|
|
202
195
|
|
|
203
196
|
# Compute softmax
|
|
@@ -206,59 +199,67 @@ class _BlocksparseSoftmax(torch.autograd.Function):
|
|
|
206
199
|
# Store output
|
|
207
200
|
tl.store(o + blk_x_idx, buf, mask=blk_x_msk)
|
|
208
201
|
|
|
209
|
-
@staticmethod
|
|
210
|
-
@triton.jit
|
|
211
|
-
def kernel_blocksparse_softmax_grad_x(g,
|
|
212
|
-
g_b, g_b_s, g_r_s, g_c_s,
|
|
213
|
-
x,
|
|
214
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
215
|
-
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
216
|
-
s,
|
|
217
|
-
s_b, s_b_s, s_r_s, s_c_s,
|
|
218
|
-
s_l_s_b, s_l_s_b_s, s_l_s_r_s,
|
|
219
|
-
r_lut_s,
|
|
220
|
-
o,
|
|
221
|
-
o_b, o_b_s, o_r_s, o_c_s,
|
|
222
|
-
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
223
|
-
# Get triton block indices
|
|
224
|
-
pid_blk = tl.program_id(axis=0)
|
|
225
|
-
pid_row = tl.program_id(axis=1)
|
|
226
|
-
pid_col = tl.program_id(axis=2)
|
|
227
|
-
|
|
228
|
-
# Get position of current sparsity block consisting of its batch and row index
|
|
229
|
-
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
230
|
-
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
231
|
-
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
232
|
-
|
|
233
|
-
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
234
|
-
spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
|
|
235
|
-
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
236
|
-
|
|
237
|
-
rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
|
|
238
|
-
spa_row * s_l_s_r_s)
|
|
239
|
-
rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
|
|
240
|
-
rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
|
|
241
|
-
|
|
242
|
-
if rev_idx_spa_s == -1:
|
|
243
|
-
tl.device_assert(False)
|
|
244
|
-
return
|
|
245
202
|
|
|
203
|
+
# noinspection PyUnusedLocal
|
|
204
|
+
@triton.autotune(
|
|
205
|
+
configs=get_autotune_configs(),
|
|
206
|
+
key=["sparsity_block_size"],
|
|
207
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
208
|
+
reset_to_zero=["o"]
|
|
209
|
+
)
|
|
210
|
+
@triton.jit
|
|
211
|
+
def softmax_kernel_grad(g,
|
|
212
|
+
g_b, g_b_s, g_r_s, g_c_s,
|
|
213
|
+
x,
|
|
214
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
215
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
216
|
+
s,
|
|
217
|
+
s_b, s_b_s, s_r_s, s_c_s,
|
|
218
|
+
s_l_s_b, s_l_s_b_s, s_l_s_r_s,
|
|
219
|
+
r_lut_s,
|
|
220
|
+
o,
|
|
221
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
222
|
+
sparsity_block_size,
|
|
223
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
224
|
+
# Get triton block indices
|
|
225
|
+
pid_blk = tl.program_id(axis=0)
|
|
226
|
+
pid_row = tl.program_id(axis=1)
|
|
227
|
+
pid_col = tl.program_id(axis=2)
|
|
228
|
+
|
|
229
|
+
# Get position of current sparsity block consisting of its batch and row index
|
|
230
|
+
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
231
|
+
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
232
|
+
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
233
|
+
|
|
234
|
+
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
235
|
+
spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
|
|
236
|
+
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
237
|
+
|
|
238
|
+
rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
|
|
239
|
+
spa_row * s_l_s_r_s)
|
|
240
|
+
rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
|
|
241
|
+
rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
|
|
242
|
+
|
|
243
|
+
if rev_idx_spa_s >= 0:
|
|
246
244
|
blk_s_idx = (rev_idx_spa_s * s_b_s +
|
|
247
245
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
|
|
248
246
|
(tl.arange(0, 1) * s_c_s)[None, :])
|
|
249
|
-
blk_s_msk = (blk_s_idx >= 0 and
|
|
247
|
+
blk_s_msk = (blk_s_idx >= 0 and
|
|
248
|
+
blk_s_idx < s_b * s_b_s)
|
|
250
249
|
blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
|
|
251
250
|
|
|
252
251
|
blk_g_idx = ((pid_blk * g_b_s) +
|
|
253
252
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_r_s)[:, None] +
|
|
254
253
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_c_s)[None, :])
|
|
255
|
-
blk_g_msk = (blk_g_idx >= 0 and
|
|
254
|
+
blk_g_msk = (blk_g_idx >= 0 and
|
|
255
|
+
blk_g_idx < g_b * g_b_s)
|
|
256
256
|
blk_g = tl.load(g + blk_g_idx, mask=blk_g_msk)
|
|
257
257
|
|
|
258
258
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
259
259
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
260
260
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
261
|
-
blk_x_msk = (blk_x_idx >= 0 and
|
|
261
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
262
|
+
blk_x_idx < x_b * x_b_s)
|
|
262
263
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
263
264
|
|
|
264
265
|
buf = blk_x * (blk_g - blk_s)
|
|
@@ -266,5 +267,38 @@ class _BlocksparseSoftmax(torch.autograd.Function):
|
|
|
266
267
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
267
268
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
268
269
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
269
|
-
blk_o_msk = (blk_o_idx >= 0 and
|
|
270
|
+
blk_o_msk = (blk_o_idx >= 0 and
|
|
271
|
+
blk_o_idx < o_b * o_b_s)
|
|
270
272
|
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def softmax_build_lut(lut: dict, sparsity_layout: Tensor):
|
|
276
|
+
if lut is None:
|
|
277
|
+
lut = dict()
|
|
278
|
+
|
|
279
|
+
if "sparsity_lut" not in lut:
|
|
280
|
+
sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
|
|
281
|
+
lut["sparsity_lut"] = sparsity_lut
|
|
282
|
+
|
|
283
|
+
if "sparsity_reverse_lut_rws" not in lut:
|
|
284
|
+
sparsity_layout_rws, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)
|
|
285
|
+
sparsity_layout_rws_flat = sparsity_layout_rws.reshape(-1)
|
|
286
|
+
sparsity_reverse_lut_rws = ((torch.cumsum(sparsity_layout_rws_flat, dim=-1) - 1) *
|
|
287
|
+
(sparsity_layout_rws_flat == 1) -
|
|
288
|
+
(1 * (sparsity_layout_rws_flat == 0)))
|
|
289
|
+
lut["sparsity_reverse_lut_rws"] = sparsity_reverse_lut_rws
|
|
290
|
+
|
|
291
|
+
validate_contiguous(sparsity_layout, lut["sparsity_lut"], lut["sparsity_reverse_lut_rws"])
|
|
292
|
+
|
|
293
|
+
return lut
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
# noinspection PyUnusedLocal
|
|
297
|
+
def softmax_setup_context(ctx, inputs, output):
|
|
298
|
+
(_, sparsity_layout, sparsity_lut, _, sparsity_block_size) = inputs
|
|
299
|
+
|
|
300
|
+
ctx.save_for_backward(output, sparsity_layout, sparsity_lut)
|
|
301
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
softmax_forward.register_autograd(softmax_backward_wrapper, setup_context=softmax_setup_context)
|