blksprs 1.0__py3-none-any.whl → 1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/layouting/distribution_layout.py +114 -0
- blksprs/layouting/sparsity_layout.py +129 -7
- blksprs/misc/broadcast_addition.py +132 -0
- blksprs/ops/conversion.py +237 -17
- blksprs/ops/distribution.py +362 -0
- blksprs/ops/exp.py +18 -8
- blksprs/ops/{matmul_sss.py → matmul.py} +28 -26
- blksprs/ops/row_wise_sum.py +21 -5
- blksprs/ops/softmax.py +23 -12
- blksprs/ops/transpose.py +19 -7
- blksprs/utils/tools.py +1 -28
- blksprs/utils/validation.py +53 -1
- {blksprs-1.0.dist-info → blksprs-1.2.dist-info}/METADATA +39 -14
- blksprs-1.2.dist-info/RECORD +17 -0
- {blksprs-1.0.dist-info → blksprs-1.2.dist-info}/WHEEL +1 -1
- blksprs-1.0.dist-info/RECORD +0 -14
- {blksprs-1.0.dist-info → blksprs-1.2.dist-info}/top_level.txt +0 -0
blksprs/ops/conversion.py
CHANGED
|
@@ -1,24 +1,39 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
1
3
|
import torch
|
|
2
4
|
import triton
|
|
3
5
|
from torch import Tensor
|
|
4
6
|
from triton import language as tl
|
|
5
7
|
|
|
8
|
+
from blksprs.layouting.sparsity_layout import build_sparsity_layout_adaption
|
|
6
9
|
from blksprs.utils.tools import get_triton_block_size
|
|
7
|
-
from blksprs.utils.validation import validate_contiguous, validate_dimensions,
|
|
10
|
+
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
11
|
+
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
8
12
|
|
|
9
13
|
|
|
10
14
|
def to_dense(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
|
|
11
15
|
triton_block_size: int = None) -> Tensor:
|
|
12
|
-
"""Converts a
|
|
16
|
+
"""Converts a block-sparse tensor in compressed form to a block-sparse tensor in regular form based on the given
|
|
17
|
+
sparsity layout.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
x (Tensor): A block-sparse tensor in compressed form.
|
|
21
|
+
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
22
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
23
|
+
fill_value (float): The value to fill the resulting dense tensor with where the block-sparse tensor is not
|
|
24
|
+
present (default ``0``).
|
|
25
|
+
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
13
26
|
|
|
14
|
-
|
|
15
|
-
|
|
27
|
+
Returns:
|
|
28
|
+
Tensor: The block-sparse tensor converted to regular form.
|
|
16
29
|
|
|
17
30
|
"""
|
|
18
31
|
validate_dimensions(x)
|
|
19
32
|
validate_contiguous(x, sparsity_layout)
|
|
20
|
-
validate_dtype_float(x)
|
|
21
33
|
validate_device(x)
|
|
34
|
+
validate_sparsity(sparsity_block_size, (x, sparsity_layout))
|
|
35
|
+
validate_sparsity_block_size(sparsity_block_size, x)
|
|
36
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
22
37
|
|
|
23
38
|
sparsity_layout_flat = sparsity_layout.reshape(-1)
|
|
24
39
|
sparsity_reverse_lut = ((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
|
|
@@ -27,6 +42,9 @@ def to_dense(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_
|
|
|
27
42
|
|
|
28
43
|
validate_contiguous(sparsity_reverse_lut)
|
|
29
44
|
|
|
45
|
+
if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
|
|
46
|
+
return x
|
|
47
|
+
|
|
30
48
|
return _BlocksparseToDense.apply(x,
|
|
31
49
|
sparsity_layout, sparsity_reverse_lut,
|
|
32
50
|
sparsity_block_size, fill_value,
|
|
@@ -68,7 +86,7 @@ class _BlocksparseToDense(torch.autograd.Function):
|
|
|
68
86
|
sparsity_block_size,
|
|
69
87
|
triton_block_size))
|
|
70
88
|
|
|
71
|
-
ctx.sparsity_layout
|
|
89
|
+
ctx.save_for_backward(sparsity_layout)
|
|
72
90
|
ctx.sparsity_block_size = sparsity_block_size
|
|
73
91
|
ctx.triton_block_size = triton_block_size
|
|
74
92
|
|
|
@@ -76,11 +94,12 @@ class _BlocksparseToDense(torch.autograd.Function):
|
|
|
76
94
|
|
|
77
95
|
@staticmethod
|
|
78
96
|
def backward(ctx, grad_output):
|
|
79
|
-
sparsity_layout = ctx.
|
|
97
|
+
sparsity_layout = ctx.saved_tensors[0]
|
|
80
98
|
sparsity_block_size = ctx.sparsity_block_size
|
|
81
99
|
triton_block_size = ctx.triton_block_size
|
|
82
100
|
|
|
83
|
-
return to_sparse(grad_output, sparsity_layout, sparsity_block_size,
|
|
101
|
+
return to_sparse(grad_output, sparsity_layout, sparsity_block_size,
|
|
102
|
+
triton_block_size), None, None, None, None, None
|
|
84
103
|
|
|
85
104
|
@staticmethod
|
|
86
105
|
@triton.jit
|
|
@@ -124,18 +143,32 @@ class _BlocksparseToDense(torch.autograd.Function):
|
|
|
124
143
|
|
|
125
144
|
|
|
126
145
|
def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
127
|
-
"""Converts a
|
|
146
|
+
"""Converts a block-sparse tensor in regular form to a block-sparse tensor in compressed form based on the given
|
|
147
|
+
sparsity layout.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
x (Tensor): A block-sparse tensor in regular form.
|
|
151
|
+
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
152
|
+
sparsity_block_size (int): The size of the sparsity blocks.
|
|
153
|
+
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
Tensor: The block-sparse tensor converted to compressed form.
|
|
128
157
|
|
|
129
158
|
"""
|
|
130
159
|
validate_dimensions(x)
|
|
131
|
-
validate_contiguous(x
|
|
132
|
-
validate_dtype_float(x)
|
|
160
|
+
validate_contiguous(x)
|
|
133
161
|
validate_device(x)
|
|
162
|
+
validate_sparsity_block_size(sparsity_block_size, x)
|
|
163
|
+
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
134
164
|
|
|
135
165
|
sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
|
|
136
166
|
n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
|
|
137
167
|
|
|
138
|
-
validate_contiguous(sparsity_lut)
|
|
168
|
+
validate_contiguous(sparsity_layout, sparsity_lut)
|
|
169
|
+
|
|
170
|
+
if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
|
|
171
|
+
return x
|
|
139
172
|
|
|
140
173
|
return _BlocksparseToSparse.apply(x,
|
|
141
174
|
sparsity_layout, sparsity_lut,
|
|
@@ -149,14 +182,15 @@ class _BlocksparseToSparse(torch.autograd.Function):
|
|
|
149
182
|
def forward(ctx, x: Tensor,
|
|
150
183
|
sparsity_layout: Tensor, sparsity_lut: Tensor,
|
|
151
184
|
sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
152
|
-
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
|
|
185
|
+
output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size), dtype=x.dtype,
|
|
186
|
+
device=x.device)
|
|
153
187
|
|
|
154
188
|
x_b, x_r, x_c = x.size()
|
|
155
189
|
x_b_s, x_r_s, x_c_s = x.stride()
|
|
156
|
-
o_b, o_r, o_c = output.size()
|
|
157
|
-
o_b_s, o_r_s, o_c_s = output.stride()
|
|
158
190
|
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
159
191
|
s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
|
|
192
|
+
o_b, o_r, o_c = output.size()
|
|
193
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
160
194
|
|
|
161
195
|
if triton_block_size is None:
|
|
162
196
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -172,7 +206,7 @@ class _BlocksparseToSparse(torch.autograd.Function):
|
|
|
172
206
|
sparsity_block_size,
|
|
173
207
|
triton_block_size))
|
|
174
208
|
|
|
175
|
-
ctx.sparsity_layout
|
|
209
|
+
ctx.save_for_backward(sparsity_layout)
|
|
176
210
|
ctx.sparsity_block_size = sparsity_block_size
|
|
177
211
|
ctx.triton_block_size = triton_block_size
|
|
178
212
|
|
|
@@ -180,7 +214,7 @@ class _BlocksparseToSparse(torch.autograd.Function):
|
|
|
180
214
|
|
|
181
215
|
@staticmethod
|
|
182
216
|
def backward(ctx, grad_output):
|
|
183
|
-
sparsity_layout = ctx.
|
|
217
|
+
sparsity_layout = ctx.saved_tensors[0]
|
|
184
218
|
sparsity_block_size = ctx.sparsity_block_size
|
|
185
219
|
triton_block_size = ctx.triton_block_size
|
|
186
220
|
|
|
@@ -229,3 +263,189 @@ class _BlocksparseToSparse(torch.autograd.Function):
|
|
|
229
263
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
|
|
230
264
|
blk_o_msk = (blk_o_idx < (pid_blk + 1) * o_b_s)
|
|
231
265
|
tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def adapt_layout(x: Tensor, sparsity_layout_from: Tensor, sparsity_block_size_from: int, sparsity_block_size_to: int,
|
|
269
|
+
preprocess_data: dict = None, triton_block_size: int = None) -> Tensor:
|
|
270
|
+
"""Adapts the sparsity layout of a block-sparse tensor, resulting in a new block-sparse tensor in compressed form
|
|
271
|
+
conforming to the new sparsity layout (and sparsity block size) definition.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
x (Tensor): A block-sparse tensor in compressed form.
|
|
275
|
+
sparsity_layout_from (Tensor): The sparsity layout of the input block-sparse tensor.
|
|
276
|
+
sparsity_block_size_from (int): The size of the sparsity blocks of the input sparsity layout.
|
|
277
|
+
sparsity_block_size_to (int): The size of the sparsity blocks of the output sparsity layout.
|
|
278
|
+
preprocess_data (dict): A dictionary containing data otherwise computed by the function (default ``None``).
|
|
279
|
+
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
Tensor: The block-sparse tensor in compressed form with the adapted sparsity layout and sparsity block size.
|
|
283
|
+
|
|
284
|
+
"""
|
|
285
|
+
validate_dimensions(x)
|
|
286
|
+
validate_contiguous(x, sparsity_layout_from)
|
|
287
|
+
validate_device(x)
|
|
288
|
+
validate_sparsity(sparsity_block_size_from, (x, sparsity_layout_from))
|
|
289
|
+
validate_sparsity_block_size(sparsity_block_size_from, x)
|
|
290
|
+
validate_sparsity_block_size(sparsity_block_size_to)
|
|
291
|
+
min_sparsity_block_size = min(sparsity_block_size_from, sparsity_block_size_to)
|
|
292
|
+
validate_triton_block_size(triton_block_size, min_sparsity_block_size)
|
|
293
|
+
|
|
294
|
+
if preprocess_data is None:
|
|
295
|
+
preprocess_data = {}
|
|
296
|
+
|
|
297
|
+
if "sparsity_reverse_lut_from" not in preprocess_data:
|
|
298
|
+
sparsity_layout_from_flat = sparsity_layout_from.reshape(-1)
|
|
299
|
+
sparsity_reverse_lut_from = ((torch.cumsum(sparsity_layout_from_flat, dim=-1) - 1) *
|
|
300
|
+
(sparsity_layout_from_flat == 1) -
|
|
301
|
+
(1 * (sparsity_layout_from_flat == 0)))
|
|
302
|
+
else:
|
|
303
|
+
sparsity_reverse_lut_from = preprocess_data["sparsity_reverse_lut_from"]
|
|
304
|
+
|
|
305
|
+
if "sparsity_layout_to" not in preprocess_data:
|
|
306
|
+
sparsity_layout_to = build_sparsity_layout_adaption(x, sparsity_layout_from,
|
|
307
|
+
sparsity_block_size_from, sparsity_block_size_to,
|
|
308
|
+
triton_block_size)
|
|
309
|
+
else:
|
|
310
|
+
sparsity_layout_to = preprocess_data["sparsity_layout_to"]
|
|
311
|
+
|
|
312
|
+
if "sparsity_lut_to" not in preprocess_data:
|
|
313
|
+
sparsity_lut_to = torch.nonzero(sparsity_layout_to).contiguous()
|
|
314
|
+
else:
|
|
315
|
+
sparsity_lut_to = preprocess_data["sparsity_lut_to"]
|
|
316
|
+
|
|
317
|
+
if "n_sparse_blocks_to" not in preprocess_data:
|
|
318
|
+
n_sparse_blocks_to = torch.sum(sparsity_layout_to.to(torch.int)).item()
|
|
319
|
+
else:
|
|
320
|
+
n_sparse_blocks_to = preprocess_data["n_sparse_blocks_to"]
|
|
321
|
+
|
|
322
|
+
validate_contiguous(sparsity_layout_to, sparsity_reverse_lut_from, sparsity_lut_to)
|
|
323
|
+
|
|
324
|
+
if (sparsity_block_size_from == sparsity_block_size_to) and torch.equal(sparsity_layout_from, sparsity_layout_to):
|
|
325
|
+
return x
|
|
326
|
+
|
|
327
|
+
return _BlocksparseAdaptLayout.apply(x,
|
|
328
|
+
sparsity_layout_from, sparsity_reverse_lut_from, sparsity_block_size_from,
|
|
329
|
+
sparsity_layout_to, sparsity_lut_to, sparsity_block_size_to,
|
|
330
|
+
n_sparse_blocks_to, min_sparsity_block_size, triton_block_size)
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
class _BlocksparseAdaptLayout(torch.autograd.Function):
|
|
334
|
+
|
|
335
|
+
@staticmethod
|
|
336
|
+
def forward(ctx, x: Tensor,
|
|
337
|
+
sparsity_layout_from: Tensor, sparsity_reverse_lut_from: Tensor, sparsity_block_size_from: int,
|
|
338
|
+
sparsity_layout_to: Tensor, sparsity_lut_to: Tensor, sparsity_block_size_to: int,
|
|
339
|
+
n_sparse_blocks_to: int, min_sparsity_block_size: int, triton_block_size: int) -> Tensor:
|
|
340
|
+
output = torch.zeros(size=(n_sparse_blocks_to, sparsity_block_size_to, sparsity_block_size_to),
|
|
341
|
+
dtype=x.dtype, device=x.device)
|
|
342
|
+
|
|
343
|
+
x_b, x_r, x_c = x.size()
|
|
344
|
+
x_b_s, x_r_s, x_c_s = x.stride()
|
|
345
|
+
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_from.size()
|
|
346
|
+
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout_from.stride()
|
|
347
|
+
o_b, o_r, o_c = output.size()
|
|
348
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
349
|
+
s_lut_o_r, s_lut_o_c = sparsity_lut_to.size()
|
|
350
|
+
s_lut_o_r_s, s_lut_o_c_s = sparsity_lut_to.stride()
|
|
351
|
+
|
|
352
|
+
if triton_block_size is None:
|
|
353
|
+
triton_block_size = get_triton_block_size(min_sparsity_block_size)
|
|
354
|
+
|
|
355
|
+
triton_grid = lambda meta: [o_b,
|
|
356
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
357
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
358
|
+
|
|
359
|
+
(_BlocksparseAdaptLayout.kernel_adapt_layout[triton_grid]
|
|
360
|
+
(x,
|
|
361
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
362
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
363
|
+
sparsity_reverse_lut_from,
|
|
364
|
+
output,
|
|
365
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
366
|
+
sparsity_lut_to, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
367
|
+
sparsity_block_size_from,
|
|
368
|
+
sparsity_block_size_to,
|
|
369
|
+
triton_block_size))
|
|
370
|
+
|
|
371
|
+
ctx.save_for_backward(x, sparsity_layout_from, sparsity_layout_to)
|
|
372
|
+
ctx.sparsity_block_size_from = sparsity_block_size_from
|
|
373
|
+
ctx.sparsity_block_size_to = sparsity_block_size_to
|
|
374
|
+
ctx.triton_block_size = triton_block_size
|
|
375
|
+
|
|
376
|
+
return output
|
|
377
|
+
|
|
378
|
+
@staticmethod
|
|
379
|
+
def backward(ctx, grad_output):
|
|
380
|
+
x, sparsity_layout_from, sparsity_layout_to = ctx.saved_tensors
|
|
381
|
+
sparsity_block_size_from = ctx.sparsity_block_size_from
|
|
382
|
+
sparsity_block_size_to = ctx.sparsity_block_size_to
|
|
383
|
+
triton_block_size = ctx.triton_block_size
|
|
384
|
+
|
|
385
|
+
return adapt_layout(grad_output, sparsity_layout_to, sparsity_block_size_to, sparsity_block_size_from,
|
|
386
|
+
preprocess_data={"sparsity_layout_to": sparsity_layout_from},
|
|
387
|
+
triton_block_size=triton_block_size), None, None, None, None, None, None, None, None, None
|
|
388
|
+
|
|
389
|
+
@staticmethod
|
|
390
|
+
@triton.jit
|
|
391
|
+
def kernel_adapt_layout(x,
|
|
392
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
393
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
394
|
+
r_lut_x,
|
|
395
|
+
o,
|
|
396
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
397
|
+
s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
398
|
+
sparsity_block_size_from,
|
|
399
|
+
sparsity_block_size_to,
|
|
400
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
401
|
+
# Get triton block indices
|
|
402
|
+
pid_blk = tl.program_id(axis=0)
|
|
403
|
+
pid_row = tl.program_id(axis=1)
|
|
404
|
+
pid_col = tl.program_id(axis=2)
|
|
405
|
+
|
|
406
|
+
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
407
|
+
spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
|
|
408
|
+
spa_bat_o_msk = (spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
409
|
+
spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
|
|
410
|
+
|
|
411
|
+
spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
|
|
412
|
+
spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
413
|
+
spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
|
|
414
|
+
|
|
415
|
+
spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
|
|
416
|
+
spa_col_o_msk = (spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
417
|
+
spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
|
|
418
|
+
|
|
419
|
+
# Get equivalent sparsity block in from layout
|
|
420
|
+
spa_bat_x = spa_bat_o
|
|
421
|
+
spa_row_x = (spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size_from
|
|
422
|
+
spa_col_x = (spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size_from
|
|
423
|
+
|
|
424
|
+
# # Get reverse sparsity indices for x
|
|
425
|
+
rev_idx_spa_x_idx = (spa_bat_x * s_l_x_b_s +
|
|
426
|
+
spa_row_x * s_l_x_r_s +
|
|
427
|
+
spa_col_x * s_l_x_c_s)
|
|
428
|
+
rev_idx_spa_x_msk = (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
|
|
429
|
+
rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
|
|
430
|
+
|
|
431
|
+
# If block is present commence operations
|
|
432
|
+
if rev_idx_spa_x >= 0:
|
|
433
|
+
# Calculate triton block size shifts
|
|
434
|
+
shift_row_x = ((spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE)
|
|
435
|
+
% sparsity_block_size_from) // TRITON_BLOCK_SIZE
|
|
436
|
+
shift_col_x = ((spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE)
|
|
437
|
+
% sparsity_block_size_from) // TRITON_BLOCK_SIZE
|
|
438
|
+
|
|
439
|
+
# Load x values
|
|
440
|
+
blk_x_idx = ((rev_idx_spa_x * x_b_s) +
|
|
441
|
+
((shift_row_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
442
|
+
((shift_col_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
443
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
444
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
445
|
+
|
|
446
|
+
# Store output
|
|
447
|
+
blk_o_idx = ((pid_blk * o_b_s) +
|
|
448
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
449
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
450
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
451
|
+
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|