blksprs 1.7__py3-none-any.whl → 1.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/__init__.py CHANGED
@@ -1,22 +1,27 @@
1
- from blksprs.ops.conversion import to_dense, to_sparse
1
+ from blksprs.ops.conversion import to_dense, to_sparse, from_blksprs, to_blksprs
2
2
  from blksprs.ops.distribution import gather, scatter, scatter_reduce
3
- from blksprs.ops.exp import exp
4
3
  from blksprs.ops.matmul import matmul
5
4
  from blksprs.ops.softmax import softmax
6
5
  from blksprs.ops.transpose import transpose
7
- from blksprs.ops.partitioning import split, merge
6
+ from blksprs.ops.repeat import repeat, repeat_interleave
7
+ from blksprs.misc.partitioning import split, merge
8
+
8
9
 
9
10
  class layout:
10
11
  from blksprs.layouting.distribution_layout import build_distribution_layout
11
- from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption, build_sparsity_layout_matmul
12
+ from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption, \
13
+ build_sparsity_layout_matmul, build_sparsity_layout_matmul_fast
14
+
12
15
 
13
16
  class misc:
14
17
  from blksprs.misc.broadcast_ops import broadcast_add, broadcast_sub
15
- from blksprs.misc.repeat_interleave import repeat_interleave
18
+ from blksprs.misc.exp import exp
16
19
  from blksprs.misc.row_wise import row_wise_sum, row_wise_max, row_wise_add, row_wise_sub
17
20
 
21
+
18
22
  class util:
19
23
  from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse, disable_validation
20
24
 
25
+
21
26
  class experimental:
22
- from blksprs.experimental.distribution_mdi import gather_mdi, scatter_reduce_mdi
27
+ from blksprs.experimental.distribution_mdi import gather_mdi, scatter_reduce_mdi
blksprs/ops/conversion.py CHANGED
@@ -8,7 +8,12 @@ from triton import language as tl
8
8
  from blksprs.layouting.sparsity_layout import build_sparsity_layout_adaption
9
9
  from blksprs.utils.tools import get_triton_block_size, stride
10
10
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
11
- validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
11
+ validate_sparsity, validate_sparsity_block_size, validate_triton_block_size, validate_sparsity_dense
12
+
13
+
14
+ def from_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
15
+ triton_block_size: int = None) -> Tensor:
16
+ return to_dense(x, sparsity_layout, sparsity_block_size, fill_value, triton_block_size)
12
17
 
13
18
 
14
19
  def to_dense(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
@@ -144,6 +149,11 @@ class _BlocksparseToDense(torch.autograd.Function):
144
149
  tl.store(o + o_idx, blk, o_msk)
145
150
 
146
151
 
152
+ def to_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
153
+ triton_block_size: int = None) -> Tensor:
154
+ return to_sparse(x, sparsity_layout, sparsity_block_size, triton_block_size)
155
+
156
+
147
157
  def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
148
158
  """Converts a block-sparse tensor in regular form to a block-sparse tensor in compressed form based on the given
149
159
  sparsity layout.
@@ -163,6 +173,7 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, trit
163
173
  validate_dimensions(x)
164
174
  validate_contiguous(x)
165
175
  validate_device(x)
176
+ validate_sparsity_dense(sparsity_block_size, (x, sparsity_layout))
166
177
  validate_sparsity_block_size(sparsity_block_size, x)
167
178
  validate_triton_block_size(triton_block_size, sparsity_block_size)
168
179
 
blksprs/ops/repeat.py CHANGED
@@ -11,6 +11,30 @@ from blksprs.utils.validation import validate_dimensions, validate_contiguous, v
11
11
  def repeat(x: Tensor, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
12
12
  sparsity_block_size: int, sparsity_layout_output: Tensor = None, triton_block_size: int = None) -> (
13
13
  Tensor, Tensor):
14
+ """Repeats a block-spare tensor in compressed form according to the given repeats.
15
+
16
+ Repeats is a 3-tuple of integers, where each integer represents the number of times the tensor should be repeated in
17
+ the first, second and third dimension respectively.
18
+
19
+ Note:
20
+ An output sparsity layout can be provided, in which case only the indicated blocks are filled. This may result
21
+ in blocks not being present in the output that were present in the input if the output sparsity layout indicates
22
+ them to be sparse.
23
+
24
+ Args:
25
+ x (Tensor): A block-sparse tensor in compressed form.
26
+ sparsity_layout_x (Tensor): The sparsity layout of the block-sparse tensor.
27
+ repeats (tuple[int, int, int]): The number of times the tensor should be repeated in the first, second and
28
+ third dimension respectively.
29
+ sparsity_block_size (int): The size of the sparsity blocks.
30
+ sparsity_layout_output (Tensor): The desired sparsity layout of the output tensor (default ``None``).
31
+ triton_block_size (int): The block size to use for the triton kernel (default ``None``).
32
+
33
+ Returns:
34
+ Tensor: A block-sparse tensor in compressed form containing the repeated values.
35
+ Tensor: The sparsity layout of the resulting output tensor.
36
+
37
+ """
14
38
  x = x.contiguous()
15
39
 
16
40
  validate_dimensions(x)
@@ -43,6 +67,64 @@ def repeat(x: Tensor, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
43
67
  sparsity_block_size, n_sparse_blocks, triton_block_size), sparsity_layout_o
44
68
 
45
69
 
70
+ def repeat_interleave(x: Tensor, sparsity_layout_x: Tensor, repeats: int,
71
+ sparsity_block_size: int, sparsity_layout_output: Tensor = None,
72
+ triton_block_size: int = None) -> (
73
+ Tensor, Tensor):
74
+ """Repeats and interleaves the block-sparse tensor in compressed form.
75
+
76
+ Repeats each matrix contained in the tensors by ``repeats`` amount and places them consecutively in the output
77
+ tensor.
78
+
79
+ Note:
80
+ In similar fashion to the regular ``repeat`` an output sparsity layout can be provided. In this case only
81
+ non-sparse blocks will be filled.
82
+
83
+ Args:
84
+ x (Tensor): A block-sparse tensor in compressed form.
85
+ sparsity_layout_x (Tensor): The sparsity layout of the block-sparse tensor.
86
+ repeats (int): The number of times to repeat the matrices.
87
+ sparsity_block_size (int): The size of the sparsity blocks.
88
+ sparsity_layout_output (Tensor): The desired sparsity layout of the output tensor (default ``None``).
89
+ triton_block_size (int): The block size to use for the triton kernel (default ``None``).
90
+
91
+ Returns:
92
+ Tensor: A block-sparse tensor in compressed form containing the repeated and interleaved matrices.
93
+ Tensor: The sparsity layout of the resulting output tensor.
94
+
95
+ """
96
+ x = x.contiguous()
97
+
98
+ validate_dimensions(x)
99
+ validate_contiguous(x)
100
+ validate_device(x)
101
+ validate_sparsity(sparsity_block_size, (x, sparsity_layout_x))
102
+ validate_sparsity_block_size(sparsity_block_size, x)
103
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
104
+
105
+ sparsity_layout_o = torch.repeat_interleave(sparsity_layout_x, repeats, dim=0).contiguous()
106
+
107
+ if sparsity_layout_output is not None:
108
+ sparsity_layout_o = torch.logical_and(sparsity_layout_o, sparsity_layout_output)
109
+
110
+ sparsity_lut = torch.nonzero(sparsity_layout_o).contiguous()
111
+
112
+ sparsity_layout_flat = sparsity_layout_x.reshape(-1)
113
+ sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
114
+ (sparsity_layout_flat == 1) -
115
+ (1 * (sparsity_layout_flat == 0)))
116
+ .reshape(sparsity_layout_x.size())
117
+ .repeat_interleave(repeats, dim=0)
118
+ .reshape(-1).contiguous())
119
+
120
+ n_sparse_blocks = torch.sum(sparsity_layout_o.to(torch.int)).item()
121
+
122
+ validate_contiguous(sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
123
+
124
+ return _BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
125
+ sparsity_block_size, n_sparse_blocks, triton_block_size), sparsity_layout_o
126
+
127
+
46
128
  class _BlocksparseRepeat(torch.autograd.Function):
47
129
 
48
130
  @staticmethod
@@ -215,7 +297,6 @@ def forward_flow(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor
215
297
  s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
216
298
  s_lut_r, s_lut_c = sparsity_lut.size()
217
299
  s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
218
- asdf = torch.tensor(sparsity_lut).stride()
219
300
 
220
301
  if triton_block_size is None:
221
302
  triton_block_size = get_triton_block_size(sparsity_block_size)
blksprs/ops/softmax.py CHANGED
@@ -3,7 +3,7 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
- from blksprs.ops.exp import exp
6
+ from blksprs.misc.exp import exp
7
7
  from blksprs.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
8
8
  from blksprs.utils.tools import get_triton_block_size, stride
9
9
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
@@ -3,6 +3,7 @@ from torch import Tensor
3
3
 
4
4
  VALIDATION = True
5
5
 
6
+
6
7
  def validate_dimensions(*tensors: Tensor, dims=3) -> None:
7
8
  if _check_skip_validation():
8
9
  return
@@ -71,10 +72,25 @@ def validate_sparsity(sparsity_block_size: int, *tensor_sparsity_layout_tuples:
71
72
  raise ValueError("Mismatch between sparsity layout and blocks")
72
73
 
73
74
 
75
+ def validate_sparsity_dense(sparsity_block_size: int, *tensor_sparsity_layout_tuples: tuple[Tensor, Tensor]) -> None:
76
+ if _check_skip_validation():
77
+ return
78
+
79
+ for (tensor, sparsity_layout) in tensor_sparsity_layout_tuples:
80
+ _validate_sparsity_layout_values(sparsity_layout)
81
+
82
+ if not sparsity_layout.dim() == 3:
83
+ raise ValueError("Sparsity layout must have exactly 3 dimensions")
84
+ if not (tensor.size(-1) // sparsity_block_size == sparsity_layout.size(-1) and
85
+ tensor.size(-2) // sparsity_block_size == sparsity_layout.size(-2)):
86
+ raise ValueError("Tensor not conforming to sparsity layout")
87
+
88
+
74
89
  def _validate_sparsity_layout_values(sparsity_layout: Tensor):
75
90
  if not torch.all(torch.logical_or(sparsity_layout == 0, sparsity_layout == 1)):
76
91
  raise ValueError("Sparsity layout values must be either 0 or 1")
77
92
 
93
+
78
94
  def validate_sparsity_block_size(sparsity_block_size: int, *tensors):
79
95
  if _check_skip_validation():
80
96
  return
@@ -86,6 +102,7 @@ def validate_sparsity_block_size(sparsity_block_size: int, *tensors):
86
102
  if not (tensor.size(-1) % sparsity_block_size == 0 and tensor.size(-2) % sparsity_block_size == 0):
87
103
  raise ValueError("Tensor sizes must be divisible by sparsity block size")
88
104
 
105
+
89
106
  def validate_triton_block_size(triton_block_size: int, sparsity_block_size: int):
90
107
  if _check_skip_validation():
91
108
  return
@@ -99,9 +116,11 @@ def validate_triton_block_size(triton_block_size: int, sparsity_block_size: int)
99
116
  if triton_block_size > sparsity_block_size:
100
117
  raise ValueError("Triton block size cannot be larger than sparsity block size")
101
118
 
119
+
102
120
  def _check_skip_validation():
103
121
  return not VALIDATION
104
122
 
123
+
105
124
  def _set_skip_validation(skip_validation: bool):
106
125
  global VALIDATION
107
- VALIDATION = not skip_validation
126
+ VALIDATION = not skip_validation
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: blksprs
3
- Version: 1.7
3
+ Version: 1.8.1
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -28,13 +28,13 @@ A lightweight and efficient library for operations on block-sparse matrices in P
28
28
 
29
29
  Currently supported operations (includes gradient calculation):
30
30
 
31
- - Sparse matrix multiplication (_supports any combination of sparse and dense matrices due to support
32
- for `sparse = sparse @ sparse` matmul_)
31
+ - Matrix multiplication
33
32
  - Softmax
34
33
  - Transpose
35
34
  - Gather
36
35
  - Scatter (_supports either no reduction or summation, gradients are only available for summation_)
37
36
  - Repeat (_supports target sparsity layout_)
37
+ - Repeat Interleave (_supports target sparsity layout_)
38
38
  - Splitting and merging of matrices along the last dimension
39
39
  - Conversion to and from sparse form
40
40
  - Conversion to different sparsity layouts and different sparsity block sizes
@@ -51,8 +51,14 @@ These include, e.g.,
51
51
  Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to
52
52
  match.
53
53
 
54
+ Further helpful operations (included in the ``bs.misc`` module) that do **not** support gradient calculation include:
55
+
56
+ - Row-wise sum, max, addition, and subtraction
57
+ - Broadcast addition and subtraction between slices
58
+
54
59
  Furthermore, the library provides a set of utility functions for the creation of sparsity layouts based on existing
55
- dense tensors.
60
+ dense tensors and for the scatter operation (module ``bs.layout``), as well as utility functions to ensure correct input
61
+ dimensionality (module ``bs.util``).
56
62
 
57
63
  ## Installation
58
64
 
@@ -0,0 +1,21 @@
1
+ blksprs/__init__.py,sha256=np0msosWMaZNVVfuFGt8rE6HZURyIald391dKAs1dSQ,1093
2
+ blksprs/experimental/distribution_mdi.py,sha256=jE_SbB0SMGPcCoeM0699cceWAikBkBpGM_8Oo5A7Ets,20083
3
+ blksprs/layouting/distribution_layout.py,sha256=nCR3fCF6eNpi97DI6MMLF2hq_9Lwyo6_aUCIZiJfiX4,4170
4
+ blksprs/layouting/sparsity_layout.py,sha256=JNwbJ4L-418zCiCbt-vLfZ0xU7ReP0zr_tLHs_dytRA,9686
5
+ blksprs/misc/broadcast_ops.py,sha256=0RLnLMYV7GAPI2YL8RotcxjIUSBZKGxdVcsGaJFeL_I,5327
6
+ blksprs/misc/exp.py,sha256=cdF0s93Q9iucIXuEE3howsB0N6D60xgvem7C-a-yiGI,3704
7
+ blksprs/misc/partitioning.py,sha256=nBRZzfi3XYAhDLEBzYflQkvGa3MIZ-qNeIlrZ16k44g,7533
8
+ blksprs/misc/row_wise.py,sha256=0vDJA8uCocmebSIPIbFeND5_PQIE10pUj3DBOQXlTvE,16888
9
+ blksprs/ops/conversion.py,sha256=9xVdCrj38m1cMh43LQs-GrXZ5pNRjhQyKx6paaw3C6A,21898
10
+ blksprs/ops/distribution.py,sha256=V3TK5SlNT_JdGHNaDNl-U4U5vwAYsgkAOg4eTmYxbuA,16877
11
+ blksprs/ops/matmul.py,sha256=uqVe6Dz2aaCbCglM1uS2eRHVKh7PQcuecaIBWFubPEw,11256
12
+ blksprs/ops/repeat.py,sha256=OSsa2rj6BHL3Kedfu3wr0D82mn4HmbJ1l7XEmT-6ehg,14423
13
+ blksprs/ops/softmax.py,sha256=5nAgeT68nucgOugjtCy1aBIMa7Kyk1KNN-j8fgmeVuk,11996
14
+ blksprs/ops/transpose.py,sha256=67pDdCEb7r-Xifupl82fBKAYsxKcCUDy--cPPfduRvU,6761
15
+ blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
16
+ blksprs/utils/tools.py,sha256=S3836Zuc-BMigv-5mLTjRznCzuaF6oYW-Ir9zzUnr3o,655
17
+ blksprs/utils/validation.py,sha256=WzihRPibXYzss3PMkhDt5_d3Q3NHA_d1TzTz3CoGPGg,4136
18
+ blksprs-1.8.1.dist-info/METADATA,sha256=UDXUjS8PHyD4Zm-gWF4maXzY1k2SjKHMQllu-uOwLIA,8009
19
+ blksprs-1.8.1.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
20
+ blksprs-1.8.1.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
21
+ blksprs-1.8.1.dist-info/RECORD,,
@@ -1,132 +0,0 @@
1
- import torch
2
- import triton
3
- from torch import Tensor
4
- from triton import language as tl
5
-
6
- from blksprs.utils.tools import get_triton_block_size, stride
7
- from blksprs.utils.validation import validate_contiguous, validate_device, \
8
- validate_sparsity_block_size, validate_triton_block_size, validate_dimensions
9
-
10
-
11
- def repeat_interleave(x: Tensor, sparsity_layout: Tensor, repeats: int,
12
- sparsity_block_size: int, triton_block_size: int = None) -> tuple[Tensor, Tensor]:
13
- """Repeats and interleaves the block-sparse tensor in compressed form.
14
-
15
- Repeats each matrix contained in the tensors by ``repeats`` amount and places them consecutively in the output
16
- tensor.
17
-
18
- Args:
19
- x (Tensor): A block-sparse tensor in compressed form.
20
- sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
21
- repeats (int): The number of times to repeat the matrices.
22
- sparsity_block_size (int): The size of the sparsity blocks.
23
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
24
-
25
- Returns:
26
- Tensor: A block-sparse tensor in compressed form containing the repeated and interleaved matrices.
27
- Tensor: The sparsity layout of the resulting output tensor.
28
-
29
- """
30
- x = x.contiguous()
31
-
32
- validate_dimensions(x)
33
- validate_contiguous(x)
34
- validate_device(x)
35
- validate_sparsity_block_size(sparsity_block_size, x)
36
- validate_triton_block_size(triton_block_size, sparsity_block_size)
37
-
38
- sparsity_layout_output = torch.repeat_interleave(sparsity_layout, repeats, dim=0).contiguous()
39
-
40
- sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
41
-
42
- sparsity_layout_output_flat = sparsity_layout_output.reshape(-1)
43
- sparsity_output_reverse_lut = ((torch.cumsum(sparsity_layout_output_flat, dim=-1) - 1) *
44
- (sparsity_layout_output_flat == 1) -
45
- (1 * (sparsity_layout_output_flat == 0)))
46
-
47
- n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
48
-
49
- validate_contiguous(sparsity_layout, sparsity_lut, sparsity_layout_output, sparsity_output_reverse_lut)
50
-
51
- output = torch.empty(n_sparse_blocks * repeats, sparsity_block_size, sparsity_block_size,
52
- dtype=x.dtype, device=x.device)
53
-
54
- x_b, x_r, x_c = x.size()
55
- x_b_s, x_r_s, x_c_s = stride(x)
56
- s_lut_r, s_lut_c = sparsity_lut.size()
57
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
58
- o_b, o_r, o_c = output.size()
59
- o_b_s, o_r_s, o_c_s = stride(output)
60
- s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
61
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_output)
62
-
63
- if triton_block_size is None:
64
- triton_block_size = get_triton_block_size(sparsity_block_size)
65
-
66
- triton_grid = lambda meta: [x_b,
67
- triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
68
- triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
69
-
70
- (kernel_repeat_interleave[triton_grid]
71
- (x,
72
- x_b, x_b_s, x_r_s, x_c_s,
73
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
74
- output,
75
- o_b, o_b_s, o_r_s, o_c_s,
76
- s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
77
- sparsity_output_reverse_lut,
78
- repeats,
79
- triton_block_size))
80
-
81
- return output, sparsity_layout_output
82
-
83
-
84
- @triton.jit
85
- def kernel_repeat_interleave(x,
86
- x_b, x_b_s, x_r_s, x_c_s,
87
- s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
88
- o,
89
- o_b, o_b_s, o_r_s, o_c_s,
90
- s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
91
- r_lut_o,
92
- repeats,
93
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
94
- # Get triton block indices
95
- pid_blk = tl.program_id(axis=0)
96
- pid_row = tl.program_id(axis=1)
97
- pid_col = tl.program_id(axis=2)
98
-
99
- # Get sparsity index of current output block consisting of its batch, row, and column index
100
- spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
101
- spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
102
- spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
103
-
104
- spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
105
- spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
106
- spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
107
-
108
- spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
109
- spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
110
- spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
111
-
112
- # Load block
113
- blk_x_idx = ((pid_blk * x_b_s) +
114
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
115
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
116
- blk_x_msk = (blk_x_idx < x_b * x_b_s)
117
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
118
-
119
- for repeat in range(repeats):
120
- # Get reverse sparsity index
121
- rev_idx_spa_idx = ((spa_bat * repeats + repeat) * s_l_o_b_s +
122
- spa_row * s_l_o_r_s +
123
- spa_col * s_l_o_c_s)
124
- rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
125
- rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
126
-
127
- # Store block
128
- blk_o_idx = ((rev_idx_spa * o_b_s) +
129
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
130
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
131
- blk_o_msk = (blk_o_idx < o_b * o_b_s)
132
- tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
@@ -1,22 +0,0 @@
1
- blksprs/__init__.py,sha256=FpvHMo1W6XvuiA1PMDp2_EJz-Xwc15cHz7WeIYXQJC4,1019
2
- blksprs/experimental/distribution_mdi.py,sha256=jE_SbB0SMGPcCoeM0699cceWAikBkBpGM_8Oo5A7Ets,20083
3
- blksprs/layouting/distribution_layout.py,sha256=nCR3fCF6eNpi97DI6MMLF2hq_9Lwyo6_aUCIZiJfiX4,4170
4
- blksprs/layouting/sparsity_layout.py,sha256=JNwbJ4L-418zCiCbt-vLfZ0xU7ReP0zr_tLHs_dytRA,9686
5
- blksprs/misc/broadcast_ops.py,sha256=0RLnLMYV7GAPI2YL8RotcxjIUSBZKGxdVcsGaJFeL_I,5327
6
- blksprs/misc/repeat_interleave.py,sha256=P5gfsZXuemLiAijUZfFkBFgMjlU9rlPEzai1xeGOFnw,5678
7
- blksprs/misc/row_wise.py,sha256=0vDJA8uCocmebSIPIbFeND5_PQIE10pUj3DBOQXlTvE,16888
8
- blksprs/ops/conversion.py,sha256=iyKIlkWGrK6q55KNRM8N6rY1k4b9k8QUkUl158yZUDA,21330
9
- blksprs/ops/distribution.py,sha256=V3TK5SlNT_JdGHNaDNl-U4U5vwAYsgkAOg4eTmYxbuA,16877
10
- blksprs/ops/exp.py,sha256=cdF0s93Q9iucIXuEE3howsB0N6D60xgvem7C-a-yiGI,3704
11
- blksprs/ops/matmul.py,sha256=uqVe6Dz2aaCbCglM1uS2eRHVKh7PQcuecaIBWFubPEw,11256
12
- blksprs/ops/partitioning.py,sha256=nBRZzfi3XYAhDLEBzYflQkvGa3MIZ-qNeIlrZ16k44g,7533
13
- blksprs/ops/repeat.py,sha256=6Wa6GG9Cx6rJXuFpvmOe5hHwYd3l9UYMosKEDsbh9XI,10408
14
- blksprs/ops/softmax.py,sha256=2dMLbkHNH18jSJmkgOJvZOKwWHhuUogAVCWv2Bwc3oQ,11995
15
- blksprs/ops/transpose.py,sha256=67pDdCEb7r-Xifupl82fBKAYsxKcCUDy--cPPfduRvU,6761
16
- blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
17
- blksprs/utils/tools.py,sha256=S3836Zuc-BMigv-5mLTjRznCzuaF6oYW-Ir9zzUnr3o,655
18
- blksprs/utils/validation.py,sha256=h2oki3xC5qLWZR4-W5QIna-wVSXvRehQEH-ynrOciVE,3467
19
- blksprs-1.7.dist-info/METADATA,sha256=raZ3ycSMUEAW71bwm-807d_dse44qKdSkWMhH4GI2Qg,7709
20
- blksprs-1.7.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
21
- blksprs-1.7.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
22
- blksprs-1.7.dist-info/RECORD,,
File without changes
File without changes
File without changes