blksprs 0.2b4__py3-none-any.whl → 1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/layouting/distribution_layout.py +114 -0
- blksprs/layouting/sparsity_layout.py +78 -0
- blksprs/misc/broadcast_addition.py +132 -0
- blksprs/ops/conversion.py +256 -0
- blksprs/ops/distribution.py +362 -0
- blksprs/ops/exp.py +101 -0
- blksprs/ops/matmul.py +221 -0
- blksprs/ops/row_wise_sum.py +231 -0
- blksprs/ops/softmax.py +263 -0
- blksprs/ops/transpose.py +154 -0
- blksprs/utils/tools.py +20 -0
- blksprs/utils/validation.py +97 -0
- blksprs-1.1.dist-info/METADATA +164 -0
- blksprs-1.1.dist-info/RECORD +17 -0
- {blksprs-0.2b4.dist-info → blksprs-1.1.dist-info}/WHEEL +1 -1
- blksprs/ops/blocksparse.py +0 -589
- blksprs-0.2b4.dist-info/METADATA +0 -26
- blksprs-0.2b4.dist-info/RECORD +0 -6
- {blksprs-0.2b4.dist-info → blksprs-1.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,362 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
from triton import language as tl
|
|
5
|
+
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size
|
|
7
|
+
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
8
|
+
validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def gather(src: Tensor, sparsity_layout_src: Tensor, idx: Tensor, sparsity_layout_idx: Tensor,
|
|
12
|
+
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
13
|
+
"""Applies a gather operation on a block-sparse tensor in compressed form.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
src (Tensor): The source block-sparse tensor in compressed form to gather from.
|
|
17
|
+
sparsity_layout_src (Tensor): The sparsity layout of the source block-sparse tensor.
|
|
18
|
+
idx (Tensor): The block-sparse indices tensor in compressed form specifying how to gather from the source tensor.
|
|
19
|
+
sparsity_layout_idx (Tensor): The sparsity layout of the indices block-sparse tensor.
|
|
20
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
21
|
+
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
Tensor: The result of the gather operation as a block-sparse tensor in compressed form.
|
|
25
|
+
|
|
26
|
+
"""
|
|
27
|
+
validate_dimensions(src, idx)
|
|
28
|
+
validate_contiguous(src, idx)
|
|
29
|
+
validate_dtype_int(idx)
|
|
30
|
+
validate_device(src, idx)
|
|
31
|
+
validate_sparsity(sparsity_block_size, (src, sparsity_layout_src), (idx, sparsity_layout_idx))
|
|
32
|
+
validate_sparsity_block_size(sparsity_block_size, src, idx)
|
|
33
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
34
|
+
|
|
35
|
+
sparsity_layout_x_flat = sparsity_layout_src.reshape(-1)
|
|
36
|
+
sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
|
|
37
|
+
(sparsity_layout_x_flat == 1) -
|
|
38
|
+
(1 * (sparsity_layout_x_flat == 0)))
|
|
39
|
+
|
|
40
|
+
sparsity_lut_i = torch.nonzero(sparsity_layout_idx).contiguous()
|
|
41
|
+
|
|
42
|
+
validate_contiguous(sparsity_layout_src, sparsity_reverse_lut_x,
|
|
43
|
+
sparsity_layout_idx, sparsity_lut_i)
|
|
44
|
+
|
|
45
|
+
return _BlocksparseGather.apply(src, sparsity_layout_src, sparsity_reverse_lut_x,
|
|
46
|
+
idx, sparsity_layout_idx, sparsity_lut_i,
|
|
47
|
+
sparsity_block_size, triton_block_size)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class _BlocksparseGather(torch.autograd.Function):
|
|
51
|
+
|
|
52
|
+
@staticmethod
|
|
53
|
+
def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
|
|
54
|
+
i: Tensor, sparsity_layout_i: Tensor, sparsity_lut_i: Tensor,
|
|
55
|
+
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
56
|
+
output = torch.empty_like(i, dtype=x.dtype)
|
|
57
|
+
|
|
58
|
+
x_b, x_r, x_c = x.size()
|
|
59
|
+
x_b_s, x_r_s, x_c_s = x.stride()
|
|
60
|
+
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
|
|
61
|
+
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout_x.stride()
|
|
62
|
+
i_b, i_r, i_c = i.size()
|
|
63
|
+
i_b_s, i_r_s, i_c_s = i.stride()
|
|
64
|
+
s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
|
|
65
|
+
s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
|
|
66
|
+
o_b, o_r, o_c = output.size()
|
|
67
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
68
|
+
|
|
69
|
+
if triton_block_size is None:
|
|
70
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
71
|
+
|
|
72
|
+
triton_grid = lambda meta: [o_b,
|
|
73
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
74
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
75
|
+
|
|
76
|
+
(_BlocksparseGather.kernel_blocksparse_gather[triton_grid]
|
|
77
|
+
(x,
|
|
78
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
79
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
80
|
+
sparsity_reverse_lut_x,
|
|
81
|
+
i,
|
|
82
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
83
|
+
output,
|
|
84
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
85
|
+
sparsity_lut_i, s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
|
|
86
|
+
sparsity_block_size,
|
|
87
|
+
triton_block_size))
|
|
88
|
+
|
|
89
|
+
ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_i)
|
|
90
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
91
|
+
ctx.triton_block_size = triton_block_size
|
|
92
|
+
|
|
93
|
+
return output
|
|
94
|
+
|
|
95
|
+
@staticmethod
|
|
96
|
+
def backward(ctx, grad_output):
|
|
97
|
+
sparsity_layout_x, i, sparsity_layout_i = ctx.saved_tensors
|
|
98
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
99
|
+
triton_block_size = ctx.triton_block_size
|
|
100
|
+
|
|
101
|
+
return scatter_reduce(grad_output, sparsity_layout_i,
|
|
102
|
+
i,
|
|
103
|
+
sparsity_layout_x,
|
|
104
|
+
sparsity_block_size,
|
|
105
|
+
reduce_op="sum",
|
|
106
|
+
triton_block_size=triton_block_size), None, None, None, None, None, None, None
|
|
107
|
+
|
|
108
|
+
@staticmethod
|
|
109
|
+
@triton.jit
|
|
110
|
+
def kernel_blocksparse_gather(x,
|
|
111
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
112
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
113
|
+
r_lut_x,
|
|
114
|
+
i,
|
|
115
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
116
|
+
o,
|
|
117
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
118
|
+
s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
119
|
+
sparsity_block_size,
|
|
120
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
121
|
+
# Get triton block indices
|
|
122
|
+
pid_blk = tl.program_id(axis=0)
|
|
123
|
+
pid_row = tl.program_id(axis=1)
|
|
124
|
+
pid_col = tl.program_id(axis=2)
|
|
125
|
+
|
|
126
|
+
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
127
|
+
spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
|
|
128
|
+
spa_bat_o_msk = (spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
129
|
+
spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
|
|
130
|
+
|
|
131
|
+
spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
|
|
132
|
+
spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
133
|
+
spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
|
|
134
|
+
|
|
135
|
+
# Load index values
|
|
136
|
+
blk_i_idx = ((pid_blk * i_b_s) +
|
|
137
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
138
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
139
|
+
blk_i_msk = (blk_i_idx < i_b * i_b_s)
|
|
140
|
+
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
|
|
141
|
+
|
|
142
|
+
# Get positions of sparsity blocks
|
|
143
|
+
pos_spa_blk_x = blk_i // sparsity_block_size
|
|
144
|
+
pos_spa_col_x = blk_i % sparsity_block_size
|
|
145
|
+
|
|
146
|
+
# Load reverse sparsity indices for x
|
|
147
|
+
rev_idx_spa_x_idx = ((spa_bat_o * s_l_x_b_s) +
|
|
148
|
+
(spa_row_o * s_l_x_r_s) +
|
|
149
|
+
(pos_spa_blk_x * s_l_x_c_s))
|
|
150
|
+
rev_idx_spa_x_msk = (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
|
|
151
|
+
rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
|
|
152
|
+
|
|
153
|
+
# Load x values
|
|
154
|
+
blk_x_idx = ((rev_idx_spa_x * x_b_s) +
|
|
155
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
156
|
+
(pos_spa_col_x * x_c_s))
|
|
157
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
158
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
159
|
+
|
|
160
|
+
# Store output
|
|
161
|
+
blk_o_idx = ((pid_blk * o_b_s) +
|
|
162
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
163
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
164
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
165
|
+
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def scatter(src: Tensor, sparsity_layout_src: Tensor,
|
|
169
|
+
idx: Tensor,
|
|
170
|
+
sparsity_layout_tgt: Tensor,
|
|
171
|
+
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
172
|
+
"""Wrapper for ``scatter_reduce`` with ``reduce_op="none"``.
|
|
173
|
+
|
|
174
|
+
"""
|
|
175
|
+
return scatter_reduce(src, sparsity_layout_src,
|
|
176
|
+
idx,
|
|
177
|
+
sparsity_layout_tgt,
|
|
178
|
+
sparsity_block_size,
|
|
179
|
+
reduce_op="none", triton_block_size=triton_block_size)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def scatter_reduce(src: Tensor, sparsity_layout_src: Tensor,
|
|
183
|
+
idx: Tensor,
|
|
184
|
+
sparsity_layout_tgt: Tensor,
|
|
185
|
+
sparsity_block_size: int,
|
|
186
|
+
reduce_op: str = "sum", triton_block_size: int = None) -> Tensor:
|
|
187
|
+
"""Applies a scatter operation on a block-sparse tensor in compressed form.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
src (Tensor): The source block-sparse tensor in compressed form to scatter from.
|
|
191
|
+
sparsity_layout_src (Tensor): The sparsity layout of the source block-sparse tensor.
|
|
192
|
+
idx (Tensor): The block-sparse indices tensor in compressed form specifying how to scatter to the target tensor.
|
|
193
|
+
sparsity_layout_tgt (Tensor): The sparsity layout of the target block-sparse tensor.
|
|
194
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
195
|
+
reduce_op (str, optional): The reduction operation to apply during the scatter operation (default ``"sum"``).
|
|
196
|
+
Supported operations are ``"none"`` and ``"sum"``.
|
|
197
|
+
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
Tensor: The result of the scatter operation as a block-sparse tensor in compressed form.
|
|
201
|
+
|
|
202
|
+
"""
|
|
203
|
+
validate_dimensions(src, idx)
|
|
204
|
+
validate_contiguous(src, idx)
|
|
205
|
+
validate_dtype_int(idx)
|
|
206
|
+
validate_device(src, idx)
|
|
207
|
+
validate_sparsity(sparsity_block_size, (src, sparsity_layout_src), (idx, sparsity_layout_src))
|
|
208
|
+
validate_sparsity_block_size(sparsity_block_size, src, idx)
|
|
209
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
210
|
+
|
|
211
|
+
if reduce_op not in ["none", "sum"]:
|
|
212
|
+
raise ValueError(f"Reduction operation '{reduce_op}' is not supported")
|
|
213
|
+
|
|
214
|
+
sparsity_lut_x = torch.nonzero(sparsity_layout_src).contiguous()
|
|
215
|
+
|
|
216
|
+
sparsity_layout_o_flat = sparsity_layout_tgt.reshape(-1)
|
|
217
|
+
sparsity_reverse_lut_o = ((torch.cumsum(sparsity_layout_o_flat, dim=-1) - 1) *
|
|
218
|
+
(sparsity_layout_o_flat == 1) -
|
|
219
|
+
(1 * (sparsity_layout_o_flat == 0)))
|
|
220
|
+
|
|
221
|
+
n_sparse_blocks = torch.sum(sparsity_layout_tgt.to(torch.int)).item()
|
|
222
|
+
|
|
223
|
+
validate_contiguous(sparsity_layout_src, sparsity_lut_x,
|
|
224
|
+
sparsity_layout_tgt, sparsity_reverse_lut_o)
|
|
225
|
+
|
|
226
|
+
return _BlocksparseScatterReduce.apply(src, sparsity_layout_src, sparsity_lut_x,
|
|
227
|
+
idx,
|
|
228
|
+
sparsity_layout_tgt, sparsity_reverse_lut_o,
|
|
229
|
+
sparsity_block_size, n_sparse_blocks,
|
|
230
|
+
reduce_op, triton_block_size)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
class _BlocksparseScatterReduce(torch.autograd.Function):
|
|
234
|
+
|
|
235
|
+
@staticmethod
|
|
236
|
+
def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_lut_x: Tensor,
|
|
237
|
+
i: Tensor,
|
|
238
|
+
sparsity_layout_o: Tensor, sparsity_reverse_lut_o: Tensor,
|
|
239
|
+
sparsity_block_size: int, n_sparse_blocks: int,
|
|
240
|
+
reduce_op: str, triton_block_size: int) -> Tensor:
|
|
241
|
+
output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
242
|
+
dtype=x.dtype, device=x.device)
|
|
243
|
+
|
|
244
|
+
x_b, x_r, x_c = x.size()
|
|
245
|
+
x_b_s, x_r_s, x_c_s = x.stride()
|
|
246
|
+
s_lut_x_r, s_lut_x_c = sparsity_lut_x.size()
|
|
247
|
+
s_lut_x_r_s, s_lut_x_c_s = sparsity_lut_x.stride()
|
|
248
|
+
i_b, i_r, i_c = i.size()
|
|
249
|
+
i_b_s, i_r_s, i_c_s = i.stride()
|
|
250
|
+
o_b, o_r, o_c = output.size()
|
|
251
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
252
|
+
s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
|
|
253
|
+
s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_o.stride()
|
|
254
|
+
|
|
255
|
+
if triton_block_size is None:
|
|
256
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
257
|
+
|
|
258
|
+
triton_grid = lambda meta: [x_b,
|
|
259
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
260
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
261
|
+
|
|
262
|
+
reduce_op_ind = 0
|
|
263
|
+
if reduce_op == "sum":
|
|
264
|
+
reduce_op_ind = 1
|
|
265
|
+
|
|
266
|
+
(_BlocksparseScatterReduce.kernel_blocksparse_scatter[triton_grid]
|
|
267
|
+
(x,
|
|
268
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
269
|
+
sparsity_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
270
|
+
i,
|
|
271
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
272
|
+
output,
|
|
273
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
274
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
275
|
+
sparsity_reverse_lut_o,
|
|
276
|
+
reduce_op_ind,
|
|
277
|
+
sparsity_block_size,
|
|
278
|
+
triton_block_size))
|
|
279
|
+
|
|
280
|
+
ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_o)
|
|
281
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
282
|
+
ctx.reduce_op = reduce_op
|
|
283
|
+
ctx.triton_block_size = triton_block_size
|
|
284
|
+
|
|
285
|
+
return output
|
|
286
|
+
|
|
287
|
+
@staticmethod
|
|
288
|
+
def backward(ctx, grad_output):
|
|
289
|
+
sparsity_layout_x, i, sparsity_layout_o = ctx.saved_tensors
|
|
290
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
291
|
+
reduce_op = ctx.reduce_op
|
|
292
|
+
triton_block_size = ctx.triton_block_size
|
|
293
|
+
|
|
294
|
+
if reduce_op == "sum":
|
|
295
|
+
return gather(grad_output, sparsity_layout_o, i, sparsity_layout_x, sparsity_block_size,
|
|
296
|
+
triton_block_size=triton_block_size), None, None, None, None, None, None, None, None, None
|
|
297
|
+
else:
|
|
298
|
+
raise ValueError(f"Reduction operation '{reduce_op}' does not support backward pass")
|
|
299
|
+
|
|
300
|
+
@staticmethod
|
|
301
|
+
@triton.jit
|
|
302
|
+
def kernel_blocksparse_scatter(x,
|
|
303
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
304
|
+
s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
305
|
+
i,
|
|
306
|
+
i_b, i_b_s, i_r_s, i_c_s,
|
|
307
|
+
o,
|
|
308
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
309
|
+
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
310
|
+
r_lut_o,
|
|
311
|
+
reduce_op_ind,
|
|
312
|
+
sparsity_block_size,
|
|
313
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
314
|
+
# Get triton block indices
|
|
315
|
+
pid_blk = tl.program_id(axis=0)
|
|
316
|
+
pid_row = tl.program_id(axis=1)
|
|
317
|
+
pid_col = tl.program_id(axis=2)
|
|
318
|
+
|
|
319
|
+
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
320
|
+
spa_bat_x_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
|
|
321
|
+
spa_bat_x_msk = (spa_bat_x_idx < s_lut_x_r * s_lut_x_r_s)
|
|
322
|
+
spa_bat_x = tl.load(s_lut_x + spa_bat_x_idx, mask=spa_bat_x_msk)
|
|
323
|
+
|
|
324
|
+
spa_row_x_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
|
|
325
|
+
spa_row_x_msk = (spa_row_x_idx < s_lut_x_r * s_lut_x_r_s)
|
|
326
|
+
spa_row_x = tl.load(s_lut_x + spa_row_x_idx, mask=spa_row_x_msk)
|
|
327
|
+
|
|
328
|
+
# Load x values
|
|
329
|
+
blk_x_idx = ((pid_blk * x_b_s) +
|
|
330
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
331
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
332
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
333
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
334
|
+
|
|
335
|
+
# Load index values
|
|
336
|
+
blk_i_idx = ((pid_blk * i_b_s) +
|
|
337
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
338
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
339
|
+
blk_i_msk = (blk_i_idx < i_b * i_b_s)
|
|
340
|
+
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
|
|
341
|
+
|
|
342
|
+
# Get positions of sparsity blocks
|
|
343
|
+
pos_spa_blk_o = blk_i // sparsity_block_size
|
|
344
|
+
pos_spa_col_o = blk_i % sparsity_block_size
|
|
345
|
+
|
|
346
|
+
# Load reverse sparsity indices for o
|
|
347
|
+
rev_idx_spa_o_idx = ((spa_bat_x * s_l_o_b_s) +
|
|
348
|
+
(spa_row_x * s_l_o_r_s) +
|
|
349
|
+
(pos_spa_blk_o * s_l_o_c_s))
|
|
350
|
+
rev_idx_spa_o_msk = (rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
|
|
351
|
+
rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
|
|
352
|
+
|
|
353
|
+
# Store output
|
|
354
|
+
blk_o_idx = ((rev_idx_spa_o * o_b_s) +
|
|
355
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
356
|
+
(pos_spa_col_o * o_c_s))
|
|
357
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
358
|
+
|
|
359
|
+
if reduce_op_ind == 0:
|
|
360
|
+
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
361
|
+
elif reduce_op_ind == 1:
|
|
362
|
+
tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
blksprs/ops/exp.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
from triton import language as tl
|
|
5
|
+
|
|
6
|
+
from blksprs.utils.tools import get_triton_block_size
|
|
7
|
+
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
8
|
+
validate_sparsity_block_size, validate_triton_block_size
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def exp(x: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
12
|
+
"""Applies the element-wise exponential function to a block-sparse tensor.
|
|
13
|
+
|
|
14
|
+
Note:
|
|
15
|
+
This operation does not consider sparse blocks, i.e., these will not be set to ``e^0``.
|
|
16
|
+
Consider this when converting back to tensors in regular form.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
x (Tensor): A block-sparse tensor in compressed form.
|
|
20
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
21
|
+
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
Tensor: The exponential function applied to all elements of the input tensor as a block-sparse tensor in
|
|
25
|
+
compressed form.
|
|
26
|
+
|
|
27
|
+
"""
|
|
28
|
+
validate_dimensions(x)
|
|
29
|
+
validate_contiguous(x)
|
|
30
|
+
validate_device(x)
|
|
31
|
+
validate_sparsity_block_size(sparsity_block_size, x)
|
|
32
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
33
|
+
|
|
34
|
+
return _BlocksparseExp.apply(x, sparsity_block_size, triton_block_size)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class _BlocksparseExp(torch.autograd.Function):
|
|
38
|
+
|
|
39
|
+
@staticmethod
|
|
40
|
+
def forward(ctx, x: Tensor, sparsity_block_size: int, triton_block_size: int) -> Tensor:
|
|
41
|
+
output = torch.empty_like(x)
|
|
42
|
+
|
|
43
|
+
x_b, x_r, x_c = x.shape
|
|
44
|
+
x_b_s, x_r_s, x_c_s = x.stride()
|
|
45
|
+
o_b, o_r, o_c = output.shape
|
|
46
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
47
|
+
|
|
48
|
+
if triton_block_size is None:
|
|
49
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
50
|
+
|
|
51
|
+
triton_grid = lambda meta: [o_b,
|
|
52
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
53
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
54
|
+
|
|
55
|
+
(_BlocksparseExp.kernel_blocksparse_exp[triton_grid]
|
|
56
|
+
(x,
|
|
57
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
58
|
+
output,
|
|
59
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
60
|
+
triton_block_size))
|
|
61
|
+
|
|
62
|
+
ctx.save_for_backward(output)
|
|
63
|
+
|
|
64
|
+
return output
|
|
65
|
+
|
|
66
|
+
@staticmethod
|
|
67
|
+
def backward(ctx, grad_output):
|
|
68
|
+
o = ctx.saved_tensors[0]
|
|
69
|
+
|
|
70
|
+
grad_x = torch.mul(grad_output, o)
|
|
71
|
+
|
|
72
|
+
return grad_x, None, None
|
|
73
|
+
|
|
74
|
+
@staticmethod
|
|
75
|
+
@triton.jit
|
|
76
|
+
def kernel_blocksparse_exp(x,
|
|
77
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
78
|
+
o,
|
|
79
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
80
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
81
|
+
# Get triton block indices
|
|
82
|
+
pid_blk = tl.program_id(axis=0)
|
|
83
|
+
pid_row = tl.program_id(axis=1)
|
|
84
|
+
pid_col = tl.program_id(axis=2)
|
|
85
|
+
|
|
86
|
+
# Load block
|
|
87
|
+
blk_x_idx = ((pid_blk * x_b_s) +
|
|
88
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
89
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
90
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
91
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
92
|
+
|
|
93
|
+
# Compute exp
|
|
94
|
+
buf = tl.exp(blk_x)
|
|
95
|
+
|
|
96
|
+
# Store block
|
|
97
|
+
blk_o_idx = ((pid_blk * o_b_s) +
|
|
98
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
99
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
100
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
101
|
+
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
blksprs/ops/matmul.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
from triton import language as tl
|
|
5
|
+
|
|
6
|
+
from blksprs.ops.transpose import transpose
|
|
7
|
+
from blksprs.utils.tools import get_triton_block_size
|
|
8
|
+
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
9
|
+
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def matmul(x: Tensor, sparsity_layout_x: Tensor,
|
|
13
|
+
y: Tensor, sparsity_layout_y: Tensor,
|
|
14
|
+
sparsity_layout_output: Tensor,
|
|
15
|
+
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
16
|
+
"""Performs matrix multiplication between two block-sparse tensors.
|
|
17
|
+
|
|
18
|
+
The sparsity layout of the output tensor is used to only calculate blocks that will be present in the output.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
x (Tensor): A block-sparse tensor in compressed form.
|
|
22
|
+
y (Tensor): A block-sparse tensor in compressed form.
|
|
23
|
+
sparsity_layout_x (Tensor): The sparsity layout of the first block-sparse tensor.
|
|
24
|
+
sparsity_layout_y (Tensor): The sparsity layout of the second block-sparse tensor.
|
|
25
|
+
sparsity_layout_output (Tensor): The sparsity layout of the output tensor.
|
|
26
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
27
|
+
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
Tensor: The result of the matrix multiplication as a block-sparse tensor in compressed form.
|
|
31
|
+
|
|
32
|
+
"""
|
|
33
|
+
validate_dimensions(x, y)
|
|
34
|
+
validate_contiguous(x, y)
|
|
35
|
+
validate_device(x, y)
|
|
36
|
+
validate_sparsity(sparsity_block_size, (x, sparsity_layout_x), (y, sparsity_layout_y))
|
|
37
|
+
if sparsity_layout_x.size(-1) != sparsity_layout_y.size(-2):
|
|
38
|
+
raise ValueError("Inner dimensions of tensors must match")
|
|
39
|
+
validate_sparsity_block_size(sparsity_block_size, x, y)
|
|
40
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
41
|
+
|
|
42
|
+
sparsity_layout_x_flat = sparsity_layout_x.reshape(-1)
|
|
43
|
+
sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
|
|
44
|
+
(sparsity_layout_x_flat == 1) -
|
|
45
|
+
(1 * (sparsity_layout_x_flat == 0)))
|
|
46
|
+
|
|
47
|
+
sparsity_layout_y_flat = sparsity_layout_y.reshape(-1)
|
|
48
|
+
sparsity_reverse_lut_y = ((torch.cumsum(sparsity_layout_y_flat, dim=-1) - 1) *
|
|
49
|
+
(sparsity_layout_y_flat == 1) -
|
|
50
|
+
(1 * (sparsity_layout_y_flat == 0)))
|
|
51
|
+
|
|
52
|
+
sparsity_lut_o = torch.nonzero(sparsity_layout_output).contiguous()
|
|
53
|
+
|
|
54
|
+
n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
|
|
55
|
+
|
|
56
|
+
validate_contiguous(sparsity_layout_x, sparsity_reverse_lut_x,
|
|
57
|
+
sparsity_layout_y, sparsity_reverse_lut_y,
|
|
58
|
+
sparsity_layout_output, sparsity_lut_o)
|
|
59
|
+
|
|
60
|
+
return _BlocksparseMatmulSSS.apply(x, y,
|
|
61
|
+
sparsity_layout_x, sparsity_reverse_lut_x,
|
|
62
|
+
sparsity_layout_y, sparsity_reverse_lut_y,
|
|
63
|
+
sparsity_layout_output, sparsity_lut_o,
|
|
64
|
+
sparsity_block_size,
|
|
65
|
+
n_sparse_blocks,
|
|
66
|
+
triton_block_size)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class _BlocksparseMatmulSSS(torch.autograd.Function):
|
|
70
|
+
|
|
71
|
+
@staticmethod
|
|
72
|
+
def forward(ctx, x: Tensor, y: Tensor,
|
|
73
|
+
sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
|
|
74
|
+
sparsity_layout_y: Tensor, sparsity_reverse_lut_y: Tensor,
|
|
75
|
+
sparsity_layout_o: Tensor, sparsity_lut_o: Tensor,
|
|
76
|
+
sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
77
|
+
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size), device=x.device)
|
|
78
|
+
|
|
79
|
+
x_b, x_r, x_c = x.size()
|
|
80
|
+
x_b_s, x_r_s, x_c_s = x.stride()
|
|
81
|
+
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
|
|
82
|
+
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout_x.stride()
|
|
83
|
+
y_b, y_r, y_c = y.size()
|
|
84
|
+
y_b_s, y_r_s, y_c_s = y.stride()
|
|
85
|
+
s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_y.size()
|
|
86
|
+
s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = sparsity_layout_y.stride()
|
|
87
|
+
o_b, o_r, o_c = output.size()
|
|
88
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
89
|
+
s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
|
|
90
|
+
s_lut_o_r_s, s_lut_o_c_s = sparsity_lut_o.stride()
|
|
91
|
+
|
|
92
|
+
if triton_block_size is None:
|
|
93
|
+
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
94
|
+
|
|
95
|
+
triton_grid = lambda meta: [o_b,
|
|
96
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
97
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
98
|
+
|
|
99
|
+
(_BlocksparseMatmulSSS.kernel_blocksparse_matmul_sss[triton_grid]
|
|
100
|
+
(x,
|
|
101
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
102
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c, s_l_x_c_s,
|
|
103
|
+
sparsity_reverse_lut_x,
|
|
104
|
+
y,
|
|
105
|
+
y_b, y_b_s, y_r_s, y_c_s,
|
|
106
|
+
s_l_y_b, s_l_y_b_s, s_l_y_r_s, s_l_y_c_s,
|
|
107
|
+
sparsity_reverse_lut_y,
|
|
108
|
+
output,
|
|
109
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
110
|
+
sparsity_lut_o,
|
|
111
|
+
s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
112
|
+
sparsity_block_size,
|
|
113
|
+
triton_block_size))
|
|
114
|
+
|
|
115
|
+
ctx.save_for_backward(x, sparsity_layout_x, y, sparsity_layout_y, sparsity_layout_o)
|
|
116
|
+
ctx.sparsity_block_size = sparsity_block_size
|
|
117
|
+
ctx.triton_block_size = triton_block_size
|
|
118
|
+
|
|
119
|
+
return output
|
|
120
|
+
|
|
121
|
+
@staticmethod
|
|
122
|
+
def backward(ctx, grad_output):
|
|
123
|
+
x, sparsity_layout_x, y, sparsity_layout_y, sparsity_layout_o = ctx.saved_tensors
|
|
124
|
+
sparsity_block_size = ctx.sparsity_block_size
|
|
125
|
+
triton_block_size = ctx.triton_block_size
|
|
126
|
+
|
|
127
|
+
x_t, sparsity_layout_x_t = transpose(x, sparsity_layout_x, sparsity_block_size, triton_block_size)
|
|
128
|
+
y_t, sparsity_layout_y_t = transpose(y, sparsity_layout_y, sparsity_block_size, triton_block_size)
|
|
129
|
+
|
|
130
|
+
grad_x = matmul(grad_output, sparsity_layout_o, y_t, sparsity_layout_y_t, sparsity_layout_x,
|
|
131
|
+
sparsity_block_size, triton_block_size)
|
|
132
|
+
grad_y = matmul(x_t, sparsity_layout_x_t, grad_output, sparsity_layout_o, sparsity_layout_y,
|
|
133
|
+
sparsity_block_size, triton_block_size)
|
|
134
|
+
|
|
135
|
+
return grad_x, grad_y, None, None, None, None, None, None, None, None, None
|
|
136
|
+
|
|
137
|
+
@staticmethod
|
|
138
|
+
@triton.jit
|
|
139
|
+
def kernel_blocksparse_matmul_sss(x,
|
|
140
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
141
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c, s_l_x_c_s,
|
|
142
|
+
r_lut_x,
|
|
143
|
+
y,
|
|
144
|
+
y_b, y_b_s, y_r_s, y_c_s,
|
|
145
|
+
s_l_y_b, s_l_y_b_s, s_l_y_r_s, s_l_y_c_s,
|
|
146
|
+
r_lut_y,
|
|
147
|
+
o,
|
|
148
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
149
|
+
s_lut_o,
|
|
150
|
+
s_lut_o_r, s_lut_o_r_s,
|
|
151
|
+
s_lut_o_c_s,
|
|
152
|
+
sparsity_block_size,
|
|
153
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
154
|
+
# Get triton block indices
|
|
155
|
+
pid_blk = tl.program_id(axis=0)
|
|
156
|
+
pid_row = tl.program_id(axis=1)
|
|
157
|
+
pid_col = tl.program_id(axis=2)
|
|
158
|
+
|
|
159
|
+
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
160
|
+
spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
|
|
161
|
+
spa_bat_o_msk = (spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
162
|
+
spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
|
|
163
|
+
|
|
164
|
+
spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
|
|
165
|
+
spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
166
|
+
spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
|
|
167
|
+
|
|
168
|
+
spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
|
|
169
|
+
spa_col_o_msk = (spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
170
|
+
spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
|
|
171
|
+
|
|
172
|
+
# Setup buffer
|
|
173
|
+
buf = tl.zeros(shape=(TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), dtype=tl.float32)
|
|
174
|
+
|
|
175
|
+
# Slide over triton block sized segments of input tensors
|
|
176
|
+
for i_seg_tri in range(0, tl.cdiv(s_l_x_c * sparsity_block_size, TRITON_BLOCK_SIZE)):
|
|
177
|
+
# Convert to segment index of sparsity layout
|
|
178
|
+
i_seg_spa = (i_seg_tri * TRITON_BLOCK_SIZE) // sparsity_block_size
|
|
179
|
+
# Calculate the triton segment index within a block
|
|
180
|
+
i_seg_tri_mod = i_seg_tri % (sparsity_block_size // TRITON_BLOCK_SIZE)
|
|
181
|
+
|
|
182
|
+
# Get reverse sparsity indices for input tensors x and y
|
|
183
|
+
# These are either -1 if the block is empty or equal to the index of the block in the sparse tensor
|
|
184
|
+
|
|
185
|
+
# Get reverse sparsity indices for x
|
|
186
|
+
rev_idx_spa_x_idx = (spa_bat_o * s_l_x_b_s +
|
|
187
|
+
spa_row_o * s_l_x_r_s +
|
|
188
|
+
i_seg_spa * s_l_x_c_s)
|
|
189
|
+
rev_idx_spa_x_msk = (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
|
|
190
|
+
rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
|
|
191
|
+
|
|
192
|
+
# Get reverse sparsity indices for y
|
|
193
|
+
rev_idx_spa_y_idx = (spa_bat_o * s_l_y_b_s + i_seg_spa * s_l_y_r_s + spa_col_o * s_l_y_c_s)
|
|
194
|
+
rev_idx_spa_y_msk = (rev_idx_spa_y_idx < s_l_y_b * s_l_y_b_s)
|
|
195
|
+
rev_idx_spa_y = tl.load(r_lut_y + rev_idx_spa_y_idx, mask=rev_idx_spa_y_msk).to(tl.int32)
|
|
196
|
+
|
|
197
|
+
# If both blocks are present commence calculation
|
|
198
|
+
if rev_idx_spa_x >= 0 and rev_idx_spa_y >= 0:
|
|
199
|
+
blk_x_idx = ((rev_idx_spa_x * x_b_s) +
|
|
200
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
201
|
+
((i_seg_tri_mod * TRITON_BLOCK_SIZE +
|
|
202
|
+
tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
203
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
204
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
205
|
+
|
|
206
|
+
blk_y_idx = ((rev_idx_spa_y * y_b_s) +
|
|
207
|
+
((i_seg_tri_mod * TRITON_BLOCK_SIZE +
|
|
208
|
+
tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
|
|
209
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
|
|
210
|
+
blk_y_msk = (blk_y_idx < y_b * y_b_s)
|
|
211
|
+
blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
|
|
212
|
+
|
|
213
|
+
# Perform matrix multiplication
|
|
214
|
+
buf += tl.dot(blk_x, blk_y)
|
|
215
|
+
|
|
216
|
+
# Store output
|
|
217
|
+
blk_o_idx = ((pid_blk * o_b_s) +
|
|
218
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
219
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
220
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
221
|
+
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|