blksprs 1.10.2__py3-none-any.whl → 2.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +2 -6
- blksprs/layouting/distribution_layout.py +32 -25
- blksprs/layouting/sparsity_layout.py +65 -52
- blksprs/ops/conversion.py +423 -374
- blksprs/ops/distribution.py +403 -335
- blksprs/ops/flow.py +135 -83
- blksprs/ops/matmul.py +221 -187
- blksprs/ops/misc/broadcast_ops.py +53 -35
- blksprs/ops/misc/row_wise.py +151 -91
- blksprs/ops/partitioning.py +136 -89
- blksprs/ops/repeat.py +115 -108
- blksprs/ops/softmax.py +244 -208
- blksprs/ops/transpose.py +69 -131
- blksprs/utils/benchmarking.py +3 -3
- blksprs/utils/tools.py +31 -4
- blksprs/utils/validation.py +0 -14
- {blksprs-1.10.2.dist-info → blksprs-2.0rc1.dist-info}/METADATA +42 -36
- blksprs-2.0rc1.dist-info/RECORD +22 -0
- {blksprs-1.10.2.dist-info → blksprs-2.0rc1.dist-info}/WHEEL +1 -1
- blksprs/ops/misc/exp.py +0 -104
- blksprs/utils/layout_utils.py +0 -17
- blksprs-1.10.2.dist-info/RECORD +0 -24
- {blksprs-1.10.2.dist-info → blksprs-2.0rc1.dist-info}/top_level.txt +0 -0
blksprs/ops/matmul.py
CHANGED
|
@@ -1,19 +1,20 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import triton
|
|
3
3
|
from torch import Tensor
|
|
4
|
+
from torch.library import triton_op, wrap_triton
|
|
4
5
|
from triton import language as tl
|
|
5
6
|
|
|
6
7
|
from blksprs.ops.transpose import transpose
|
|
7
8
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
8
|
-
from blksprs.utils.tools import
|
|
9
|
+
from blksprs.utils.tools import stride, get_autotune_configs
|
|
9
10
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
10
|
-
validate_sparsity, validate_sparsity_block_size,
|
|
11
|
+
validate_sparsity, validate_sparsity_block_size, validate_dtype_float
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
def matmul(x: BlksprsTensor, sparsity_layout_x: Tensor,
|
|
14
15
|
y: BlksprsTensor, sparsity_layout_y: Tensor,
|
|
15
16
|
sparsity_layout_output: Tensor,
|
|
16
|
-
sparsity_block_size: int,
|
|
17
|
+
sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
|
|
17
18
|
"""Performs matrix multiplication between two block-sparse tensors.
|
|
18
19
|
|
|
19
20
|
The sparsity layout of the output tensor is used to only calculate blocks that will be present in the output.
|
|
@@ -25,7 +26,7 @@ def matmul(x: BlksprsTensor, sparsity_layout_x: Tensor,
|
|
|
25
26
|
sparsity_layout_y (Tensor): The sparsity layout of the second block-sparse tensor.
|
|
26
27
|
sparsity_layout_output (Tensor): The sparsity layout of the output tensor.
|
|
27
28
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
28
|
-
|
|
29
|
+
lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
|
|
29
30
|
|
|
30
31
|
Returns:
|
|
31
32
|
BlksprsTensor: The result of the matrix multiplication as a block-sparse tensor in compressed form.
|
|
@@ -42,186 +43,219 @@ def matmul(x: BlksprsTensor, sparsity_layout_x: Tensor,
|
|
|
42
43
|
if sparsity_layout_x.size(-1) != sparsity_layout_y.size(-2):
|
|
43
44
|
raise ValueError("Inner dimensions of tensors must match")
|
|
44
45
|
validate_sparsity_block_size(sparsity_block_size, x, y)
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
46
|
+
|
|
47
|
+
lut = matmul_build_lut(lut, sparsity_layout_x, sparsity_layout_y, sparsity_layout_output)
|
|
48
|
+
|
|
49
|
+
return BlksprsTensor(matmul_forward(x, y,
|
|
50
|
+
sparsity_layout_x, lut["sparsity_reverse_lut_x"],
|
|
51
|
+
sparsity_layout_y, lut["sparsity_reverse_lut_y"],
|
|
52
|
+
sparsity_layout_output, lut["sparsity_lut_o"],
|
|
53
|
+
sparsity_block_size, lut["n_sparse_blocks"]))
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@triton_op("blksprs::matmul", mutates_args={})
|
|
57
|
+
def matmul_forward(x: Tensor, y: Tensor,
|
|
58
|
+
sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
|
|
59
|
+
sparsity_layout_y: Tensor, sparsity_reverse_lut_y: Tensor,
|
|
60
|
+
_: Tensor, sparsity_lut_o: Tensor,
|
|
61
|
+
sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
62
|
+
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
63
|
+
dtype=x.dtype, device=x.device)
|
|
64
|
+
|
|
65
|
+
x_b, x_r, x_c = x.size()
|
|
66
|
+
x_b_s, x_r_s, x_c_s = stride(x)
|
|
67
|
+
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
|
|
68
|
+
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
|
|
69
|
+
y_b, y_r, y_c = y.size()
|
|
70
|
+
y_b_s, y_r_s, y_c_s = stride(y)
|
|
71
|
+
s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_y.size()
|
|
72
|
+
s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = stride(sparsity_layout_y)
|
|
73
|
+
o_b, o_r, o_c = output.size()
|
|
74
|
+
o_b_s, o_r_s, o_c_s = stride(output)
|
|
75
|
+
s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
|
|
76
|
+
s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_o)
|
|
77
|
+
|
|
78
|
+
triton_grid = lambda meta: [o_b,
|
|
79
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
80
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
81
|
+
|
|
82
|
+
(wrap_triton(matmul_kernel)[triton_grid]
|
|
83
|
+
(x,
|
|
84
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
85
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s,
|
|
86
|
+
s_l_x_c, s_l_x_c_s,
|
|
87
|
+
sparsity_reverse_lut_x,
|
|
88
|
+
y,
|
|
89
|
+
y_b, y_b_s, y_r_s, y_c_s,
|
|
90
|
+
s_l_y_b, s_l_y_b_s, s_l_y_r_s,
|
|
91
|
+
s_l_y_c_s,
|
|
92
|
+
sparsity_reverse_lut_y,
|
|
93
|
+
output,
|
|
94
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
95
|
+
sparsity_lut_o,
|
|
96
|
+
s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
97
|
+
sparsity_block_size))
|
|
98
|
+
|
|
99
|
+
return output
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def matmul_backward(ctx, grad_output):
|
|
103
|
+
x, sparsity_layout_x, y, sparsity_layout_y, sparsity_layout_o = ctx.saved_tensors
|
|
104
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
105
|
+
|
|
106
|
+
x_t, sparsity_layout_x_t = transpose(x, sparsity_layout_x, sparsity_block_size)
|
|
107
|
+
y_t, sparsity_layout_y_t = transpose(y, sparsity_layout_y, sparsity_block_size)
|
|
108
|
+
|
|
109
|
+
grad_x = matmul(grad_output, sparsity_layout_o, y_t, sparsity_layout_y_t, sparsity_layout_x,
|
|
110
|
+
sparsity_block_size)
|
|
111
|
+
grad_y = matmul(x_t, sparsity_layout_x_t, grad_output, sparsity_layout_o, sparsity_layout_y,
|
|
112
|
+
sparsity_block_size)
|
|
113
|
+
|
|
114
|
+
return grad_x, grad_y, None, None, None, None, None, None, None, None
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
@triton.autotune(
|
|
118
|
+
configs=get_autotune_configs(),
|
|
119
|
+
key=[],
|
|
120
|
+
)
|
|
121
|
+
@triton.jit
|
|
122
|
+
def matmul_kernel(x,
|
|
123
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
124
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c, s_l_x_c_s,
|
|
125
|
+
r_lut_x,
|
|
126
|
+
y,
|
|
127
|
+
y_b, y_b_s, y_r_s, y_c_s,
|
|
128
|
+
s_l_y_b, s_l_y_b_s, s_l_y_r_s, s_l_y_c_s,
|
|
129
|
+
r_lut_y,
|
|
130
|
+
o,
|
|
131
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
132
|
+
s_lut_o,
|
|
133
|
+
s_lut_o_r, s_lut_o_r_s,
|
|
134
|
+
s_lut_o_c_s,
|
|
135
|
+
sparsity_block_size,
|
|
136
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
137
|
+
# Get triton block indices
|
|
138
|
+
pid_blk = tl.program_id(axis=0)
|
|
139
|
+
pid_row = tl.program_id(axis=1)
|
|
140
|
+
pid_col = tl.program_id(axis=2)
|
|
141
|
+
|
|
142
|
+
# Get valid triton block size
|
|
143
|
+
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
144
|
+
|
|
145
|
+
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
146
|
+
spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
|
|
147
|
+
spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
148
|
+
spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
|
|
149
|
+
|
|
150
|
+
spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
|
|
151
|
+
spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
152
|
+
spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
|
|
153
|
+
|
|
154
|
+
spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
|
|
155
|
+
spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
156
|
+
spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
|
|
157
|
+
|
|
158
|
+
# Setup buffer
|
|
159
|
+
buf = tl.zeros(shape=(TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), dtype=tl.float32)
|
|
160
|
+
|
|
161
|
+
# Slide over triton block sized segments of input tensors
|
|
162
|
+
for i_seg_tri in range(0, tl.cdiv(s_l_x_c * sparsity_block_size, val_tbs)):
|
|
163
|
+
# Convert to segment index of sparsity layout
|
|
164
|
+
i_seg_spa = (i_seg_tri * val_tbs) // sparsity_block_size
|
|
165
|
+
# Calculate the triton segment index within a block
|
|
166
|
+
i_seg_tri_mod = i_seg_tri % (sparsity_block_size // val_tbs)
|
|
167
|
+
|
|
168
|
+
# Get reverse sparsity indices for input tensors x and y
|
|
169
|
+
# These are either -1 if the block is empty or equal to the index of the block in the sparse tensor
|
|
170
|
+
|
|
171
|
+
# Get reverse sparsity indices for x
|
|
172
|
+
rev_idx_spa_x_idx = (spa_bat_o * s_l_x_b_s +
|
|
173
|
+
spa_row_o * s_l_x_r_s +
|
|
174
|
+
i_seg_spa * s_l_x_c_s)
|
|
175
|
+
rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
|
|
176
|
+
rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
|
|
177
|
+
|
|
178
|
+
# Get reverse sparsity indices for y
|
|
179
|
+
rev_idx_spa_y_idx = (spa_bat_o * s_l_y_b_s + i_seg_spa * s_l_y_r_s + spa_col_o * s_l_y_c_s)
|
|
180
|
+
rev_idx_spa_y_msk = (rev_idx_spa_y_idx >= 0 and rev_idx_spa_y_idx < s_l_y_b * s_l_y_b_s)
|
|
181
|
+
rev_idx_spa_y = tl.load(r_lut_y + rev_idx_spa_y_idx, mask=rev_idx_spa_y_msk).to(tl.int32)
|
|
182
|
+
|
|
183
|
+
# If both blocks are present commence calculation
|
|
184
|
+
if rev_idx_spa_x >= 0 and rev_idx_spa_y >= 0:
|
|
185
|
+
blk_x_idx = ((rev_idx_spa_x * x_b_s) +
|
|
186
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
187
|
+
((i_seg_tri_mod * val_tbs +
|
|
188
|
+
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
189
|
+
blk_x_msk = ((blk_x_idx >= 0 and
|
|
190
|
+
blk_x_idx < x_b * x_b_s) and
|
|
191
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
192
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
193
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
194
|
+
|
|
195
|
+
blk_y_idx = ((rev_idx_spa_y * y_b_s) +
|
|
196
|
+
((i_seg_tri_mod * val_tbs +
|
|
197
|
+
tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
|
|
198
|
+
((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
|
|
199
|
+
blk_y_msk = ((blk_y_idx >= 0 and
|
|
200
|
+
blk_y_idx < y_b * y_b_s) and
|
|
201
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
202
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
203
|
+
blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
|
|
204
|
+
|
|
205
|
+
# Perform matrix multiplication
|
|
206
|
+
buf += tl.dot(blk_x, blk_y)
|
|
207
|
+
|
|
208
|
+
# Store output
|
|
209
|
+
blk_o_idx = ((pid_blk * o_b_s) +
|
|
210
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
211
|
+
((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
212
|
+
blk_o_msk = ((blk_o_idx >= 0 and
|
|
213
|
+
blk_o_idx < o_b * o_b_s) and
|
|
214
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
215
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
216
|
+
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def matmul_build_lut(lut: dict, sparsity_layout_x: Tensor, sparsity_layout_y: Tensor, sparsity_layout_output: Tensor):
|
|
220
|
+
if lut is None:
|
|
221
|
+
lut = dict()
|
|
222
|
+
|
|
223
|
+
if "sparsity_reverse_lut_x" not in lut:
|
|
224
|
+
sparsity_layout_x_flat = sparsity_layout_x.reshape(-1)
|
|
225
|
+
sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
|
|
226
|
+
(sparsity_layout_x_flat == 1) -
|
|
227
|
+
(1 * (sparsity_layout_x_flat == 0)))
|
|
228
|
+
lut["sparsity_reverse_lut_x"] = sparsity_reverse_lut_x
|
|
229
|
+
|
|
230
|
+
if "sparsity_reverse_lut_y" not in lut:
|
|
231
|
+
sparsity_layout_y_flat = sparsity_layout_y.reshape(-1)
|
|
232
|
+
sparsity_reverse_lut_y = ((torch.cumsum(sparsity_layout_y_flat, dim=-1) - 1) *
|
|
233
|
+
(sparsity_layout_y_flat == 1) -
|
|
234
|
+
(1 * (sparsity_layout_y_flat == 0)))
|
|
235
|
+
lut["sparsity_reverse_lut_y"] = sparsity_reverse_lut_y
|
|
236
|
+
|
|
237
|
+
if "sparsity_lut_o" not in lut:
|
|
238
|
+
sparsity_lut_o = torch.nonzero(sparsity_layout_output).contiguous()
|
|
239
|
+
lut["sparsity_lut_o"] = sparsity_lut_o
|
|
240
|
+
|
|
241
|
+
if "n_sparse_blocks" not in lut:
|
|
242
|
+
n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
|
|
243
|
+
lut["n_sparse_blocks"] = n_sparse_blocks
|
|
244
|
+
|
|
245
|
+
validate_contiguous(sparsity_layout_x, lut["sparsity_reverse_lut_x"],
|
|
246
|
+
sparsity_layout_y, lut["sparsity_reverse_lut_y"],
|
|
247
|
+
sparsity_layout_output, lut["sparsity_lut_o"])
|
|
248
|
+
|
|
249
|
+
return lut
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
# noinspection PyUnusedLocal
|
|
253
|
+
def matmul_setup_context(ctx, inputs, output):
|
|
254
|
+
(x, y, sparsity_layout_x, _, sparsity_layout_y, _,
|
|
255
|
+
sparsity_layout_o, _, sparsity_block_size, _) = inputs
|
|
256
|
+
|
|
257
|
+
ctx.save_for_backward(x, sparsity_layout_x, y, sparsity_layout_y, sparsity_layout_o)
|
|
258
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
matmul_forward.register_autograd(matmul_backward, setup_context=matmul_setup_context)
|
|
@@ -1,16 +1,18 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import triton
|
|
3
3
|
from torch import Tensor
|
|
4
|
+
from torch._library import triton_op
|
|
5
|
+
from torch._library.triton import wrap_triton
|
|
4
6
|
from triton import language as tl
|
|
5
7
|
|
|
6
8
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
7
|
-
from blksprs.utils.tools import
|
|
9
|
+
from blksprs.utils.tools import stride, get_autotune_configs
|
|
8
10
|
from blksprs.utils.validation import validate_contiguous, validate_device, \
|
|
9
|
-
validate_sparsity_block_size
|
|
11
|
+
validate_sparsity_block_size
|
|
10
12
|
|
|
11
13
|
|
|
12
14
|
def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
13
|
-
sparsity_block_size: int
|
|
15
|
+
sparsity_block_size: int) -> BlksprsTensor:
|
|
14
16
|
"""Performs a broadcast and subsequent addition of two dense tensors x and y. Returns a block-sparse tensor in
|
|
15
17
|
compressed form.
|
|
16
18
|
|
|
@@ -19,7 +21,6 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
|
19
21
|
y (Tensor): A dense input tensor.
|
|
20
22
|
sparsity_layout_output (Tensor): The sparsity layout of the output tensor.
|
|
21
23
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
22
|
-
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
23
24
|
|
|
24
25
|
Returns:
|
|
25
26
|
BlksprsTensor: The result of the operation as a block-sparse tensor in compressed form. Each element o(i, j) of the
|
|
@@ -34,7 +35,6 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
|
34
35
|
if x.size(-1) != y.size(-1):
|
|
35
36
|
raise ValueError("Dimensions of tensors must match")
|
|
36
37
|
validate_sparsity_block_size(sparsity_block_size)
|
|
37
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
38
38
|
|
|
39
39
|
sparsity_lut_o = torch.nonzero(sparsity_layout_output).contiguous()
|
|
40
40
|
|
|
@@ -42,6 +42,21 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
|
42
42
|
|
|
43
43
|
validate_contiguous(sparsity_layout_output, sparsity_lut_o)
|
|
44
44
|
|
|
45
|
+
return BlksprsTensor(broadcast_add_forward(x, y, sparsity_lut_o, sparsity_block_size, n_sparse_blocks))
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def broadcast_sub(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
49
|
+
sparsity_block_size: int) -> BlksprsTensor:
|
|
50
|
+
"""Wrapper for ``broadcast_add`` with negated y.
|
|
51
|
+
|
|
52
|
+
"""
|
|
53
|
+
return broadcast_add(x, torch.neg(y), sparsity_layout_output, sparsity_block_size)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@triton_op("blksprs::broadcast_add", mutates_args={})
|
|
57
|
+
def broadcast_add_forward(x: Tensor, y: Tensor,
|
|
58
|
+
sparsity_lut_o: Tensor,
|
|
59
|
+
sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
|
|
45
60
|
output = torch.zeros(n_sparse_blocks, sparsity_block_size, sparsity_block_size, dtype=x.dtype, device=x.device)
|
|
46
61
|
|
|
47
62
|
x_b, x_c = x.size()
|
|
@@ -53,14 +68,11 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
|
53
68
|
s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
|
|
54
69
|
s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_o)
|
|
55
70
|
|
|
56
|
-
if triton_block_size is None:
|
|
57
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
58
|
-
|
|
59
71
|
triton_grid = lambda meta: [o_b,
|
|
60
72
|
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
61
73
|
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
62
74
|
|
|
63
|
-
(
|
|
75
|
+
(wrap_triton(broadcast_add_kernel)[triton_grid]
|
|
64
76
|
(x,
|
|
65
77
|
x_b, x_b_s, x_c_s,
|
|
66
78
|
y,
|
|
@@ -68,35 +80,34 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
|
|
|
68
80
|
output,
|
|
69
81
|
o_b, o_b_s, o_r_s, o_c_s,
|
|
70
82
|
sparsity_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
71
|
-
sparsity_block_size
|
|
72
|
-
triton_block_size))
|
|
83
|
+
sparsity_block_size))
|
|
73
84
|
|
|
74
85
|
return BlksprsTensor(output)
|
|
75
86
|
|
|
76
87
|
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
return broadcast_add(x, torch.neg(y), sparsity_layout_output, sparsity_block_size, triton_block_size)
|
|
83
|
-
|
|
84
|
-
|
|
88
|
+
@triton.autotune(
|
|
89
|
+
configs=get_autotune_configs(),
|
|
90
|
+
key=[],
|
|
91
|
+
reset_to_zero=["o"]
|
|
92
|
+
)
|
|
85
93
|
@triton.jit
|
|
86
|
-
def
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
94
|
+
def broadcast_add_kernel(x,
|
|
95
|
+
x_b, x_b_s, x_c_s,
|
|
96
|
+
y,
|
|
97
|
+
y_b, y_b_s, y_c_s,
|
|
98
|
+
o,
|
|
99
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
100
|
+
s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
101
|
+
sparsity_block_size,
|
|
102
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
95
103
|
# Get triton block indices
|
|
96
104
|
pid_blk = tl.program_id(axis=0)
|
|
97
105
|
pid_row = tl.program_id(axis=1)
|
|
98
106
|
pid_col = tl.program_id(axis=2)
|
|
99
107
|
|
|
108
|
+
# Get valid triton block size
|
|
109
|
+
val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
|
|
110
|
+
|
|
100
111
|
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
101
112
|
spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
|
|
102
113
|
spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
@@ -112,16 +123,20 @@ def kernel_broadcast_addition(x,
|
|
|
112
123
|
|
|
113
124
|
# Load x block
|
|
114
125
|
blk_x_idx = (spa_bat_o * x_b_s +
|
|
115
|
-
((
|
|
126
|
+
((pid_row * val_tbs + spa_row_o * sparsity_block_size +
|
|
116
127
|
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
117
|
-
blk_x_msk = (blk_x_idx >= 0 and
|
|
128
|
+
blk_x_msk = ((blk_x_idx >= 0 and
|
|
129
|
+
blk_x_idx < x_b * x_b_s) and
|
|
130
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
118
131
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
119
132
|
|
|
120
133
|
# Load y block
|
|
121
134
|
blk_y_idx = (spa_bat_o * y_b_s +
|
|
122
|
-
((
|
|
135
|
+
((pid_col * val_tbs + spa_col_o * sparsity_block_size +
|
|
123
136
|
tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
|
|
124
|
-
blk_y_msk = (blk_y_idx >= 0 and
|
|
137
|
+
blk_y_msk = ((blk_y_idx >= 0 and
|
|
138
|
+
blk_y_idx < y_b * y_b_s) and
|
|
139
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
125
140
|
blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
|
|
126
141
|
|
|
127
142
|
# Compute sum
|
|
@@ -130,7 +145,10 @@ def kernel_broadcast_addition(x,
|
|
|
130
145
|
|
|
131
146
|
# Store result
|
|
132
147
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
133
|
-
((pid_row *
|
|
134
|
-
((pid_col *
|
|
135
|
-
blk_o_msk = (blk_o_idx >= 0 and
|
|
148
|
+
((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
149
|
+
((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
150
|
+
blk_o_msk = ((blk_o_idx >= 0 and
|
|
151
|
+
blk_o_idx < o_b * o_b_s) and
|
|
152
|
+
(tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
|
|
153
|
+
tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
|
|
136
154
|
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|