blksprs 1.9.3__py3-none-any.whl → 1.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +0 -6
- blksprs/layouting/distribution_layout.py +6 -6
- blksprs/layouting/sparsity_layout.py +7 -7
- blksprs/ops/conversion.py +19 -21
- blksprs/ops/distribution.py +14 -14
- blksprs/ops/flow.py +12 -12
- blksprs/ops/matmul.py +8 -8
- blksprs/ops/misc/broadcast_ops.py +6 -6
- blksprs/ops/misc/exp.py +2 -2
- blksprs/ops/misc/row_wise.py +16 -19
- blksprs/ops/partitioning.py +24 -10
- blksprs/ops/softmax.py +17 -16
- blksprs/ops/transpose.py +9 -8
- {blksprs-1.9.3.dist-info → blksprs-1.10.1.dist-info}/METADATA +18 -14
- blksprs-1.10.1.dist-info/RECORD +24 -0
- blksprs/ops/experimental/distribution_mdi.py +0 -447
- blksprs-1.9.3.dist-info/RECORD +0 -25
- {blksprs-1.9.3.dist-info → blksprs-1.10.1.dist-info}/WHEEL +0 -0
- {blksprs-1.9.3.dist-info → blksprs-1.10.1.dist-info}/top_level.txt +0 -0
blksprs/ops/misc/row_wise.py
CHANGED
|
@@ -104,17 +104,17 @@ def kernel_blocksparse_row_wise_sum(x,
|
|
|
104
104
|
|
|
105
105
|
# Get position of current sparsity block consisting of its batch and row index
|
|
106
106
|
spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
|
|
107
|
-
spa_bat_msk = (spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
|
|
107
|
+
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
|
|
108
108
|
spa_bat = tl.load(s_lut_x + spa_bat_idx, mask=spa_bat_msk)
|
|
109
109
|
|
|
110
110
|
spa_row_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
|
|
111
|
-
spa_row_msk = (spa_row_idx < s_lut_x_r * s_lut_x_r_s)
|
|
111
|
+
spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_x_r * s_lut_x_r_s)
|
|
112
112
|
spa_row = tl.load(s_lut_x + spa_row_idx, mask=spa_row_msk)
|
|
113
113
|
|
|
114
114
|
# Load reverse sparsity index for current block
|
|
115
115
|
rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
|
|
116
116
|
spa_row * s_l_o_r_s)
|
|
117
|
-
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
|
|
117
|
+
rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
|
|
118
118
|
rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
119
119
|
|
|
120
120
|
if rev_idx_spa == -1:
|
|
@@ -124,7 +124,7 @@ def kernel_blocksparse_row_wise_sum(x,
|
|
|
124
124
|
blk_idx = ((pid_blk * x_b_s) +
|
|
125
125
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
126
126
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
127
|
-
blk_msk = (blk_idx < x_b * x_b_s)
|
|
127
|
+
blk_msk = (blk_idx >= 0 and blk_idx < x_b * x_b_s)
|
|
128
128
|
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
129
129
|
|
|
130
130
|
buf = tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
|
|
@@ -132,7 +132,7 @@ def kernel_blocksparse_row_wise_sum(x,
|
|
|
132
132
|
o_idx = (rev_idx_spa * o_b_s +
|
|
133
133
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
134
134
|
(tl.arange(0, 1))[None, :])
|
|
135
|
-
o_msk = (o_idx < o_b * o_b_s)
|
|
135
|
+
o_msk = (o_idx >= 0 and o_idx < o_b * o_b_s)
|
|
136
136
|
tl.atomic_add(o + o_idx, buf, o_msk)
|
|
137
137
|
|
|
138
138
|
|
|
@@ -231,17 +231,17 @@ def kernel_blocksparse_row_wise_max(x,
|
|
|
231
231
|
|
|
232
232
|
# Get position of current sparsity block consisting of its batch and row index
|
|
233
233
|
spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
|
|
234
|
-
spa_bat_msk = (spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
|
|
234
|
+
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
|
|
235
235
|
spa_bat = tl.load(s_lut_x + spa_bat_idx, mask=spa_bat_msk)
|
|
236
236
|
|
|
237
237
|
spa_row_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
|
|
238
|
-
spa_row_msk = (spa_row_idx < s_lut_x_r * s_lut_x_r_s)
|
|
238
|
+
spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_x_r * s_lut_x_r_s)
|
|
239
239
|
spa_row = tl.load(s_lut_x + spa_row_idx, mask=spa_row_msk)
|
|
240
240
|
|
|
241
241
|
# Load reverse sparsity index for current block
|
|
242
242
|
rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
|
|
243
243
|
spa_row * s_l_o_r_s)
|
|
244
|
-
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
|
|
244
|
+
rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
|
|
245
245
|
rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
246
246
|
|
|
247
247
|
if rev_idx_spa == -1:
|
|
@@ -251,7 +251,7 @@ def kernel_blocksparse_row_wise_max(x,
|
|
|
251
251
|
blk_idx = ((pid_blk * x_b_s) +
|
|
252
252
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
253
253
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
254
|
-
blk_msk = (blk_idx < x_b * x_b_s)
|
|
254
|
+
blk_msk = (blk_idx >= 0 and blk_idx < x_b * x_b_s)
|
|
255
255
|
blk = tl.load(x + blk_idx, mask=blk_msk)
|
|
256
256
|
|
|
257
257
|
buf = tl.reshape(tl.max(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
|
|
@@ -259,7 +259,7 @@ def kernel_blocksparse_row_wise_max(x,
|
|
|
259
259
|
o_idx = (rev_idx_spa * o_b_s +
|
|
260
260
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
261
261
|
(tl.arange(0, 1))[None, :])
|
|
262
|
-
o_msk = (o_idx < o_b * o_b_s)
|
|
262
|
+
o_msk = (o_idx >= 0 and o_idx < o_b * o_b_s)
|
|
263
263
|
tl.atomic_max(o + o_idx, buf, o_msk)
|
|
264
264
|
|
|
265
265
|
|
|
@@ -356,17 +356,17 @@ def kernel_blocksparse_row_wise_add(x,
|
|
|
356
356
|
|
|
357
357
|
# Get position of current sparsity block consisting of its batch and row index
|
|
358
358
|
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
359
|
-
spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
359
|
+
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
360
360
|
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
361
361
|
|
|
362
362
|
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
363
|
-
spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
|
|
363
|
+
spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
|
|
364
364
|
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
365
365
|
|
|
366
366
|
# Get reverse sparsity indices for s
|
|
367
367
|
rev_idx_spa_s_idx = (spa_bat * s_l_y_b_s +
|
|
368
368
|
spa_row * s_l_y_r_s)
|
|
369
|
-
rev_idx_spa_s_msk = (rev_idx_spa_s_idx < s_l_y_b * s_l_y_b_s)
|
|
369
|
+
rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_y_b * s_l_y_b_s)
|
|
370
370
|
rev_idx_spa_s = tl.load(r_lut_y + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
|
|
371
371
|
|
|
372
372
|
if rev_idx_spa_s == -1:
|
|
@@ -377,25 +377,22 @@ def kernel_blocksparse_row_wise_add(x,
|
|
|
377
377
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
378
378
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
379
379
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
380
|
-
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
380
|
+
blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
|
|
381
381
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
382
382
|
|
|
383
383
|
# Load sum block
|
|
384
384
|
blk_s_idx = (rev_idx_spa_s * y_b_s +
|
|
385
385
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
|
|
386
386
|
(tl.arange(0, 1) * y_c_s)[None, :])
|
|
387
|
-
blk_s_msk = (blk_s_idx < y_b * y_b_s)
|
|
387
|
+
blk_s_msk = (blk_s_idx >= 0 and blk_s_idx < y_b * y_b_s)
|
|
388
388
|
blk_s = tl.load(y + blk_s_idx, mask=blk_s_msk)
|
|
389
389
|
|
|
390
390
|
# Compute exp
|
|
391
391
|
buf = blk_x + tl.broadcast_to(blk_s, (TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE))
|
|
392
392
|
|
|
393
|
-
# debug
|
|
394
|
-
asdf = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), 1.0, dtype=tl.float32)
|
|
395
|
-
|
|
396
393
|
# Store block
|
|
397
394
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
398
395
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
399
396
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
400
|
-
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
397
|
+
blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
|
|
401
398
|
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
blksprs/ops/partitioning.py
CHANGED
|
@@ -9,13 +9,14 @@ from blksprs.utils.validation import validate_dimensions, validate_contiguous, v
|
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
12
|
-
sparsity_block_size: int, triton_block_size: int = None) -> (BlksprsTensor, Tensor):
|
|
12
|
+
dim: int, sparsity_block_size: int, triton_block_size: int = None) -> (BlksprsTensor, Tensor):
|
|
13
13
|
"""Splits a block-sparse tensor in compressed form along the last dimension into partitions.
|
|
14
14
|
|
|
15
15
|
Args:
|
|
16
16
|
x (BlksprsTensor): A block-sparse tensor in compressed form.
|
|
17
17
|
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
18
18
|
partitions (int): The number of partitions to split the block-sparse tensor into.
|
|
19
|
+
dim (int): The dimension along which to split the tensor. Currently only supports dim=2.
|
|
19
20
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
20
21
|
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
21
22
|
|
|
@@ -54,17 +55,22 @@ def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
|
54
55
|
|
|
55
56
|
validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
|
|
56
57
|
|
|
58
|
+
adjusted_dim = dim % 3
|
|
59
|
+
if adjusted_dim != 2:
|
|
60
|
+
raise NotImplementedError("Currently only supports dim=2")
|
|
61
|
+
|
|
57
62
|
return BlksprsTensor(_BlocksparseSplit.apply(x, sparsity_layout_output, sparsity_lut, sparsity_reverse_lut, partitions,
|
|
58
|
-
sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_output
|
|
63
|
+
adjusted_dim, sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_output
|
|
59
64
|
|
|
60
65
|
|
|
61
66
|
class _BlocksparseSplit(torch.autograd.Function):
|
|
62
67
|
|
|
63
68
|
@staticmethod
|
|
64
69
|
def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
65
|
-
num_partitions: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
70
|
+
num_partitions: int, dim: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
66
71
|
ctx.save_for_backward(sparsity_layout_o)
|
|
67
72
|
ctx.num_partitions = num_partitions
|
|
73
|
+
ctx.dim = dim
|
|
68
74
|
|
|
69
75
|
return flow_forward(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
70
76
|
n_sparse_blocks, triton_block_size)
|
|
@@ -73,21 +79,23 @@ class _BlocksparseSplit(torch.autograd.Function):
|
|
|
73
79
|
def backward(ctx, grad_output):
|
|
74
80
|
sparsity_layout = ctx.saved_tensors[0]
|
|
75
81
|
num_partitions = ctx.num_partitions
|
|
82
|
+
dim = ctx.dim
|
|
76
83
|
sparsity_block_size = ctx.sparsity_block_size
|
|
77
84
|
triton_block_size = ctx.triton_block_size
|
|
78
85
|
|
|
79
|
-
return merge(grad_output, sparsity_layout, num_partitions,
|
|
80
|
-
sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None
|
|
86
|
+
return merge(grad_output, sparsity_layout, num_partitions, dim,
|
|
87
|
+
sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None, None
|
|
81
88
|
|
|
82
89
|
|
|
83
90
|
def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
84
|
-
sparsity_block_size: int, triton_block_size: int = None) -> (BlksprsTensor, Tensor):
|
|
91
|
+
dim: int, sparsity_block_size: int, triton_block_size: int = None) -> (BlksprsTensor, Tensor):
|
|
85
92
|
"""Merges the specified partitions of a block-sparse tensor in compressed form along the last dimension.
|
|
86
93
|
|
|
87
94
|
Args:
|
|
88
95
|
x (BlksprsTensor): A block-sparse tensor in compressed form.
|
|
89
96
|
sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
|
|
90
97
|
partitions (int): The number of partitions to be merged.
|
|
98
|
+
dim (int): The dimension along which to merge the tensor. Currently only supports dim=2.
|
|
91
99
|
sparsity_block_size (int): The size of the sparsity blocks.
|
|
92
100
|
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
93
101
|
|
|
@@ -128,17 +136,22 @@ def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
|
|
|
128
136
|
|
|
129
137
|
validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
|
|
130
138
|
|
|
139
|
+
adjusted_dim = dim % 3
|
|
140
|
+
if adjusted_dim != 2:
|
|
141
|
+
raise NotImplementedError("Currently only supports dim=2")
|
|
142
|
+
|
|
131
143
|
return BlksprsTensor(_BlocksparseMerge.apply(x, sparsity_layout_output, sparsity_lut, sparsity_reverse_lut, partitions,
|
|
132
|
-
sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_output
|
|
144
|
+
adjusted_dim, sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_output
|
|
133
145
|
|
|
134
146
|
|
|
135
147
|
class _BlocksparseMerge(torch.autograd.Function):
|
|
136
148
|
|
|
137
149
|
@staticmethod
|
|
138
150
|
def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
|
|
139
|
-
num_partitions: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
151
|
+
num_partitions: int, dim: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
|
|
140
152
|
ctx.save_for_backward(sparsity_layout_o)
|
|
141
153
|
ctx.num_partitions = num_partitions
|
|
154
|
+
ctx.dim = dim
|
|
142
155
|
|
|
143
156
|
return flow_forward(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
144
157
|
n_sparse_blocks, triton_block_size)
|
|
@@ -147,10 +160,11 @@ class _BlocksparseMerge(torch.autograd.Function):
|
|
|
147
160
|
def backward(ctx, grad_output):
|
|
148
161
|
sparsity_layout = ctx.saved_tensors[0]
|
|
149
162
|
num_partitions = ctx.num_partitions
|
|
163
|
+
dim = ctx.dim
|
|
150
164
|
sparsity_block_size = ctx.sparsity_block_size
|
|
151
165
|
triton_block_size = ctx.triton_block_size
|
|
152
166
|
|
|
153
|
-
return split(grad_output, sparsity_layout, num_partitions,
|
|
154
|
-
sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None
|
|
167
|
+
return split(grad_output, sparsity_layout, num_partitions, dim,
|
|
168
|
+
sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None, None
|
|
155
169
|
|
|
156
170
|
|
blksprs/ops/softmax.py
CHANGED
|
@@ -11,7 +11,8 @@ from blksprs.utils.validation import validate_contiguous, validate_dimensions, v
|
|
|
11
11
|
validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
|
|
12
12
|
|
|
13
13
|
|
|
14
|
-
def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
14
|
+
def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
15
|
+
triton_block_size: int = None) -> BlksprsTensor:
|
|
15
16
|
"""Computes the softmax of a block-sparse tensor in compressed form.
|
|
16
17
|
|
|
17
18
|
Note:
|
|
@@ -47,9 +48,9 @@ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
|
|
|
47
48
|
validate_contiguous(sparsity_layout, sparsity_lut, sparsity_reverse_lut_rws)
|
|
48
49
|
|
|
49
50
|
return BlksprsTensor(_BlocksparseSoftmax.apply(x, sparsity_layout,
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
51
|
+
sparsity_lut,
|
|
52
|
+
sparsity_reverse_lut_rws,
|
|
53
|
+
sparsity_block_size, triton_block_size))
|
|
53
54
|
|
|
54
55
|
|
|
55
56
|
class _BlocksparseSoftmax(torch.autograd.Function):
|
|
@@ -168,17 +169,17 @@ class _BlocksparseSoftmax(torch.autograd.Function):
|
|
|
168
169
|
|
|
169
170
|
# Get position of current sparsity block consisting of its batch and row index
|
|
170
171
|
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
171
|
-
spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
172
|
+
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
172
173
|
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
173
174
|
|
|
174
175
|
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
175
|
-
spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
|
|
176
|
+
spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
|
|
176
177
|
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
177
178
|
|
|
178
179
|
# Get reverse sparsity indices for s
|
|
179
180
|
rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
|
|
180
181
|
spa_row * s_l_s_r_s)
|
|
181
|
-
rev_idx_spa_s_msk = (rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
|
|
182
|
+
rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
|
|
182
183
|
rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
|
|
183
184
|
|
|
184
185
|
if rev_idx_spa_s == -1:
|
|
@@ -189,14 +190,14 @@ class _BlocksparseSoftmax(torch.autograd.Function):
|
|
|
189
190
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
190
191
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
191
192
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
192
|
-
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
193
|
+
blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
|
|
193
194
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
194
195
|
|
|
195
196
|
# Load sum block
|
|
196
197
|
blk_s_idx = (rev_idx_spa_s * s_b_s +
|
|
197
198
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
|
|
198
199
|
(tl.arange(0, 1) * s_c_s)[None, :])
|
|
199
|
-
blk_s_msk = (blk_s_idx < s_b * s_b_s)
|
|
200
|
+
blk_s_msk = (blk_s_idx >= 0 and blk_s_idx < s_b * s_b_s)
|
|
200
201
|
blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
|
|
201
202
|
|
|
202
203
|
# Compute softmax
|
|
@@ -226,16 +227,16 @@ class _BlocksparseSoftmax(torch.autograd.Function):
|
|
|
226
227
|
|
|
227
228
|
# Get position of current sparsity block consisting of its batch and row index
|
|
228
229
|
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
229
|
-
spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
230
|
+
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
230
231
|
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
231
232
|
|
|
232
233
|
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
233
|
-
spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
|
|
234
|
+
spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
|
|
234
235
|
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
235
236
|
|
|
236
237
|
rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
|
|
237
238
|
spa_row * s_l_s_r_s)
|
|
238
|
-
rev_idx_spa_s_msk = (rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
|
|
239
|
+
rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
|
|
239
240
|
rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
|
|
240
241
|
|
|
241
242
|
if rev_idx_spa_s == -1:
|
|
@@ -245,19 +246,19 @@ class _BlocksparseSoftmax(torch.autograd.Function):
|
|
|
245
246
|
blk_s_idx = (rev_idx_spa_s * s_b_s +
|
|
246
247
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
|
|
247
248
|
(tl.arange(0, 1) * s_c_s)[None, :])
|
|
248
|
-
blk_s_msk = (blk_s_idx < s_b * s_b_s)
|
|
249
|
+
blk_s_msk = (blk_s_idx >= 0 and blk_s_idx < s_b * s_b_s)
|
|
249
250
|
blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
|
|
250
251
|
|
|
251
252
|
blk_g_idx = ((pid_blk * g_b_s) +
|
|
252
253
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_r_s)[:, None] +
|
|
253
254
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_c_s)[None, :])
|
|
254
|
-
blk_g_msk = (blk_g_idx < g_b * g_b_s)
|
|
255
|
+
blk_g_msk = (blk_g_idx >= 0 and blk_g_idx < g_b * g_b_s)
|
|
255
256
|
blk_g = tl.load(g + blk_g_idx, mask=blk_g_msk)
|
|
256
257
|
|
|
257
258
|
blk_x_idx = ((pid_blk * x_b_s) +
|
|
258
259
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
259
260
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
260
|
-
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
261
|
+
blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
|
|
261
262
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
262
263
|
|
|
263
264
|
buf = blk_x * (blk_g - blk_s)
|
|
@@ -265,5 +266,5 @@ class _BlocksparseSoftmax(torch.autograd.Function):
|
|
|
265
266
|
blk_o_idx = ((pid_blk * o_b_s) +
|
|
266
267
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
267
268
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
268
|
-
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
269
|
+
blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
|
|
269
270
|
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
blksprs/ops/transpose.py
CHANGED
|
@@ -50,8 +50,9 @@ def transpose(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: in
|
|
|
50
50
|
|
|
51
51
|
validate_contiguous(sparsity_layout_t, sparsity_lut, sparsity_reverse_lut)
|
|
52
52
|
|
|
53
|
-
return BlksprsTensor(
|
|
54
|
-
|
|
53
|
+
return BlksprsTensor(
|
|
54
|
+
_BlocksparseTranspose.apply(x, sparsity_layout_t, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
|
|
55
|
+
n_sparse_blocks, triton_block_size)), sparsity_layout_t
|
|
55
56
|
|
|
56
57
|
|
|
57
58
|
class _BlocksparseTranspose(torch.autograd.Function):
|
|
@@ -122,22 +123,22 @@ class _BlocksparseTranspose(torch.autograd.Function):
|
|
|
122
123
|
|
|
123
124
|
# Get sparsity index of current output block consisting of its batch, row, and column index
|
|
124
125
|
spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
|
|
125
|
-
spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
126
|
+
spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
|
|
126
127
|
spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
|
|
127
128
|
|
|
128
129
|
spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
|
|
129
|
-
spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
|
|
130
|
+
spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
|
|
130
131
|
spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
|
|
131
132
|
|
|
132
133
|
spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
|
|
133
|
-
spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
|
|
134
|
+
spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
|
|
134
135
|
spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
|
|
135
136
|
|
|
136
137
|
# Get reverse sparsity index
|
|
137
138
|
rev_idx_spa_idx = (spa_bat * s_l_b_s +
|
|
138
139
|
spa_row * s_l_r_s +
|
|
139
140
|
spa_col * s_l_c_s)
|
|
140
|
-
rev_idx_spa_msk = (rev_idx_spa_idx < s_l_b * s_l_b_s)
|
|
141
|
+
rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_b * s_l_b_s)
|
|
141
142
|
rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
|
|
142
143
|
|
|
143
144
|
if rev_idx_spa == -1:
|
|
@@ -147,7 +148,7 @@ class _BlocksparseTranspose(torch.autograd.Function):
|
|
|
147
148
|
blk_x_idx = (rev_idx_spa * x_b_s +
|
|
148
149
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
149
150
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
150
|
-
blk_x_msk = (blk_x_idx < x_b * x_b_s)
|
|
151
|
+
blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
|
|
151
152
|
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
152
153
|
|
|
153
154
|
blk_x_t = tl.trans(blk_x)
|
|
@@ -155,5 +156,5 @@ class _BlocksparseTranspose(torch.autograd.Function):
|
|
|
155
156
|
blk_o_idx = (pid_blk * o_b_s +
|
|
156
157
|
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
157
158
|
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
158
|
-
blk_o_msk = (blk_o_idx < o_b * o_b_s)
|
|
159
|
+
blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
|
|
159
160
|
tl.store(o + blk_o_idx, blk_x_t, mask=blk_o_msk)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: blksprs
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.10.1
|
|
4
4
|
Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
|
|
5
5
|
Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
|
|
6
6
|
Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
|
|
@@ -23,14 +23,6 @@ Requires-Dist: build; extra == "build"
|
|
|
23
23
|
[](https://github.com/FelixSchoen/blksprs/releases)
|
|
24
24
|
[](https://www.python.org/downloads/release/python-3119/)
|
|
25
25
|
|
|
26
|
-
## Important Notice
|
|
27
|
-
|
|
28
|
-
🚨 **Non-Final API** 🚨
|
|
29
|
-
|
|
30
|
-
Although it already supports a wide variety of functions, this library is still under active development and the API is
|
|
31
|
-
subject to change. For feature requests or bug reports, please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
|
|
32
|
-
We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
|
|
33
|
-
|
|
34
26
|
## Overview
|
|
35
27
|
|
|
36
28
|
A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
|
|
@@ -44,7 +36,7 @@ Currently supported operations (includes gradient calculation):
|
|
|
44
36
|
- Scatter (_supports either no reduction or summation, gradients are only available for summation_)
|
|
45
37
|
- Repeat (_supports target sparsity layout_)
|
|
46
38
|
- Repeat Interleave (_supports target sparsity layout_)
|
|
47
|
-
- Splitting and merging of matrices along the last
|
|
39
|
+
- Splitting and merging of matrices (_currently* only supports splitting and merging along the last dimension_)
|
|
48
40
|
- Conversion to and from sparse form
|
|
49
41
|
- Conversion to different sparsity layouts and different sparsity block sizes
|
|
50
42
|
|
|
@@ -70,13 +62,15 @@ Furthermore, the library provides a set of utility functions
|
|
|
70
62
|
- for the creation of sparsity layouts based on existing
|
|
71
63
|
dense tensors and for the scatter operation (module ``bs.layouting``),
|
|
72
64
|
- for the application of ``nn.Linear``, ``nn.Dropout``, and ``nn.LayerNorm`` layers to block-sparse tensors,
|
|
73
|
-
- as well as utility functions to
|
|
74
|
-
|
|
65
|
+
- as well as utility functions to ensure correct input dimensionality, and validate input (module ``bs.utils``).
|
|
66
|
+
|
|
67
|
+
_* see the [Roadmap](#roadmap) section for more information_
|
|
75
68
|
|
|
76
69
|
## Installation
|
|
77
70
|
|
|
78
|
-
Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is only compatible with
|
|
79
|
-
the Linux platform
|
|
71
|
+
Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible with
|
|
72
|
+
the Linux platform**.
|
|
73
|
+
Keep track of this [issue](https://github.com/triton-lang/triton/issues/1640) for updates.
|
|
80
74
|
|
|
81
75
|
We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) using pip:
|
|
82
76
|
|
|
@@ -92,6 +86,16 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
|
|
|
92
86
|
|
|
93
87
|
See [`CHANGELOG.md`](https://github.com/FelixSchoen/blksprs/blob/main/CHANGELOG.md) for a detailed changelog.
|
|
94
88
|
|
|
89
|
+
## Roadmap
|
|
90
|
+
|
|
91
|
+
Note that since this library covers all our current needs it is in a **bugfix-only** state.
|
|
92
|
+
This means that there are no plans to add new features, e.g., support for dimension specification of the ``split`` and ``merge`` operations.
|
|
93
|
+
We will continue to maintain the library and fix any issues that arise.
|
|
94
|
+
Should you find any bugs please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
|
|
95
|
+
We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
|
|
96
|
+
|
|
97
|
+
It might be that this changes with future projects, but as of December 2024, we are content with the current state of the library.
|
|
98
|
+
|
|
95
99
|
## Usage
|
|
96
100
|
|
|
97
101
|
We provide an example below to demonstrate the usage of the library.
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
blksprs/__init__.py,sha256=wnpk-20jXq7xV0xa-WpHfPQuauI2gEZz9sH-0blKxP0,1766
|
|
2
|
+
blksprs/layouting/distribution_layout.py,sha256=xDGY5-J7uSD8oenlf8bEJ2amMiQG3NBf2klTTydbTJE,5140
|
|
3
|
+
blksprs/layouting/sparsity_layout.py,sha256=IVtHc_nN3ZM2y4GFcys70PqDWmWc7tkHlVGlToErANk,9894
|
|
4
|
+
blksprs/ops/conversion.py,sha256=NK5uXMepPJ9yYh0vnxKwx5_Ffj_bAvhqPVogf_7PY0g,22248
|
|
5
|
+
blksprs/ops/distribution.py,sha256=qK5t5XgQSJxXPced8RohprqCtUMMTaEP2pFm3KU1c8o,20267
|
|
6
|
+
blksprs/ops/flow.py,sha256=SWHDQ5zx0cjnPR0CcAcRNZdSusSAHSU840SwDNUr24g,6437
|
|
7
|
+
blksprs/ops/matmul.py,sha256=LAQyPNwWVmBMRnAex3msLSPD_aG5SblLCMiutJWqvus,11632
|
|
8
|
+
blksprs/ops/partitioning.py,sha256=ugKnpvH36ND7qeJQp56M74qqfACkzcTVuXebzw__28Y,8286
|
|
9
|
+
blksprs/ops/repeat.py,sha256=RCa-dITomA5v12K5Oxa5_ReA361zS7WHPNNHxSp9PGw,8578
|
|
10
|
+
blksprs/ops/softmax.py,sha256=i8NJhvPRYya94AzpN6qiki6_G9KfDrtPifhWd7wbYzk,12496
|
|
11
|
+
blksprs/ops/transpose.py,sha256=oAtUu7QzQnNAH3lvRs_MIvIKpBu9h74f9Sk07AxKnDM,6991
|
|
12
|
+
blksprs/ops/misc/broadcast_ops.py,sha256=pv0nssSDOdDbQFttpqUIs2ZXShqfm2RYCfJH-C5x3H0,5544
|
|
13
|
+
blksprs/ops/misc/exp.py,sha256=ygfw7oD6ALdPwNQX_HelKgO8I3-LCgIXH_x0gWzkUN8,3840
|
|
14
|
+
blksprs/ops/misc/row_wise.py,sha256=DnV5-xEJUbqZlK2fETwHiPQDUMwT-lkc0VUhBlnJ5Y0,17458
|
|
15
|
+
blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
|
|
16
|
+
blksprs/utils/blksprs_tensor.py,sha256=pfoz59aJixj_fIoFx76ySiygwRQUemmgjMKepZ2c4j0,244
|
|
17
|
+
blksprs/utils/layout_utils.py,sha256=49ZdPS_gMn_IrWty3FARbi2rda5a8g5DmAEL8LOrC30,670
|
|
18
|
+
blksprs/utils/processing.py,sha256=WLuMJQ8v-YovXwcDjhlDn3N31WMZXrtyeeyKSgq_zn4,3642
|
|
19
|
+
blksprs/utils/tools.py,sha256=r7Y4C37vfSWUyQTGwa8NyRqgovmsq9hMufkenqYHOxo,539
|
|
20
|
+
blksprs/utils/validation.py,sha256=CbxBbeQWJo8wox5eMoVzaTlP9FVBwt3-gxUOmi3EUgw,4213
|
|
21
|
+
blksprs-1.10.1.dist-info/METADATA,sha256=5in6lYCZo1bd8urYR0wkTxIiTTAIAANukLpKeZfGasY,9107
|
|
22
|
+
blksprs-1.10.1.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
|
23
|
+
blksprs-1.10.1.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
|
|
24
|
+
blksprs-1.10.1.dist-info/RECORD,,
|