blksprs 1.6.1__py3-none-any.whl → 1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/__init__.py CHANGED
@@ -1,22 +1,27 @@
1
- from blksprs.ops.conversion import to_dense, to_sparse
1
+ from blksprs.ops.conversion import to_dense, to_sparse, from_blksprs, to_blksprs
2
2
  from blksprs.ops.distribution import gather, scatter, scatter_reduce
3
- from blksprs.ops.exp import exp
4
3
  from blksprs.ops.matmul import matmul
5
4
  from blksprs.ops.softmax import softmax
6
5
  from blksprs.ops.transpose import transpose
7
- from blksprs.ops.partitioning import split, merge
6
+ from blksprs.ops.repeat import repeat, repeat_interleave
7
+ from blksprs.misc.partitioning import split, merge
8
+
8
9
 
9
10
  class layout:
10
11
  from blksprs.layouting.distribution_layout import build_distribution_layout
11
- from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption, build_sparsity_layout_matmul
12
+ from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption, \
13
+ build_sparsity_layout_matmul
14
+
12
15
 
13
16
  class misc:
14
17
  from blksprs.misc.broadcast_ops import broadcast_add, broadcast_sub
15
- from blksprs.misc.repeat_interleave import repeat_interleave
18
+ from blksprs.misc.exp import exp
16
19
  from blksprs.misc.row_wise import row_wise_sum, row_wise_max, row_wise_add, row_wise_sub
17
20
 
21
+
18
22
  class util:
19
23
  from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse, disable_validation
20
24
 
25
+
21
26
  class experimental:
22
- from blksprs.experimental.distribution_mdi import gather_mdi, scatter_reduce_mdi
27
+ from blksprs.experimental.distribution_mdi import gather_mdi, scatter_reduce_mdi
@@ -3,7 +3,7 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
- from blksprs.utils.tools import get_triton_block_size
6
+ from blksprs.utils.tools import get_triton_block_size, stride
7
7
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
8
8
  validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
9
9
 
@@ -51,15 +51,15 @@ class _BlocksparseGatherMDI(torch.autograd.Function):
51
51
  output = torch.empty_like(idx_col, dtype=x.dtype)
52
52
 
53
53
  x_b, x_r, x_c = x.size()
54
- x_b_s, x_r_s, x_c_s = x.stride()
54
+ x_b_s, x_r_s, x_c_s = stride(x)
55
55
  s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
56
- s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout_x.stride()
56
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
57
57
  i_b, i_r, i_c = idx_col.size()
58
- i_b_s, i_r_s, i_c_s = idx_col.stride()
58
+ i_b_s, i_r_s, i_c_s = stride(idx_col)
59
59
  s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
60
- s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
60
+ s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
61
61
  o_b, o_r, o_c = output.size()
62
- o_b_s, o_r_s, o_c_s = output.stride()
62
+ o_b_s, o_r_s, o_c_s = stride(output)
63
63
 
64
64
  if triton_block_size is None:
65
65
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -224,15 +224,15 @@ class _BlocksparseScatterReduceMDI(torch.autograd.Function):
224
224
  dtype=x.dtype, device=x.device)
225
225
 
226
226
  x_b, x_r, x_c = x.size()
227
- x_b_s, x_r_s, x_c_s = x.stride()
227
+ x_b_s, x_r_s, x_c_s = stride(x)
228
228
  s_lut_x_r, s_lut_x_c = sparsity_lut_x.size()
229
- s_lut_x_r_s, s_lut_x_c_s = sparsity_lut_x.stride()
229
+ s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut_x)
230
230
  i_b, i_r, i_c = idx_col.size()
231
- i_b_s, i_r_s, i_c_s = idx_col.stride()
231
+ i_b_s, i_r_s, i_c_s = stride(idx_col)
232
232
  o_b, o_r, o_c = output.size()
233
- o_b_s, o_r_s, o_c_s = output.stride()
233
+ o_b_s, o_r_s, o_c_s = stride(output)
234
234
  s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
235
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_o.stride()
235
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
236
236
 
237
237
  if triton_block_size is None:
238
238
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -366,11 +366,11 @@ def build_distribution_layout_mdi(idx_bat: Tensor, idx_row: Tensor, idx_col: Ten
366
366
  dtype=torch.bool, device=idx_col.device)
367
367
 
368
368
  i_b, i_r, i_c = idx_col.size()
369
- i_b_s, i_r_s, i_c_s = idx_col.stride()
369
+ i_b_s, i_r_s, i_c_s = stride(idx_col)
370
370
  s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
371
- s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
371
+ s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
372
372
  o_b, o_r, o_c = output.size()
373
- o_b_s, o_r_s, o_c_s = output.stride()
373
+ o_b_s, o_r_s, o_c_s = stride(output)
374
374
 
375
375
  if triton_block_size is None:
376
376
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -3,7 +3,7 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
- from blksprs.utils.tools import get_triton_block_size
6
+ from blksprs.utils.tools import get_triton_block_size, stride
7
7
  from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
8
8
  validate_contiguous
9
9
 
@@ -34,11 +34,11 @@ def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
34
34
  dtype=torch.bool, device=indices.device)
35
35
 
36
36
  i_b, i_r, i_c = indices.size()
37
- i_b_s, i_r_s, i_c_s = indices.stride()
37
+ i_b_s, i_r_s, i_c_s = stride(indices)
38
38
  s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
39
- s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
39
+ s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
40
40
  o_b, o_r, o_c = output.size()
41
- o_b_s, o_r_s, o_c_s = output.stride()
41
+ o_b_s, o_r_s, o_c_s = stride(output)
42
42
 
43
43
  if triton_block_size is None:
44
44
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -5,7 +5,7 @@ import triton
5
5
  from torch import Tensor
6
6
  from triton import language as tl
7
7
 
8
- from blksprs.utils.tools import get_triton_block_size
8
+ from blksprs.utils.tools import get_triton_block_size, stride
9
9
  from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
10
10
  validate_contiguous, validate_sparsity, validate_sparsity_block_size
11
11
 
@@ -30,9 +30,9 @@ def build_sparsity_layout(x: Tensor, sparsity_block_size: int, triton_block_size
30
30
  dtype=torch.bool, device=x.device)
31
31
 
32
32
  x_b, x_r, x_c = x.size()
33
- x_b_s, x_r_s, x_c_s = x.stride()
33
+ x_b_s, x_r_s, x_c_s = stride(x)
34
34
  o_b, o_r, o_c = output.size()
35
- o_b_s, o_r_s, o_c_s = output.stride()
35
+ o_b_s, o_r_s, o_c_s = stride(output)
36
36
 
37
37
  if triton_block_size is None:
38
38
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -120,10 +120,10 @@ def build_sparsity_layout_adaption(x: Tensor, sparsity_layout_from: Tensor,
120
120
  output = torch.zeros(o_b, o_r, o_c, dtype=torch.bool, device=x.device)
121
121
 
122
122
  x_b, x_r, x_c = x.size()
123
- x_b_s, x_r_s, x_c_s = x.stride()
123
+ x_b_s, x_r_s, x_c_s = stride(x)
124
124
  s_lut_r, s_lut_c = sparsity_lut.size()
125
- s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
126
- o_b_s, o_r_s, o_c_s = output.stride()
125
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
126
+ o_b_s, o_r_s, o_c_s = stride(output)
127
127
 
128
128
  if triton_block_size is None:
129
129
  triton_block_size = get_triton_block_size(sparsity_block_size_from)
@@ -3,7 +3,7 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
- from blksprs.utils.tools import get_triton_block_size
6
+ from blksprs.utils.tools import get_triton_block_size, stride
7
7
  from blksprs.utils.validation import validate_contiguous, validate_device, \
8
8
  validate_sparsity_block_size, validate_triton_block_size
9
9
 
@@ -44,13 +44,13 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
44
44
  output = torch.zeros(n_sparse_blocks, sparsity_block_size, sparsity_block_size, dtype=x.dtype, device=x.device)
45
45
 
46
46
  x_b, x_c = x.size()
47
- x_b_s, x_c_s = x.stride()
47
+ x_b_s, x_c_s = stride(x)
48
48
  y_b, y_c = y.size()
49
- y_b_s, y_c_s = y.stride()
49
+ y_b_s, y_c_s = stride(y)
50
50
  o_b, o_r, o_c = output.size()
51
- o_b_s, o_r_s, o_c_s = output.stride()
51
+ o_b_s, o_r_s, o_c_s = stride(output)
52
52
  s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
53
- s_lut_o_r_s, s_lut_o_c_s = sparsity_lut_o.stride()
53
+ s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_o)
54
54
 
55
55
  if triton_block_size is None:
56
56
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -3,7 +3,7 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
- from blksprs.utils.tools import get_triton_block_size
6
+ from blksprs.utils.tools import get_triton_block_size, stride
7
7
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
8
8
  validate_sparsity_block_size, validate_triton_block_size
9
9
 
@@ -43,9 +43,9 @@ class _BlocksparseExp(torch.autograd.Function):
43
43
  output = torch.empty_like(x)
44
44
 
45
45
  x_b, x_r, x_c = x.shape
46
- x_b_s, x_r_s, x_c_s = x.stride()
46
+ x_b_s, x_r_s, x_c_s = stride(x)
47
47
  o_b, o_r, o_c = output.shape
48
- o_b_s, o_r_s, o_c_s = output.stride()
48
+ o_b_s, o_r_s, o_c_s = stride(output)
49
49
 
50
50
  if triton_block_size is None:
51
51
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -1,10 +1,7 @@
1
1
  import torch
2
- import triton
3
- from sympy.utilities.iterables import partitions
4
2
  from torch import Tensor
5
- from triton import language as tl
6
3
 
7
- from blksprs.utils.tools import get_triton_block_size
4
+ from blksprs.ops.repeat import forward_flow
8
5
 
9
6
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
10
7
  validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
@@ -48,12 +45,11 @@ def split(x: Tensor, sparsity_layout: Tensor, partitions: int,
48
45
  sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
49
46
  (sparsity_layout_flat == 1) -
50
47
  (1 * (sparsity_layout_flat == 0)))
51
- .reshape(sparsity_layout.size())
52
48
  .reshape(sparsity_layout.size(0), sparsity_layout.size(1), partitions,
53
49
  sparsity_layout.size(2) // partitions)
54
50
  .permute(0, 2, 1, 3).reshape(-1).contiguous())
55
51
 
56
- n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
52
+ n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
57
53
 
58
54
  validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
59
55
 
@@ -66,10 +62,11 @@ class _BlocksparseSplit(torch.autograd.Function):
66
62
  @staticmethod
67
63
  def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
68
64
  num_partitions: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
65
+ ctx.save_for_backward(sparsity_layout_o)
69
66
  ctx.num_partitions = num_partitions
70
67
 
71
- return forward_reorder(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
72
- n_sparse_blocks, triton_block_size)
68
+ return forward_flow(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
69
+ n_sparse_blocks, triton_block_size)
73
70
 
74
71
  @staticmethod
75
72
  def backward(ctx, grad_output):
@@ -126,7 +123,7 @@ def merge(x: Tensor, sparsity_layout: Tensor, partitions: int,
126
123
  sparsity_layout.size(1), sparsity_layout.size(2) * partitions)
127
124
  .reshape(-1).contiguous())
128
125
 
129
- n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
126
+ n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
130
127
 
131
128
  validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
132
129
 
@@ -139,10 +136,11 @@ class _BlocksparseMerge(torch.autograd.Function):
139
136
  @staticmethod
140
137
  def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
141
138
  num_partitions: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
139
+ ctx.save_for_backward(sparsity_layout_o)
142
140
  ctx.num_partitions = num_partitions
143
141
 
144
- return forward_reorder(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
145
- n_sparse_blocks, triton_block_size)
142
+ return forward_flow(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
143
+ n_sparse_blocks, triton_block_size)
146
144
 
147
145
  @staticmethod
148
146
  def backward(ctx, grad_output):
@@ -155,90 +153,3 @@ class _BlocksparseMerge(torch.autograd.Function):
155
153
  sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None
156
154
 
157
155
 
158
- def forward_reorder(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
159
- sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
160
- output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
161
- dtype=x.dtype, device=x.device)
162
-
163
- x_b, x_r, x_c = x.size()
164
- x_b_s, x_r_s, x_c_s = x.stride()
165
- s_l_b, s_l_r, s_l_c = sparsity_layout_o.size()
166
- s_l_b_s, s_l_r_s, s_l_c_s = sparsity_layout_o.stride()
167
- s_lut_r, s_lut_c = sparsity_lut.shape
168
- s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
169
- o_b, o_r, o_c = output.size()
170
- o_b_s, o_r_s, o_c_s = output.stride()
171
-
172
- if triton_block_size is None:
173
- triton_block_size = get_triton_block_size(sparsity_block_size)
174
-
175
- triton_grid = lambda meta: [o_b,
176
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
177
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
178
-
179
- (kernel_blocksparse_reorder[triton_grid]
180
- (x,
181
- x_b, x_b_s, x_r_s, x_c_s,
182
- s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
183
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
184
- sparsity_reverse_lut,
185
- output,
186
- o_b, o_b_s,
187
- triton_block_size))
188
-
189
- # Save for backward pass
190
- ctx.save_for_backward(sparsity_layout_o)
191
- ctx.sparsity_block_size = sparsity_block_size
192
- ctx.triton_block_size = triton_block_size
193
-
194
- return output
195
-
196
-
197
- @triton.jit
198
- def kernel_blocksparse_reorder(x,
199
- x_b, x_b_s, x_r_s, x_c_s,
200
- s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
201
- s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
202
- r_lut,
203
- o,
204
- o_b, o_b_s,
205
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
206
- # Get triton block indices
207
- pid_blk = tl.program_id(axis=0)
208
- pid_row = tl.program_id(axis=1)
209
- pid_col = tl.program_id(axis=2)
210
-
211
- # Get sparsity index of current output block consisting of its batch, row, and column index
212
- spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
213
- spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
214
- spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
215
-
216
- spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
217
- spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
218
- spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
219
-
220
- spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
221
- spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
222
- spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
223
-
224
- # Get reverse sparsity index
225
- rev_idx_spa_idx = (spa_bat * s_l_b_s +
226
- spa_row * s_l_r_s +
227
- spa_col * s_l_c_s)
228
- rev_idx_spa_msk = (rev_idx_spa_idx < s_l_b * s_l_b_s)
229
- rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
230
-
231
- if rev_idx_spa == -1:
232
- assert False, "Invalid sparsity block"
233
-
234
- blk_x_idx = (rev_idx_spa * x_b_s +
235
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
236
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
237
- blk_x_msk = (blk_x_idx < x_b * x_b_s)
238
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
239
-
240
- blk_o_idx = (pid_blk * o_b_s +
241
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
242
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
243
- blk_o_msk = (blk_o_idx < o_b * o_b_s)
244
- tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
blksprs/misc/row_wise.py CHANGED
@@ -3,7 +3,7 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
- from blksprs.utils.tools import get_triton_block_size
6
+ from blksprs.utils.tools import get_triton_block_size, stride
7
7
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, validate_sparsity, \
8
8
  validate_sparsity_block_size, validate_triton_block_size
9
9
 
@@ -60,13 +60,13 @@ def row_wise_sum(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
60
60
  device=x.device)
61
61
 
62
62
  x_b, x_r, x_c = x.size()
63
- x_b_s, x_r_s, x_c_s = x.stride()
63
+ x_b_s, x_r_s, x_c_s = stride(x)
64
64
  s_lut_x_r, s_lut_x_c = sparsity_lut.size()
65
- s_lut_x_r_s, s_lut_x_c_s = sparsity_lut.stride()
65
+ s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut)
66
66
  o_b, o_r, o_c = output.size()
67
- o_b_s, o_r_s, o_c_s = output.stride()
67
+ o_b_s, o_r_s, o_c_s = stride(output)
68
68
  s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
69
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_output.stride()
69
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
70
70
 
71
71
  if triton_block_size is None:
72
72
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -183,13 +183,13 @@ def row_wise_max(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
183
183
  device=x.device)
184
184
 
185
185
  x_b, x_r, x_c = x.size()
186
- x_b_s, x_r_s, x_c_s = x.stride()
186
+ x_b_s, x_r_s, x_c_s = stride(x)
187
187
  s_lut_x_r, s_lut_x_c = sparsity_lut.size()
188
- s_lut_x_r_s, s_lut_x_c_s = sparsity_lut.stride()
188
+ s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut)
189
189
  o_b, o_r, o_c = output.size()
190
- o_b_s, o_r_s, o_c_s = output.stride()
190
+ o_b_s, o_r_s, o_c_s = stride(output)
191
191
  s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
192
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_output.stride()
192
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
193
193
 
194
194
  if triton_block_size is None:
195
195
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -290,15 +290,15 @@ def row_wise_add(x: Tensor, sparsity_layout_x: Tensor, y: Tensor,
290
290
  output = torch.empty_like(x)
291
291
 
292
292
  x_b, x_r, x_c = x.size()
293
- x_b_s, x_r_s, x_c_s = x.stride()
293
+ x_b_s, x_r_s, x_c_s = stride(x)
294
294
  s_lut_r, s_lut_c = sparsity_lut.size()
295
- s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
295
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
296
296
  y_b, y_r, y_c = y.size()
297
- y_b_s, y_r_s, y_c_s = y.stride()
297
+ y_b_s, y_r_s, y_c_s = stride(y)
298
298
  s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_rwm.size()
299
- s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = sparsity_layout_rwm.stride()
299
+ s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = stride(sparsity_layout_rwm)
300
300
  o_b, o_r, o_c = output.size()
301
- o_b_s, o_r_s, o_c_s = output.stride()
301
+ o_b_s, o_r_s, o_c_s = stride(output)
302
302
 
303
303
  if triton_block_size is None:
304
304
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -361,7 +361,8 @@ def kernel_blocksparse_row_wise_add(x,
361
361
  rev_idx_spa_s = tl.load(r_lut_y + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
362
362
 
363
363
  if rev_idx_spa_s == -1:
364
- assert False, "Invalid sparsity block"
364
+ tl.device_assert(False)
365
+ return
365
366
 
366
367
  # Load x block
367
368
  blk_x_idx = ((pid_blk * x_b_s) +
blksprs/ops/conversion.py CHANGED
@@ -6,9 +6,14 @@ from torch import Tensor
6
6
  from triton import language as tl
7
7
 
8
8
  from blksprs.layouting.sparsity_layout import build_sparsity_layout_adaption
9
- from blksprs.utils.tools import get_triton_block_size
9
+ from blksprs.utils.tools import get_triton_block_size, stride
10
10
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
11
- validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
11
+ validate_sparsity, validate_sparsity_block_size, validate_triton_block_size, validate_sparsity_dense
12
+
13
+
14
+ def from_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
15
+ triton_block_size: int = None) -> Tensor:
16
+ return to_dense(x, sparsity_layout, sparsity_block_size, fill_value, triton_block_size)
12
17
 
13
18
 
14
19
  def to_dense(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
@@ -65,11 +70,11 @@ class _BlocksparseToDense(torch.autograd.Function):
65
70
  dtype=x.dtype, device=x.device)
66
71
 
67
72
  x_b, x_r, x_c = x.shape
68
- x_b_s, x_r_s, x_c_s = x.stride()
73
+ x_b_s, x_r_s, x_c_s = stride(x)
69
74
  s_l_b, s_l_r, s_l_c = sparsity_layout.size()
70
- s_l_b_s, s_l_r_s, s_l_c_s = sparsity_layout.stride()
75
+ s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout)
71
76
  o_b, o_r, o_c = output.size()
72
- o_b_s, o_r_s, o_c_s = output.stride()
77
+ o_b_s, o_r_s, o_c_s = stride(output)
73
78
 
74
79
  if triton_block_size is None:
75
80
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -144,6 +149,11 @@ class _BlocksparseToDense(torch.autograd.Function):
144
149
  tl.store(o + o_idx, blk, o_msk)
145
150
 
146
151
 
152
+ def to_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
153
+ triton_block_size: int = None) -> Tensor:
154
+ return to_sparse(x, sparsity_layout, sparsity_block_size, triton_block_size)
155
+
156
+
147
157
  def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
148
158
  """Converts a block-sparse tensor in regular form to a block-sparse tensor in compressed form based on the given
149
159
  sparsity layout.
@@ -163,6 +173,7 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, trit
163
173
  validate_dimensions(x)
164
174
  validate_contiguous(x)
165
175
  validate_device(x)
176
+ validate_sparsity_dense(sparsity_block_size, (x, sparsity_layout))
166
177
  validate_sparsity_block_size(sparsity_block_size, x)
167
178
  validate_triton_block_size(triton_block_size, sparsity_block_size)
168
179
 
@@ -190,11 +201,11 @@ class _BlocksparseToSparse(torch.autograd.Function):
190
201
  dtype=x.dtype, device=x.device)
191
202
 
192
203
  x_b, x_r, x_c = x.size()
193
- x_b_s, x_r_s, x_c_s = x.stride()
204
+ x_b_s, x_r_s, x_c_s = stride(x)
194
205
  s_lut_r, s_lut_c = sparsity_lut.size()
195
- s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
206
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
196
207
  o_b, o_r, o_c = output.size()
197
- o_b_s, o_r_s, o_c_s = output.stride()
208
+ o_b_s, o_r_s, o_c_s = stride(output)
198
209
 
199
210
  if triton_block_size is None:
200
211
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -347,13 +358,13 @@ class _BlocksparseAdaptLayout(torch.autograd.Function):
347
358
  dtype=x.dtype, device=x.device)
348
359
 
349
360
  x_b, x_r, x_c = x.size()
350
- x_b_s, x_r_s, x_c_s = x.stride()
361
+ x_b_s, x_r_s, x_c_s = stride(x)
351
362
  s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_from.size()
352
- s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout_from.stride()
363
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_from)
353
364
  o_b, o_r, o_c = output.size()
354
- o_b_s, o_r_s, o_c_s = output.stride()
365
+ o_b_s, o_r_s, o_c_s = stride(output)
355
366
  s_lut_o_r, s_lut_o_c = sparsity_lut_to.size()
356
- s_lut_o_r_s, s_lut_o_c_s = sparsity_lut_to.stride()
367
+ s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_to)
357
368
 
358
369
  if triton_block_size is None:
359
370
  triton_block_size = get_triton_block_size(min_sparsity_block_size)
@@ -3,7 +3,7 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
- from blksprs.utils.tools import get_triton_block_size
6
+ from blksprs.utils.tools import get_triton_block_size, stride
7
7
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
8
8
  validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
9
9
 
@@ -59,15 +59,15 @@ class _BlocksparseGather(torch.autograd.Function):
59
59
  output = torch.empty_like(i, dtype=x.dtype)
60
60
 
61
61
  x_b, x_r, x_c = x.size()
62
- x_b_s, x_r_s, x_c_s = x.stride()
62
+ x_b_s, x_r_s, x_c_s = stride(x)
63
63
  s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
64
- s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout_x.stride()
64
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
65
65
  i_b, i_r, i_c = i.size()
66
- i_b_s, i_r_s, i_c_s = i.stride()
66
+ i_b_s, i_r_s, i_c_s = stride(i)
67
67
  s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
68
- s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
68
+ s_lut_i_r_s, s_lut_i_c_s = stride(sparsity_lut_i)
69
69
  o_b, o_r, o_c = output.size()
70
- o_b_s, o_r_s, o_c_s = output.stride()
70
+ o_b_s, o_r_s, o_c_s = stride(output)
71
71
 
72
72
  if triton_block_size is None:
73
73
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -248,15 +248,15 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
248
248
  dtype=x.dtype, device=x.device)
249
249
 
250
250
  x_b, x_r, x_c = x.size()
251
- x_b_s, x_r_s, x_c_s = x.stride()
251
+ x_b_s, x_r_s, x_c_s = stride(x)
252
252
  s_lut_x_r, s_lut_x_c = sparsity_lut_x.size()
253
- s_lut_x_r_s, s_lut_x_c_s = sparsity_lut_x.stride()
253
+ s_lut_x_r_s, s_lut_x_c_s = stride(sparsity_lut_x)
254
254
  i_b, i_r, i_c = i.size()
255
- i_b_s, i_r_s, i_c_s = i.stride()
255
+ i_b_s, i_r_s, i_c_s = stride(i)
256
256
  o_b, o_r, o_c = output.size()
257
- o_b_s, o_r_s, o_c_s = output.stride()
257
+ o_b_s, o_r_s, o_c_s = stride(output)
258
258
  s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
259
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_o.stride()
259
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
260
260
 
261
261
  if triton_block_size is None:
262
262
  triton_block_size = get_triton_block_size(sparsity_block_size)
blksprs/ops/matmul.py CHANGED
@@ -4,7 +4,7 @@ from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
6
  from blksprs.ops.transpose import transpose
7
- from blksprs.utils.tools import get_triton_block_size
7
+ from blksprs.utils.tools import get_triton_block_size, stride
8
8
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
9
9
  validate_sparsity, validate_sparsity_block_size, validate_triton_block_size, validate_dtype_float
10
10
 
@@ -82,17 +82,17 @@ class _BlocksparseMatmulSSS(torch.autograd.Function):
82
82
  dtype=x.dtype, device=x.device)
83
83
 
84
84
  x_b, x_r, x_c = x.size()
85
- x_b_s, x_r_s, x_c_s = x.stride()
85
+ x_b_s, x_r_s, x_c_s = stride(x)
86
86
  s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
87
- s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout_x.stride()
87
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
88
88
  y_b, y_r, y_c = y.size()
89
- y_b_s, y_r_s, y_c_s = y.stride()
89
+ y_b_s, y_r_s, y_c_s = stride(y)
90
90
  s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_y.size()
91
- s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = sparsity_layout_y.stride()
91
+ s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = stride(sparsity_layout_y)
92
92
  o_b, o_r, o_c = output.size()
93
- o_b_s, o_r_s, o_c_s = output.stride()
93
+ o_b_s, o_r_s, o_c_s = stride(output)
94
94
  s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
95
- s_lut_o_r_s, s_lut_o_c_s = sparsity_lut_o.stride()
95
+ s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_o)
96
96
 
97
97
  if triton_block_size is None:
98
98
  triton_block_size = get_triton_block_size(sparsity_block_size)