blksprs 0.2b4__py3-none-any.whl → 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,114 @@
1
+ import torch
2
+ import triton
3
+ from torch import Tensor
4
+ from triton import language as tl
5
+
6
+ from blksprs.utils.tools import get_triton_block_size
7
+ from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
8
+ validate_contiguous
9
+
10
+
11
+ def build_distribution_layout(indices: Tensor, sparsity_layout_indices: Tensor,
12
+ size_target: torch.Size,
13
+ sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
14
+ """Builds the sparsity layout of either the source of a gather or the target of a scatter operation.
15
+
16
+ Args:
17
+ indices (Tensor): The block-sparse indices tensor in compressed form used for the gather or scatter operation.
18
+ sparsity_layout_indices (Tensor): The sparsity layout of the indices block-sparse tensor.
19
+ size_target (torch.Size): The size of the block-sparse target tensor in regular form.
20
+ sparsity_block_size (int): The size of the sparsity blocks.
21
+ triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
22
+
23
+ Returns:
24
+ Tensor: The sparsity layout of the source or target tensor.
25
+
26
+ """
27
+ validate_dimensions(indices)
28
+ validate_contiguous(indices)
29
+ validate_device(indices)
30
+
31
+ sparsity_lut_i = torch.nonzero(sparsity_layout_indices).contiguous()
32
+
33
+ output = torch.zeros(size_target[0], size_target[1] // sparsity_block_size, size_target[2] // sparsity_block_size,
34
+ device=indices.device, dtype=torch.int32)
35
+
36
+ i_b, i_r, i_c = indices.size()
37
+ i_b_s, i_r_s, i_c_s = indices.stride()
38
+ s_l_i_b, s_l_i_r, s_l_i_c = sparsity_layout_indices.size()
39
+ s_l_i_b_s, s_l_i_r_s, s_l_i_c_s = sparsity_layout_indices.stride()
40
+ s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
41
+ s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
42
+ o_b, o_r, o_c = output.size()
43
+ o_b_s, o_r_s, o_c_s = output.stride()
44
+
45
+ if triton_block_size is None:
46
+ triton_block_size = get_triton_block_size(sparsity_block_size)
47
+
48
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
49
+
50
+ triton_grid = lambda meta: [i_b,
51
+ triton.cdiv(i_r, meta["TRITON_BLOCK_SIZE"]),
52
+ triton.cdiv(i_c, meta["TRITON_BLOCK_SIZE"])]
53
+
54
+ (kernel_distribution_layout[triton_grid]
55
+ (indices,
56
+ i_b, i_b_s, i_r_s, i_c_s,
57
+ sparsity_layout_indices,
58
+ s_l_i_b, s_l_i_b_s, s_l_i_r, s_l_i_r_s, s_l_i_c, s_l_i_c_s,
59
+ sparsity_lut_i,
60
+ s_lut_i_r, s_lut_i_r_s, s_lut_i_c, s_lut_i_c_s,
61
+ output,
62
+ o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,
63
+ sparsity_block_size,
64
+ triton_block_size))
65
+
66
+ return output
67
+
68
+
69
+ @triton.jit
70
+ def kernel_distribution_layout(i,
71
+ i_b, i_b_s, i_r_s, i_c_s,
72
+ s_l_i,
73
+ s_l_i_b, s_l_i_b_s, s_l_i_r, s_l_i_r_s, s_l_i_c, s_l_i_c_s,
74
+ s_lut_i,
75
+ s_lut_i_r, s_lut_i_r_s, s_lut_i_c, s_lut_i_c_s,
76
+ o,
77
+ o_b, o_b_s, o_r, o_r_s, o_c, o_c_s,
78
+ sparsity_block_size,
79
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
80
+ # Get triton block indices
81
+ pid_blk = tl.program_id(axis=0)
82
+ pid_row = tl.program_id(axis=1)
83
+ pid_col = tl.program_id(axis=2)
84
+
85
+ # Get position of current sparsity block consisting of its batch, row, and column index
86
+ spa_bat_i_idx = (pid_blk * s_lut_i_r_s + 0 * s_lut_i_c_s)
87
+ spa_bat_i_msk = (spa_bat_i_idx < s_lut_i_r * s_lut_i_r_s)
88
+ spa_bat_i = tl.load(s_lut_i + spa_bat_i_idx, mask=spa_bat_i_msk)
89
+
90
+ spa_row_i_idx = (pid_blk * s_lut_i_r_s + 1 * s_lut_i_c_s)
91
+ spa_row_i_msk = (spa_row_i_idx < s_lut_i_r * s_lut_i_r_s)
92
+ spa_row_i = tl.load(s_lut_i + spa_row_i_idx, mask=spa_row_i_msk)
93
+
94
+ blk_i_idx = (pid_blk * i_b_s +
95
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
96
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
97
+ blk_i_msk = (blk_i_idx < i_b * i_b_s)
98
+ blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk)
99
+
100
+ blk_i = blk_i // sparsity_block_size
101
+ blk_v = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), 1, dtype=tl.int32)
102
+
103
+ blk_o_idx = ((spa_bat_i * o_b_s) +
104
+ (spa_row_i * o_r_s) +
105
+ (blk_i * o_c_s))
106
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
107
+ tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
108
+
109
+ # if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
110
+ # blk_o_idx = (pid_bat * o_b_s +
111
+ # (((pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_r_s +
112
+ # ((pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_c_s))
113
+ # blk_o_msk = (blk_o_idx < o_b * o_b_s)
114
+ # tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
@@ -0,0 +1,78 @@
1
+ import torch
2
+ import triton
3
+ from torch import Tensor
4
+ from triton import language as tl
5
+
6
+ from blksprs.utils.tools import get_triton_block_size
7
+ from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
8
+ validate_contiguous
9
+
10
+
11
+ def build_sparsity_layout(x: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
12
+ """Builds the sparsity layout of a dense tensor covering its sparse blocks.
13
+
14
+ Args:
15
+ x (Tensor): A block-sparse (or dense) tensor in regular form.
16
+ sparsity_block_size (int): The size of the sparsity blocks.
17
+ triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
18
+
19
+ Returns:
20
+ Tensor: The sparsity layout of the input block-sparse (or dense) tensor.
21
+
22
+ """
23
+ validate_dimensions(x)
24
+ validate_contiguous(x)
25
+ validate_device(x)
26
+
27
+ output = torch.zeros(x.size(0), x.size(1) // sparsity_block_size, x.size(2) // sparsity_block_size,
28
+ device=x.device, dtype=torch.int32)
29
+
30
+ x_b, x_r, x_c = x.size()
31
+ x_b_s, x_r_s, x_c_s = x.stride()
32
+ o_b, o_r, o_c = output.size()
33
+ o_b_s, o_r_s, o_c_s = output.stride()
34
+
35
+ if triton_block_size is None:
36
+ triton_block_size = get_triton_block_size(sparsity_block_size)
37
+
38
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
39
+
40
+ triton_grid = lambda meta: [x_b,
41
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
42
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
43
+
44
+ (kernel_sparsity_layout[triton_grid]
45
+ (x,
46
+ x_b, x_b_s, x_r_s, x_c_s,
47
+ output,
48
+ o_b, o_b_s, o_r_s, o_c_s,
49
+ sparsity_block_size,
50
+ triton_block_size))
51
+
52
+ return output
53
+
54
+
55
+ @triton.jit
56
+ def kernel_sparsity_layout(x,
57
+ x_b, x_b_s, x_r_s, x_c_s,
58
+ o,
59
+ o_b, o_b_s, o_r_s, o_c_s,
60
+ sparsity_block_size,
61
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
62
+ # Get triton block indices
63
+ pid_bat = tl.program_id(axis=0)
64
+ pid_row = tl.program_id(axis=1)
65
+ pid_col = tl.program_id(axis=2)
66
+
67
+ blk_x_idx = (pid_bat * x_b_s +
68
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
69
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
70
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
71
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
72
+
73
+ if tl.min(blk_x) != 0 or tl.max(blk_x) != 0:
74
+ blk_o_idx = (pid_bat * o_b_s +
75
+ (((pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_r_s +
76
+ ((pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size) * o_c_s))
77
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
78
+ tl.store(o + blk_o_idx, 1, mask=blk_o_msk)
@@ -0,0 +1,132 @@
1
+ import torch
2
+ import triton
3
+ from torch import Tensor
4
+ from triton import language as tl
5
+
6
+ from blksprs.utils.tools import get_triton_block_size
7
+ from blksprs.utils.validation import validate_contiguous, validate_device, \
8
+ validate_sparsity_block_size, validate_triton_block_size
9
+
10
+
11
+ def broadcast_addition(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
12
+ sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
13
+ """Performs a broadcast and subsequent addition of two dense tensors x and y. Returns a block-sparse tensor in
14
+ compressed form.
15
+
16
+ Args:
17
+ x (Tensor): A dense input tensor.
18
+ y (Tensor): A dense input tensor.
19
+ sparsity_layout_output (Tensor): The sparsity layout of the output tensor.
20
+ sparsity_block_size (int): The size of the sparsity blocks.
21
+ triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
22
+
23
+ Returns:
24
+ Tensor: The result of the operation as a block-sparse tensor in compressed form. Each element o(i, j) of the
25
+ output tensor corresponds to x(i) + y(j).
26
+
27
+ """
28
+ validate_device(x, y)
29
+ validate_contiguous(x, y)
30
+ if x.size(-1) != y.size(-1):
31
+ raise ValueError("Dimensions of tensors must match")
32
+ validate_sparsity_block_size(sparsity_block_size)
33
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
34
+
35
+ sparsity_lut_o = torch.nonzero(sparsity_layout_output).contiguous()
36
+
37
+ n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
38
+
39
+ validate_contiguous(sparsity_layout_output, sparsity_lut_o)
40
+
41
+ output = torch.zeros(n_sparse_blocks, sparsity_block_size, sparsity_block_size, device=x.device)
42
+
43
+ x_b, x_c = x.size()
44
+ x_b_s, x_c_s = x.stride()
45
+ y_b, y_c = y.size()
46
+ y_b_s, y_c_s = y.stride()
47
+ o_b, o_r, o_c = output.size()
48
+ o_b_s, o_r_s, o_c_s = output.stride()
49
+ s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
50
+ s_lut_o_r_s, s_lut_o_c_s = sparsity_lut_o.stride()
51
+
52
+ if triton_block_size is None:
53
+ triton_block_size = get_triton_block_size(sparsity_block_size)
54
+
55
+ triton_grid = lambda meta: [o_b,
56
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
57
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
58
+
59
+ (kernel_broadcast_addition[triton_grid]
60
+ (x,
61
+ x_b, x_b_s, x_c_s,
62
+ y,
63
+ y_b, y_b_s, y_c_s,
64
+ output,
65
+ o_b, o_b_s, o_r_s, o_c_s,
66
+ sparsity_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
67
+ sparsity_block_size,
68
+ triton_block_size))
69
+
70
+ return output
71
+
72
+
73
+ def broadcast_subtraction(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
74
+ sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
75
+ """Wrapper for ``broadcast_addition`` with negated y.
76
+
77
+ """
78
+ return broadcast_addition(x, torch.neg(y), sparsity_layout_output, sparsity_block_size, triton_block_size)
79
+
80
+
81
+ @triton.jit
82
+ def kernel_broadcast_addition(x,
83
+ x_b, x_b_s, x_c_s,
84
+ y,
85
+ y_b, y_b_s, y_c_s,
86
+ o,
87
+ o_b, o_b_s, o_r_s, o_c_s,
88
+ s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
89
+ sparsity_block_size,
90
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
91
+ # Get triton block indices
92
+ pid_blk = tl.program_id(axis=0)
93
+ pid_row = tl.program_id(axis=1)
94
+ pid_col = tl.program_id(axis=2)
95
+
96
+ # Get position of current sparsity block consisting of its batch, row, and column index
97
+ spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
98
+ spa_bat_o_msk = (spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
99
+ spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
100
+
101
+ spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
102
+ spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
103
+ spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
104
+
105
+ spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
106
+ spa_col_o_msk = (spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
107
+ spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
108
+
109
+ # Load x block
110
+ blk_x_idx = (spa_bat_o * x_b_s +
111
+ ((spa_row_o * sparsity_block_size + pid_row * TRITON_BLOCK_SIZE +
112
+ tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
113
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
114
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
115
+
116
+ # Load y block
117
+ blk_y_idx = (spa_bat_o * y_b_s +
118
+ ((spa_col_o * sparsity_block_size + pid_col * TRITON_BLOCK_SIZE +
119
+ tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
120
+ blk_y_msk = (blk_y_idx < y_b * y_b_s)
121
+ blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
122
+
123
+ # Compute sum
124
+ blk_x, blk_y = tl.broadcast(tl.trans(blk_x), blk_y)
125
+ buf = blk_x + blk_y
126
+
127
+ # Store result
128
+ blk_o_idx = ((pid_blk * o_b_s) +
129
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
130
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
131
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
132
+ tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
@@ -0,0 +1,256 @@
1
+ import torch
2
+ import triton
3
+ from torch import Tensor
4
+ from triton import language as tl
5
+
6
+ from blksprs.utils.tools import get_triton_block_size
7
+ from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
8
+ validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
9
+
10
+
11
+ def to_dense(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
12
+ triton_block_size: int = None) -> Tensor:
13
+ """Converts a block-sparse tensor in compressed form to a block-sparse tensor in regular form based on the given
14
+ sparsity layout.
15
+
16
+ Args:
17
+ x (Tensor): A block-sparse tensor in compressed form.
18
+ sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
19
+ sparsity_block_size (int): The size of the sparsity blocks.
20
+ fill_value (float): The value to fill the resulting dense tensor with where the block-sparse tensor is not
21
+ present (default ``0``).
22
+ triton_block_size (int): The block size to use for the triton kernel (default ``None``).
23
+
24
+ Returns:
25
+ Tensor: The block-sparse tensor converted to regular form.
26
+
27
+ """
28
+ validate_dimensions(x)
29
+ validate_contiguous(x, sparsity_layout)
30
+ validate_device(x)
31
+ validate_sparsity(sparsity_block_size, (x, sparsity_layout))
32
+ validate_sparsity_block_size(sparsity_block_size, x)
33
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
34
+
35
+ sparsity_layout_flat = sparsity_layout.reshape(-1)
36
+ sparsity_reverse_lut = ((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
37
+ (sparsity_layout_flat == 1) -
38
+ (1 * (sparsity_layout_flat == 0)))
39
+
40
+ validate_contiguous(sparsity_reverse_lut)
41
+
42
+ return _BlocksparseToDense.apply(x,
43
+ sparsity_layout, sparsity_reverse_lut,
44
+ sparsity_block_size, fill_value,
45
+ triton_block_size)
46
+
47
+
48
+ class _BlocksparseToDense(torch.autograd.Function):
49
+
50
+ @staticmethod
51
+ def forward(ctx, x: Tensor,
52
+ sparsity_layout: Tensor, sparsity_reverse_lut: Tensor,
53
+ sparsity_block_size: int, fill_value: float,
54
+ triton_block_size: int) -> Tensor:
55
+ output = torch.full(size=(sparsity_layout.size(0), sparsity_layout.size(1) * sparsity_block_size,
56
+ sparsity_layout.size(2) * sparsity_block_size), fill_value=fill_value,
57
+ dtype=x.dtype, device=x.device)
58
+
59
+ x_b, x_r, x_c = x.shape
60
+ x_b_s, x_r_s, x_c_s = x.stride()
61
+ s_l_b, s_l_r, s_l_c = sparsity_layout.size()
62
+ s_l_b_s, s_l_r_s, s_l_c_s = sparsity_layout.stride()
63
+ o_b, o_r, o_c = output.size()
64
+ o_b_s, o_r_s, o_c_s = output.stride()
65
+
66
+ if triton_block_size is None:
67
+ triton_block_size = get_triton_block_size(sparsity_block_size)
68
+
69
+ triton_grid = lambda meta: [o_b,
70
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
71
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
72
+
73
+ (_BlocksparseToDense.kernel_blocksparse_to_dense[triton_grid]
74
+ (x,
75
+ x_b, x_b_s, x_r_s, x_c_s,
76
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
77
+ sparsity_reverse_lut,
78
+ output,
79
+ o_b, o_b_s, o_r_s, o_c_s,
80
+ sparsity_block_size,
81
+ triton_block_size))
82
+
83
+ ctx.save_for_backward(sparsity_layout)
84
+ ctx.sparsity_block_size = sparsity_block_size
85
+ ctx.triton_block_size = triton_block_size
86
+
87
+ return output
88
+
89
+ @staticmethod
90
+ def backward(ctx, grad_output):
91
+ sparsity_layout = ctx.saved_tensors[0]
92
+ sparsity_block_size = ctx.sparsity_block_size
93
+ triton_block_size = ctx.triton_block_size
94
+
95
+ return to_sparse(grad_output, sparsity_layout, sparsity_block_size,
96
+ triton_block_size), None, None, None, None, None
97
+
98
+ @staticmethod
99
+ @triton.jit
100
+ def kernel_blocksparse_to_dense(x,
101
+ x_b, x_b_s, x_r_s, x_c_s,
102
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
103
+ sparsity_reverse_lut,
104
+ o,
105
+ o_b, o_b_s, o_r_s, o_c_s,
106
+ sparsity_block_size,
107
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
108
+ # Get triton block indices
109
+ pid_blk = tl.program_id(axis=0)
110
+ pid_row = tl.program_id(axis=1)
111
+ pid_col = tl.program_id(axis=2)
112
+
113
+ # Get sparsity index of current block
114
+ spa_row = (pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size
115
+ spa_col = (pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size
116
+
117
+ # Get reverse sparsity index for current block
118
+ rev_idx_spa_idx = (pid_blk * s_l_b_s + spa_row * s_l_r_s + spa_col * s_l_c_s)
119
+ rev_idx_spa_msk = (rev_idx_spa_idx < s_l_b * s_l_b_s)
120
+ rev_idx_spa = tl.load(sparsity_reverse_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
121
+
122
+ # If block is present commence operations
123
+ if rev_idx_spa >= 0:
124
+ blk_idx = (rev_idx_spa * x_b_s +
125
+ (((pid_row % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
126
+ tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
127
+ (((pid_col % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
128
+ tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
129
+ blk_msk = (blk_idx < x_b * x_b_s)
130
+ blk = tl.load(x + blk_idx, mask=blk_msk)
131
+
132
+ o_idx = (pid_blk * o_b_s +
133
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
134
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
135
+ o_msk = (o_idx < o_b * o_b_s)
136
+ tl.store(o + o_idx, blk, o_msk)
137
+
138
+
139
+ def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
140
+ """Converts a block-sparse tensor in regular form to a block-sparse tensor in compressed form based on the given
141
+ sparsity layout.
142
+
143
+ Args:
144
+ x (Tensor): A block-sparse tensor in regular form.
145
+ sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
146
+ sparsity_block_size (int): The size of the sparsity blocks.
147
+ triton_block_size (int): The block size to use for the triton kernel (default ``None``).
148
+
149
+ Returns:
150
+ Tensor: The block-sparse tensor converted to compressed form.
151
+
152
+ """
153
+ validate_dimensions(x)
154
+ validate_contiguous(x)
155
+ validate_device(x)
156
+ validate_sparsity_block_size(sparsity_block_size, x)
157
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
158
+
159
+ sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
160
+ n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
161
+
162
+ validate_contiguous(sparsity_layout, sparsity_lut)
163
+
164
+ return _BlocksparseToSparse.apply(x,
165
+ sparsity_layout, sparsity_lut,
166
+ sparsity_block_size, n_sparse_blocks,
167
+ triton_block_size)
168
+
169
+
170
+ class _BlocksparseToSparse(torch.autograd.Function):
171
+
172
+ @staticmethod
173
+ def forward(ctx, x: Tensor,
174
+ sparsity_layout: Tensor, sparsity_lut: Tensor,
175
+ sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
176
+ output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size), dtype=x.dtype,
177
+ device=x.device)
178
+
179
+ x_b, x_r, x_c = x.size()
180
+ x_b_s, x_r_s, x_c_s = x.stride()
181
+ o_b, o_r, o_c = output.size()
182
+ o_b_s, o_r_s, o_c_s = output.stride()
183
+ s_lut_r, s_lut_c = sparsity_lut.size()
184
+ s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
185
+
186
+ if triton_block_size is None:
187
+ triton_block_size = get_triton_block_size(sparsity_block_size)
188
+
189
+ triton_grid = lambda meta: [o_b,
190
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
191
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
192
+
193
+ (_BlocksparseToSparse.kernel_blocksparse_to_sparse[triton_grid]
194
+ (x, x_b, x_b_s, x_r_s, x_c_s,
195
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
196
+ output, o_b_s, o_r_s, o_c_s,
197
+ sparsity_block_size,
198
+ triton_block_size))
199
+
200
+ ctx.save_for_backward(sparsity_layout)
201
+ ctx.sparsity_block_size = sparsity_block_size
202
+ ctx.triton_block_size = triton_block_size
203
+
204
+ return output
205
+
206
+ @staticmethod
207
+ def backward(ctx, grad_output):
208
+ sparsity_layout = ctx.saved_tensors[0]
209
+ sparsity_block_size = ctx.sparsity_block_size
210
+ triton_block_size = ctx.triton_block_size
211
+
212
+ return to_dense(grad_output, sparsity_layout, sparsity_block_size,
213
+ triton_block_size=triton_block_size), None, None, None, None, None
214
+
215
+ @staticmethod
216
+ @triton.jit
217
+ def kernel_blocksparse_to_sparse(x,
218
+ x_b, x_b_s, x_r_s, x_c_s,
219
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
220
+ o,
221
+ o_b_s, o_r_s, o_c_s,
222
+ sparsity_block_size,
223
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
224
+ # Get triton block indices
225
+ pid_blk = tl.program_id(axis=0)
226
+ pid_row = tl.program_id(axis=1)
227
+ pid_col = tl.program_id(axis=2)
228
+
229
+ # Get sparsity index of current output block consisting of its batch, row, and column index
230
+ spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
231
+ spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
232
+ spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
233
+
234
+ spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
235
+ spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
236
+ spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
237
+
238
+ spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
239
+ spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
240
+ spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
241
+
242
+ # Load block from dense tensor
243
+ blk_d_idx = (spa_bat * x_b_s +
244
+ ((spa_row * sparsity_block_size + pid_row * TRITON_BLOCK_SIZE +
245
+ tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
246
+ ((spa_col * sparsity_block_size + pid_col * TRITON_BLOCK_SIZE +
247
+ tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
248
+ blk_d_msk = (blk_d_idx < x_b * x_b_s)
249
+ blk_d = tl.load(x + blk_d_idx, mask=blk_d_msk)
250
+
251
+ # Store block in sparse tensor
252
+ blk_o_idx = ((pid_blk * o_b_s) +
253
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
254
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
255
+ blk_o_msk = (blk_o_idx < (pid_blk + 1) * o_b_s)
256
+ tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)