blksprs 1.10.2__py3-none-any.whl → 2.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +2 -6
- blksprs/layouting/distribution_layout.py +32 -25
- blksprs/layouting/sparsity_layout.py +65 -52
- blksprs/ops/conversion.py +423 -374
- blksprs/ops/distribution.py +403 -335
- blksprs/ops/flow.py +135 -83
- blksprs/ops/matmul.py +221 -187
- blksprs/ops/misc/broadcast_ops.py +53 -35
- blksprs/ops/misc/row_wise.py +151 -91
- blksprs/ops/partitioning.py +136 -89
- blksprs/ops/repeat.py +115 -108
- blksprs/ops/softmax.py +244 -208
- blksprs/ops/transpose.py +69 -131
- blksprs/utils/benchmarking.py +3 -3
- blksprs/utils/tools.py +31 -4
- blksprs/utils/validation.py +0 -14
- {blksprs-1.10.2.dist-info → blksprs-2.0rc1.dist-info}/METADATA +42 -36
- blksprs-2.0rc1.dist-info/RECORD +22 -0
- {blksprs-1.10.2.dist-info → blksprs-2.0rc1.dist-info}/WHEEL +1 -1
- blksprs/ops/misc/exp.py +0 -104
- blksprs/utils/layout_utils.py +0 -17
- blksprs-1.10.2.dist-info/RECORD +0 -24
- {blksprs-1.10.2.dist-info → blksprs-2.0rc1.dist-info}/top_level.txt +0 -0
blksprs/ops/softmax.py
CHANGED
|
@@ -1,18 +1,18 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import triton
|
|
3
3
|
from torch import Tensor
|
|
4
|
+
from torch._library import triton_op
|
|
5
|
+
from torch._library.triton import wrap_triton
|
|
4
6
|
from triton import language as tl
|
|
5
7
|
|
|
6
|
-
from blksprs.ops.misc.exp import exp
|
|
7
8
|
from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
|
|
8
9
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
9
|
-
from blksprs.utils.tools import
|
|
10
|
+
from blksprs.utils.tools import stride, get_autotune_configs
|
|
10
11
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
11
|
-
validate_sparsity, validate_sparsity_block_size
|
|
12
|
+
validate_sparsity, validate_sparsity_block_size
|
|
12
13
|
|
|
13
14
|
|
|
14
|
-
def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
15
|
-
triton_block_size: int = None) -> BlksprsTensor:
|
|
15
|
+
def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
|
|
16
16
|
"""Computes the softmax of a block-sparse tensor in compressed form.
|
|
17
17
|
|
|
18
18
|
Note:
|
|
@@ -22,7 +22,7 @@ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
|
22
22
|
x (BlksprsTensor): A block-sparse tensor in compressed form.
|
|
23
23
|
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
24
24
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
25
|
-
|
|
25
|
+
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
26
26
|
|
|
27
27
|
Returns:
|
|
28
28
|
BlksprsTensor: The result of the softmax operation as a block-sparse tensor in compressed form.
|
|
@@ -35,169 +35,156 @@ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
|
35
35
|
validate_device(x)
|
|
36
36
|
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
37
37
|
validate_sparsity_block_size(sparsity_block_size, x)
|
|
38
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
39
|
-
|
|
40
|
-
sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
|
|
41
|
-
|
|
42
|
-
sparsity_layout_rws, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)
|
|
43
|
-
sparsity_layout_rws_flat = sparsity_layout_rws.reshape(-1)
|
|
44
|
-
sparsity_reverse_lut_rws = ((torch.cumsum(sparsity_layout_rws_flat, dim=-1) - 1) *
|
|
45
|
-
(sparsity_layout_rws_flat == 1) -
|
|
46
|
-
(1 * (sparsity_layout_rws_flat == 0)))
|
|
47
|
-
|
|
48
|
-
validate_contiguous(sparsity_layout, sparsity_lut, sparsity_reverse_lut_rws)
|
|
49
|
-
|
|
50
|
-
return BlksprsTensor(_BlocksparseSoftmax.apply(x, sparsity_layout,
|
|
51
|
-
sparsity_lut,
|
|
52
|
-
sparsity_reverse_lut_rws,
|
|
53
|
-
sparsity_block_size, triton_block_size))
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
class _BlocksparseSoftmax(torch.autograd.Function):
|
|
57
|
-
|
|
58
|
-
@staticmethod
|
|
59
|
-
def forward(ctx, x: Tensor, sparsity_layout: Tensor,
|
|
60
|
-
sparsity_lut: Tensor,
|
|
61
|
-
sparsity_reverse_lut_rws: Tensor,
|
|
62
|
-
sparsity_block_size: int, triton_block_size: int) -> Tensor:
|
|
63
|
-
output = torch.empty_like(x)
|
|
64
|
-
|
|
65
|
-
x_b, x_r, x_c = x.size()
|
|
66
|
-
x_b_s, x_r_s, x_c_s = stride(x)
|
|
67
|
-
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
68
|
-
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
69
|
-
o_b, o_r, o_c = output.size()
|
|
70
|
-
|
|
71
|
-
x_row_wise_max, sparsity_layout_rwm = row_wise_max(x, sparsity_layout, sparsity_block_size,
|
|
72
|
-
flag_slice_only=True,
|
|
73
|
-
triton_block_size=triton_block_size)
|
|
74
|
-
x_scaled = row_wise_sub(x, sparsity_layout, x_row_wise_max, sparsity_block_size, triton_block_size)
|
|
75
|
-
x_exp = exp(x_scaled, sparsity_block_size, triton_block_size=triton_block_size)
|
|
76
|
-
x_exp_row_wise_sum, sparsity_layout_rws = row_wise_sum(x_exp, sparsity_layout, sparsity_block_size,
|
|
77
|
-
flag_slice_only=True,
|
|
78
|
-
triton_block_size=triton_block_size)
|
|
79
|
-
|
|
80
|
-
s_b, s_r, s_c = x_exp_row_wise_sum.shape
|
|
81
|
-
s_b_s, s_r_s, s_c_s = stride(x_exp_row_wise_sum)
|
|
82
|
-
s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_rws.shape
|
|
83
|
-
s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_rws)
|
|
84
|
-
|
|
85
|
-
if triton_block_size is None:
|
|
86
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
87
|
-
|
|
88
|
-
triton_grid = lambda meta: [o_b,
|
|
89
|
-
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
90
|
-
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
91
|
-
|
|
92
|
-
(_BlocksparseSoftmax.kernel_blocksparse_softmax[triton_grid]
|
|
93
|
-
(x_exp,
|
|
94
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
95
|
-
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
96
|
-
x_exp_row_wise_sum, s_b, s_b_s, s_r_s, s_c_s,
|
|
97
|
-
s_l_s_b, s_l_s_b_s, s_l_s_r_s,
|
|
98
|
-
sparsity_reverse_lut_rws,
|
|
99
|
-
output,
|
|
100
|
-
triton_block_size))
|
|
101
|
-
|
|
102
|
-
# Save for backward pass
|
|
103
|
-
ctx.save_for_backward(output, sparsity_layout, sparsity_lut)
|
|
104
|
-
ctx.sparsity_block_size = sparsity_block_size
|
|
105
|
-
ctx.triton_block_size = triton_block_size
|
|
106
|
-
|
|
107
|
-
return output
|
|
108
|
-
|
|
109
|
-
@staticmethod
|
|
110
|
-
def backward(ctx, grad_output):
|
|
111
|
-
o, sparsity_layout, sparsity_lut = ctx.saved_tensors
|
|
112
|
-
sparsity_block_size = ctx.sparsity_block_size
|
|
113
|
-
triton_block_size = ctx.triton_block_size
|
|
114
|
-
|
|
115
|
-
s, sparsity_layout_s = row_wise_sum(grad_output * o, sparsity_layout, sparsity_block_size, flag_slice_only=True,
|
|
116
|
-
triton_block_size=triton_block_size)
|
|
117
|
-
|
|
118
|
-
sparsity_layout_s_flat = sparsity_layout_s.reshape(-1)
|
|
119
|
-
sparsity_reverse_lut_s = ((torch.cumsum(sparsity_layout_s_flat, dim=-1) - 1) *
|
|
120
|
-
(sparsity_layout_s_flat == 1) -
|
|
121
|
-
(1 * (sparsity_layout_s_flat == 0)))
|
|
122
|
-
|
|
123
|
-
o_b, o_r, o_c = o.size()
|
|
124
|
-
o_b_s, o_r_s, o_c_s = stride(o)
|
|
125
|
-
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
126
|
-
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
127
|
-
s_b, s_r, s_c = s.size()
|
|
128
|
-
s_b_s, s_r_s, s_c_s = stride(s)
|
|
129
|
-
s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
|
|
130
|
-
s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_s)
|
|
131
|
-
|
|
132
|
-
grad_x = torch.empty_like(o, dtype=torch.float)
|
|
133
|
-
|
|
134
|
-
triton_grid = lambda meta: [o_b,
|
|
135
|
-
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
136
|
-
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
137
|
-
|
|
138
|
-
(_BlocksparseSoftmax.kernel_blocksparse_softmax_grad_x[triton_grid]
|
|
139
|
-
(grad_output,
|
|
140
|
-
o_b, o_b_s, o_r_s, o_c_s,
|
|
141
|
-
o,
|
|
142
|
-
o_b, o_b_s, o_r_s, o_c_s,
|
|
143
|
-
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
144
|
-
s,
|
|
145
|
-
s_b, s_b_s, s_r_s, s_c_s,
|
|
146
|
-
s_l_s_b, s_l_s_b_s, s_l_s_r_s,
|
|
147
|
-
sparsity_reverse_lut_s,
|
|
148
|
-
grad_x,
|
|
149
|
-
o_b, o_b_s, o_r_s, o_c_s,
|
|
150
|
-
triton_block_size
|
|
151
|
-
))
|
|
152
|
-
|
|
153
|
-
return grad_x, None, None, None, None, None
|
|
154
|
-
|
|
155
|
-
@staticmethod
|
|
156
|
-
@triton.jit
|
|
157
|
-
def kernel_blocksparse_softmax(x,
|
|
158
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
159
|
-
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
160
|
-
s, s_b, s_b_s, s_r_s, s_c_s,
|
|
161
|
-
s_l_s_b, s_l_s_b_s, s_l_s_r_s,
|
|
162
|
-
r_lut_s,
|
|
163
|
-
o,
|
|
164
|
-
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
165
|
-
# Get triton block indices
|
|
166
|
-
pid_blk = tl.program_id(axis=0)
|
|
167
|
-
pid_row = tl.program_id(axis=1)
|
|
168
|
-
pid_col = tl.program_id(axis=2)
|
|
169
|
-
|
|
170
|
-
# Get position of current sparsity block consisting of its batch and row index
|
|
171
|
-
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
172
|
-
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
173
|
-
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
174
|
-
|
|
175
|
-
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
176
|
-
spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
|
|
177
|
-
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
178
|
-
|
|
179
|
-
# Get reverse sparsity indices for s
|
|
180
|
-
rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
|
|
181
|
-
spa_row * s_l_s_r_s)
|
|
182
|
-
rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
|
|
183
|
-
rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
|
|
184
|
-
|
|
185
|
-
if rev_idx_spa_s == -1:
|
|
186
|
-
tl.device_assert(False)
|
|
187
|
-
return
|
|
188
38
|
|
|
39
|
+
lut = softmax_build_lut(lut, sparsity_layout)
|
|
40
|
+
|
|
41
|
+
return BlksprsTensor(softmax_forward(x, sparsity_layout,
|
|
42
|
+
lut["sparsity_lut"],
|
|
43
|
+
lut["sparsity_reverse_lut_rws"],
|
|
44
|
+
sparsity_block_size))
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@triton_op("blksprs::softmax", mutates_args={})
|
|
48
|
+
def softmax_forward(x: Tensor, sparsity_layout: Tensor,
|
|
49
|
+
sparsity_lut: Tensor,
|
|
50
|
+
sparsity_reverse_lut_rws: Tensor,
|
|
51
|
+
sparsity_block_size: int) -> Tensor:
|
|
52
|
+
output = torch.empty_like(x)
|
|
53
|
+
|
|
54
|
+
x_b, x_r, x_c = x.size()
|
|
55
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
56
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
57
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
58
|
+
o_b, o_r, o_c = output.size()
|
|
59
|
+
|
|
60
|
+
x_row_wise_max, sparsity_layout_rwm = row_wise_max(x, sparsity_layout, sparsity_block_size,
|
|
61
|
+
flag_slice_only=True)
|
|
62
|
+
x_scaled = row_wise_sub(x, sparsity_layout, x_row_wise_max, sparsity_block_size)
|
|
63
|
+
x_exp = torch.exp(x_scaled)
|
|
64
|
+
x_exp_row_wise_sum, sparsity_layout_rws = row_wise_sum(x_exp, sparsity_layout, sparsity_block_size,
|
|
65
|
+
flag_slice_only=True)
|
|
66
|
+
|
|
67
|
+
s_b, s_r, s_c = x_exp_row_wise_sum.shape
|
|
68
|
+
s_b_s, s_r_s, s_c_s = stride(x_exp_row_wise_sum)
|
|
69
|
+
s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_rws.shape
|
|
70
|
+
s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_rws)
|
|
71
|
+
|
|
72
|
+
triton_grid = lambda meta: [o_b,
|
|
73
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
74
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
75
|
+
|
|
76
|
+
(wrap_triton(softmax_kernel)[triton_grid]
|
|
77
|
+
(x_exp,
|
|
78
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
79
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
80
|
+
x_exp_row_wise_sum, s_b, s_b_s, s_r_s, s_c_s,
|
|
81
|
+
s_l_s_b, s_l_s_b_s, s_l_s_r_s,
|
|
82
|
+
sparsity_reverse_lut_rws,
|
|
83
|
+
output,
|
|
84
|
+
sparsity_block_size))
|
|
85
|
+
|
|
86
|
+
return output
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def softmax_backward(ctx, grad_output):
|
|
90
|
+
o, sparsity_layout, sparsity_lut = ctx.saved_tensors
|
|
91
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
92
|
+
|
|
93
|
+
s, sparsity_layout_s = row_wise_sum(grad_output * o, sparsity_layout, sparsity_block_size, flag_slice_only=True)
|
|
94
|
+
|
|
95
|
+
sparsity_layout_s_flat = sparsity_layout_s.reshape(-1)
|
|
96
|
+
sparsity_reverse_lut_s = ((torch.cumsum(sparsity_layout_s_flat, dim=-1) - 1) *
|
|
97
|
+
(sparsity_layout_s_flat == 1) -
|
|
98
|
+
(1 * (sparsity_layout_s_flat == 0)))
|
|
99
|
+
|
|
100
|
+
o_b, o_r, o_c = o.size()
|
|
101
|
+
o_b_s, o_r_s, o_c_s = stride(o)
|
|
102
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
103
|
+
s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
|
|
104
|
+
s_b, s_r, s_c = s.size()
|
|
105
|
+
s_b_s, s_r_s, s_c_s = stride(s)
|
|
106
|
+
s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
|
|
107
|
+
s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_s)
|
|
108
|
+
|
|
109
|
+
grad_x = torch.empty_like(o, dtype=torch.float)
|
|
110
|
+
|
|
111
|
+
triton_grid = lambda meta: [o_b,
|
|
112
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
113
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
114
|
+
|
|
115
|
+
(wrap_triton(softmax_kernel_grad)[triton_grid]
|
|
116
|
+
(grad_output,
|
|
117
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
118
|
+
o,
|
|
119
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
120
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
121
|
+
s,
|
|
122
|
+
s_b, s_b_s, s_r_s, s_c_s,
|
|
123
|
+
s_l_s_b, s_l_s_b_s, s_l_s_r_s,
|
|
124
|
+
sparsity_reverse_lut_s,
|
|
125
|
+
grad_x,
|
|
126
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
127
|
+
sparsity_block_size))
|
|
128
|
+
|
|
129
|
+
return grad_x, None, None, None, None, None
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@triton.autotune(
|
|
133
|
+
configs=get_autotune_configs(),
|
|
134
|
+
key=[]
|
|
135
|
+
)
|
|
136
|
+
@triton.jit
|
|
137
|
+
def softmax_kernel(x,
|
|
138
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
139
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
140
|
+
s, s_b, s_b_s, s_r_s, s_c_s,
|
|
141
|
+
s_l_s_b, s_l_s_b_s, s_l_s_r_s,
|
|
142
|
+
r_lut_s,
|
|
143
|
+
o,
|
|
144
|
+
sparsity_block_size,
|
|
145
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
146
|
+
# Get triton block indices
|
|
147
|
+
pid_blk = tl.program_id(axis=0)
|
|
148
|
+
pid_row = tl.program_id(axis=1)
|
|
149
|
+
pid_col = tl.program_id(axis=2)
|
|
150
|
+
|
|
151
|
+
# Get valid triton block size
|
|
152
|
+
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
153
|
+
|
|
154
|
+
# Get position of current sparsity block consisting of its batch and row index
|
|
155
|
+
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
156
|
+
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
157
|
+
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
158
|
+
|
|
159
|
+
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
160
|
+
spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
|
|
161
|
+
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
162
|
+
|
|
163
|
+
# Get reverse sparsity indices for s
|
|
164
|
+
rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
|
|
165
|
+
spa_row * s_l_s_r_s)
|
|
166
|
+
rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
|
|
167
|
+
rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
|
|
168
|
+
|
|
169
|
+
if rev_idx_spa_s >= 0:
|
|
189
170
|
# Load x block
|
|
190
171
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
191
|
-
((pid_row *
|
|
192
|
-
((pid_col *
|
|
193
|
-
blk_x_msk = (blk_x_idx >= 0 and
|
|
172
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
173
|
+
((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
174
|
+
blk_x_msk = ((blk_x_idx >= 0 and
|
|
175
|
+
blk_x_idx < x_b * x_b_s) and
|
|
176
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
177
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
194
178
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
195
179
|
|
|
196
180
|
# Load sum block
|
|
197
181
|
blk_s_idx = (rev_idx_spa_s * s_b_s +
|
|
198
|
-
((pid_row *
|
|
182
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
|
|
199
183
|
(tl.arange(0, 1) * s_c_s)[None, :])
|
|
200
|
-
blk_s_msk = (blk_s_idx >= 0 and
|
|
184
|
+
blk_s_msk = ((blk_s_idx >= 0 and
|
|
185
|
+
blk_s_idx < s_b * s_b_s) and
|
|
186
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
187
|
+
tl.arange(0, 1)[None, :] < val_tbs))
|
|
201
188
|
blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
|
|
202
189
|
|
|
203
190
|
# Compute softmax
|
|
@@ -206,65 +193,114 @@ class _BlocksparseSoftmax(torch.autograd.Function):
|
|
|
206
193
|
# Store output
|
|
207
194
|
tl.store(o + blk_x_idx, buf, mask=blk_x_msk)
|
|
208
195
|
|
|
209
|
-
@staticmethod
|
|
210
|
-
@triton.jit
|
|
211
|
-
def kernel_blocksparse_softmax_grad_x(g,
|
|
212
|
-
g_b, g_b_s, g_r_s, g_c_s,
|
|
213
|
-
x,
|
|
214
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
215
|
-
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
216
|
-
s,
|
|
217
|
-
s_b, s_b_s, s_r_s, s_c_s,
|
|
218
|
-
s_l_s_b, s_l_s_b_s, s_l_s_r_s,
|
|
219
|
-
r_lut_s,
|
|
220
|
-
o,
|
|
221
|
-
o_b, o_b_s, o_r_s, o_c_s,
|
|
222
|
-
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
223
|
-
# Get triton block indices
|
|
224
|
-
pid_blk = tl.program_id(axis=0)
|
|
225
|
-
pid_row = tl.program_id(axis=1)
|
|
226
|
-
pid_col = tl.program_id(axis=2)
|
|
227
|
-
|
|
228
|
-
# Get position of current sparsity block consisting of its batch and row index
|
|
229
|
-
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
230
|
-
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
231
|
-
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
232
|
-
|
|
233
|
-
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
234
|
-
spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
|
|
235
|
-
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
236
|
-
|
|
237
|
-
rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
|
|
238
|
-
spa_row * s_l_s_r_s)
|
|
239
|
-
rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
|
|
240
|
-
rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
|
|
241
|
-
|
|
242
|
-
if rev_idx_spa_s == -1:
|
|
243
|
-
tl.device_assert(False)
|
|
244
|
-
return
|
|
245
196
|
|
|
197
|
+
@triton.autotune(
|
|
198
|
+
configs=get_autotune_configs(),
|
|
199
|
+
key=[]
|
|
200
|
+
)
|
|
201
|
+
@triton.jit
|
|
202
|
+
def softmax_kernel_grad(g,
|
|
203
|
+
g_b, g_b_s, g_r_s, g_c_s,
|
|
204
|
+
x,
|
|
205
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
206
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
207
|
+
s,
|
|
208
|
+
s_b, s_b_s, s_r_s, s_c_s,
|
|
209
|
+
s_l_s_b, s_l_s_b_s, s_l_s_r_s,
|
|
210
|
+
r_lut_s,
|
|
211
|
+
o,
|
|
212
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
213
|
+
sparsity_block_size,
|
|
214
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
215
|
+
# Get triton block indices
|
|
216
|
+
pid_blk = tl.program_id(axis=0)
|
|
217
|
+
pid_row = tl.program_id(axis=1)
|
|
218
|
+
pid_col = tl.program_id(axis=2)
|
|
219
|
+
|
|
220
|
+
# Get valid triton block size
|
|
221
|
+
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
222
|
+
|
|
223
|
+
# Get position of current sparsity block consisting of its batch and row index
|
|
224
|
+
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
225
|
+
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
226
|
+
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
227
|
+
|
|
228
|
+
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
229
|
+
spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
|
|
230
|
+
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
231
|
+
|
|
232
|
+
rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
|
|
233
|
+
spa_row * s_l_s_r_s)
|
|
234
|
+
rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
|
|
235
|
+
rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
|
|
236
|
+
|
|
237
|
+
if rev_idx_spa_s >= 0:
|
|
246
238
|
blk_s_idx = (rev_idx_spa_s * s_b_s +
|
|
247
|
-
((pid_row *
|
|
239
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
|
|
248
240
|
(tl.arange(0, 1) * s_c_s)[None, :])
|
|
249
|
-
blk_s_msk = (blk_s_idx >= 0 and
|
|
241
|
+
blk_s_msk = ((blk_s_idx >= 0 and
|
|
242
|
+
blk_s_idx < s_b * s_b_s) and
|
|
243
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
244
|
+
tl.arange(0, 1)[None, :] < val_tbs))
|
|
250
245
|
blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
|
|
251
246
|
|
|
252
247
|
blk_g_idx = ((pid_blk * g_b_s) +
|
|
253
|
-
((pid_row *
|
|
254
|
-
((pid_col *
|
|
255
|
-
blk_g_msk = (blk_g_idx >= 0 and
|
|
248
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * g_r_s)[:, None] +
|
|
249
|
+
((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * g_c_s)[None, :])
|
|
250
|
+
blk_g_msk = ((blk_g_idx >= 0 and
|
|
251
|
+
blk_g_idx < g_b * g_b_s) and
|
|
252
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
253
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
256
254
|
blk_g = tl.load(g + blk_g_idx, mask=blk_g_msk)
|
|
257
255
|
|
|
258
256
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
259
|
-
((pid_row *
|
|
260
|
-
((pid_col *
|
|
261
|
-
blk_x_msk = (blk_x_idx >= 0 and
|
|
257
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
258
|
+
((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
259
|
+
blk_x_msk = ((blk_x_idx >= 0 and
|
|
260
|
+
blk_x_idx < x_b * x_b_s) and
|
|
261
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
262
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
262
263
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
263
264
|
|
|
264
265
|
buf = blk_x * (blk_g - blk_s)
|
|
265
266
|
|
|
266
267
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
267
|
-
((pid_row *
|
|
268
|
-
((pid_col *
|
|
269
|
-
blk_o_msk = (blk_o_idx >= 0 and
|
|
268
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
269
|
+
((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
270
|
+
blk_o_msk = ((blk_o_idx >= 0 and
|
|
271
|
+
blk_o_idx < o_b * o_b_s) and
|
|
272
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
273
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
270
274
|
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def softmax_build_lut(lut: dict, sparsity_layout: Tensor):
|
|
278
|
+
if lut is None:
|
|
279
|
+
lut = dict()
|
|
280
|
+
|
|
281
|
+
if "sparsity_lut" not in lut:
|
|
282
|
+
sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
|
|
283
|
+
lut["sparsity_lut"] = sparsity_lut
|
|
284
|
+
|
|
285
|
+
if "sparsity_reverse_lut_rws" not in lut:
|
|
286
|
+
sparsity_layout_rws, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)
|
|
287
|
+
sparsity_layout_rws_flat = sparsity_layout_rws.reshape(-1)
|
|
288
|
+
sparsity_reverse_lut_rws = ((torch.cumsum(sparsity_layout_rws_flat, dim=-1) - 1) *
|
|
289
|
+
(sparsity_layout_rws_flat == 1) -
|
|
290
|
+
(1 * (sparsity_layout_rws_flat == 0)))
|
|
291
|
+
lut["sparsity_reverse_lut_rws"] = sparsity_reverse_lut_rws
|
|
292
|
+
|
|
293
|
+
validate_contiguous(sparsity_layout, lut["sparsity_lut"], lut["sparsity_reverse_lut_rws"])
|
|
294
|
+
|
|
295
|
+
return lut
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
# noinspection PyUnusedLocal
|
|
299
|
+
def softmax_setup_context(ctx, inputs, output):
|
|
300
|
+
(_, sparsity_layout, sparsity_lut, _, sparsity_block_size) = inputs
|
|
301
|
+
|
|
302
|
+
ctx.save_for_backward(output, sparsity_layout, sparsity_lut)
|
|
303
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
softmax_forward.register_autograd(softmax_backward, setup_context=softmax_setup_context)
|