blksprs 1.8.3__tar.gz → 1.9.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. {blksprs-1.8.3 → blksprs-1.9.1}/PKG-INFO +7 -3
  2. {blksprs-1.8.3 → blksprs-1.9.1}/README.md +6 -2
  3. {blksprs-1.8.3 → blksprs-1.9.1}/blksprs/__init__.py +4 -1
  4. {blksprs-1.8.3 → blksprs-1.9.1}/blksprs/layouting/distribution_layout.py +23 -5
  5. {blksprs-1.8.3 → blksprs-1.9.1}/blksprs/ops/conversion.py +26 -34
  6. {blksprs-1.8.3 → blksprs-1.9.1}/blksprs/ops/distribution.py +94 -35
  7. {blksprs-1.8.3 → blksprs-1.9.1}/blksprs/ops/experimental/distribution_mdi.py +8 -0
  8. blksprs-1.9.1/blksprs/ops/flow.py +147 -0
  9. {blksprs-1.8.3 → blksprs-1.9.1}/blksprs/ops/misc/row_wise.py +8 -0
  10. {blksprs-1.8.3 → blksprs-1.9.1}/blksprs/ops/partitioning.py +3 -3
  11. {blksprs-1.8.3 → blksprs-1.9.1}/blksprs/ops/repeat.py +8 -147
  12. {blksprs-1.8.3 → blksprs-1.9.1}/blksprs/ops/softmax.py +4 -0
  13. blksprs-1.9.1/blksprs/utils/layout_utils.py +17 -0
  14. {blksprs-1.8.3 → blksprs-1.9.1}/blksprs/utils/processing.py +35 -2
  15. {blksprs-1.8.3 → blksprs-1.9.1}/blksprs/utils/validation.py +2 -1
  16. {blksprs-1.8.3 → blksprs-1.9.1}/blksprs.egg-info/PKG-INFO +7 -3
  17. {blksprs-1.8.3 → blksprs-1.9.1}/blksprs.egg-info/SOURCES.txt +2 -0
  18. {blksprs-1.8.3 → blksprs-1.9.1}/pyproject.toml +1 -1
  19. {blksprs-1.8.3 → blksprs-1.9.1}/blksprs/layouting/sparsity_layout.py +0 -0
  20. {blksprs-1.8.3 → blksprs-1.9.1}/blksprs/ops/matmul.py +0 -0
  21. {blksprs-1.8.3 → blksprs-1.9.1}/blksprs/ops/misc/broadcast_ops.py +0 -0
  22. {blksprs-1.8.3 → blksprs-1.9.1}/blksprs/ops/misc/exp.py +0 -0
  23. {blksprs-1.8.3 → blksprs-1.9.1}/blksprs/ops/transpose.py +0 -0
  24. {blksprs-1.8.3 → blksprs-1.9.1}/blksprs/utils/benchmarking.py +0 -0
  25. {blksprs-1.8.3 → blksprs-1.9.1}/blksprs/utils/blksprs_tensor.py +0 -0
  26. {blksprs-1.8.3 → blksprs-1.9.1}/blksprs/utils/tools.py +0 -0
  27. {blksprs-1.8.3 → blksprs-1.9.1}/blksprs.egg-info/dependency_links.txt +0 -0
  28. {blksprs-1.8.3 → blksprs-1.9.1}/blksprs.egg-info/requires.txt +0 -0
  29. {blksprs-1.8.3 → blksprs-1.9.1}/blksprs.egg-info/top_level.txt +0 -0
  30. {blksprs-1.8.3 → blksprs-1.9.1}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: blksprs
3
- Version: 1.8.3
3
+ Version: 1.9.1
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -64,8 +64,12 @@ Further helpful operations (included in the ``bs.ops.misc`` module) that do **no
64
64
  - Row-wise sum, max, addition, and subtraction
65
65
  - Broadcast addition and subtraction between slices
66
66
 
67
- Furthermore, the library provides a set of utility functions for the creation of sparsity layouts based on existing
68
- dense tensors and for the scatter operation (module ``bs.layouting``), as well as utility functions to apply linear layers,
67
+ Furthermore, the library provides a set of utility functions
68
+
69
+ - for the creation of sparsity layouts based on existing
70
+ dense tensors and for the scatter operation (module ``bs.layouting``),
71
+ - for the application of ``nn.Linear``, ``nn.Dropout``, and ``nn.LayerNorm`` layers to block-sparse tensors,
72
+ - as well as utility functions to apply linear layers,
69
73
  ensure correct input dimensionality, and validate input (module ``bs.utils``).
70
74
 
71
75
  ## Installation
@@ -45,8 +45,12 @@ Further helpful operations (included in the ``bs.ops.misc`` module) that do **no
45
45
  - Row-wise sum, max, addition, and subtraction
46
46
  - Broadcast addition and subtraction between slices
47
47
 
48
- Furthermore, the library provides a set of utility functions for the creation of sparsity layouts based on existing
49
- dense tensors and for the scatter operation (module ``bs.layouting``), as well as utility functions to apply linear layers,
48
+ Furthermore, the library provides a set of utility functions
49
+
50
+ - for the creation of sparsity layouts based on existing
51
+ dense tensors and for the scatter operation (module ``bs.layouting``),
52
+ - for the application of ``nn.Linear``, ``nn.Dropout``, and ``nn.LayerNorm`` layers to block-sparse tensors,
53
+ - as well as utility functions to apply linear layers,
50
54
  ensure correct input dimensionality, and validate input (module ``bs.utils``).
51
55
 
52
56
  ## Installation
@@ -1,5 +1,6 @@
1
1
  from blksprs.utils.blksprs_tensor import BlksprsTensor
2
2
 
3
+
3
4
  class ops:
4
5
  from blksprs.ops.conversion import to_dense, to_sparse, from_blksprs, to_blksprs, adapt_layout
5
6
  from blksprs.ops.distribution import gather, scatter, scatter_reduce
@@ -22,13 +23,15 @@ class layouting:
22
23
  from blksprs.layouting.distribution_layout import build_distribution_layout
23
24
  from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption, \
24
25
  build_sparsity_layout_matmul, build_sparsity_layout_matmul_fast
26
+ from blksprs.utils.layout_utils import build_full_sparsity_layout
25
27
 
26
28
  class experimental:
27
29
  from blksprs.ops.experimental.distribution_mdi import build_distribution_layout_mdi
28
30
 
29
31
 
30
32
  class utils:
31
- from blksprs.utils.processing import apply_torch_linear
33
+ from blksprs.utils.processing import apply_torch_linear, apply_torch_normalisation, apply_torch_dropout, \
34
+ apply_function_applicable_row_wise
32
35
  from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse
33
36
  from blksprs.utils.validation import disable_validation
34
37
 
@@ -10,13 +10,14 @@ from blksprs.utils.validation import validate_triton_block_size, validate_dimens
10
10
 
11
11
 
12
12
  def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: Tensor,
13
- size_target: torch.Size,
13
+ dim: int, size_target: torch.Size,
14
14
  sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
15
15
  """Builds the sparsity layout of either the source of a gather or the target of a scatter operation.
16
16
 
17
17
  Args:
18
18
  indices (BlksprsTensor): The block-sparse indices tensor in compressed form used for the gather or scatter operation.
19
19
  sparsity_layout_indices (Tensor): The sparsity layout of the indices block-sparse tensor.
20
+ dim (int): The dimension along which the operation is conducted.
20
21
  size_target (torch.Size): The size of the block-sparse target tensor in regular form.
21
22
  sparsity_block_size (int): The size of the sparsity blocks.
22
23
  triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
@@ -31,6 +32,8 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
31
32
 
32
33
  sparsity_lut_i = torch.nonzero(sparsity_layout_indices).contiguous()
33
34
 
35
+ adjusted_dim = dim % 3
36
+
34
37
  output = torch.zeros(size_target[0], size_target[1] // sparsity_block_size, size_target[2] // sparsity_block_size,
35
38
  dtype=torch.bool, device=indices.device)
36
39
 
@@ -55,6 +58,7 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
55
58
  i_b, i_b_s, i_r_s, i_c_s,
56
59
  sparsity_lut_i,
57
60
  s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
61
+ adjusted_dim,
58
62
  output,
59
63
  o_b, o_b_s, o_r_s, o_c_s,
60
64
  sparsity_block_size,
@@ -68,6 +72,7 @@ def kernel_distribution_layout(i,
68
72
  i_b, i_b_s, i_r_s, i_c_s,
69
73
  s_lut_i,
70
74
  s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
75
+ dim,
71
76
  o,
72
77
  o_b, o_b_s, o_r_s, o_c_s,
73
78
  sparsity_block_size,
@@ -86,17 +91,30 @@ def kernel_distribution_layout(i,
86
91
  spa_row_i_msk = (spa_row_i_idx < s_lut_i_r * s_lut_i_r_s)
87
92
  spa_row_i = tl.load(s_lut_i + spa_row_i_idx, mask=spa_row_i_msk)
88
93
 
94
+ spa_col_i_idx = (pid_blk * s_lut_i_r_s + 2 * s_lut_i_c_s)
95
+ spa_col_i_msk = (spa_col_i_idx < s_lut_i_r * s_lut_i_r_s)
96
+ spa_col_i = tl.load(s_lut_i + spa_col_i_idx, mask=spa_col_i_msk)
97
+
89
98
  blk_i_idx = (pid_blk * i_b_s +
90
99
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
91
100
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
92
101
  blk_i_msk = (blk_i_idx < i_b * i_b_s)
93
102
  blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk)
94
103
 
95
- blk_i = blk_i // sparsity_block_size
104
+ dst_bat_idx = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_i, dtype=tl.int32)
105
+ dst_row_idx = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_i, dtype=tl.int32)
106
+ dst_col_idx = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_i, dtype=tl.int32)
107
+ if dim == 0:
108
+ dst_bat_idx = blk_i
109
+ elif dim == 1:
110
+ dst_row_idx = blk_i // sparsity_block_size
111
+ elif dim == 2:
112
+ dst_col_idx = blk_i // sparsity_block_size
113
+
96
114
  blk_v = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), 1, dtype=tl.int32)
97
115
 
98
- blk_o_idx = ((spa_bat_i * o_b_s) +
99
- (spa_row_i * o_r_s) +
100
- (blk_i * o_c_s))
116
+ blk_o_idx = ((dst_bat_idx * o_b_s) +
117
+ (dst_row_idx * o_r_s) +
118
+ (dst_col_idx * o_c_s))
101
119
  blk_o_msk = (blk_o_idx < o_b * o_b_s)
102
120
  tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
@@ -289,8 +289,8 @@ class _BlocksparseToSparse(torch.autograd.Function):
289
289
 
290
290
 
291
291
  def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_size_from: int,
292
- sparsity_block_size_to: int,
293
- preprocess_data: dict = None, triton_block_size: int = None) -> BlksprsTensor:
292
+ sparsity_block_size_to: int, sparsity_layout_to: Tensor = None,
293
+ triton_block_size: int = None) -> (BlksprsTensor, Tensor):
294
294
  """Adapts the sparsity layout of a block-sparse tensor, resulting in a new block-sparse tensor in compressed form
295
295
  conforming to the new sparsity layout (and sparsity block size) definition.
296
296
 
@@ -299,11 +299,12 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
299
299
  sparsity_layout_from (Tensor): The sparsity layout of the input block-sparse tensor.
300
300
  sparsity_block_size_from (int): The size of the sparsity blocks of the input sparsity layout.
301
301
  sparsity_block_size_to (int): The size of the sparsity blocks of the output sparsity layout.
302
- preprocess_data (dict): A dictionary containing data otherwise computed by the function (default ``None``).
302
+ sparsity_layout_to (Tensor): The sparsity layout of the output block-sparse tensor (default ``None``).
303
303
  triton_block_size (int): The block size to use for the triton kernel (default ``None``).
304
304
 
305
305
  Returns:
306
306
  BlksprsTensor: The block-sparse tensor in compressed form with the adapted sparsity layout and sparsity block size.
307
+ Tensor: The sparsity layout of the resulting output tensor.
307
308
 
308
309
  """
309
310
  x = x.contiguous()
@@ -317,52 +318,42 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
317
318
  min_sparsity_block_size = min(sparsity_block_size_from, sparsity_block_size_to)
318
319
  validate_triton_block_size(triton_block_size, min_sparsity_block_size)
319
320
 
320
- if preprocess_data is None:
321
- preprocess_data = {}
321
+ sparsity_layout_from_flat = sparsity_layout_from.reshape(-1)
322
+ sparsity_reverse_lut_from = ((torch.cumsum(sparsity_layout_from_flat, dim=-1) - 1) *
323
+ (sparsity_layout_from_flat == 1) -
324
+ (1 * (sparsity_layout_from_flat == 0)))
322
325
 
323
- if "sparsity_reverse_lut_from" not in preprocess_data:
324
- sparsity_layout_from_flat = sparsity_layout_from.reshape(-1)
325
- sparsity_reverse_lut_from = ((torch.cumsum(sparsity_layout_from_flat, dim=-1) - 1) *
326
- (sparsity_layout_from_flat == 1) -
327
- (1 * (sparsity_layout_from_flat == 0)))
328
- else:
329
- sparsity_reverse_lut_from = preprocess_data["sparsity_reverse_lut_from"]
330
-
331
- if "sparsity_layout_to" not in preprocess_data:
326
+ if sparsity_layout_to is None:
332
327
  sparsity_layout_to = build_sparsity_layout_adaption(x, sparsity_layout_from,
333
328
  sparsity_block_size_from, sparsity_block_size_to,
334
329
  triton_block_size)
335
- else:
336
- sparsity_layout_to = preprocess_data["sparsity_layout_to"]
337
330
 
338
- if "sparsity_lut_to" not in preprocess_data:
339
- sparsity_lut_to = torch.nonzero(sparsity_layout_to).contiguous()
340
- else:
341
- sparsity_lut_to = preprocess_data["sparsity_lut_to"]
331
+ sparsity_lut_to = torch.nonzero(sparsity_layout_to).contiguous()
342
332
 
343
- if "n_sparse_blocks_to" not in preprocess_data:
344
- n_sparse_blocks_to = torch.sum(sparsity_layout_to.to(torch.int)).item()
345
- else:
346
- n_sparse_blocks_to = preprocess_data["n_sparse_blocks_to"]
333
+ n_sparse_blocks_to = torch.sum(sparsity_layout_to.to(torch.int)).item()
347
334
 
348
- validate_contiguous(sparsity_layout_to, sparsity_reverse_lut_from, sparsity_lut_to)
335
+ validate_contiguous(sparsity_reverse_lut_from, sparsity_layout_to, sparsity_lut_to)
349
336
 
350
337
  if (sparsity_block_size_from == sparsity_block_size_to) and torch.equal(sparsity_layout_from, sparsity_layout_to):
351
- return BlksprsTensor(x)
338
+ return BlksprsTensor(x), sparsity_layout_to
352
339
 
353
340
  return BlksprsTensor(_BlocksparseAdaptLayout.apply(x,
354
341
  sparsity_layout_from, sparsity_reverse_lut_from,
355
342
  sparsity_block_size_from,
356
- sparsity_layout_to, sparsity_lut_to, sparsity_block_size_to,
357
- n_sparse_blocks_to, min_sparsity_block_size, triton_block_size))
343
+ sparsity_layout_to, sparsity_lut_to,
344
+ sparsity_block_size_to,
345
+ n_sparse_blocks_to, min_sparsity_block_size,
346
+ triton_block_size)), sparsity_layout_to
358
347
 
359
348
 
360
349
  class _BlocksparseAdaptLayout(torch.autograd.Function):
361
350
 
362
351
  @staticmethod
363
352
  def forward(ctx, x: Tensor,
364
- sparsity_layout_from: Tensor, sparsity_reverse_lut_from: Tensor, sparsity_block_size_from: int,
365
- sparsity_layout_to: Tensor, sparsity_lut_to: Tensor, sparsity_block_size_to: int,
353
+ sparsity_layout_from: Tensor, sparsity_reverse_lut_from: Tensor,
354
+ sparsity_block_size_from: int,
355
+ sparsity_layout_to: Tensor, sparsity_lut_to: Tensor,
356
+ sparsity_block_size_to: int,
366
357
  n_sparse_blocks_to: int, min_sparsity_block_size: int, triton_block_size: int) -> Tensor:
367
358
  output = torch.zeros(size=(n_sparse_blocks_to, sparsity_block_size_to, sparsity_block_size_to),
368
359
  dtype=x.dtype, device=x.device)
@@ -409,9 +400,10 @@ class _BlocksparseAdaptLayout(torch.autograd.Function):
409
400
  sparsity_block_size_to = ctx.sparsity_block_size_to
410
401
  triton_block_size = ctx.triton_block_size
411
402
 
412
- return adapt_layout(grad_output, sparsity_layout_to, sparsity_block_size_to, sparsity_block_size_from,
413
- preprocess_data={"sparsity_layout_to": sparsity_layout_from},
414
- triton_block_size=triton_block_size), None, None, None, None, None, None, None, None, None
403
+ return adapt_layout(
404
+ grad_output, sparsity_layout_to, sparsity_block_size_to, sparsity_block_size_from,
405
+ sparsity_layout_to=sparsity_layout_from,
406
+ triton_block_size=triton_block_size)[0], None, None, None, None, None, None, None, None, None
415
407
 
416
408
  @staticmethod
417
409
  @triton.jit
@@ -448,7 +440,7 @@ class _BlocksparseAdaptLayout(torch.autograd.Function):
448
440
  spa_row_x = (spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size_from
449
441
  spa_col_x = (spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size_from
450
442
 
451
- # # Get reverse sparsity indices for x
443
+ # Get reverse sparsity indices for x
452
444
  rev_idx_spa_x_idx = (spa_bat_x * s_l_x_b_s +
453
445
  spa_row_x * s_l_x_r_s +
454
446
  spa_col_x * s_l_x_c_s)
@@ -3,19 +3,23 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
+ from blksprs.ops.conversion import to_dense
6
7
  from blksprs.utils.blksprs_tensor import BlksprsTensor
7
8
  from blksprs.utils.tools import get_triton_block_size, stride
8
9
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
9
10
  validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
10
11
 
11
12
 
12
- def gather(src: BlksprsTensor, sparsity_layout_src: Tensor, idx: BlksprsTensor, sparsity_layout_idx: Tensor,
13
+ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
14
+ dim: int,
15
+ idx: BlksprsTensor, sparsity_layout_idx: Tensor,
13
16
  sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
14
17
  """Applies a gather operation on a block-sparse tensor in compressed form.
15
18
 
16
19
  Args:
17
20
  src (BlksprsTensor): The source block-sparse tensor in compressed form to gather from.
18
21
  sparsity_layout_src (Tensor): The sparsity layout of the source block-sparse tensor.
22
+ dim (int): The dimension along which to gather.
19
23
  idx (BlksprsTensor): The block-sparse indices tensor in compressed form specifying how to gather from the source tensor.
20
24
  sparsity_layout_idx (Tensor): The sparsity layout of the indices block-sparse tensor.
21
25
  sparsity_block_size (int): The size of the sparsity blocks.
@@ -46,16 +50,18 @@ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor, idx: BlksprsTensor,
46
50
  validate_contiguous(sparsity_layout_src, sparsity_reverse_lut_x,
47
51
  sparsity_layout_idx, sparsity_lut_i)
48
52
 
53
+ adjusted_dim = dim % 3
54
+
49
55
  return BlksprsTensor(_BlocksparseGather.apply(src, sparsity_layout_src, sparsity_reverse_lut_x,
50
- idx, sparsity_layout_idx, sparsity_lut_i,
51
- sparsity_block_size, triton_block_size))
56
+ adjusted_dim, idx, sparsity_layout_idx, sparsity_lut_i,
57
+ sparsity_block_size, triton_block_size))
52
58
 
53
59
 
54
60
  class _BlocksparseGather(torch.autograd.Function):
55
61
 
56
62
  @staticmethod
57
63
  def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
58
- i: Tensor, sparsity_layout_i: Tensor, sparsity_lut_i: Tensor,
64
+ dim: int, i: Tensor, sparsity_layout_i: Tensor, sparsity_lut_i: Tensor,
59
65
  sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
60
66
  output = torch.empty_like(i, dtype=x.dtype)
61
67
 
@@ -82,6 +88,7 @@ class _BlocksparseGather(torch.autograd.Function):
82
88
  x_b, x_b_s, x_r_s, x_c_s,
83
89
  s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
84
90
  sparsity_reverse_lut_x,
91
+ dim,
85
92
  i,
86
93
  i_b, i_b_s, i_r_s, i_c_s,
87
94
  output,
@@ -91,6 +98,7 @@ class _BlocksparseGather(torch.autograd.Function):
91
98
  triton_block_size))
92
99
 
93
100
  ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_i)
101
+ ctx.dim = dim
94
102
  ctx.sparsity_block_size = sparsity_block_size
95
103
  ctx.triton_block_size = triton_block_size
96
104
 
@@ -99,15 +107,15 @@ class _BlocksparseGather(torch.autograd.Function):
99
107
  @staticmethod
100
108
  def backward(ctx, grad_output):
101
109
  sparsity_layout_x, i, sparsity_layout_i = ctx.saved_tensors
110
+ dim = ctx.dim
102
111
  sparsity_block_size = ctx.sparsity_block_size
103
112
  triton_block_size = ctx.triton_block_size
104
113
 
105
114
  return scatter_reduce(grad_output, sparsity_layout_i,
106
- i,
107
- sparsity_layout_x,
108
- sparsity_block_size,
115
+ dim, i,
116
+ sparsity_layout_x, sparsity_block_size,
109
117
  reduce_op="sum",
110
- triton_block_size=triton_block_size), None, None, None, None, None, None, None
118
+ triton_block_size=triton_block_size), None, None, None, None, None, None, None, None
111
119
 
112
120
  @staticmethod
113
121
  @triton.jit
@@ -115,6 +123,7 @@ class _BlocksparseGather(torch.autograd.Function):
115
123
  x_b, x_b_s, x_r_s, x_c_s,
116
124
  s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
117
125
  r_lut_x,
126
+ dim,
118
127
  i,
119
128
  i_b, i_b_s, i_r_s, i_c_s,
120
129
  o,
@@ -136,6 +145,10 @@ class _BlocksparseGather(torch.autograd.Function):
136
145
  spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
137
146
  spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
138
147
 
148
+ spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
149
+ spa_col_o_msk = (spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
150
+ spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
151
+
139
152
  # Load index values
140
153
  blk_i_idx = ((pid_blk * i_b_s) +
141
154
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
@@ -143,33 +156,50 @@ class _BlocksparseGather(torch.autograd.Function):
143
156
  blk_i_msk = (blk_i_idx < i_b * i_b_s)
144
157
  blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
145
158
 
146
- # Get positions of sparsity blocks
159
+ # Get indices of sparsity blocks and positions within the blocks
147
160
  pos_spa_blk_x = blk_i // sparsity_block_size
148
- pos_spa_col_x = blk_i % sparsity_block_size
161
+ pos_spa_int_x = blk_i % sparsity_block_size
162
+
163
+ rev_dst_bat_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_o, dtype=tl.int32)
164
+ rev_dst_row_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_o, dtype=tl.int32)
165
+ rev_dst_col_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_o, dtype=tl.int32)
166
+ dst_row_x = (((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
167
+ .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
168
+ dst_col_x = (((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
169
+ .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
170
+ if dim == 0:
171
+ rev_dst_bat_x = blk_i
172
+ elif dim == 1:
173
+ rev_dst_row_x = pos_spa_blk_x
174
+ dst_row_x = pos_spa_int_x * x_r_s
175
+ elif dim == 2:
176
+ rev_dst_col_x = pos_spa_blk_x
177
+ dst_col_x = pos_spa_int_x * x_c_s
149
178
 
150
179
  # Load reverse sparsity indices for x
151
- rev_idx_spa_x_idx = ((spa_bat_o * s_l_x_b_s) +
152
- (spa_row_o * s_l_x_r_s) +
153
- (pos_spa_blk_x * s_l_x_c_s))
180
+ rev_idx_spa_x_idx = ((rev_dst_bat_x * s_l_x_b_s) +
181
+ (rev_dst_row_x * s_l_x_r_s) +
182
+ (rev_dst_col_x * s_l_x_c_s))
154
183
  rev_idx_spa_x_msk = (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
155
184
  rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
156
185
 
157
186
  # Load x values
158
187
  blk_x_idx = ((rev_idx_spa_x * x_b_s) +
159
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
160
- (pos_spa_col_x * x_c_s))
161
- blk_x_msk = (blk_x_idx < x_b * x_b_s)
188
+ dst_row_x +
189
+ dst_col_x)
190
+ blk_x_msk = ((blk_x_idx < x_b * x_b_s) & rev_idx_spa_x_msk != -1)
162
191
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
163
192
 
164
193
  # Store output
165
194
  blk_o_idx = ((pid_blk * o_b_s) +
166
195
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
167
196
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
168
- blk_o_msk = (blk_o_idx < o_b * o_b_s)
197
+ blk_o_msk = ((blk_o_idx < o_b * o_b_s) & rev_idx_spa_x_msk != -1)
169
198
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
170
199
 
171
200
 
172
201
  def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
202
+ dim: int,
173
203
  idx: BlksprsTensor,
174
204
  sparsity_layout_tgt: Tensor,
175
205
  sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
@@ -177,6 +207,7 @@ def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
177
207
 
178
208
  """
179
209
  return scatter_reduce(src, sparsity_layout_src,
210
+ dim,
180
211
  idx,
181
212
  sparsity_layout_tgt,
182
213
  sparsity_block_size,
@@ -184,6 +215,7 @@ def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
184
215
 
185
216
 
186
217
  def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
218
+ dim: int,
187
219
  idx: BlksprsTensor,
188
220
  sparsity_layout_tgt: Tensor,
189
221
  sparsity_block_size: int,
@@ -193,6 +225,7 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
193
225
  Args:
194
226
  src (BlksprsTensor): The source block-sparse tensor in compressed form to scatter from.
195
227
  sparsity_layout_src (Tensor): The sparsity layout of the source block-sparse tensor.
228
+ dim (int): The dimension along which to scatter.
196
229
  idx (BlksprsTensor): The block-sparse indices tensor in compressed form specifying how to scatter to the target tensor.
197
230
  sparsity_layout_tgt (Tensor): The sparsity layout of the target block-sparse tensor.
198
231
  sparsity_block_size (int): The size of the sparsity blocks.
@@ -230,18 +263,20 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
230
263
  validate_contiguous(sparsity_layout_src, sparsity_lut_x,
231
264
  sparsity_layout_tgt, sparsity_reverse_lut_o)
232
265
 
266
+ adjusted_dim = dim % 3
267
+
233
268
  return BlksprsTensor(_BlocksparseScatterReduce.apply(src, sparsity_layout_src, sparsity_lut_x,
234
- idx,
235
- sparsity_layout_tgt, sparsity_reverse_lut_o,
236
- sparsity_block_size, n_sparse_blocks,
237
- reduce_op, triton_block_size))
269
+ adjusted_dim, idx,
270
+ sparsity_layout_tgt, sparsity_reverse_lut_o,
271
+ sparsity_block_size, n_sparse_blocks,
272
+ reduce_op, triton_block_size))
238
273
 
239
274
 
240
275
  class _BlocksparseScatterReduce(torch.autograd.Function):
241
276
 
242
277
  @staticmethod
243
278
  def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_lut_x: Tensor,
244
- i: Tensor,
279
+ dim: int, i: Tensor,
245
280
  sparsity_layout_o: Tensor, sparsity_reverse_lut_o: Tensor,
246
281
  sparsity_block_size: int, n_sparse_blocks: int,
247
282
  reduce_op: str, triton_block_size: int) -> Tensor:
@@ -274,10 +309,11 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
274
309
  (x,
275
310
  x_b, x_b_s, x_r_s, x_c_s,
276
311
  sparsity_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
312
+ dim,
277
313
  i,
278
314
  i_b, i_b_s, i_r_s, i_c_s,
279
315
  output,
280
- o_b, o_b_s, o_r_s, o_c_s,
316
+ o_b, o_b_s,
281
317
  s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
282
318
  sparsity_reverse_lut_o,
283
319
  reduce_op_ind,
@@ -285,6 +321,7 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
285
321
  triton_block_size))
286
322
 
287
323
  ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_o)
324
+ ctx.dim = dim
288
325
  ctx.sparsity_block_size = sparsity_block_size
289
326
  ctx.reduce_op = reduce_op
290
327
  ctx.triton_block_size = triton_block_size
@@ -294,13 +331,14 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
294
331
  @staticmethod
295
332
  def backward(ctx, grad_output):
296
333
  sparsity_layout_x, i, sparsity_layout_o = ctx.saved_tensors
334
+ dim = ctx.dim
297
335
  sparsity_block_size = ctx.sparsity_block_size
298
336
  reduce_op = ctx.reduce_op
299
337
  triton_block_size = ctx.triton_block_size
300
338
 
301
339
  if reduce_op == "sum":
302
- return gather(grad_output, sparsity_layout_o, i, sparsity_layout_x, sparsity_block_size,
303
- triton_block_size=triton_block_size), None, None, None, None, None, None, None, None, None
340
+ return gather(grad_output, sparsity_layout_o, dim, i, sparsity_layout_x, sparsity_block_size,
341
+ triton_block_size=triton_block_size), None, None, None, None, None, None, None, None, None, None
304
342
  else:
305
343
  raise ValueError(f"Reduction operation '{reduce_op}' does not support backward pass")
306
344
 
@@ -309,10 +347,11 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
309
347
  def kernel_blocksparse_scatter(x,
310
348
  x_b, x_b_s, x_r_s, x_c_s,
311
349
  s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
350
+ dim,
312
351
  i,
313
352
  i_b, i_b_s, i_r_s, i_c_s,
314
353
  o,
315
- o_b, o_b_s, o_r_s, o_c_s,
354
+ o_b, o_b_s,
316
355
  s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
317
356
  r_lut_o,
318
357
  reduce_op_ind,
@@ -332,6 +371,10 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
332
371
  spa_row_x_msk = (spa_row_x_idx < s_lut_x_r * s_lut_x_r_s)
333
372
  spa_row_x = tl.load(s_lut_x + spa_row_x_idx, mask=spa_row_x_msk)
334
373
 
374
+ spa_col_x_idx = (pid_blk * s_lut_x_r_s + 2 * s_lut_x_c_s)
375
+ spa_col_x_msk = (spa_col_x_idx < s_lut_x_r * s_lut_x_r_s)
376
+ spa_col_x = tl.load(s_lut_x + spa_col_x_idx, mask=spa_col_x_msk)
377
+
335
378
  # Load x values
336
379
  blk_x_idx = ((pid_blk * x_b_s) +
337
380
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
@@ -346,22 +389,38 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
346
389
  blk_i_msk = (blk_i_idx < i_b * i_b_s)
347
390
  blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
348
391
 
349
- # Get positions of sparsity blocks
350
- pos_spa_blk_o = blk_i // sparsity_block_size
351
- pos_spa_col_o = blk_i % sparsity_block_size
392
+ # Get indices of sparsity blocks and positions within the blocks
393
+ pos_spa_blk_x = blk_i // sparsity_block_size
394
+ pos_spa_int_x = blk_i % sparsity_block_size
395
+
396
+ rev_dst_bat_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_x, dtype=tl.int32)
397
+ rev_dst_row_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_x, dtype=tl.int32)
398
+ rev_dst_col_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_x, dtype=tl.int32)
399
+ dst_row_o = (((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
400
+ .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
401
+ dst_col_o = (((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
402
+ .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
403
+ if dim == 0:
404
+ rev_dst_bat_o = blk_i
405
+ elif dim == 1:
406
+ rev_dst_row_o = pos_spa_blk_x
407
+ dst_row_o = pos_spa_int_x * x_r_s
408
+ elif dim == 2:
409
+ rev_dst_col_o = pos_spa_blk_x
410
+ dst_col_o = pos_spa_int_x * x_c_s
352
411
 
353
412
  # Load reverse sparsity indices for o
354
- rev_idx_spa_o_idx = ((spa_bat_x * s_l_o_b_s) +
355
- (spa_row_x * s_l_o_r_s) +
356
- (pos_spa_blk_o * s_l_o_c_s))
413
+ rev_idx_spa_o_idx = ((rev_dst_bat_o * s_l_o_b_s) +
414
+ (rev_dst_row_o * s_l_o_r_s) +
415
+ (rev_dst_col_o * s_l_o_c_s))
357
416
  rev_idx_spa_o_msk = (rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
358
417
  rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
359
418
 
360
419
  # Store output
361
420
  blk_o_idx = ((rev_idx_spa_o * o_b_s) +
362
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
363
- (pos_spa_col_o * o_c_s))
364
- blk_o_msk = (blk_o_idx < o_b * o_b_s)
421
+ dst_row_o +
422
+ dst_col_o)
423
+ blk_o_msk = ((blk_o_idx < o_b * o_b_s) & rev_idx_spa_o_msk != -1)
365
424
 
366
425
  if reduce_op_ind == 0:
367
426
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
@@ -153,6 +153,10 @@ class _BlocksparseGatherMDI(torch.autograd.Function):
153
153
  rev_idx_spa_x_msk = (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
154
154
  rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
155
155
 
156
+ if rev_idx_spa_x == -1:
157
+ tl.device_assert(False)
158
+ return
159
+
156
160
  # Load x values
157
161
  blk_x_idx = ((rev_idx_spa_x * x_b_s) +
158
162
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
@@ -342,6 +346,10 @@ class _BlocksparseScatterReduceMDI(torch.autograd.Function):
342
346
  rev_idx_spa_o_msk = (rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
343
347
  rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
344
348
 
349
+ if rev_idx_spa_o == -1:
350
+ tl.device_assert(False)
351
+ return
352
+
345
353
  # Store output
346
354
  blk_o_idx = ((rev_idx_spa_o * o_b_s) +
347
355
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
@@ -0,0 +1,147 @@
1
+ import torch
2
+ import triton
3
+ from torch import Tensor
4
+ from triton import language as tl
5
+
6
+ from blksprs.utils.tools import stride, get_triton_block_size
7
+
8
+
9
+ @triton.jit
10
+ def kernel_blocksparse_flow_pull(x,
11
+ x_b, x_b_s, x_r_s, x_c_s,
12
+ o,
13
+ o_b, o_b_s, o_r_s, o_c_s,
14
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
15
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
16
+ r_lut,
17
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
18
+ # Get triton block indices
19
+ pid_blk = tl.program_id(axis=0)
20
+ pid_row = tl.program_id(axis=1)
21
+ pid_col = tl.program_id(axis=2)
22
+
23
+ # Get sparsity index of current output block consisting of its batch, row, and column index
24
+ spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
25
+ spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
26
+ spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
27
+
28
+ spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
29
+ spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
30
+ spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
31
+
32
+ spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
33
+ spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
34
+ spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
35
+
36
+ # Get reverse sparsity index
37
+ rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
38
+ spa_row * s_l_o_r_s +
39
+ spa_col * s_l_o_c_s)
40
+ rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
41
+ rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
42
+
43
+ if rev_idx_spa == -1:
44
+ tl.device_assert(False)
45
+ return
46
+
47
+ blk_x_idx = (rev_idx_spa * x_b_s +
48
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
49
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
50
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
51
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
52
+
53
+ blk_o_idx = (pid_blk * o_b_s +
54
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
55
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
56
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
57
+ tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
58
+
59
+
60
+ @triton.jit
61
+ def kernel_blocksparse_flow_push(x,
62
+ x_b, x_b_s, x_r_s, x_c_s,
63
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
64
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
65
+ r_lut,
66
+ o,
67
+ o_b, o_b_s, o_r_s, o_c_s,
68
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
69
+ # Get triton block indices
70
+ pid_blk = tl.program_id(axis=0)
71
+ pid_row = tl.program_id(axis=1)
72
+ pid_col = tl.program_id(axis=2)
73
+
74
+ # Get sparsity index of current input block consisting of its batch, row, and column index
75
+ spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
76
+ spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
77
+ spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
78
+
79
+ spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
80
+ spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
81
+ spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
82
+
83
+ spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
84
+ spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
85
+ spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
86
+
87
+ # Get reverse sparsity index
88
+ rev_idx_spa_idx = (spa_bat * s_l_x_b_s +
89
+ spa_row * s_l_x_r_s +
90
+ spa_col * s_l_x_c_s)
91
+ rev_idx_spa_msk = (rev_idx_spa_idx < s_l_x_b * s_l_x_b_s)
92
+ rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
93
+
94
+ if rev_idx_spa == -1:
95
+ tl.device_assert(False)
96
+ return
97
+
98
+ blk_x_idx = (pid_blk * x_b_s +
99
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
100
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
101
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
102
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
103
+
104
+ blk_o_idx = (rev_idx_spa * o_b_s +
105
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
106
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
107
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
108
+ tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
109
+
110
+
111
+ def flow_forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
112
+ sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
113
+ output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
114
+ dtype=x.dtype, device=x.device)
115
+ output = torch.zeros_like(output)
116
+
117
+ x_b, x_r, x_c = x.size()
118
+ x_b_s, x_r_s, x_c_s = stride(x)
119
+ o_b, o_r, o_c = output.size()
120
+ o_b_s, o_r_s, o_c_s = stride(output)
121
+ s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
122
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
123
+ s_lut_r, s_lut_c = sparsity_lut.size()
124
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
125
+
126
+ if triton_block_size is None:
127
+ triton_block_size = get_triton_block_size(sparsity_block_size)
128
+
129
+ triton_grid = lambda meta: [o_b,
130
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
131
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
132
+
133
+ (kernel_blocksparse_flow_pull[triton_grid]
134
+ (x,
135
+ x_b, x_b_s, x_r_s, x_c_s,
136
+ output,
137
+ o_b, o_b_s, o_r_s, o_c_s,
138
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
139
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
140
+ sparsity_reverse_lut,
141
+ triton_block_size))
142
+
143
+ # Save for backward pass
144
+ ctx.sparsity_block_size = sparsity_block_size
145
+ ctx.triton_block_size = triton_block_size
146
+
147
+ return output
@@ -117,6 +117,10 @@ def kernel_blocksparse_row_wise_sum(x,
117
117
  rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
118
118
  rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
119
119
 
120
+ if rev_idx_spa == -1:
121
+ tl.device_assert(False)
122
+ return
123
+
120
124
  blk_idx = ((pid_blk * x_b_s) +
121
125
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
122
126
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
@@ -240,6 +244,10 @@ def kernel_blocksparse_row_wise_max(x,
240
244
  rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
241
245
  rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
242
246
 
247
+ if rev_idx_spa == -1:
248
+ tl.device_assert(False)
249
+ return
250
+
243
251
  blk_idx = ((pid_blk * x_b_s) +
244
252
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
245
253
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
  from torch import Tensor
3
3
 
4
- from blksprs.ops.repeat import forward_flow
4
+ from blksprs.ops.flow import flow_forward
5
5
  from blksprs.utils.blksprs_tensor import BlksprsTensor
6
6
 
7
7
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
@@ -66,7 +66,7 @@ class _BlocksparseSplit(torch.autograd.Function):
66
66
  ctx.save_for_backward(sparsity_layout_o)
67
67
  ctx.num_partitions = num_partitions
68
68
 
69
- return forward_flow(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
69
+ return flow_forward(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
70
70
  n_sparse_blocks, triton_block_size)
71
71
 
72
72
  @staticmethod
@@ -140,7 +140,7 @@ class _BlocksparseMerge(torch.autograd.Function):
140
140
  ctx.save_for_backward(sparsity_layout_o)
141
141
  ctx.num_partitions = num_partitions
142
142
 
143
- return forward_flow(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
143
+ return flow_forward(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
144
144
  n_sparse_blocks, triton_block_size)
145
145
 
146
146
  @staticmethod
@@ -1,8 +1,8 @@
1
1
  import torch
2
2
  import triton
3
- from triton import language as tl
4
3
  from torch import Tensor
5
4
 
5
+ from blksprs.ops.flow import kernel_blocksparse_flow_push, flow_forward
6
6
  from blksprs.utils.blksprs_tensor import BlksprsTensor
7
7
  from blksprs.utils.tools import get_triton_block_size, stride
8
8
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
@@ -64,8 +64,9 @@ def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int,
64
64
 
65
65
  validate_contiguous(sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
66
66
 
67
- return BlksprsTensor(_BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
68
- sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_o
67
+ return BlksprsTensor(
68
+ _BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
69
+ sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_o
69
70
 
70
71
 
71
72
  def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
@@ -122,8 +123,9 @@ def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
122
123
 
123
124
  validate_contiguous(sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
124
125
 
125
- return BlksprsTensor(_BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
126
- sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_o
126
+ return BlksprsTensor(
127
+ _BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
128
+ sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_o
127
129
 
128
130
 
129
131
  class _BlocksparseRepeat(torch.autograd.Function):
@@ -137,7 +139,7 @@ class _BlocksparseRepeat(torch.autograd.Function):
137
139
  ctx.x_size = x.size()
138
140
  ctx.x_stride = stride(x)
139
141
 
140
- return forward_flow(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
142
+ return flow_forward(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
141
143
  n_sparse_blocks, triton_block_size)
142
144
 
143
145
  @staticmethod
@@ -180,144 +182,3 @@ class _BlocksparseRepeat(torch.autograd.Function):
180
182
  triton_block_size))
181
183
 
182
184
  return output, None, None, None, None, None, None, None
183
-
184
-
185
- @triton.jit
186
- def kernel_blocksparse_flow_pull(x,
187
- x_b, x_b_s, x_r_s, x_c_s,
188
- o,
189
- o_b, o_b_s, o_r_s, o_c_s,
190
- s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
191
- s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
192
- r_lut,
193
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
194
- # Get triton block indices
195
- pid_blk = tl.program_id(axis=0)
196
- pid_row = tl.program_id(axis=1)
197
- pid_col = tl.program_id(axis=2)
198
-
199
- # Get sparsity index of current output block consisting of its batch, row, and column index
200
- spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
201
- spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
202
- spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
203
-
204
- spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
205
- spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
206
- spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
207
-
208
- spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
209
- spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
210
- spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
211
-
212
- # Get reverse sparsity index
213
- rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
214
- spa_row * s_l_o_r_s +
215
- spa_col * s_l_o_c_s)
216
- rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
217
- rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
218
-
219
- if rev_idx_spa == -1:
220
- tl.device_assert(False)
221
- return
222
-
223
- blk_x_idx = (rev_idx_spa * x_b_s +
224
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
225
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
226
- blk_x_msk = (blk_x_idx < x_b * x_b_s)
227
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
228
-
229
- blk_o_idx = (pid_blk * o_b_s +
230
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
231
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
232
- blk_o_msk = (blk_o_idx < o_b * o_b_s)
233
- tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
234
-
235
-
236
- @triton.jit
237
- def kernel_blocksparse_flow_push(x,
238
- x_b, x_b_s, x_r_s, x_c_s,
239
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
240
- s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
241
- r_lut,
242
- o,
243
- o_b, o_b_s, o_r_s, o_c_s,
244
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
245
- # Get triton block indices
246
- pid_blk = tl.program_id(axis=0)
247
- pid_row = tl.program_id(axis=1)
248
- pid_col = tl.program_id(axis=2)
249
-
250
- # Get sparsity index of current input block consisting of its batch, row, and column index
251
- spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
252
- spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
253
- spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
254
-
255
- spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
256
- spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
257
- spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
258
-
259
- spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
260
- spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
261
- spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
262
-
263
- # Get reverse sparsity index
264
- rev_idx_spa_idx = (spa_bat * s_l_x_b_s +
265
- spa_row * s_l_x_r_s +
266
- spa_col * s_l_x_c_s)
267
- rev_idx_spa_msk = (rev_idx_spa_idx < s_l_x_b * s_l_x_b_s)
268
- rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
269
-
270
- if rev_idx_spa == -1:
271
- tl.device_assert(False)
272
- return
273
-
274
- blk_x_idx = (pid_blk * x_b_s +
275
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
276
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
277
- blk_x_msk = (blk_x_idx < x_b * x_b_s)
278
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
279
-
280
- blk_o_idx = (rev_idx_spa * o_b_s +
281
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
282
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
283
- blk_o_msk = (blk_o_idx < o_b * o_b_s)
284
- tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
285
-
286
-
287
- def forward_flow(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
288
- sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
289
- output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
290
- dtype=x.dtype, device=x.device)
291
- output = torch.zeros_like(output)
292
-
293
- x_b, x_r, x_c = x.size()
294
- x_b_s, x_r_s, x_c_s = stride(x)
295
- o_b, o_r, o_c = output.size()
296
- o_b_s, o_r_s, o_c_s = stride(output)
297
- s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
298
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
299
- s_lut_r, s_lut_c = sparsity_lut.size()
300
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
301
-
302
- if triton_block_size is None:
303
- triton_block_size = get_triton_block_size(sparsity_block_size)
304
-
305
- triton_grid = lambda meta: [o_b,
306
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
307
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
308
-
309
- (kernel_blocksparse_flow_pull[triton_grid]
310
- (x,
311
- x_b, x_b_s, x_r_s, x_c_s,
312
- output,
313
- o_b, o_b_s, o_r_s, o_c_s,
314
- s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
315
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
316
- sparsity_reverse_lut,
317
- triton_block_size))
318
-
319
- # Save for backward pass
320
- ctx.sparsity_block_size = sparsity_block_size
321
- ctx.triton_block_size = triton_block_size
322
-
323
- return output
@@ -238,6 +238,10 @@ class _BlocksparseSoftmax(torch.autograd.Function):
238
238
  rev_idx_spa_s_msk = (rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
239
239
  rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
240
240
 
241
+ if rev_idx_spa_s == -1:
242
+ tl.device_assert(False)
243
+ return
244
+
241
245
  blk_s_idx = (rev_idx_spa_s * s_b_s +
242
246
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
243
247
  (tl.arange(0, 1) * s_c_s)[None, :])
@@ -0,0 +1,17 @@
1
+ import math
2
+
3
+ import torch
4
+ import triton
5
+ from torch import Tensor
6
+ from torch.xpu import device
7
+ from triton import language as tl
8
+
9
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
10
+ from blksprs.utils.tools import get_triton_block_size, stride
11
+ from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
12
+ validate_contiguous, validate_sparsity, validate_sparsity_block_size
13
+
14
+
15
+ def build_full_sparsity_layout(x: Tensor, sparsity_block_size: int) -> Tensor:
16
+ return torch.ones(size=(x.size(0), x.size(1) // sparsity_block_size, x.size(2) // sparsity_block_size),
17
+ dtype=torch.bool, device=x.device)
@@ -1,7 +1,9 @@
1
+ from collections.abc import Callable
2
+
1
3
  import torch
2
4
  from torch import Tensor, nn
3
- from triton.language import dtype
4
5
 
6
+ import blksprs as bs
5
7
  from blksprs.layouting.sparsity_layout import build_sparsity_layout_matmul_fast
6
8
  from blksprs.ops.conversion import to_sparse
7
9
  from blksprs.ops.matmul import matmul
@@ -10,7 +12,7 @@ from blksprs.utils.blksprs_tensor import BlksprsTensor
10
12
 
11
13
 
12
14
  def apply_torch_linear(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
13
- linear: nn.Linear) -> (BlksprsTensor, Tensor):
15
+ linear: nn.Linear, bias: nn.Parameter = None) -> (BlksprsTensor, Tensor):
14
16
  # Extract weight and bias
15
17
  w = linear.weight
16
18
  b = linear.bias
@@ -27,6 +29,8 @@ def apply_torch_linear(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block
27
29
  interim = xw
28
30
 
29
31
  # Apply bias
32
+ if bias is not None:
33
+ b = bias
30
34
  if b is not None:
31
35
  b_slice = b.unsqueeze(0).unsqueeze(0).repeat(1, sparsity_block_size, 1)
32
36
  sparsity_layout_b_slice = torch.ones(size=(1, b_slice.size(1) // sparsity_block_size,
@@ -39,3 +43,32 @@ def apply_torch_linear(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block
39
43
  interim = interim + b_bs
40
44
 
41
45
  return interim, sparsity_layout_xw
46
+
47
+
48
+ def apply_torch_normalisation(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
49
+ normalisation: nn.Module) -> BlksprsTensor:
50
+ return apply_function_applicable_row_wise(x, sparsity_layout, sparsity_block_size, normalisation)
51
+
52
+
53
+ def apply_torch_dropout(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
54
+ dropout: nn.Dropout) -> BlksprsTensor:
55
+ return apply_function_applicable_row_wise(x, sparsity_layout, sparsity_block_size, dropout)
56
+
57
+
58
+ def apply_function_applicable_row_wise(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
59
+ function: Callable) -> BlksprsTensor:
60
+ sparsity_layout_packed = _pack_layout(sparsity_layout)
61
+ blksprs_pseudo_dense = bs.ops.to_dense(x, sparsity_layout_packed, sparsity_block_size)
62
+ normalisation_out = function(blksprs_pseudo_dense)
63
+ blksprs_sparse = bs.ops.to_sparse(normalisation_out, sparsity_layout_packed, sparsity_block_size)
64
+
65
+ return blksprs_sparse
66
+
67
+
68
+ def _pack_layout(sparsity_layout: Tensor) -> BlksprsTensor:
69
+ sparsity_layout_resized = sparsity_layout.resize(1, sparsity_layout.size(0) * sparsity_layout.size(1),
70
+ sparsity_layout.size(2))
71
+ non_zero_rows = torch.any(sparsity_layout_resized, dim=-1)
72
+ sparsity_layout_filtered = sparsity_layout_resized[non_zero_rows].unsqueeze(0)
73
+
74
+ return sparsity_layout_filtered
@@ -36,7 +36,8 @@ def validate_dtype_int(*tensors: Tensor) -> None:
36
36
  return
37
37
 
38
38
  for tensor in tensors:
39
- if tensor.dtype != torch.int32 and tensor.dtype != torch.int64:
39
+ if (tensor.dtype !=
40
+ torch.int32 and tensor.dtype != torch.int64):
40
41
  raise ValueError("Tensor must have int32 or int64 dtype")
41
42
 
42
43
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: blksprs
3
- Version: 1.8.3
3
+ Version: 1.9.1
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -64,8 +64,12 @@ Further helpful operations (included in the ``bs.ops.misc`` module) that do **no
64
64
  - Row-wise sum, max, addition, and subtraction
65
65
  - Broadcast addition and subtraction between slices
66
66
 
67
- Furthermore, the library provides a set of utility functions for the creation of sparsity layouts based on existing
68
- dense tensors and for the scatter operation (module ``bs.layouting``), as well as utility functions to apply linear layers,
67
+ Furthermore, the library provides a set of utility functions
68
+
69
+ - for the creation of sparsity layouts based on existing
70
+ dense tensors and for the scatter operation (module ``bs.layouting``),
71
+ - for the application of ``nn.Linear``, ``nn.Dropout``, and ``nn.LayerNorm`` layers to block-sparse tensors,
72
+ - as well as utility functions to apply linear layers,
69
73
  ensure correct input dimensionality, and validate input (module ``bs.utils``).
70
74
 
71
75
  ## Installation
@@ -10,6 +10,7 @@ blksprs/layouting/distribution_layout.py
10
10
  blksprs/layouting/sparsity_layout.py
11
11
  blksprs/ops/conversion.py
12
12
  blksprs/ops/distribution.py
13
+ blksprs/ops/flow.py
13
14
  blksprs/ops/matmul.py
14
15
  blksprs/ops/partitioning.py
15
16
  blksprs/ops/repeat.py
@@ -21,6 +22,7 @@ blksprs/ops/misc/exp.py
21
22
  blksprs/ops/misc/row_wise.py
22
23
  blksprs/utils/benchmarking.py
23
24
  blksprs/utils/blksprs_tensor.py
25
+ blksprs/utils/layout_utils.py
24
26
  blksprs/utils/processing.py
25
27
  blksprs/utils/tools.py
26
28
  blksprs/utils/validation.py
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "blksprs"
3
- version = "1.8.3"
3
+ version = "1.9.1"
4
4
  authors = [{ name = "Felix Schön", email = "schoen@kr.tuwien.ac.at" }]
5
5
  description = "A lightweight library for operations on blocksparse matrices in PyTorch."
6
6
  readme = "README.md"
File without changes
File without changes
File without changes
File without changes