blksprs 1.11__py3-none-any.whl → 2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +4 -5
- blksprs/layouting/distribution_layout.py +64 -48
- blksprs/layouting/sparsity_layout.py +96 -72
- blksprs/ops/conversion.py +349 -338
- blksprs/ops/distribution.py +318 -294
- blksprs/ops/flow.py +113 -100
- blksprs/ops/matmul.py +187 -172
- blksprs/ops/misc/broadcast_ops.py +68 -53
- blksprs/ops/misc/row_wise.py +223 -176
- blksprs/ops/partitioning.py +140 -132
- blksprs/ops/repeat.py +118 -120
- blksprs/ops/softmax.py +240 -214
- blksprs/ops/transpose.py +55 -52
- blksprs/utils/autotuning.py +78 -0
- blksprs/utils/benchmarking.py +3 -3
- blksprs/utils/processing.py +2 -1
- blksprs/utils/tools.py +5 -6
- blksprs/utils/validation.py +22 -16
- {blksprs-1.11.dist-info → blksprs-2.0.dist-info}/METADATA +55 -36
- blksprs-2.0.dist-info/RECORD +23 -0
- {blksprs-1.11.dist-info → blksprs-2.0.dist-info}/WHEEL +1 -1
- blksprs/utils/layout_utils.py +0 -17
- blksprs-1.11.dist-info/RECORD +0 -23
- {blksprs-1.11.dist-info → blksprs-2.0.dist-info}/top_level.txt +0 -0
blksprs/ops/matmul.py
CHANGED
|
@@ -1,19 +1,22 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import triton
|
|
3
3
|
from torch import Tensor
|
|
4
|
+
from torch.library import triton_op, wrap_triton
|
|
4
5
|
from triton import language as tl
|
|
5
6
|
|
|
6
7
|
from blksprs.ops.transpose import transpose
|
|
7
8
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
8
|
-
from blksprs.utils.tools import
|
|
9
|
+
from blksprs.utils.tools import stride
|
|
10
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
9
11
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
10
|
-
validate_sparsity, validate_sparsity_block_size,
|
|
12
|
+
validate_sparsity, validate_sparsity_block_size, validate_dtype_float
|
|
11
13
|
|
|
12
14
|
|
|
15
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
13
16
|
def matmul(x: BlksprsTensor, sparsity_layout_x: Tensor,
|
|
14
17
|
y: BlksprsTensor, sparsity_layout_y: Tensor,
|
|
15
18
|
sparsity_layout_output: Tensor,
|
|
16
|
-
sparsity_block_size: int,
|
|
19
|
+
sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
|
|
17
20
|
"""Performs matrix multiplication between two block-sparse tensors.
|
|
18
21
|
|
|
19
22
|
The sparsity layout of the output tensor is used to only calculate blocks that will be present in the output.
|
|
@@ -25,7 +28,6 @@ def matmul(x: BlksprsTensor, sparsity_layout_x: Tensor,
|
|
|
25
28
|
sparsity_layout_y (Tensor): The sparsity layout of the second block-sparse tensor.
|
|
26
29
|
sparsity_layout_output (Tensor): The sparsity layout of the output tensor.
|
|
27
30
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
28
|
-
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
29
31
|
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
30
32
|
|
|
31
33
|
Returns:
|
|
@@ -43,61 +45,24 @@ def matmul(x: BlksprsTensor, sparsity_layout_x: Tensor,
|
|
|
43
45
|
if sparsity_layout_x.size(-1) != sparsity_layout_y.size(-2):
|
|
44
46
|
raise ValueError("Inner dimensions of tensors must match")
|
|
45
47
|
validate_sparsity_block_size(sparsity_block_size, x, y)
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
lut = dict()
|
|
65
|
-
|
|
66
|
-
if "sparsity_reverse_lut_x" not in lut:
|
|
67
|
-
sparsity_layout_x_flat = sparsity_layout_x.reshape(-1)
|
|
68
|
-
sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
|
|
69
|
-
(sparsity_layout_x_flat == 1) -
|
|
70
|
-
(1 * (sparsity_layout_x_flat == 0)))
|
|
71
|
-
lut["sparsity_reverse_lut_x"] = sparsity_reverse_lut_x
|
|
72
|
-
|
|
73
|
-
if "sparsity_reverse_lut_y" not in lut:
|
|
74
|
-
sparsity_layout_y_flat = sparsity_layout_y.reshape(-1)
|
|
75
|
-
sparsity_reverse_lut_y = ((torch.cumsum(sparsity_layout_y_flat, dim=-1) - 1) *
|
|
76
|
-
(sparsity_layout_y_flat == 1) -
|
|
77
|
-
(1 * (sparsity_layout_y_flat == 0)))
|
|
78
|
-
lut["sparsity_reverse_lut_y"] = sparsity_reverse_lut_y
|
|
79
|
-
|
|
80
|
-
if "sparsity_lut_o" not in lut:
|
|
81
|
-
sparsity_lut_o = torch.nonzero(sparsity_layout_output).contiguous()
|
|
82
|
-
lut["sparsity_lut_o"] = sparsity_lut_o
|
|
83
|
-
|
|
84
|
-
if "n_sparse_blocks" not in lut:
|
|
85
|
-
n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
|
|
86
|
-
lut["n_sparse_blocks"] = n_sparse_blocks
|
|
87
|
-
|
|
88
|
-
validate_contiguous(sparsity_layout_x, lut["sparsity_reverse_lut_x"],
|
|
89
|
-
sparsity_layout_y, lut["sparsity_reverse_lut_y"],
|
|
90
|
-
sparsity_layout_output, lut["sparsity_lut_o"])
|
|
91
|
-
|
|
92
|
-
return lut
|
|
93
|
-
|
|
94
|
-
@staticmethod
|
|
95
|
-
def forward(ctx, x: Tensor, y: Tensor,
|
|
96
|
-
sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
|
|
97
|
-
sparsity_layout_y: Tensor, sparsity_reverse_lut_y: Tensor,
|
|
98
|
-
sparsity_layout_o: Tensor, sparsity_lut_o: Tensor,
|
|
99
|
-
sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
100
|
-
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
48
|
+
|
|
49
|
+
lut = matmul_build_lut(lut, sparsity_layout_x, sparsity_layout_y, sparsity_layout_output)
|
|
50
|
+
|
|
51
|
+
return BlksprsTensor(matmul_forward(x, y,
|
|
52
|
+
sparsity_layout_x, lut["sparsity_reverse_lut_x"],
|
|
53
|
+
sparsity_layout_y, lut["sparsity_reverse_lut_y"],
|
|
54
|
+
sparsity_layout_output, lut["sparsity_lut_o"],
|
|
55
|
+
sparsity_block_size, lut["n_sparse_blocks"]))
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@triton_op("blksprs::matmul_forward", mutates_args={})
|
|
59
|
+
def matmul_forward(x: Tensor, y: Tensor,
|
|
60
|
+
sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
|
|
61
|
+
sparsity_layout_y: Tensor, sparsity_reverse_lut_y: Tensor,
|
|
62
|
+
_: Tensor, sparsity_lut_o: Tensor,
|
|
63
|
+
sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
64
|
+
with torch.no_grad():
|
|
65
|
+
output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
101
66
|
dtype=x.dtype, device=x.device)
|
|
102
67
|
|
|
103
68
|
x_b, x_r, x_c = x.size()
|
|
@@ -113,133 +78,183 @@ class _BlocksparseMatmulSSS(torch.autograd.Function):
|
|
|
113
78
|
s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
|
|
114
79
|
s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_o)
|
|
115
80
|
|
|
116
|
-
if triton_block_size is None:
|
|
117
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
118
|
-
|
|
119
81
|
triton_grid = lambda meta: [o_b,
|
|
120
82
|
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
121
83
|
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
122
84
|
|
|
123
|
-
(
|
|
85
|
+
(wrap_triton(matmul_kernel)[triton_grid]
|
|
124
86
|
(x,
|
|
125
87
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
126
|
-
s_l_x_b, s_l_x_b_s, s_l_x_r_s,
|
|
88
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s,
|
|
89
|
+
s_l_x_c, s_l_x_c_s,
|
|
127
90
|
sparsity_reverse_lut_x,
|
|
128
91
|
y,
|
|
129
92
|
y_b, y_b_s, y_r_s, y_c_s,
|
|
130
|
-
s_l_y_b, s_l_y_b_s, s_l_y_r_s,
|
|
93
|
+
s_l_y_b, s_l_y_b_s, s_l_y_r_s,
|
|
94
|
+
s_l_y_c_s,
|
|
131
95
|
sparsity_reverse_lut_y,
|
|
132
96
|
output,
|
|
133
97
|
o_b, o_b_s, o_r_s, o_c_s,
|
|
134
98
|
sparsity_lut_o,
|
|
135
99
|
s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
136
|
-
sparsity_block_size
|
|
137
|
-
triton_block_size))
|
|
138
|
-
|
|
139
|
-
ctx.save_for_backward(x, sparsity_layout_x, y, sparsity_layout_y, sparsity_layout_o)
|
|
140
|
-
ctx.sparsity_block_size = sparsity_block_size
|
|
141
|
-
ctx.triton_block_size = triton_block_size
|
|
100
|
+
sparsity_block_size))
|
|
142
101
|
|
|
143
102
|
return output
|
|
144
103
|
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
104
|
+
|
|
105
|
+
def matmul_wrapper_backward(ctx, grad_output):
|
|
106
|
+
x, sparsity_layout_x, y, sparsity_layout_y, sparsity_layout_o = ctx.saved_tensors
|
|
107
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
108
|
+
|
|
109
|
+
x_t, sparsity_layout_x_t = transpose(x, sparsity_layout_x, sparsity_block_size)
|
|
110
|
+
y_t, sparsity_layout_y_t = transpose(y, sparsity_layout_y, sparsity_block_size)
|
|
111
|
+
|
|
112
|
+
grad_x = matmul(grad_output, sparsity_layout_o, y_t, sparsity_layout_y_t, sparsity_layout_x,
|
|
113
|
+
sparsity_block_size)
|
|
114
|
+
grad_y = matmul(x_t, sparsity_layout_x_t, grad_output, sparsity_layout_o, sparsity_layout_y,
|
|
115
|
+
sparsity_block_size)
|
|
116
|
+
|
|
117
|
+
return grad_x, grad_y, None, None, None, None, None, None, None, None
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@triton.autotune(
|
|
121
|
+
configs=get_autotune_configs(),
|
|
122
|
+
key=["sparsity_block_size"],
|
|
123
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
124
|
+
reset_to_zero=["o"]
|
|
125
|
+
)
|
|
126
|
+
@triton.jit
|
|
127
|
+
def matmul_kernel(x,
|
|
128
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
129
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c, s_l_x_c_s,
|
|
130
|
+
r_lut_x,
|
|
131
|
+
y,
|
|
132
|
+
y_b, y_b_s, y_r_s, y_c_s,
|
|
133
|
+
s_l_y_b, s_l_y_b_s, s_l_y_r_s, s_l_y_c_s,
|
|
134
|
+
r_lut_y,
|
|
135
|
+
o,
|
|
136
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
137
|
+
s_lut_o,
|
|
138
|
+
s_lut_o_r, s_lut_o_r_s,
|
|
139
|
+
s_lut_o_c_s,
|
|
140
|
+
sparsity_block_size,
|
|
141
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
142
|
+
# Get triton block indices
|
|
143
|
+
pid_blk = tl.program_id(axis=0)
|
|
144
|
+
pid_row = tl.program_id(axis=1)
|
|
145
|
+
pid_col = tl.program_id(axis=2)
|
|
146
|
+
|
|
147
|
+
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
148
|
+
spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
|
|
149
|
+
spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
150
|
+
spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
|
|
151
|
+
|
|
152
|
+
spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
|
|
153
|
+
spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
154
|
+
spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
|
|
155
|
+
|
|
156
|
+
spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
|
|
157
|
+
spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
158
|
+
spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
|
|
159
|
+
|
|
160
|
+
# Setup buffer
|
|
161
|
+
buf = tl.zeros(shape=(TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), dtype=tl.float32)
|
|
162
|
+
|
|
163
|
+
# Slide over triton block sized segments of input tensors
|
|
164
|
+
for i_seg_tri in range(0, tl.cdiv(s_l_x_c * sparsity_block_size, TRITON_BLOCK_SIZE)):
|
|
165
|
+
# Convert to segment index of sparsity layout
|
|
166
|
+
i_seg_spa = (i_seg_tri * TRITON_BLOCK_SIZE) // sparsity_block_size
|
|
167
|
+
# Calculate the triton segment index within a block
|
|
168
|
+
i_seg_tri_mod = i_seg_tri % (sparsity_block_size // TRITON_BLOCK_SIZE)
|
|
169
|
+
|
|
170
|
+
# Get reverse sparsity indices for input tensors x and y
|
|
171
|
+
# These are either -1 if the block is empty or equal to the index of the block in the sparse tensor
|
|
172
|
+
|
|
173
|
+
# Get reverse sparsity indices for x
|
|
174
|
+
rev_idx_spa_x_idx = (spa_bat_o * s_l_x_b_s +
|
|
175
|
+
spa_row_o * s_l_x_r_s +
|
|
176
|
+
i_seg_spa * s_l_x_c_s)
|
|
177
|
+
rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
|
|
178
|
+
rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
|
|
179
|
+
|
|
180
|
+
# Get reverse sparsity indices for y
|
|
181
|
+
rev_idx_spa_y_idx = (spa_bat_o * s_l_y_b_s + i_seg_spa * s_l_y_r_s + spa_col_o * s_l_y_c_s)
|
|
182
|
+
rev_idx_spa_y_msk = (rev_idx_spa_y_idx >= 0 and rev_idx_spa_y_idx < s_l_y_b * s_l_y_b_s)
|
|
183
|
+
rev_idx_spa_y = tl.load(r_lut_y + rev_idx_spa_y_idx, mask=rev_idx_spa_y_msk).to(tl.int32)
|
|
184
|
+
|
|
185
|
+
# If both blocks are present commence calculation
|
|
186
|
+
if rev_idx_spa_x >= 0 and rev_idx_spa_y >= 0:
|
|
187
|
+
blk_x_idx = ((rev_idx_spa_x * x_b_s) +
|
|
188
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
189
|
+
((i_seg_tri_mod * TRITON_BLOCK_SIZE +
|
|
190
|
+
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
191
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
192
|
+
blk_x_idx < x_b * x_b_s)
|
|
193
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
194
|
+
|
|
195
|
+
blk_y_idx = ((rev_idx_spa_y * y_b_s) +
|
|
196
|
+
((i_seg_tri_mod * TRITON_BLOCK_SIZE +
|
|
197
|
+
tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
|
|
198
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
|
|
199
|
+
blk_y_msk = (blk_y_idx >= 0 and
|
|
200
|
+
blk_y_idx < y_b * y_b_s)
|
|
201
|
+
blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
|
|
202
|
+
|
|
203
|
+
# Perform matrix multiplication
|
|
204
|
+
buf += tl.dot(blk_x, blk_y)
|
|
205
|
+
|
|
206
|
+
# Cast buffer
|
|
207
|
+
buf = buf.to(o.dtype.element_ty)
|
|
208
|
+
|
|
209
|
+
# Store output
|
|
210
|
+
blk_o_idx = ((pid_blk * o_b_s) +
|
|
211
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
212
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
213
|
+
blk_o_msk = (blk_o_idx >= 0 and
|
|
214
|
+
blk_o_idx < o_b * o_b_s)
|
|
215
|
+
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def matmul_build_lut(lut: dict, sparsity_layout_x: Tensor, sparsity_layout_y: Tensor, sparsity_layout_output: Tensor):
|
|
219
|
+
if lut is None:
|
|
220
|
+
lut = dict()
|
|
221
|
+
|
|
222
|
+
if "sparsity_reverse_lut_x" not in lut:
|
|
223
|
+
sparsity_layout_x_flat = sparsity_layout_x.reshape(-1)
|
|
224
|
+
sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
|
|
225
|
+
(sparsity_layout_x_flat == 1) -
|
|
226
|
+
(1 * (sparsity_layout_x_flat == 0)))
|
|
227
|
+
lut["sparsity_reverse_lut_x"] = sparsity_reverse_lut_x
|
|
228
|
+
|
|
229
|
+
if "sparsity_reverse_lut_y" not in lut:
|
|
230
|
+
sparsity_layout_y_flat = sparsity_layout_y.reshape(-1)
|
|
231
|
+
sparsity_reverse_lut_y = ((torch.cumsum(sparsity_layout_y_flat, dim=-1) - 1) *
|
|
232
|
+
(sparsity_layout_y_flat == 1) -
|
|
233
|
+
(1 * (sparsity_layout_y_flat == 0)))
|
|
234
|
+
lut["sparsity_reverse_lut_y"] = sparsity_reverse_lut_y
|
|
235
|
+
|
|
236
|
+
if "sparsity_lut_o" not in lut:
|
|
237
|
+
sparsity_lut_o = torch.nonzero(sparsity_layout_output).contiguous()
|
|
238
|
+
lut["sparsity_lut_o"] = sparsity_lut_o
|
|
239
|
+
|
|
240
|
+
if "n_sparse_blocks" not in lut:
|
|
241
|
+
n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
|
|
242
|
+
lut["n_sparse_blocks"] = n_sparse_blocks
|
|
243
|
+
|
|
244
|
+
validate_contiguous(sparsity_layout_x, lut["sparsity_reverse_lut_x"],
|
|
245
|
+
sparsity_layout_y, lut["sparsity_reverse_lut_y"],
|
|
246
|
+
sparsity_layout_output, lut["sparsity_lut_o"])
|
|
247
|
+
|
|
248
|
+
return lut
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
# noinspection PyUnusedLocal
|
|
252
|
+
def matmul_setup_context(ctx, inputs, output):
|
|
253
|
+
(x, y, sparsity_layout_x, _, sparsity_layout_y, _,
|
|
254
|
+
sparsity_layout_o, _, sparsity_block_size, _) = inputs
|
|
255
|
+
|
|
256
|
+
ctx.save_for_backward(x, sparsity_layout_x, y, sparsity_layout_y, sparsity_layout_o)
|
|
257
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
matmul_forward.register_autograd(matmul_wrapper_backward, setup_context=matmul_setup_context)
|
|
@@ -1,16 +1,20 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import triton
|
|
3
3
|
from torch import Tensor
|
|
4
|
+
from torch._library import triton_op
|
|
5
|
+
from torch._library.triton import wrap_triton
|
|
4
6
|
from triton import language as tl
|
|
5
7
|
|
|
6
8
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
7
|
-
from blksprs.utils.tools import
|
|
9
|
+
from blksprs.utils.tools import stride
|
|
10
|
+
from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
|
|
8
11
|
from blksprs.utils.validation import validate_contiguous, validate_device, \
|
|
9
|
-
validate_sparsity_block_size
|
|
12
|
+
validate_sparsity_block_size
|
|
10
13
|
|
|
11
14
|
|
|
15
|
+
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
|
|
12
16
|
def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
13
|
-
sparsity_block_size: int
|
|
17
|
+
sparsity_block_size: int) -> BlksprsTensor:
|
|
14
18
|
"""Performs a broadcast and subsequent addition of two dense tensors x and y. Returns a block-sparse tensor in
|
|
15
19
|
compressed form.
|
|
16
20
|
|
|
@@ -19,7 +23,6 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
|
19
23
|
y (Tensor): A dense input tensor.
|
|
20
24
|
sparsity_layout_output (Tensor): The sparsity layout of the output tensor.
|
|
21
25
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
22
|
-
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
23
26
|
|
|
24
27
|
Returns:
|
|
25
28
|
BlksprsTensor: The result of the operation as a block-sparse tensor in compressed form. Each element o(i, j) of the
|
|
@@ -34,7 +37,6 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
|
34
37
|
if x.size(-1) != y.size(-1):
|
|
35
38
|
raise ValueError("Dimensions of tensors must match")
|
|
36
39
|
validate_sparsity_block_size(sparsity_block_size)
|
|
37
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
38
40
|
|
|
39
41
|
sparsity_lut_o = torch.nonzero(sparsity_layout_output).contiguous()
|
|
40
42
|
|
|
@@ -42,56 +44,66 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
|
42
44
|
|
|
43
45
|
validate_contiguous(sparsity_layout_output, sparsity_lut_o)
|
|
44
46
|
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
x_b, x_c = x.size()
|
|
48
|
-
x_b_s, x_c_s = stride(x)
|
|
49
|
-
y_b, y_c = y.size()
|
|
50
|
-
y_b_s, y_c_s = stride(y)
|
|
51
|
-
o_b, o_r, o_c = output.size()
|
|
52
|
-
o_b_s, o_r_s, o_c_s = stride(output)
|
|
53
|
-
s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
|
|
54
|
-
s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_o)
|
|
55
|
-
|
|
56
|
-
if triton_block_size is None:
|
|
57
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
58
|
-
|
|
59
|
-
triton_grid = lambda meta: [o_b,
|
|
60
|
-
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
61
|
-
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
62
|
-
|
|
63
|
-
(kernel_broadcast_addition[triton_grid]
|
|
64
|
-
(x,
|
|
65
|
-
x_b, x_b_s, x_c_s,
|
|
66
|
-
y,
|
|
67
|
-
y_b, y_b_s, y_c_s,
|
|
68
|
-
output,
|
|
69
|
-
o_b, o_b_s, o_r_s, o_c_s,
|
|
70
|
-
sparsity_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
71
|
-
sparsity_block_size,
|
|
72
|
-
triton_block_size))
|
|
73
|
-
|
|
74
|
-
return BlksprsTensor(output)
|
|
47
|
+
return BlksprsTensor(broadcast_add_forward(x, y, sparsity_lut_o, sparsity_block_size, n_sparse_blocks))
|
|
75
48
|
|
|
76
49
|
|
|
77
50
|
def broadcast_sub(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
78
|
-
sparsity_block_size: int
|
|
51
|
+
sparsity_block_size: int) -> BlksprsTensor:
|
|
79
52
|
"""Wrapper for ``broadcast_add`` with negated y.
|
|
80
53
|
|
|
81
54
|
"""
|
|
82
|
-
return broadcast_add(x, torch.neg(y), sparsity_layout_output, sparsity_block_size
|
|
83
|
-
|
|
84
|
-
|
|
55
|
+
return broadcast_add(x, torch.neg(y), sparsity_layout_output, sparsity_block_size)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@triton_op("blksprs::broadcast_add_forward", mutates_args={})
|
|
59
|
+
def broadcast_add_forward(x: Tensor, y: Tensor,
|
|
60
|
+
sparsity_lut_o: Tensor,
|
|
61
|
+
sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
62
|
+
with torch.no_grad():
|
|
63
|
+
output = torch.zeros(n_sparse_blocks, sparsity_block_size, sparsity_block_size, dtype=x.dtype, device=x.device)
|
|
64
|
+
|
|
65
|
+
x_b, x_c = x.size()
|
|
66
|
+
x_b_s, x_c_s = stride(x)
|
|
67
|
+
y_b, y_c = y.size()
|
|
68
|
+
y_b_s, y_c_s = stride(y)
|
|
69
|
+
o_b, o_r, o_c = output.size()
|
|
70
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
71
|
+
s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
|
|
72
|
+
s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_o)
|
|
73
|
+
|
|
74
|
+
triton_grid = lambda meta: [o_b,
|
|
75
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
76
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
77
|
+
|
|
78
|
+
(wrap_triton(broadcast_add_kernel)[triton_grid]
|
|
79
|
+
(x,
|
|
80
|
+
x_b, x_b_s, x_c_s,
|
|
81
|
+
y,
|
|
82
|
+
y_b, y_b_s, y_c_s,
|
|
83
|
+
output,
|
|
84
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
85
|
+
sparsity_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
86
|
+
sparsity_block_size))
|
|
87
|
+
|
|
88
|
+
return output
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@triton.autotune(
|
|
92
|
+
configs=get_autotune_configs(),
|
|
93
|
+
key=["sparsity_block_size"],
|
|
94
|
+
prune_configs_by={"early_config_prune": prune_autotune_configs},
|
|
95
|
+
reset_to_zero=["o"]
|
|
96
|
+
)
|
|
85
97
|
@triton.jit
|
|
86
|
-
def
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
98
|
+
def broadcast_add_kernel(x,
|
|
99
|
+
x_b, x_b_s, x_c_s,
|
|
100
|
+
y,
|
|
101
|
+
y_b, y_b_s, y_c_s,
|
|
102
|
+
o,
|
|
103
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
104
|
+
s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
105
|
+
sparsity_block_size,
|
|
106
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
95
107
|
# Get triton block indices
|
|
96
108
|
pid_blk = tl.program_id(axis=0)
|
|
97
109
|
pid_row = tl.program_id(axis=1)
|
|
@@ -112,16 +124,18 @@ def kernel_broadcast_addition(x,
|
|
|
112
124
|
|
|
113
125
|
# Load x block
|
|
114
126
|
blk_x_idx = (spa_bat_o * x_b_s +
|
|
115
|
-
((
|
|
127
|
+
((pid_row * TRITON_BLOCK_SIZE + spa_row_o * sparsity_block_size +
|
|
116
128
|
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
117
|
-
blk_x_msk = (blk_x_idx >= 0 and
|
|
129
|
+
blk_x_msk = (blk_x_idx >= 0 and
|
|
130
|
+
blk_x_idx < x_b * x_b_s)
|
|
118
131
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
119
132
|
|
|
120
133
|
# Load y block
|
|
121
134
|
blk_y_idx = (spa_bat_o * y_b_s +
|
|
122
|
-
((
|
|
135
|
+
((pid_col * TRITON_BLOCK_SIZE + spa_col_o * sparsity_block_size +
|
|
123
136
|
tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
|
|
124
|
-
blk_y_msk = (blk_y_idx >= 0 and
|
|
137
|
+
blk_y_msk = (blk_y_idx >= 0 and
|
|
138
|
+
blk_y_idx < y_b * y_b_s)
|
|
125
139
|
blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
|
|
126
140
|
|
|
127
141
|
# Compute sum
|
|
@@ -132,5 +146,6 @@ def kernel_broadcast_addition(x,
|
|
|
132
146
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
133
147
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
134
148
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
135
|
-
blk_o_msk = (blk_o_idx >= 0 and
|
|
149
|
+
blk_o_msk = (blk_o_idx >= 0 and
|
|
150
|
+
blk_o_idx < o_b * o_b_s)
|
|
136
151
|
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|