blksprs 1.8.3__tar.gz → 1.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {blksprs-1.8.3 → blksprs-1.9}/PKG-INFO +1 -1
- {blksprs-1.8.3 → blksprs-1.9}/blksprs/layouting/distribution_layout.py +23 -5
- {blksprs-1.8.3 → blksprs-1.9}/blksprs/ops/distribution.py +93 -35
- {blksprs-1.8.3 → blksprs-1.9}/blksprs/ops/experimental/distribution_mdi.py +8 -0
- {blksprs-1.8.3 → blksprs-1.9}/blksprs/ops/misc/row_wise.py +8 -0
- {blksprs-1.8.3 → blksprs-1.9}/blksprs/ops/softmax.py +4 -0
- {blksprs-1.8.3 → blksprs-1.9}/blksprs.egg-info/PKG-INFO +1 -1
- {blksprs-1.8.3 → blksprs-1.9}/pyproject.toml +1 -1
- {blksprs-1.8.3 → blksprs-1.9}/README.md +0 -0
- {blksprs-1.8.3 → blksprs-1.9}/blksprs/__init__.py +0 -0
- {blksprs-1.8.3 → blksprs-1.9}/blksprs/layouting/sparsity_layout.py +0 -0
- {blksprs-1.8.3 → blksprs-1.9}/blksprs/ops/conversion.py +0 -0
- {blksprs-1.8.3 → blksprs-1.9}/blksprs/ops/matmul.py +0 -0
- {blksprs-1.8.3 → blksprs-1.9}/blksprs/ops/misc/broadcast_ops.py +0 -0
- {blksprs-1.8.3 → blksprs-1.9}/blksprs/ops/misc/exp.py +0 -0
- {blksprs-1.8.3 → blksprs-1.9}/blksprs/ops/partitioning.py +0 -0
- {blksprs-1.8.3 → blksprs-1.9}/blksprs/ops/repeat.py +0 -0
- {blksprs-1.8.3 → blksprs-1.9}/blksprs/ops/transpose.py +0 -0
- {blksprs-1.8.3 → blksprs-1.9}/blksprs/utils/benchmarking.py +0 -0
- {blksprs-1.8.3 → blksprs-1.9}/blksprs/utils/blksprs_tensor.py +0 -0
- {blksprs-1.8.3 → blksprs-1.9}/blksprs/utils/processing.py +0 -0
- {blksprs-1.8.3 → blksprs-1.9}/blksprs/utils/tools.py +0 -0
- {blksprs-1.8.3 → blksprs-1.9}/blksprs/utils/validation.py +0 -0
- {blksprs-1.8.3 → blksprs-1.9}/blksprs.egg-info/SOURCES.txt +0 -0
- {blksprs-1.8.3 → blksprs-1.9}/blksprs.egg-info/dependency_links.txt +0 -0
- {blksprs-1.8.3 → blksprs-1.9}/blksprs.egg-info/requires.txt +0 -0
- {blksprs-1.8.3 → blksprs-1.9}/blksprs.egg-info/top_level.txt +0 -0
- {blksprs-1.8.3 → blksprs-1.9}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.9
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -10,13 +10,14 @@ from blksprs.utils.validation import validate_triton_block_size, validate_dimens
|
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: Tensor,
|
|
13
|
-
size_target: torch.Size,
|
|
13
|
+
dim: int, size_target: torch.Size,
|
|
14
14
|
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
15
15
|
"""Builds the sparsity layout of either the source of a gather or the target of a scatter operation.
|
|
16
16
|
|
|
17
17
|
Args:
|
|
18
18
|
indices (BlksprsTensor): The block-sparse indices tensor in compressed form used for the gather or scatter operation.
|
|
19
19
|
sparsity_layout_indices (Tensor): The sparsity layout of the indices block-sparse tensor.
|
|
20
|
+
dim (int): The dimension along which the operation is conducted.
|
|
20
21
|
size_target (torch.Size): The size of the block-sparse target tensor in regular form.
|
|
21
22
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
22
23
|
triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
|
|
@@ -31,6 +32,8 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
|
|
|
31
32
|
|
|
32
33
|
sparsity_lut_i = torch.nonzero(sparsity_layout_indices).contiguous()
|
|
33
34
|
|
|
35
|
+
adjusted_dim = dim % 3
|
|
36
|
+
|
|
34
37
|
output = torch.zeros(size_target[0], size_target[1] // sparsity_block_size, size_target[2] // sparsity_block_size,
|
|
35
38
|
dtype=torch.bool, device=indices.device)
|
|
36
39
|
|
|
@@ -55,6 +58,7 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
|
|
|
55
58
|
i_b, i_b_s, i_r_s, i_c_s,
|
|
56
59
|
sparsity_lut_i,
|
|
57
60
|
s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
|
|
61
|
+
adjusted_dim,
|
|
58
62
|
output,
|
|
59
63
|
o_b, o_b_s, o_r_s, o_c_s,
|
|
60
64
|
sparsity_block_size,
|
|
@@ -68,6 +72,7 @@ def kernel_distribution_layout(i,
|
|
|
68
72
|
i_b, i_b_s, i_r_s, i_c_s,
|
|
69
73
|
s_lut_i,
|
|
70
74
|
s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
|
|
75
|
+
dim,
|
|
71
76
|
o,
|
|
72
77
|
o_b, o_b_s, o_r_s, o_c_s,
|
|
73
78
|
sparsity_block_size,
|
|
@@ -86,17 +91,30 @@ def kernel_distribution_layout(i,
|
|
|
86
91
|
spa_row_i_msk = (spa_row_i_idx < s_lut_i_r * s_lut_i_r_s)
|
|
87
92
|
spa_row_i = tl.load(s_lut_i + spa_row_i_idx, mask=spa_row_i_msk)
|
|
88
93
|
|
|
94
|
+
spa_col_i_idx = (pid_blk * s_lut_i_r_s + 2 * s_lut_i_c_s)
|
|
95
|
+
spa_col_i_msk = (spa_col_i_idx < s_lut_i_r * s_lut_i_r_s)
|
|
96
|
+
spa_col_i = tl.load(s_lut_i + spa_col_i_idx, mask=spa_col_i_msk)
|
|
97
|
+
|
|
89
98
|
blk_i_idx = (pid_blk * i_b_s +
|
|
90
99
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
91
100
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
|
|
92
101
|
blk_i_msk = (blk_i_idx < i_b * i_b_s)
|
|
93
102
|
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk)
|
|
94
103
|
|
|
95
|
-
|
|
104
|
+
dst_bat_idx = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_i, dtype=tl.int32)
|
|
105
|
+
dst_row_idx = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_i, dtype=tl.int32)
|
|
106
|
+
dst_col_idx = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_i, dtype=tl.int32)
|
|
107
|
+
if dim == 0:
|
|
108
|
+
dst_bat_idx = blk_i
|
|
109
|
+
elif dim == 1:
|
|
110
|
+
dst_row_idx = blk_i // sparsity_block_size
|
|
111
|
+
elif dim == 2:
|
|
112
|
+
dst_col_idx = blk_i // sparsity_block_size
|
|
113
|
+
|
|
96
114
|
blk_v = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), 1, dtype=tl.int32)
|
|
97
115
|
|
|
98
|
-
blk_o_idx = ((
|
|
99
|
-
(
|
|
100
|
-
(
|
|
116
|
+
blk_o_idx = ((dst_bat_idx * o_b_s) +
|
|
117
|
+
(dst_row_idx * o_r_s) +
|
|
118
|
+
(dst_col_idx * o_c_s))
|
|
101
119
|
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
102
120
|
tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
|
|
@@ -3,19 +3,23 @@ import triton
|
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
from triton import language as tl
|
|
5
5
|
|
|
6
|
+
from blksprs.ops.conversion import to_dense
|
|
6
7
|
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
7
8
|
from blksprs.utils.tools import get_triton_block_size, stride
|
|
8
9
|
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
9
10
|
validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
|
|
10
11
|
|
|
11
12
|
|
|
12
|
-
def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
13
|
+
def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
14
|
+
dim: int,
|
|
15
|
+
idx: BlksprsTensor, sparsity_layout_idx: Tensor,
|
|
13
16
|
sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
|
|
14
17
|
"""Applies a gather operation on a block-sparse tensor in compressed form.
|
|
15
18
|
|
|
16
19
|
Args:
|
|
17
20
|
src (BlksprsTensor): The source block-sparse tensor in compressed form to gather from.
|
|
18
21
|
sparsity_layout_src (Tensor): The sparsity layout of the source block-sparse tensor.
|
|
22
|
+
dim (int): The dimension along which to gather.
|
|
19
23
|
idx (BlksprsTensor): The block-sparse indices tensor in compressed form specifying how to gather from the source tensor.
|
|
20
24
|
sparsity_layout_idx (Tensor): The sparsity layout of the indices block-sparse tensor.
|
|
21
25
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
@@ -46,16 +50,18 @@ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor, idx: BlksprsTensor,
|
|
|
46
50
|
validate_contiguous(sparsity_layout_src, sparsity_reverse_lut_x,
|
|
47
51
|
sparsity_layout_idx, sparsity_lut_i)
|
|
48
52
|
|
|
53
|
+
adjusted_dim = dim % 3
|
|
54
|
+
|
|
49
55
|
return BlksprsTensor(_BlocksparseGather.apply(src, sparsity_layout_src, sparsity_reverse_lut_x,
|
|
50
|
-
|
|
51
|
-
|
|
56
|
+
adjusted_dim, idx, sparsity_layout_idx, sparsity_lut_i,
|
|
57
|
+
sparsity_block_size, triton_block_size))
|
|
52
58
|
|
|
53
59
|
|
|
54
60
|
class _BlocksparseGather(torch.autograd.Function):
|
|
55
61
|
|
|
56
62
|
@staticmethod
|
|
57
63
|
def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
|
|
58
|
-
i: Tensor, sparsity_layout_i: Tensor, sparsity_lut_i: Tensor,
|
|
64
|
+
dim: int, i: Tensor, sparsity_layout_i: Tensor, sparsity_lut_i: Tensor,
|
|
59
65
|
sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
|
|
60
66
|
output = torch.empty_like(i, dtype=x.dtype)
|
|
61
67
|
|
|
@@ -82,6 +88,7 @@ class _BlocksparseGather(torch.autograd.Function):
|
|
|
82
88
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
83
89
|
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
84
90
|
sparsity_reverse_lut_x,
|
|
91
|
+
dim,
|
|
85
92
|
i,
|
|
86
93
|
i_b, i_b_s, i_r_s, i_c_s,
|
|
87
94
|
output,
|
|
@@ -91,6 +98,7 @@ class _BlocksparseGather(torch.autograd.Function):
|
|
|
91
98
|
triton_block_size))
|
|
92
99
|
|
|
93
100
|
ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_i)
|
|
101
|
+
ctx.dim = dim
|
|
94
102
|
ctx.sparsity_block_size = sparsity_block_size
|
|
95
103
|
ctx.triton_block_size = triton_block_size
|
|
96
104
|
|
|
@@ -99,15 +107,15 @@ class _BlocksparseGather(torch.autograd.Function):
|
|
|
99
107
|
@staticmethod
|
|
100
108
|
def backward(ctx, grad_output):
|
|
101
109
|
sparsity_layout_x, i, sparsity_layout_i = ctx.saved_tensors
|
|
110
|
+
dim = ctx.dim
|
|
102
111
|
sparsity_block_size = ctx.sparsity_block_size
|
|
103
112
|
triton_block_size = ctx.triton_block_size
|
|
104
113
|
|
|
105
114
|
return scatter_reduce(grad_output, sparsity_layout_i,
|
|
106
|
-
i,
|
|
107
|
-
sparsity_layout_x,
|
|
108
|
-
sparsity_block_size,
|
|
115
|
+
dim, i,
|
|
116
|
+
sparsity_layout_x, sparsity_block_size,
|
|
109
117
|
reduce_op="sum",
|
|
110
|
-
triton_block_size=triton_block_size), None, None, None, None, None, None, None
|
|
118
|
+
triton_block_size=triton_block_size), None, None, None, None, None, None, None, None
|
|
111
119
|
|
|
112
120
|
@staticmethod
|
|
113
121
|
@triton.jit
|
|
@@ -115,6 +123,7 @@ class _BlocksparseGather(torch.autograd.Function):
|
|
|
115
123
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
116
124
|
s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
|
|
117
125
|
r_lut_x,
|
|
126
|
+
dim,
|
|
118
127
|
i,
|
|
119
128
|
i_b, i_b_s, i_r_s, i_c_s,
|
|
120
129
|
o,
|
|
@@ -136,6 +145,10 @@ class _BlocksparseGather(torch.autograd.Function):
|
|
|
136
145
|
spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
137
146
|
spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
|
|
138
147
|
|
|
148
|
+
spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
|
|
149
|
+
spa_col_o_msk = (spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
|
|
150
|
+
spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
|
|
151
|
+
|
|
139
152
|
# Load index values
|
|
140
153
|
blk_i_idx = ((pid_blk * i_b_s) +
|
|
141
154
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
|
|
@@ -143,33 +156,50 @@ class _BlocksparseGather(torch.autograd.Function):
|
|
|
143
156
|
blk_i_msk = (blk_i_idx < i_b * i_b_s)
|
|
144
157
|
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
|
|
145
158
|
|
|
146
|
-
# Get
|
|
159
|
+
# Get indices of sparsity blocks and positions within the blocks
|
|
147
160
|
pos_spa_blk_x = blk_i // sparsity_block_size
|
|
148
|
-
|
|
161
|
+
pos_spa_int_x = blk_i % sparsity_block_size
|
|
162
|
+
|
|
163
|
+
rev_dst_bat_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_o, dtype=tl.int32)
|
|
164
|
+
rev_dst_row_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_o, dtype=tl.int32)
|
|
165
|
+
rev_dst_col_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_o, dtype=tl.int32)
|
|
166
|
+
dst_row_x = (((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
|
|
167
|
+
.broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
|
|
168
|
+
dst_col_x = (((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
|
|
169
|
+
.broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
|
|
170
|
+
if dim == 0:
|
|
171
|
+
rev_dst_bat_x = blk_i
|
|
172
|
+
elif dim == 1:
|
|
173
|
+
rev_dst_row_x = pos_spa_blk_x
|
|
174
|
+
dst_row_x = pos_spa_int_x * x_r_s
|
|
175
|
+
elif dim == 2:
|
|
176
|
+
rev_dst_col_x = pos_spa_blk_x
|
|
177
|
+
dst_col_x = pos_spa_int_x * x_c_s
|
|
149
178
|
|
|
150
179
|
# Load reverse sparsity indices for x
|
|
151
|
-
rev_idx_spa_x_idx = ((
|
|
152
|
-
(
|
|
153
|
-
(
|
|
180
|
+
rev_idx_spa_x_idx = ((rev_dst_bat_x * s_l_x_b_s) +
|
|
181
|
+
(rev_dst_row_x * s_l_x_r_s) +
|
|
182
|
+
(rev_dst_col_x * s_l_x_c_s))
|
|
154
183
|
rev_idx_spa_x_msk = (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
|
|
155
184
|
rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
|
|
156
185
|
|
|
157
186
|
# Load x values
|
|
158
187
|
blk_x_idx = ((rev_idx_spa_x * x_b_s) +
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
188
|
+
dst_row_x +
|
|
189
|
+
dst_col_x)
|
|
190
|
+
blk_x_msk = ((blk_x_idx < x_b * x_b_s) & rev_idx_spa_x_msk != -1)
|
|
162
191
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
163
192
|
|
|
164
193
|
# Store output
|
|
165
194
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
166
195
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
167
196
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
168
|
-
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
197
|
+
blk_o_msk = ((blk_o_idx < o_b * o_b_s) & rev_idx_spa_x_msk != -1)
|
|
169
198
|
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
170
199
|
|
|
171
200
|
|
|
172
201
|
def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
202
|
+
dim: int,
|
|
173
203
|
idx: BlksprsTensor,
|
|
174
204
|
sparsity_layout_tgt: Tensor,
|
|
175
205
|
sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
|
|
@@ -184,6 +214,7 @@ def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
184
214
|
|
|
185
215
|
|
|
186
216
|
def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
217
|
+
dim: int,
|
|
187
218
|
idx: BlksprsTensor,
|
|
188
219
|
sparsity_layout_tgt: Tensor,
|
|
189
220
|
sparsity_block_size: int,
|
|
@@ -193,6 +224,7 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
193
224
|
Args:
|
|
194
225
|
src (BlksprsTensor): The source block-sparse tensor in compressed form to scatter from.
|
|
195
226
|
sparsity_layout_src (Tensor): The sparsity layout of the source block-sparse tensor.
|
|
227
|
+
dim (int): The dimension along which to scatter.
|
|
196
228
|
idx (BlksprsTensor): The block-sparse indices tensor in compressed form specifying how to scatter to the target tensor.
|
|
197
229
|
sparsity_layout_tgt (Tensor): The sparsity layout of the target block-sparse tensor.
|
|
198
230
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
@@ -230,18 +262,20 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
|
|
|
230
262
|
validate_contiguous(sparsity_layout_src, sparsity_lut_x,
|
|
231
263
|
sparsity_layout_tgt, sparsity_reverse_lut_o)
|
|
232
264
|
|
|
265
|
+
adjusted_dim = dim % 3
|
|
266
|
+
|
|
233
267
|
return BlksprsTensor(_BlocksparseScatterReduce.apply(src, sparsity_layout_src, sparsity_lut_x,
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
268
|
+
adjusted_dim, idx,
|
|
269
|
+
sparsity_layout_tgt, sparsity_reverse_lut_o,
|
|
270
|
+
sparsity_block_size, n_sparse_blocks,
|
|
271
|
+
reduce_op, triton_block_size))
|
|
238
272
|
|
|
239
273
|
|
|
240
274
|
class _BlocksparseScatterReduce(torch.autograd.Function):
|
|
241
275
|
|
|
242
276
|
@staticmethod
|
|
243
277
|
def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_lut_x: Tensor,
|
|
244
|
-
i: Tensor,
|
|
278
|
+
dim: int, i: Tensor,
|
|
245
279
|
sparsity_layout_o: Tensor, sparsity_reverse_lut_o: Tensor,
|
|
246
280
|
sparsity_block_size: int, n_sparse_blocks: int,
|
|
247
281
|
reduce_op: str, triton_block_size: int) -> Tensor:
|
|
@@ -274,10 +308,11 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
|
|
|
274
308
|
(x,
|
|
275
309
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
276
310
|
sparsity_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
311
|
+
dim,
|
|
277
312
|
i,
|
|
278
313
|
i_b, i_b_s, i_r_s, i_c_s,
|
|
279
314
|
output,
|
|
280
|
-
o_b, o_b_s,
|
|
315
|
+
o_b, o_b_s,
|
|
281
316
|
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
282
317
|
sparsity_reverse_lut_o,
|
|
283
318
|
reduce_op_ind,
|
|
@@ -285,6 +320,7 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
|
|
|
285
320
|
triton_block_size))
|
|
286
321
|
|
|
287
322
|
ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_o)
|
|
323
|
+
ctx.dim = dim
|
|
288
324
|
ctx.sparsity_block_size = sparsity_block_size
|
|
289
325
|
ctx.reduce_op = reduce_op
|
|
290
326
|
ctx.triton_block_size = triton_block_size
|
|
@@ -294,13 +330,14 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
|
|
|
294
330
|
@staticmethod
|
|
295
331
|
def backward(ctx, grad_output):
|
|
296
332
|
sparsity_layout_x, i, sparsity_layout_o = ctx.saved_tensors
|
|
333
|
+
dim = ctx.dim
|
|
297
334
|
sparsity_block_size = ctx.sparsity_block_size
|
|
298
335
|
reduce_op = ctx.reduce_op
|
|
299
336
|
triton_block_size = ctx.triton_block_size
|
|
300
337
|
|
|
301
338
|
if reduce_op == "sum":
|
|
302
|
-
return gather(grad_output, sparsity_layout_o, i, sparsity_layout_x, sparsity_block_size,
|
|
303
|
-
triton_block_size=triton_block_size), None, None, None, None, None, None, None, None, None
|
|
339
|
+
return gather(grad_output, sparsity_layout_o, dim, i, sparsity_layout_x, sparsity_block_size,
|
|
340
|
+
triton_block_size=triton_block_size), None, None, None, None, None, None, None, None, None, None
|
|
304
341
|
else:
|
|
305
342
|
raise ValueError(f"Reduction operation '{reduce_op}' does not support backward pass")
|
|
306
343
|
|
|
@@ -309,10 +346,11 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
|
|
|
309
346
|
def kernel_blocksparse_scatter(x,
|
|
310
347
|
x_b, x_b_s, x_r_s, x_c_s,
|
|
311
348
|
s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
|
|
349
|
+
dim,
|
|
312
350
|
i,
|
|
313
351
|
i_b, i_b_s, i_r_s, i_c_s,
|
|
314
352
|
o,
|
|
315
|
-
o_b, o_b_s,
|
|
353
|
+
o_b, o_b_s,
|
|
316
354
|
s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
|
|
317
355
|
r_lut_o,
|
|
318
356
|
reduce_op_ind,
|
|
@@ -332,6 +370,10 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
|
|
|
332
370
|
spa_row_x_msk = (spa_row_x_idx < s_lut_x_r * s_lut_x_r_s)
|
|
333
371
|
spa_row_x = tl.load(s_lut_x + spa_row_x_idx, mask=spa_row_x_msk)
|
|
334
372
|
|
|
373
|
+
spa_col_x_idx = (pid_blk * s_lut_x_r_s + 2 * s_lut_x_c_s)
|
|
374
|
+
spa_col_x_msk = (spa_col_x_idx < s_lut_x_r * s_lut_x_r_s)
|
|
375
|
+
spa_col_x = tl.load(s_lut_x + spa_col_x_idx, mask=spa_col_x_msk)
|
|
376
|
+
|
|
335
377
|
# Load x values
|
|
336
378
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
337
379
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
@@ -346,22 +388,38 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
|
|
|
346
388
|
blk_i_msk = (blk_i_idx < i_b * i_b_s)
|
|
347
389
|
blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
|
|
348
390
|
|
|
349
|
-
# Get
|
|
350
|
-
|
|
351
|
-
|
|
391
|
+
# Get indices of sparsity blocks and positions within the blocks
|
|
392
|
+
pos_spa_blk_x = blk_i // sparsity_block_size
|
|
393
|
+
pos_spa_int_x = blk_i % sparsity_block_size
|
|
394
|
+
|
|
395
|
+
rev_dst_bat_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_x, dtype=tl.int32)
|
|
396
|
+
rev_dst_row_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_x, dtype=tl.int32)
|
|
397
|
+
rev_dst_col_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_x, dtype=tl.int32)
|
|
398
|
+
dst_row_o = (((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
|
|
399
|
+
.broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
|
|
400
|
+
dst_col_o = (((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
|
|
401
|
+
.broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
|
|
402
|
+
if dim == 0:
|
|
403
|
+
rev_dst_bat_o = blk_i
|
|
404
|
+
elif dim == 1:
|
|
405
|
+
rev_dst_row_o = pos_spa_blk_x
|
|
406
|
+
dst_row_o = pos_spa_int_x * x_r_s
|
|
407
|
+
elif dim == 2:
|
|
408
|
+
rev_dst_col_o = pos_spa_blk_x
|
|
409
|
+
dst_col_o = pos_spa_int_x * x_c_s
|
|
352
410
|
|
|
353
411
|
# Load reverse sparsity indices for o
|
|
354
|
-
rev_idx_spa_o_idx = ((
|
|
355
|
-
(
|
|
356
|
-
(
|
|
412
|
+
rev_idx_spa_o_idx = ((rev_dst_bat_o * s_l_o_b_s) +
|
|
413
|
+
(rev_dst_row_o * s_l_o_r_s) +
|
|
414
|
+
(rev_dst_col_o * s_l_o_c_s))
|
|
357
415
|
rev_idx_spa_o_msk = (rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
|
|
358
416
|
rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
|
|
359
417
|
|
|
360
418
|
# Store output
|
|
361
419
|
blk_o_idx = ((rev_idx_spa_o * o_b_s) +
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
420
|
+
dst_row_o +
|
|
421
|
+
dst_col_o)
|
|
422
|
+
blk_o_msk = ((blk_o_idx < o_b * o_b_s) & rev_idx_spa_o_msk != -1)
|
|
365
423
|
|
|
366
424
|
if reduce_op_ind == 0:
|
|
367
425
|
tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
|
|
@@ -153,6 +153,10 @@ class _BlocksparseGatherMDI(torch.autograd.Function):
|
|
|
153
153
|
rev_idx_spa_x_msk = (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
|
|
154
154
|
rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
|
|
155
155
|
|
|
156
|
+
if rev_idx_spa_x == -1:
|
|
157
|
+
tl.device_assert(False)
|
|
158
|
+
return
|
|
159
|
+
|
|
156
160
|
# Load x values
|
|
157
161
|
blk_x_idx = ((rev_idx_spa_x * x_b_s) +
|
|
158
162
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
@@ -342,6 +346,10 @@ class _BlocksparseScatterReduceMDI(torch.autograd.Function):
|
|
|
342
346
|
rev_idx_spa_o_msk = (rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
|
|
343
347
|
rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
|
|
344
348
|
|
|
349
|
+
if rev_idx_spa_o == -1:
|
|
350
|
+
tl.device_assert(False)
|
|
351
|
+
return
|
|
352
|
+
|
|
345
353
|
# Store output
|
|
346
354
|
blk_o_idx = ((rev_idx_spa_o * o_b_s) +
|
|
347
355
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
@@ -117,6 +117,10 @@ def kernel_blocksparse_row_wise_sum(x,
|
|
|
117
117
|
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
|
|
118
118
|
rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
119
119
|
|
|
120
|
+
if rev_idx_spa == -1:
|
|
121
|
+
tl.device_assert(False)
|
|
122
|
+
return
|
|
123
|
+
|
|
120
124
|
blk_idx = ((pid_blk * x_b_s) +
|
|
121
125
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
122
126
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
@@ -240,6 +244,10 @@ def kernel_blocksparse_row_wise_max(x,
|
|
|
240
244
|
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
|
|
241
245
|
rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
242
246
|
|
|
247
|
+
if rev_idx_spa == -1:
|
|
248
|
+
tl.device_assert(False)
|
|
249
|
+
return
|
|
250
|
+
|
|
243
251
|
blk_idx = ((pid_blk * x_b_s) +
|
|
244
252
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
245
253
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
@@ -238,6 +238,10 @@ class _BlocksparseSoftmax(torch.autograd.Function):
|
|
|
238
238
|
rev_idx_spa_s_msk = (rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
|
|
239
239
|
rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
|
|
240
240
|
|
|
241
|
+
if rev_idx_spa_s == -1:
|
|
242
|
+
tl.device_assert(False)
|
|
243
|
+
return
|
|
244
|
+
|
|
241
245
|
blk_s_idx = (rev_idx_spa_s * s_b_s +
|
|
242
246
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
|
|
243
247
|
(tl.arange(0, 1) * s_c_s)[None, :])
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.9
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|