blksprs 1.6.1__py3-none-any.whl → 1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,132 +0,0 @@
1
- import torch
2
- import triton
3
- from torch import Tensor
4
- from triton import language as tl
5
-
6
- from blksprs.utils.tools import get_triton_block_size
7
- from blksprs.utils.validation import validate_contiguous, validate_device, \
8
- validate_sparsity_block_size, validate_triton_block_size, validate_dimensions
9
-
10
-
11
- def repeat_interleave(x: Tensor, sparsity_layout: Tensor, repeats: int,
12
- sparsity_block_size: int, triton_block_size: int = None) -> tuple[Tensor, Tensor]:
13
- """Repeats and interleaves the block-sparse tensor in compressed form.
14
-
15
- Repeats each matrix contained in the tensors by ``repeats`` amount and places them consecutively in the output
16
- tensor.
17
-
18
- Args:
19
- x (Tensor): A block-sparse tensor in compressed form.
20
- sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
21
- repeats (int): The number of times to repeat the matrices.
22
- sparsity_block_size (int): The size of the sparsity blocks.
23
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
24
-
25
- Returns:
26
- Tensor: A block-sparse tensor in compressed form containing the repeated and interleaved matrices.
27
- Tensor: The sparsity layout of the resulting output tensor.
28
-
29
- """
30
- x = x.contiguous()
31
-
32
- validate_dimensions(x)
33
- validate_contiguous(x)
34
- validate_device(x)
35
- validate_sparsity_block_size(sparsity_block_size, x)
36
- validate_triton_block_size(triton_block_size, sparsity_block_size)
37
-
38
- sparsity_layout_output = torch.repeat_interleave(sparsity_layout, repeats, dim=0).contiguous()
39
-
40
- sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
41
-
42
- sparsity_layout_output_flat = sparsity_layout_output.reshape(-1)
43
- sparsity_output_reverse_lut = ((torch.cumsum(sparsity_layout_output_flat, dim=-1) - 1) *
44
- (sparsity_layout_output_flat == 1) -
45
- (1 * (sparsity_layout_output_flat == 0)))
46
-
47
- n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
48
-
49
- validate_contiguous(sparsity_layout, sparsity_lut, sparsity_layout_output, sparsity_output_reverse_lut)
50
-
51
- output = torch.empty(n_sparse_blocks * repeats, sparsity_block_size, sparsity_block_size,
52
- dtype=x.dtype, device=x.device)
53
-
54
- x_b, x_r, x_c = x.size()
55
- x_b_s, x_r_s, x_c_s = x.stride()
56
- s_lut_r, s_lut_c = sparsity_lut.size()
57
- s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
58
- o_b, o_r, o_c = output.size()
59
- o_b_s, o_r_s, o_c_s = output.stride()
60
- s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
61
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_output.stride()
62
-
63
- if triton_block_size is None:
64
- triton_block_size = get_triton_block_size(sparsity_block_size)
65
-
66
- triton_grid = lambda meta: [x_b,
67
- triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
68
- triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
69
-
70
- (kernel_repeat_interleave[triton_grid]
71
- (x,
72
- x_b, x_b_s, x_r_s, x_c_s,
73
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
74
- output,
75
- o_b, o_b_s, o_r_s, o_c_s,
76
- s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
77
- sparsity_output_reverse_lut,
78
- repeats,
79
- triton_block_size))
80
-
81
- return output, sparsity_layout_output
82
-
83
-
84
- @triton.jit
85
- def kernel_repeat_interleave(x,
86
- x_b, x_b_s, x_r_s, x_c_s,
87
- s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
88
- o,
89
- o_b, o_b_s, o_r_s, o_c_s,
90
- s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
91
- r_lut_o,
92
- repeats,
93
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
94
- # Get triton block indices
95
- pid_blk = tl.program_id(axis=0)
96
- pid_row = tl.program_id(axis=1)
97
- pid_col = tl.program_id(axis=2)
98
-
99
- # Get sparsity index of current output block consisting of its batch, row, and column index
100
- spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
101
- spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
102
- spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
103
-
104
- spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
105
- spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
106
- spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
107
-
108
- spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
109
- spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
110
- spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
111
-
112
- # Load block
113
- blk_x_idx = ((pid_blk * x_b_s) +
114
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
115
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
116
- blk_x_msk = (blk_x_idx < x_b * x_b_s)
117
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
118
-
119
- for repeat in range(repeats):
120
- # Get reverse sparsity index
121
- rev_idx_spa_idx = ((spa_bat * repeats + repeat) * s_l_o_b_s +
122
- spa_row * s_l_o_r_s +
123
- spa_col * s_l_o_c_s)
124
- rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
125
- rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
126
-
127
- # Store block
128
- blk_o_idx = ((rev_idx_spa * o_b_s) +
129
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
130
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
131
- blk_o_msk = (blk_o_idx < o_b * o_b_s)
132
- tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
@@ -1,21 +0,0 @@
1
- blksprs/__init__.py,sha256=FpvHMo1W6XvuiA1PMDp2_EJz-Xwc15cHz7WeIYXQJC4,1019
2
- blksprs/experimental/distribution_mdi.py,sha256=shu-3Nt7nkaLIb4O2kSajC8Lh7IWFXO9rsjzP14ASYA,20088
3
- blksprs/layouting/distribution_layout.py,sha256=Zv-b2t5VOvW6-ejdX42kUV7X1yYsvDCY_PXFE_wKwi0,4165
4
- blksprs/layouting/sparsity_layout.py,sha256=-cwP6Qq51-hPgKkYa0Fp0tHv-4J9p0ALmusVgf9nXVk,9683
5
- blksprs/misc/broadcast_ops.py,sha256=ahm7_lI12bJ6VTKRuSkwEeaEYWRY-BeMIOhtei35zpQ,5323
6
- blksprs/misc/repeat_interleave.py,sha256=uZmjjEfG6neoebvFTqp0vNZXWhjVRvLrw-LTPsW7nzo,5674
7
- blksprs/misc/row_wise.py,sha256=1UtjLplrGx1FkxhzQ2hjSBBY11ToLQs0JiLaXKRAkL4,16893
8
- blksprs/ops/conversion.py,sha256=vuiNwrwyuGI6H4PKrS_UHI7OKWJwNZd2i3LSjf6RetU,21332
9
- blksprs/ops/distribution.py,sha256=KhtHRVcv4_woyNlldAjIWF-7021-KX-xyIcN6rE-UgE,16879
10
- blksprs/ops/exp.py,sha256=CVWVq_emO2CnS_xk6Unx67P7EI7IL26dwtsmBJZOLzQ,3698
11
- blksprs/ops/matmul.py,sha256=743XeD5M4iUv28sYf7q6mVXDd4jZpV04JAx8bF7hWkw,11254
12
- blksprs/ops/partitioning.py,sha256=CYVUTK6NHS0CeYdYNPAMNxrMggPRhavPhDLrKVhibKs,11289
13
- blksprs/ops/softmax.py,sha256=cs1utM6UCzHhdJpf-ZysBr6CwbjI-5aQG0ahYY37Zy0,11991
14
- blksprs/ops/transpose.py,sha256=RtfK_GbpgzatJs4obfUFIMBizWdPl4GxSby55KYeFFU,6753
15
- blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
16
- blksprs/utils/tools.py,sha256=JAuwsLISr_hcvxIgUVvKz5ZPf9M5ycquplsBU5dVfDc,596
17
- blksprs/utils/validation.py,sha256=h2oki3xC5qLWZR4-W5QIna-wVSXvRehQEH-ynrOciVE,3467
18
- blksprs-1.6.1.dist-info/METADATA,sha256=sArC97eyknolW2KNe8UdHpDisiYELRYknYudLFkjMKM,7666
19
- blksprs-1.6.1.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
20
- blksprs-1.6.1.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
21
- blksprs-1.6.1.dist-info/RECORD,,
File without changes