blksprs 1.1__py3-none-any.whl → 1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/layouting/sparsity_layout.py +114 -2
- blksprs/ops/conversion.py +197 -2
- {blksprs-1.1.dist-info → blksprs-1.2.dist-info}/METADATA +8 -2
- {blksprs-1.1.dist-info → blksprs-1.2.dist-info}/RECORD +6 -6
- {blksprs-1.1.dist-info → blksprs-1.2.dist-info}/WHEEL +0 -0
- {blksprs-1.1.dist-info → blksprs-1.2.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
1
3
|
import torch
|
|
2
4
|
import triton
|
|
3
5
|
from torch import Tensor
|
|
@@ -5,11 +7,11 @@ from triton import language as tl
|
|
|
5
7
|
|
|
6
8
|
from blksprs.utils.tools import get_triton_block_size
|
|
7
9
|
from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
|
|
8
|
-
validate_contiguous
|
|
10
|
+
validate_contiguous, validate_sparsity, validate_sparsity_block_size
|
|
9
11
|
|
|
10
12
|
|
|
11
13
|
def build_sparsity_layout(x: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
12
|
-
"""Builds the sparsity layout of a dense tensor covering its sparse blocks.
|
|
14
|
+
"""Builds the sparsity layout of a dense tensor in regular form covering its sparse blocks.
|
|
13
15
|
|
|
14
16
|
Args:
|
|
15
17
|
x (Tensor): A block-sparse (or dense) tensor in regular form.
|
|
@@ -64,15 +66,125 @@ def kernel_sparsity_layout(x,
|
|
|
64
66
|
pid_row = tl.program_id(axis=1)
|
|
65
67
|
pid_col = tl.program_id(axis=2)
|
|
66
68
|
|
|
69
|
+
# Load x values
|
|
67
70
|
blk_x_idx = (pid_bat * x_b_s +
|
|
68
71
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
69
72
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
70
73
|
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
71
74
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
72
75
|
|
|
76
|
+
# Store sparsity layout value
|
|
73
77
|
if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
|
|
74
78
|
blk_o_idx = (pid_bat * o_b_s +
|
|
75
79
|
(((pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_r_s +
|
|
76
80
|
((pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_c_s))
|
|
77
81
|
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
78
82
|
tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def build_sparsity_layout_adaption(x: Tensor, sparsity_layout_from: Tensor,
|
|
86
|
+
sparsity_block_size_from: int, sparsity_block_size_to: int,
|
|
87
|
+
triton_block_size: int = None) -> Tensor:
|
|
88
|
+
"""Builds the sparsity layout of a block-sparse tensor in compressed form if a different sparsity block size were
|
|
89
|
+
used.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
x (Tensor): A block-sparse tensor in compressed form.
|
|
93
|
+
sparsity_layout_from (Tensor): The sparsity layout of the input block-sparse tensor.
|
|
94
|
+
sparsity_block_size_from (int): The size of the sparsity blocks of the input tensor.
|
|
95
|
+
sparsity_block_size_to (int): The desired size of the sparsity blocks for the resulting layout.
|
|
96
|
+
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
Tensor: The sparsity layout in regular form using the new sparsity block size of the input block-sparse tensor
|
|
100
|
+
in compressed form.
|
|
101
|
+
|
|
102
|
+
"""
|
|
103
|
+
validate_dimensions(x)
|
|
104
|
+
validate_contiguous(x, sparsity_layout_from)
|
|
105
|
+
validate_device(x)
|
|
106
|
+
validate_sparsity(sparsity_block_size_from, (x, sparsity_layout_from))
|
|
107
|
+
validate_sparsity_block_size(sparsity_block_size_from, x)
|
|
108
|
+
validate_sparsity_block_size(sparsity_block_size_to)
|
|
109
|
+
min_sparsity_block_size = min(sparsity_block_size_from, sparsity_block_size_to)
|
|
110
|
+
validate_triton_block_size(triton_block_size, min_sparsity_block_size)
|
|
111
|
+
|
|
112
|
+
sparsity_lut = torch.nonzero(sparsity_layout_from).contiguous()
|
|
113
|
+
|
|
114
|
+
validate_contiguous(sparsity_layout_from, sparsity_lut)
|
|
115
|
+
|
|
116
|
+
o_b = sparsity_layout_from.size(0)
|
|
117
|
+
o_r = math.ceil(sparsity_layout_from.size(1) * sparsity_block_size_from // sparsity_block_size_to)
|
|
118
|
+
o_c = math.ceil(sparsity_layout_from.size(2) * sparsity_block_size_from // sparsity_block_size_to)
|
|
119
|
+
|
|
120
|
+
output = torch.zeros(o_b, o_r, o_c, device=x.device, dtype=torch.int32)
|
|
121
|
+
|
|
122
|
+
x_b, x_r, x_c = x.size()
|
|
123
|
+
x_b_s, x_r_s, x_c_s = x.stride()
|
|
124
|
+
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
125
|
+
s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
|
|
126
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
127
|
+
|
|
128
|
+
if triton_block_size is None:
|
|
129
|
+
triton_block_size = get_triton_block_size(sparsity_block_size_from)
|
|
130
|
+
|
|
131
|
+
triton_grid = lambda meta: [x_b,
|
|
132
|
+
triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
|
|
133
|
+
triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
|
|
134
|
+
|
|
135
|
+
(kernel_sparsity_layout_adaption[triton_grid]
|
|
136
|
+
(x,
|
|
137
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
138
|
+
sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
139
|
+
output,
|
|
140
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
141
|
+
sparsity_block_size_from,
|
|
142
|
+
sparsity_block_size_to,
|
|
143
|
+
triton_block_size))
|
|
144
|
+
|
|
145
|
+
return output
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
@triton.jit
|
|
149
|
+
def kernel_sparsity_layout_adaption(x,
|
|
150
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
151
|
+
s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
|
|
152
|
+
o,
|
|
153
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
154
|
+
sparsity_block_size_from,
|
|
155
|
+
sparsity_block_size_to,
|
|
156
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
157
|
+
# Get triton block indices
|
|
158
|
+
pid_blk = tl.program_id(axis=0)
|
|
159
|
+
pid_row = tl.program_id(axis=1)
|
|
160
|
+
pid_col = tl.program_id(axis=2)
|
|
161
|
+
|
|
162
|
+
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
163
|
+
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
164
|
+
spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
165
|
+
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
166
|
+
|
|
167
|
+
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
168
|
+
spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
|
|
169
|
+
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
170
|
+
|
|
171
|
+
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
172
|
+
spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
|
|
173
|
+
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
174
|
+
|
|
175
|
+
# Load x values
|
|
176
|
+
blk_x_idx = ((pid_blk * x_b_s) +
|
|
177
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
178
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
179
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
180
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
181
|
+
|
|
182
|
+
# Store sparsity layout value
|
|
183
|
+
if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
|
|
184
|
+
blk_o_idx = ((spa_bat * o_b_s) +
|
|
185
|
+
(((spa_row * sparsity_block_size_from + pid_row * TRITON_BLOCK_SIZE)
|
|
186
|
+
// sparsity_block_size_to) * o_r_s) +
|
|
187
|
+
(((spa_col * sparsity_block_size_from + pid_col * TRITON_BLOCK_SIZE)
|
|
188
|
+
// sparsity_block_size_to) * o_c_s))
|
|
189
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
190
|
+
tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
|
blksprs/ops/conversion.py
CHANGED
|
@@ -1,8 +1,11 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
1
3
|
import torch
|
|
2
4
|
import triton
|
|
3
5
|
from torch import Tensor
|
|
4
6
|
from triton import language as tl
|
|
5
7
|
|
|
8
|
+
from blksprs.layouting.sparsity_layout import build_sparsity_layout_adaption
|
|
6
9
|
from blksprs.utils.tools import get_triton_block_size
|
|
7
10
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
8
11
|
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
@@ -39,6 +42,9 @@ def to_dense(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_
|
|
|
39
42
|
|
|
40
43
|
validate_contiguous(sparsity_reverse_lut)
|
|
41
44
|
|
|
45
|
+
if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
|
|
46
|
+
return x
|
|
47
|
+
|
|
42
48
|
return _BlocksparseToDense.apply(x,
|
|
43
49
|
sparsity_layout, sparsity_reverse_lut,
|
|
44
50
|
sparsity_block_size, fill_value,
|
|
@@ -161,6 +167,9 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, trit
|
|
|
161
167
|
|
|
162
168
|
validate_contiguous(sparsity_layout, sparsity_lut)
|
|
163
169
|
|
|
170
|
+
if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
|
|
171
|
+
return x
|
|
172
|
+
|
|
164
173
|
return _BlocksparseToSparse.apply(x,
|
|
165
174
|
sparsity_layout, sparsity_lut,
|
|
166
175
|
sparsity_block_size, n_sparse_blocks,
|
|
@@ -178,10 +187,10 @@ class _BlocksparseToSparse(torch.autograd.Function):
|
|
|
178
187
|
|
|
179
188
|
x_b, x_r, x_c = x.size()
|
|
180
189
|
x_b_s, x_r_s, x_c_s = x.stride()
|
|
181
|
-
o_b, o_r, o_c = output.size()
|
|
182
|
-
o_b_s, o_r_s, o_c_s = output.stride()
|
|
183
190
|
s_lut_r, s_lut_c = sparsity_lut.size()
|
|
184
191
|
s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
|
|
192
|
+
o_b, o_r, o_c = output.size()
|
|
193
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
185
194
|
|
|
186
195
|
if triton_block_size is None:
|
|
187
196
|
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
@@ -254,3 +263,189 @@ class _BlocksparseToSparse(torch.autograd.Function):
|
|
|
254
263
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
|
|
255
264
|
blk_o_msk = (blk_o_idx < (pid_blk + 1) * o_b_s)
|
|
256
265
|
tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def adapt_layout(x: Tensor, sparsity_layout_from: Tensor, sparsity_block_size_from: int, sparsity_block_size_to: int,
|
|
269
|
+
preprocess_data: dict = None, triton_block_size: int = None) -> Tensor:
|
|
270
|
+
"""Adapts the sparsity layout of a block-sparse tensor, resulting in a new block-sparse tensor in compressed form
|
|
271
|
+
conforming to the new sparsity layout (and sparsity block size) definition.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
x (Tensor): A block-sparse tensor in compressed form.
|
|
275
|
+
sparsity_layout_from (Tensor): The sparsity layout of the input block-sparse tensor.
|
|
276
|
+
sparsity_block_size_from (int): The size of the sparsity blocks of the input sparsity layout.
|
|
277
|
+
sparsity_block_size_to (int): The size of the sparsity blocks of the output sparsity layout.
|
|
278
|
+
preprocess_data (dict): A dictionary containing data otherwise computed by the function (default ``None``).
|
|
279
|
+
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
Tensor: The block-sparse tensor in compressed form with the adapted sparsity layout and sparsity block size.
|
|
283
|
+
|
|
284
|
+
"""
|
|
285
|
+
validate_dimensions(x)
|
|
286
|
+
validate_contiguous(x, sparsity_layout_from)
|
|
287
|
+
validate_device(x)
|
|
288
|
+
validate_sparsity(sparsity_block_size_from, (x, sparsity_layout_from))
|
|
289
|
+
validate_sparsity_block_size(sparsity_block_size_from, x)
|
|
290
|
+
validate_sparsity_block_size(sparsity_block_size_to)
|
|
291
|
+
min_sparsity_block_size = min(sparsity_block_size_from, sparsity_block_size_to)
|
|
292
|
+
validate_triton_block_size(triton_block_size, min_sparsity_block_size)
|
|
293
|
+
|
|
294
|
+
if preprocess_data is None:
|
|
295
|
+
preprocess_data = {}
|
|
296
|
+
|
|
297
|
+
if "sparsity_reverse_lut_from" not in preprocess_data:
|
|
298
|
+
sparsity_layout_from_flat = sparsity_layout_from.reshape(-1)
|
|
299
|
+
sparsity_reverse_lut_from = ((torch.cumsum(sparsity_layout_from_flat, dim=-1) - 1) *
|
|
300
|
+
(sparsity_layout_from_flat == 1) -
|
|
301
|
+
(1 * (sparsity_layout_from_flat == 0)))
|
|
302
|
+
else:
|
|
303
|
+
sparsity_reverse_lut_from = preprocess_data["sparsity_reverse_lut_from"]
|
|
304
|
+
|
|
305
|
+
if "sparsity_layout_to" not in preprocess_data:
|
|
306
|
+
sparsity_layout_to = build_sparsity_layout_adaption(x, sparsity_layout_from,
|
|
307
|
+
sparsity_block_size_from, sparsity_block_size_to,
|
|
308
|
+
triton_block_size)
|
|
309
|
+
else:
|
|
310
|
+
sparsity_layout_to = preprocess_data["sparsity_layout_to"]
|
|
311
|
+
|
|
312
|
+
if "sparsity_lut_to" not in preprocess_data:
|
|
313
|
+
sparsity_lut_to = torch.nonzero(sparsity_layout_to).contiguous()
|
|
314
|
+
else:
|
|
315
|
+
sparsity_lut_to = preprocess_data["sparsity_lut_to"]
|
|
316
|
+
|
|
317
|
+
if "n_sparse_blocks_to" not in preprocess_data:
|
|
318
|
+
n_sparse_blocks_to = torch.sum(sparsity_layout_to.to(torch.int)).item()
|
|
319
|
+
else:
|
|
320
|
+
n_sparse_blocks_to = preprocess_data["n_sparse_blocks_to"]
|
|
321
|
+
|
|
322
|
+
validate_contiguous(sparsity_layout_to, sparsity_reverse_lut_from, sparsity_lut_to)
|
|
323
|
+
|
|
324
|
+
if (sparsity_block_size_from == sparsity_block_size_to) and torch.equal(sparsity_layout_from, sparsity_layout_to):
|
|
325
|
+
return x
|
|
326
|
+
|
|
327
|
+
return _BlocksparseAdaptLayout.apply(x,
|
|
328
|
+
sparsity_layout_from, sparsity_reverse_lut_from, sparsity_block_size_from,
|
|
329
|
+
sparsity_layout_to, sparsity_lut_to, sparsity_block_size_to,
|
|
330
|
+
n_sparse_blocks_to, min_sparsity_block_size, triton_block_size)
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
class _BlocksparseAdaptLayout(torch.autograd.Function):
|
|
334
|
+
|
|
335
|
+
@staticmethod
|
|
336
|
+
def forward(ctx, x: Tensor,
|
|
337
|
+
sparsity_layout_from: Tensor, sparsity_reverse_lut_from: Tensor, sparsity_block_size_from: int,
|
|
338
|
+
sparsity_layout_to: Tensor, sparsity_lut_to: Tensor, sparsity_block_size_to: int,
|
|
339
|
+
n_sparse_blocks_to: int, min_sparsity_block_size: int, triton_block_size: int) -> Tensor:
|
|
340
|
+
output = torch.zeros(size=(n_sparse_blocks_to, sparsity_block_size_to, sparsity_block_size_to),
|
|
341
|
+
dtype=x.dtype, device=x.device)
|
|
342
|
+
|
|
343
|
+
x_b, x_r, x_c = x.size()
|
|
344
|
+
x_b_s, x_r_s, x_c_s = x.stride()
|
|
345
|
+
s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_from.size()
|
|
346
|
+
s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout_from.stride()
|
|
347
|
+
o_b, o_r, o_c = output.size()
|
|
348
|
+
o_b_s, o_r_s, o_c_s = output.stride()
|
|
349
|
+
s_lut_o_r, s_lut_o_c = sparsity_lut_to.size()
|
|
350
|
+
s_lut_o_r_s, s_lut_o_c_s = sparsity_lut_to.stride()
|
|
351
|
+
|
|
352
|
+
if triton_block_size is None:
|
|
353
|
+
triton_block_size = get_triton_block_size(min_sparsity_block_size)
|
|
354
|
+
|
|
355
|
+
triton_grid = lambda meta: [o_b,
|
|
356
|
+
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
357
|
+
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
358
|
+
|
|
359
|
+
(_BlocksparseAdaptLayout.kernel_adapt_layout[triton_grid]
|
|
360
|
+
(x,
|
|
361
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
362
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
363
|
+
sparsity_reverse_lut_from,
|
|
364
|
+
output,
|
|
365
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
366
|
+
sparsity_lut_to, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
367
|
+
sparsity_block_size_from,
|
|
368
|
+
sparsity_block_size_to,
|
|
369
|
+
triton_block_size))
|
|
370
|
+
|
|
371
|
+
ctx.save_for_backward(x, sparsity_layout_from, sparsity_layout_to)
|
|
372
|
+
ctx.sparsity_block_size_from = sparsity_block_size_from
|
|
373
|
+
ctx.sparsity_block_size_to = sparsity_block_size_to
|
|
374
|
+
ctx.triton_block_size = triton_block_size
|
|
375
|
+
|
|
376
|
+
return output
|
|
377
|
+
|
|
378
|
+
@staticmethod
|
|
379
|
+
def backward(ctx, grad_output):
|
|
380
|
+
x, sparsity_layout_from, sparsity_layout_to = ctx.saved_tensors
|
|
381
|
+
sparsity_block_size_from = ctx.sparsity_block_size_from
|
|
382
|
+
sparsity_block_size_to = ctx.sparsity_block_size_to
|
|
383
|
+
triton_block_size = ctx.triton_block_size
|
|
384
|
+
|
|
385
|
+
return adapt_layout(grad_output, sparsity_layout_to, sparsity_block_size_to, sparsity_block_size_from,
|
|
386
|
+
preprocess_data={"sparsity_layout_to": sparsity_layout_from},
|
|
387
|
+
triton_block_size=triton_block_size), None, None, None, None, None, None, None, None, None
|
|
388
|
+
|
|
389
|
+
@staticmethod
|
|
390
|
+
@triton.jit
|
|
391
|
+
def kernel_adapt_layout(x,
|
|
392
|
+
x_b, x_b_s, x_r_s, x_c_s,
|
|
393
|
+
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
394
|
+
r_lut_x,
|
|
395
|
+
o,
|
|
396
|
+
o_b, o_b_s, o_r_s, o_c_s,
|
|
397
|
+
s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
|
|
398
|
+
sparsity_block_size_from,
|
|
399
|
+
sparsity_block_size_to,
|
|
400
|
+
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
401
|
+
# Get triton block indices
|
|
402
|
+
pid_blk = tl.program_id(axis=0)
|
|
403
|
+
pid_row = tl.program_id(axis=1)
|
|
404
|
+
pid_col = tl.program_id(axis=2)
|
|
405
|
+
|
|
406
|
+
# Get position of current sparsity block consisting of its batch, row, and column index
|
|
407
|
+
spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
|
|
408
|
+
spa_bat_o_msk = (spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
409
|
+
spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
|
|
410
|
+
|
|
411
|
+
spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
|
|
412
|
+
spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
413
|
+
spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
|
|
414
|
+
|
|
415
|
+
spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
|
|
416
|
+
spa_col_o_msk = (spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
417
|
+
spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
|
|
418
|
+
|
|
419
|
+
# Get equivalent sparsity block in from layout
|
|
420
|
+
spa_bat_x = spa_bat_o
|
|
421
|
+
spa_row_x = (spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size_from
|
|
422
|
+
spa_col_x = (spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size_from
|
|
423
|
+
|
|
424
|
+
# # Get reverse sparsity indices for x
|
|
425
|
+
rev_idx_spa_x_idx = (spa_bat_x * s_l_x_b_s +
|
|
426
|
+
spa_row_x * s_l_x_r_s +
|
|
427
|
+
spa_col_x * s_l_x_c_s)
|
|
428
|
+
rev_idx_spa_x_msk = (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
|
|
429
|
+
rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
|
|
430
|
+
|
|
431
|
+
# If block is present commence operations
|
|
432
|
+
if rev_idx_spa_x >= 0:
|
|
433
|
+
# Calculate triton block size shifts
|
|
434
|
+
shift_row_x = ((spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE)
|
|
435
|
+
% sparsity_block_size_from) // TRITON_BLOCK_SIZE
|
|
436
|
+
shift_col_x = ((spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE)
|
|
437
|
+
% sparsity_block_size_from) // TRITON_BLOCK_SIZE
|
|
438
|
+
|
|
439
|
+
# Load x values
|
|
440
|
+
blk_x_idx = ((rev_idx_spa_x * x_b_s) +
|
|
441
|
+
((shift_row_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
442
|
+
((shift_col_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
443
|
+
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
444
|
+
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
445
|
+
|
|
446
|
+
# Store output
|
|
447
|
+
blk_o_idx = ((pid_blk * o_b_s) +
|
|
448
|
+
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
449
|
+
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
450
|
+
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
451
|
+
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.2
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -33,7 +33,8 @@ Currently supported operations (includes gradient calculation):
|
|
|
33
33
|
- Transposition
|
|
34
34
|
- Gather
|
|
35
35
|
- Scatter (_supports either no reduction or summation, gradients are only available for summation_)
|
|
36
|
-
- Conversion
|
|
36
|
+
- Conversion to and from sparse form
|
|
37
|
+
- Conversion to different sparsity layouts and different sparsity block sizes
|
|
37
38
|
|
|
38
39
|
As with this library sparse matrices are represented using a tuple of `(matrix, sparsity_layout, sparsity_block_size)`,
|
|
39
40
|
any element-wise operations can be applied in regular torch-like fashion.
|
|
@@ -59,6 +60,11 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
|
|
|
59
60
|
|
|
60
61
|
```pip install blksprs```
|
|
61
62
|
|
|
63
|
+
### Dependencies
|
|
64
|
+
|
|
65
|
+
- [PyTorch](https://pytorch.org/) (built with v2.4.0)
|
|
66
|
+
- _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
|
|
67
|
+
|
|
62
68
|
## Changelog
|
|
63
69
|
|
|
64
70
|
See [`CHANGELOG.md`](https://github.com/FelixSchoen/blksprs/blob/main/CHANGELOG.md) for a detailed changelog.
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
blksprs/layouting/distribution_layout.py,sha256=GQ-ZRXbeImiLcbaqnL2FuUZ6DoFwmB0naT_YrOpD84Q,4940
|
|
2
|
-
blksprs/layouting/sparsity_layout.py,sha256=
|
|
2
|
+
blksprs/layouting/sparsity_layout.py,sha256=TtADT_WWcZpW3zyGy6KAgkAo44gDryXZqdJLZGEX2V8,7895
|
|
3
3
|
blksprs/misc/broadcast_addition.py,sha256=vf1Hdqz9Uyqykto3DCjmdyepMzpMXL238SpANQqRAwI,5297
|
|
4
|
-
blksprs/ops/conversion.py,sha256
|
|
4
|
+
blksprs/ops/conversion.py,sha256=-AOzj_j3WrBLGIgd2oVPvYS8XKfzlvGtSIWzW_qP1lk,21260
|
|
5
5
|
blksprs/ops/distribution.py,sha256=_fQb6fWpLxocAh86D74ATahChi0EK0eBb4eUOUEBVps,16769
|
|
6
6
|
blksprs/ops/exp.py,sha256=qs8fVtCzxl4CKT4GepaqurjEL62jyi8VjMY12JFrFAU,3674
|
|
7
7
|
blksprs/ops/matmul.py,sha256=x3lrYg4g8fIf5PeMtZY_SEpi11kP9RFcRoemCIxcSDE,11086
|
|
@@ -11,7 +11,7 @@ blksprs/ops/transpose.py,sha256=DVEXoxo2MoTNL3NZrjxsukMDrzk2vnEXL1uRnKFWkn0,6722
|
|
|
11
11
|
blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
|
|
12
12
|
blksprs/utils/tools.py,sha256=P2UALvccRjJJ7w05YGuaxB3qmNObgct4idfM0jlE2wg,465
|
|
13
13
|
blksprs/utils/validation.py,sha256=gJYZO5C48YUrXV3Fy_Z_lCaOpiFj951FT-Od7sKfprg,3007
|
|
14
|
-
blksprs-1.
|
|
15
|
-
blksprs-1.
|
|
16
|
-
blksprs-1.
|
|
17
|
-
blksprs-1.
|
|
14
|
+
blksprs-1.2.dist-info/METADATA,sha256=4sbWg-lZK8DuRnkh3kh8toQRGMcBK9UlQtNLh4cU6mY,7209
|
|
15
|
+
blksprs-1.2.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
|
16
|
+
blksprs-1.2.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
|
|
17
|
+
blksprs-1.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|