blksprs 1.11__py3-none-any.whl → 2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +4 -5
- blksprs/layouting/distribution_layout.py +64 -48
- blksprs/layouting/sparsity_layout.py +96 -72
- blksprs/ops/conversion.py +349 -338
- blksprs/ops/distribution.py +318 -294
- blksprs/ops/flow.py +113 -100
- blksprs/ops/matmul.py +187 -172
- blksprs/ops/misc/broadcast_ops.py +68 -53
- blksprs/ops/misc/row_wise.py +223 -176
- blksprs/ops/partitioning.py +140 -132
- blksprs/ops/repeat.py +118 -120
- blksprs/ops/softmax.py +240 -214
- blksprs/ops/transpose.py +55 -52
- blksprs/utils/autotuning.py +78 -0
- blksprs/utils/benchmarking.py +3 -3
- blksprs/utils/processing.py +2 -1
- blksprs/utils/tools.py +5 -6
- blksprs/utils/validation.py +22 -16
- {blksprs-1.11.dist-info → blksprs-2.0.dist-info}/METADATA +55 -36
- blksprs-2.0.dist-info/RECORD +23 -0
- {blksprs-1.11.dist-info → blksprs-2.0.dist-info}/WHEEL +1 -1
- blksprs/utils/layout_utils.py +0 -17
- blksprs-1.11.dist-info/RECORD +0 -23
- {blksprs-1.11.dist-info → blksprs-2.0.dist-info}/top_level.txt +0 -0
blksprs/ops/softmax.py
CHANGED
|
@@ -1,17 +1,20 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import triton
|
|
3
3
|
from torch import Tensor
|
|
4
|
+
from torch._library import triton_op
|
|
5
|
+
from torch._library.triton import wrap_triton
|
|
4
6
|
from triton import language as tl
|
|
5
7
|
|
|
6
8
|
from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
|
|
7
9
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
8
|
-
from blksprs.utils.tools import
|
|
10
|
+
from blksprs.utils.tools import stride
|
|
11
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
9
12
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
10
|
-
validate_sparsity, validate_sparsity_block_size,
|
|
13
|
+
validate_sparsity, validate_sparsity_block_size, validate_dtype_float_32
|
|
11
14
|
|
|
12
15
|
|
|
13
|
-
|
|
14
|
-
|
|
16
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
|
|
17
|
+
def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
|
|
15
18
|
"""Computes the softmax of a block-sparse tensor in compressed form.
|
|
16
19
|
|
|
17
20
|
Note:
|
|
@@ -21,7 +24,6 @@ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
|
21
24
|
x (BlksprsTensor): A block-sparse tensor in compressed form.
|
|
22
25
|
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
23
26
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
24
|
-
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
25
27
|
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
26
28
|
|
|
27
29
|
Returns:
|
|
@@ -32,102 +34,73 @@ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
|
32
34
|
|
|
33
35
|
validate_dimensions(x)
|
|
34
36
|
validate_contiguous(x)
|
|
37
|
+
validate_dtype_float_32(x)
|
|
35
38
|
validate_device(x)
|
|
36
39
|
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
37
40
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
38
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
39
41
|
|
|
40
|
-
lut =
|
|
41
|
-
|
|
42
|
-
return BlksprsTensor(
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
triton_grid = lambda meta: [o_b,
|
|
103
|
-
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
104
|
-
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
105
|
-
|
|
106
|
-
(_BlocksparseSoftmax.kernel_blocksparse_softmax[triton_grid]
|
|
107
|
-
(x_exp,
|
|
108
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
109
|
-
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
110
|
-
x_exp_row_wise_sum, s_b, s_b_s, s_r_s, s_c_s,
|
|
111
|
-
s_l_s_b, s_l_s_b_s, s_l_s_r_s,
|
|
112
|
-
sparsity_reverse_lut_rws,
|
|
113
|
-
output,
|
|
114
|
-
triton_block_size))
|
|
115
|
-
|
|
116
|
-
# Save for backward pass
|
|
117
|
-
ctx.save_for_backward(output, sparsity_layout, sparsity_lut)
|
|
118
|
-
ctx.sparsity_block_size = sparsity_block_size
|
|
119
|
-
ctx.triton_block_size = triton_block_size
|
|
120
|
-
|
|
121
|
-
return output
|
|
122
|
-
|
|
123
|
-
@staticmethod
|
|
124
|
-
def backward(ctx, grad_output):
|
|
125
|
-
o, sparsity_layout, sparsity_lut = ctx.saved_tensors
|
|
126
|
-
sparsity_block_size = ctx.sparsity_block_size
|
|
127
|
-
triton_block_size = ctx.triton_block_size
|
|
128
|
-
|
|
129
|
-
s, sparsity_layout_s = row_wise_sum(grad_output * o, sparsity_layout, sparsity_block_size, flag_slice_only=True,
|
|
130
|
-
triton_block_size=triton_block_size)
|
|
42
|
+
lut = softmax_build_lut(lut, sparsity_layout)
|
|
43
|
+
|
|
44
|
+
return BlksprsTensor(softmax_forward(x, sparsity_layout,
|
|
45
|
+
lut["sparsity_lut"],
|
|
46
|
+
lut["sparsity_reverse_lut_rws"],
|
|
47
|
+
sparsity_block_size))
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@triton_op("blksprs::softmax_forward", mutates_args={})
|
|
51
|
+
def softmax_forward(x: Tensor, sparsity_layout: Tensor,
|
|
52
|
+
sparsity_lut: Tensor,
|
|
53
|
+
sparsity_reverse_lut_rws: Tensor,
|
|
54
|
+
sparsity_block_size: int) -> Tensor:
|
|
55
|
+
output = torch.zeros_like(x)
|
|
56
|
+
|
|
57
|
+
x_row_wise_max, sparsity_layout_rwm = row_wise_max(x, sparsity_layout, sparsity_block_size,
|
|
58
|
+
flag_slice_only=True)
|
|
59
|
+
x_scaled = row_wise_sub(x, sparsity_layout, x_row_wise_max, sparsity_block_size)
|
|
60
|
+
x_exp = torch.exp(x_scaled)
|
|
61
|
+
x_exp_row_wise_sum, sparsity_layout_rws = row_wise_sum(x_exp, sparsity_layout, sparsity_block_size,
|
|
62
|
+
flag_slice_only=True)
|
|
63
|
+
|
|
64
|
+
x_b, x_r, x_c = x.size()
|
|
65
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
66
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
67
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
68
|
+
o_b, o_r, o_c = output.size()
|
|
69
|
+
s_b, s_r, s_c = x_exp_row_wise_sum.shape
|
|
70
|
+
s_b_s, s_r_s, s_c_s = stride(x_exp_row_wise_sum)
|
|
71
|
+
s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_rws.shape
|
|
72
|
+
s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_rws)
|
|
73
|
+
|
|
74
|
+
triton_grid = lambda meta: [o_b,
|
|
75
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
76
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
77
|
+
|
|
78
|
+
(wrap_triton(softmax_kernel)[triton_grid]
|
|
79
|
+
(x_exp,
|
|
80
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
81
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
82
|
+
x_exp_row_wise_sum, s_b, s_b_s, s_r_s, s_c_s,
|
|
83
|
+
s_l_s_b, s_l_s_b_s, s_l_s_r_s,
|
|
84
|
+
sparsity_reverse_lut_rws,
|
|
85
|
+
output,
|
|
86
|
+
sparsity_block_size))
|
|
87
|
+
|
|
88
|
+
return output
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def softmax_backward_wrapper(ctx, grad_output):
|
|
92
|
+
o, sparsity_layout, sparsity_lut = ctx.saved_tensors
|
|
93
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
94
|
+
|
|
95
|
+
return softmax_backward(grad_output, o, sparsity_lut, sparsity_layout,
|
|
96
|
+
sparsity_block_size), None, None, None, None, None
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@triton_op("blksprs::softmax_backward", mutates_args={})
|
|
100
|
+
def softmax_backward(grad_output: Tensor, o: Tensor, sparsity_lut: Tensor, sparsity_layout: Tensor,
|
|
101
|
+
sparsity_block_size: int) -> Tensor:
|
|
102
|
+
with torch.no_grad():
|
|
103
|
+
s, sparsity_layout_s = row_wise_sum(grad_output * o, sparsity_layout, sparsity_block_size, flag_slice_only=True)
|
|
131
104
|
|
|
132
105
|
sparsity_layout_s_flat = sparsity_layout_s.reshape(-1)
|
|
133
106
|
sparsity_reverse_lut_s = ((torch.cumsum(sparsity_layout_s_flat, dim=-1) - 1) *
|
|
@@ -143,13 +116,13 @@ class _BlocksparseSoftmax(torch.autograd.Function):
|
|
|
143
116
|
s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
|
|
144
117
|
s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_s)
|
|
145
118
|
|
|
146
|
-
grad_x = torch.
|
|
119
|
+
grad_x = torch.zeros_like(o, dtype=torch.float)
|
|
147
120
|
|
|
148
121
|
triton_grid = lambda meta: [o_b,
|
|
149
122
|
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
150
123
|
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
151
124
|
|
|
152
|
-
(
|
|
125
|
+
(wrap_triton(softmax_kernel_grad)[triton_grid]
|
|
153
126
|
(grad_output,
|
|
154
127
|
o_b, o_b_s, o_r_s, o_c_s,
|
|
155
128
|
o,
|
|
@@ -161,118 +134,171 @@ class _BlocksparseSoftmax(torch.autograd.Function):
|
|
|
161
134
|
sparsity_reverse_lut_s,
|
|
162
135
|
grad_x,
|
|
163
136
|
o_b, o_b_s, o_r_s, o_c_s,
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
137
|
+
sparsity_block_size))
|
|
138
|
+
|
|
139
|
+
return grad_x
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
# noinspection PyUnusedLocal
|
|
143
|
+
@triton.autotune(
|
|
144
|
+
configs=get_autotune_configs(),
|
|
145
|
+
key=["sparsity_block_size"],
|
|
146
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
147
|
+
reset_to_zero=["o"]
|
|
148
|
+
)
|
|
149
|
+
@triton.jit
|
|
150
|
+
def softmax_kernel(x,
|
|
151
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
152
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
153
|
+
s, s_b, s_b_s, s_r_s, s_c_s,
|
|
154
|
+
s_l_s_b, s_l_s_b_s, s_l_s_r_s,
|
|
155
|
+
r_lut_s,
|
|
156
|
+
o,
|
|
157
|
+
sparsity_block_size,
|
|
158
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
159
|
+
# Get triton block indices
|
|
160
|
+
pid_blk = tl.program_id(axis=0)
|
|
161
|
+
pid_row = tl.program_id(axis=1)
|
|
162
|
+
pid_col = tl.program_id(axis=2)
|
|
163
|
+
|
|
164
|
+
# Get position of current sparsity block consisting of its batch and row index
|
|
165
|
+
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
166
|
+
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
167
|
+
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
168
|
+
|
|
169
|
+
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
170
|
+
spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
|
|
171
|
+
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
172
|
+
|
|
173
|
+
# Get reverse sparsity indices for s
|
|
174
|
+
rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
|
|
175
|
+
spa_row * s_l_s_r_s)
|
|
176
|
+
rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
|
|
177
|
+
rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
|
|
178
|
+
|
|
179
|
+
if rev_idx_spa_s >= 0:
|
|
180
|
+
# Load x block
|
|
181
|
+
blk_x_idx = ((pid_blk * x_b_s) +
|
|
182
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
183
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
184
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
185
|
+
blk_x_idx < x_b * x_b_s)
|
|
186
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
187
|
+
|
|
188
|
+
# Load sum block
|
|
189
|
+
blk_s_idx = (rev_idx_spa_s * s_b_s +
|
|
190
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
|
|
191
|
+
(tl.arange(0, 1) * s_c_s)[None, :])
|
|
192
|
+
blk_s_msk = (blk_s_idx >= 0 and
|
|
193
|
+
blk_s_idx < s_b * s_b_s)
|
|
194
|
+
blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
|
|
195
|
+
|
|
196
|
+
# Compute softmax
|
|
197
|
+
buf = tl.div_rn(blk_x, blk_s)
|
|
198
|
+
|
|
199
|
+
# Store output
|
|
200
|
+
tl.store(o + blk_x_idx, buf, mask=blk_x_msk)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
# noinspection PyUnusedLocal
|
|
204
|
+
@triton.autotune(
|
|
205
|
+
configs=get_autotune_configs(),
|
|
206
|
+
key=["sparsity_block_size"],
|
|
207
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
208
|
+
reset_to_zero=["o"]
|
|
209
|
+
)
|
|
210
|
+
@triton.jit
|
|
211
|
+
def softmax_kernel_grad(g,
|
|
212
|
+
g_b, g_b_s, g_r_s, g_c_s,
|
|
213
|
+
x,
|
|
214
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
215
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
216
|
+
s,
|
|
217
|
+
s_b, s_b_s, s_r_s, s_c_s,
|
|
218
|
+
s_l_s_b, s_l_s_b_s, s_l_s_r_s,
|
|
219
|
+
r_lut_s,
|
|
220
|
+
o,
|
|
221
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
222
|
+
sparsity_block_size,
|
|
223
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
224
|
+
# Get triton block indices
|
|
225
|
+
pid_blk = tl.program_id(axis=0)
|
|
226
|
+
pid_row = tl.program_id(axis=1)
|
|
227
|
+
pid_col = tl.program_id(axis=2)
|
|
228
|
+
|
|
229
|
+
# Get position of current sparsity block consisting of its batch and row index
|
|
230
|
+
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
231
|
+
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
232
|
+
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
233
|
+
|
|
234
|
+
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
235
|
+
spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
|
|
236
|
+
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
237
|
+
|
|
238
|
+
rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
|
|
239
|
+
spa_row * s_l_s_r_s)
|
|
240
|
+
rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
|
|
241
|
+
rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
|
|
242
|
+
|
|
243
|
+
if rev_idx_spa_s >= 0:
|
|
244
|
+
blk_s_idx = (rev_idx_spa_s * s_b_s +
|
|
245
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
|
|
246
|
+
(tl.arange(0, 1) * s_c_s)[None, :])
|
|
247
|
+
blk_s_msk = (blk_s_idx >= 0 and
|
|
248
|
+
blk_s_idx < s_b * s_b_s)
|
|
249
|
+
blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
|
|
250
|
+
|
|
251
|
+
blk_g_idx = ((pid_blk * g_b_s) +
|
|
252
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_r_s)[:, None] +
|
|
253
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_c_s)[None, :])
|
|
254
|
+
blk_g_msk = (blk_g_idx >= 0 and
|
|
255
|
+
blk_g_idx < g_b * g_b_s)
|
|
256
|
+
blk_g = tl.load(g + blk_g_idx, mask=blk_g_msk)
|
|
257
|
+
|
|
258
|
+
blk_x_idx = ((pid_blk * x_b_s) +
|
|
259
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
260
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
261
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
262
|
+
blk_x_idx < x_b * x_b_s)
|
|
263
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
264
|
+
|
|
265
|
+
buf = blk_x * (blk_g - blk_s)
|
|
266
|
+
|
|
267
|
+
blk_o_idx = ((pid_blk * o_b_s) +
|
|
268
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
269
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
270
|
+
blk_o_msk = (blk_o_idx >= 0 and
|
|
271
|
+
blk_o_idx < o_b * o_b_s)
|
|
272
|
+
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def softmax_build_lut(lut: dict, sparsity_layout: Tensor):
|
|
276
|
+
if lut is None:
|
|
277
|
+
lut = dict()
|
|
278
|
+
|
|
279
|
+
if "sparsity_lut" not in lut:
|
|
280
|
+
sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
|
|
281
|
+
lut["sparsity_lut"] = sparsity_lut
|
|
282
|
+
|
|
283
|
+
if "sparsity_reverse_lut_rws" not in lut:
|
|
284
|
+
sparsity_layout_rws, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)
|
|
285
|
+
sparsity_layout_rws_flat = sparsity_layout_rws.reshape(-1)
|
|
286
|
+
sparsity_reverse_lut_rws = ((torch.cumsum(sparsity_layout_rws_flat, dim=-1) - 1) *
|
|
287
|
+
(sparsity_layout_rws_flat == 1) -
|
|
288
|
+
(1 * (sparsity_layout_rws_flat == 0)))
|
|
289
|
+
lut["sparsity_reverse_lut_rws"] = sparsity_reverse_lut_rws
|
|
290
|
+
|
|
291
|
+
validate_contiguous(sparsity_layout, lut["sparsity_lut"], lut["sparsity_reverse_lut_rws"])
|
|
292
|
+
|
|
293
|
+
return lut
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
# noinspection PyUnusedLocal
|
|
297
|
+
def softmax_setup_context(ctx, inputs, output):
|
|
298
|
+
(_, sparsity_layout, sparsity_lut, _, sparsity_block_size) = inputs
|
|
299
|
+
|
|
300
|
+
ctx.save_for_backward(output, sparsity_layout, sparsity_lut)
|
|
301
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
softmax_forward.register_autograd(softmax_backward_wrapper, setup_context=softmax_setup_context)
|