blksprs 0.2b4__py3-none-any.whl → 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,231 @@
1
+ import torch
2
+ import triton
3
+ from torch import Tensor
4
+ from triton import language as tl
5
+
6
+ from blksprs.utils.tools import get_triton_block_size
7
+ from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
8
+ validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
9
+
10
+
11
+ def row_wise_sum(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
12
+ flag_slice_only: bool = False, triton_block_size: int = None) -> tuple[Tensor, Tensor]:
13
+ """Computes the row-wise sum of a block-sparse tensor.
14
+
15
+ Returns a block-sparse tensor in compressed form with only one block per row, where the first entry contains the sum
16
+ of the corresponding row.
17
+
18
+ Note:
19
+ If ``flag_slice_only`` is set the output will be of shape ``[x.size(0), x.size(1), 1]``.
20
+
21
+ Args:
22
+ x (Tensor): A block-sparse tensor in compressed form.
23
+ sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
24
+ sparsity_block_size (int): The size of the sparsity blocks.
25
+ flag_slice_only (bool, optional): If set the output will be of shape ``[x.size(0), x.size(1), 1]``
26
+ (default ``False``).
27
+ triton_block_size (int): The block size to use for the triton kernel (default ``None``).
28
+
29
+ Returns:
30
+ tuple[Tensor, Tensor]: A tuple containing a block-sparse tensor in compressed form containing the row-wise sum
31
+ of the input and the sparsity layout of the output tensor.
32
+
33
+ """
34
+ validate_dimensions(x)
35
+ validate_contiguous(x)
36
+ validate_device(x)
37
+ validate_sparsity(sparsity_block_size, (x, sparsity_layout))
38
+ validate_sparsity_block_size(sparsity_block_size, x)
39
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
40
+
41
+ sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
42
+ sparsity_layout_flat = sparsity_layout.reshape(-1)
43
+ sparsity_reverse_lut = ((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
44
+ (sparsity_layout_flat == 1) -
45
+ (1 * (sparsity_layout_flat == 0)))
46
+
47
+ sparsity_layout_output, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)
48
+ sparsity_lut_output = torch.nonzero(sparsity_layout_output).contiguous()
49
+ sparsity_layout_output_flat = sparsity_layout_output.reshape(-1)
50
+ sparsity_reverse_lut_output = ((torch.cumsum(sparsity_layout_output_flat, dim=-1) - 1) *
51
+ (sparsity_layout_output_flat == 1) -
52
+ (1 * (sparsity_layout_output_flat == 0)))
53
+
54
+ n_sparse_blocks_output = torch.sum(sparsity_layout_output.to(torch.int)).item()
55
+
56
+ validate_contiguous(sparsity_layout, sparsity_lut, sparsity_reverse_lut,
57
+ sparsity_layout_output, sparsity_lut_output, sparsity_reverse_lut_output)
58
+
59
+ return (_BlocksparseRowWiseSum.apply(x,
60
+ sparsity_layout, sparsity_lut, sparsity_reverse_lut,
61
+ sparsity_layout_output, sparsity_lut_output, sparsity_reverse_lut_output,
62
+ n_sparse_blocks_output,
63
+ flag_slice_only,
64
+ sparsity_block_size, triton_block_size),
65
+ sparsity_layout_output)
66
+
67
+
68
+ class _BlocksparseRowWiseSum(torch.autograd.Function):
69
+ IMPLEMENTATION = "atomic_add"
70
+
71
+ @staticmethod
72
+ def forward(ctx, x: Tensor,
73
+ sparsity_layout: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
74
+ sparsity_layout_output: Tensor, sparsity_lut_output: Tensor, sparsity_reverse_lut_output: Tensor,
75
+ n_sparse_blocks_output: int,
76
+ flag_slice_only: bool,
77
+ sparsity_block_size: int, triton_block_size: int) -> Tensor:
78
+ output = torch.zeros(size=(n_sparse_blocks_output,
79
+ sparsity_block_size,
80
+ 1 if flag_slice_only else sparsity_block_size),
81
+ device=x.device)
82
+
83
+ x_b, x_r, x_c = x.size()
84
+ x_b_s, x_r_s, x_c_s = x.stride()
85
+ s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout.size()
86
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout.stride()
87
+ s_lut_x_r, s_lut_x_c = sparsity_lut.size()
88
+ s_lut_x_r_s, s_lut_x_c_s = sparsity_lut.stride()
89
+ o_b, o_r, o_c = output.size()
90
+ o_b_s, o_r_s, o_c_s = output.stride()
91
+ s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
92
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_output.stride()
93
+ s_lut_o_r, s_lut_o_c = sparsity_lut_output.size()
94
+ s_lut_o_r_s, s_lut_o_c_s = sparsity_lut_output.stride()
95
+
96
+ if triton_block_size is None:
97
+ triton_block_size = get_triton_block_size(sparsity_block_size)
98
+
99
+ if _BlocksparseRowWiseSum.IMPLEMENTATION == "basic":
100
+ triton_grid = lambda meta: [o_b,
101
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"])]
102
+
103
+ (_BlocksparseRowWiseSum.kernel_blocksparse_row_wise_sum[triton_grid]
104
+ (x,
105
+ x_b, x_b_s, x_r_s, x_c_s,
106
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c, s_l_x_c_s,
107
+ sparsity_reverse_lut,
108
+ output,
109
+ o_b, o_b_s, o_r_s,
110
+ sparsity_lut_output, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
111
+ sparsity_block_size,
112
+ triton_block_size))
113
+ elif _BlocksparseRowWiseSum.IMPLEMENTATION == "atomic_add":
114
+ triton_grid = lambda meta: [x_b,
115
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
116
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
117
+
118
+ (_BlocksparseRowWiseSum.kernel_blocksparse_row_wise_sum_atomic_add[triton_grid]
119
+ (x,
120
+ x_b, x_b_s, x_r_s, x_c_s,
121
+ sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
122
+ output,
123
+ o_b, o_b_s, o_r_s,
124
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s,
125
+ sparsity_reverse_lut_output,
126
+ triton_block_size))
127
+
128
+ return output
129
+
130
+ @staticmethod
131
+ def backward(ctx, grad_output):
132
+ raise NotImplementedError
133
+
134
+ @staticmethod
135
+ @triton.jit
136
+ def kernel_blocksparse_row_wise_sum(x,
137
+ x_b, x_b_s, x_r_s, x_c_s,
138
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c, s_l_x_c_s,
139
+ r_lut_x,
140
+ o,
141
+ o_b, o_b_s, o_r_s,
142
+ s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
143
+ sparsity_block_size,
144
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
145
+ pid_blk = tl.program_id(axis=0)
146
+ pid_row = tl.program_id(axis=1)
147
+
148
+ # Get position of current sparsity block consisting of its batch and row index
149
+ spa_bat_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
150
+ spa_bat_msk = (spa_bat_idx < s_lut_o_r * s_lut_o_r_s)
151
+ spa_bat = tl.load(s_lut_o + spa_bat_idx, mask=spa_bat_msk)
152
+
153
+ spa_row_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
154
+ spa_row_msk = (spa_row_idx < s_lut_o_r * s_lut_o_r_s)
155
+ spa_row = tl.load(s_lut_o + spa_row_idx, mask=spa_row_msk)
156
+
157
+ buf = tl.zeros(shape=(TRITON_BLOCK_SIZE, 1), dtype=tl.float32)
158
+
159
+ # Slide over triton block sized segments of input tensor
160
+ for i_seg_tri in range(0, tl.cdiv(s_l_x_c * sparsity_block_size, TRITON_BLOCK_SIZE)):
161
+ # Convert to segment index of sparsity layout
162
+ i_seg_spa = (i_seg_tri * TRITON_BLOCK_SIZE) // sparsity_block_size
163
+ # Calculate the triton segment index within a block
164
+ i_seg_tri_mod = i_seg_tri % (sparsity_block_size // TRITON_BLOCK_SIZE)
165
+
166
+ # Load reverse sparsity index for current block
167
+ rev_idx_spa_idx = (spa_bat * s_l_x_b_s +
168
+ spa_row * s_l_x_r_s +
169
+ i_seg_spa * s_l_x_c_s)
170
+ rev_idx_spa_msk = (rev_idx_spa_idx < s_l_x_b * s_l_x_b_s)
171
+ rev_idx_spa = tl.load(r_lut_x + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
172
+
173
+ # If block is present commence operations
174
+ if rev_idx_spa >= 0:
175
+ blk_idx = ((rev_idx_spa * x_b_s) +
176
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
177
+ ((i_seg_tri_mod * TRITON_BLOCK_SIZE +
178
+ tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
179
+ blk_msk = (blk_idx < x_b * x_b_s)
180
+ blk = tl.load(x + blk_idx, mask=blk_msk)
181
+
182
+ buf = buf + tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
183
+
184
+ o_idx = (pid_blk * o_b_s +
185
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
186
+ (tl.arange(0, 1))[None, :])
187
+ o_msk = (o_idx < o_b * o_b_s)
188
+ tl.store(o + o_idx, buf, o_msk)
189
+
190
+ @staticmethod
191
+ @triton.jit
192
+ def kernel_blocksparse_row_wise_sum_atomic_add(x,
193
+ x_b, x_b_s, x_r_s, x_c_s,
194
+ s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
195
+ o,
196
+ o_b, o_b_s, o_r_s,
197
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s,
198
+ r_lut_o,
199
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
200
+ pid_blk = tl.program_id(axis=0)
201
+ pid_row = tl.program_id(axis=1)
202
+ pid_col = tl.program_id(axis=2)
203
+
204
+ # Get position of current sparsity block consisting of its batch and row index
205
+ spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
206
+ spa_bat_msk = (spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
207
+ spa_bat = tl.load(s_lut_x + spa_bat_idx, mask=spa_bat_msk)
208
+
209
+ spa_row_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
210
+ spa_row_msk = (spa_row_idx < s_lut_x_r * s_lut_x_r_s)
211
+ spa_row = tl.load(s_lut_x + spa_row_idx, mask=spa_row_msk)
212
+
213
+ # Load reverse sparsity index for current block
214
+ rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
215
+ spa_row * s_l_o_r_s)
216
+ rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
217
+ rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
218
+
219
+ blk_idx = ((pid_blk * x_b_s) +
220
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
221
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
222
+ blk_msk = (blk_idx < x_b * x_b_s)
223
+ blk = tl.load(x + blk_idx, mask=blk_msk)
224
+
225
+ buf = tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
226
+
227
+ o_idx = (rev_idx_spa * o_b_s +
228
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
229
+ (tl.arange(0, 1))[None, :])
230
+ o_msk = (o_idx < o_b * o_b_s)
231
+ tl.atomic_add(o + o_idx, buf, o_msk)
blksprs/ops/softmax.py ADDED
@@ -0,0 +1,263 @@
1
+ import torch
2
+ import triton
3
+ from torch import Tensor
4
+ from triton import language as tl
5
+
6
+ from blksprs.ops.exp import exp
7
+ from blksprs.ops.row_wise_sum import row_wise_sum
8
+ from blksprs.utils.tools import get_triton_block_size
9
+ from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
10
+ validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
11
+
12
+
13
+ def softmax(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
14
+ """Computes the softmax of a block-sparse tensor in compressed form.
15
+
16
+ Note:
17
+ Sparse blocks are not considered for the calculation of the softmax, i.e., all values are assumed to be ``-inf``.
18
+
19
+ Args:
20
+ x (Tensor): A block-sparse tensor in compressed form.
21
+ sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
22
+ sparsity_block_size (int): The size of the sparsity blocks.
23
+ triton_block_size (int): The block size to use for the triton kernel (default ``None``).
24
+
25
+ Returns:
26
+ Tensor: The result of the softmax operation as a block-sparse tensor in compressed form.
27
+
28
+ """
29
+ validate_dimensions(x)
30
+ validate_contiguous(x)
31
+ validate_device(x)
32
+ validate_sparsity(sparsity_block_size, (x, sparsity_layout))
33
+ validate_sparsity_block_size(sparsity_block_size, x)
34
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
35
+
36
+ if x.size(0) != 0:
37
+ max_val = torch.max(x).item()
38
+ else:
39
+ max_val = 0
40
+ x_scaled = x - max_val
41
+
42
+ sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
43
+
44
+ sparsity_layout_rws, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)
45
+ sparsity_layout_rws_flat = sparsity_layout_rws.reshape(-1)
46
+ sparsity_reverse_lut_rws = ((torch.cumsum(sparsity_layout_rws_flat, dim=-1) - 1) *
47
+ (sparsity_layout_rws_flat == 1) -
48
+ (1 * (sparsity_layout_rws_flat == 0)))
49
+
50
+ validate_contiguous(sparsity_layout, sparsity_lut, sparsity_reverse_lut_rws)
51
+
52
+ return _BlocksparseSoftmax.apply(x_scaled, sparsity_layout,
53
+ sparsity_lut,
54
+ sparsity_reverse_lut_rws,
55
+ sparsity_block_size, triton_block_size)
56
+
57
+
58
+ class _BlocksparseSoftmax(torch.autograd.Function):
59
+
60
+ @staticmethod
61
+ def forward(ctx, x: Tensor, sparsity_layout: Tensor,
62
+ sparsity_lut: Tensor,
63
+ sparsity_reverse_lut_rws: Tensor,
64
+ sparsity_block_size: int, triton_block_size: int) -> Tensor:
65
+ output = torch.empty_like(x)
66
+
67
+ x_b, x_r, x_c = x.shape
68
+ x_b_s, x_r_s, x_c_s = x.stride()
69
+ s_lut_r, s_lut_c = sparsity_lut.shape
70
+ s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
71
+ o_b, o_r, o_c = output.shape
72
+
73
+ x_exp = exp(x, sparsity_block_size, triton_block_size=triton_block_size)
74
+ x_exp_row_wise_sum, sparsity_layout_rws = row_wise_sum(x_exp, sparsity_layout, sparsity_block_size,
75
+ flag_slice_only=True,
76
+ triton_block_size=triton_block_size)
77
+
78
+ s_b, s_r, s_c = x_exp_row_wise_sum.shape
79
+ s_b_s, s_r_s, s_c_s = x_exp_row_wise_sum.stride()
80
+ s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_rws.shape
81
+ s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = sparsity_layout_rws.stride()
82
+
83
+ if triton_block_size is None:
84
+ triton_block_size = get_triton_block_size(sparsity_block_size)
85
+
86
+ triton_grid = lambda meta: [o_b,
87
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
88
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
89
+
90
+ (_BlocksparseSoftmax.kernel_blocksparse_softmax[triton_grid]
91
+ (x_exp,
92
+ x_b, x_b_s, x_r_s, x_c_s,
93
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
94
+ x_exp_row_wise_sum, s_b, s_b_s, s_r_s, s_c_s,
95
+ s_l_s_b, s_l_s_b_s, s_l_s_r_s,
96
+ sparsity_reverse_lut_rws,
97
+ output,
98
+ triton_block_size))
99
+
100
+ # Save for backward pass
101
+ ctx.save_for_backward(output, sparsity_layout, sparsity_lut)
102
+ ctx.sparsity_block_size = sparsity_block_size
103
+ ctx.triton_block_size = triton_block_size
104
+
105
+ return output
106
+
107
+ @staticmethod
108
+ def backward(ctx, grad_output):
109
+ o, sparsity_layout, sparsity_lut = ctx.saved_tensors
110
+ sparsity_block_size = ctx.sparsity_block_size
111
+ triton_block_size = ctx.triton_block_size
112
+
113
+ s, sparsity_layout_s = row_wise_sum(grad_output * o, sparsity_layout, sparsity_block_size, flag_slice_only=True,
114
+ triton_block_size=triton_block_size)
115
+
116
+ sparsity_layout_s_flat = sparsity_layout_s.reshape(-1)
117
+ sparsity_reverse_lut_s = ((torch.cumsum(sparsity_layout_s_flat, dim=-1) - 1) *
118
+ (sparsity_layout_s_flat == 1) -
119
+ (1 * (sparsity_layout_s_flat == 0)))
120
+
121
+ o_b, o_r, o_c = o.size()
122
+ o_b_s, o_r_s, o_c_s = o.stride()
123
+ s_lut_r, s_lut_c = sparsity_lut.size()
124
+ s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
125
+ s_b, s_r, s_c = s.size()
126
+ s_b_s, s_r_s, s_c_s = s.stride()
127
+ s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
128
+ s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = sparsity_layout_s.stride()
129
+
130
+ grad_x = torch.empty_like(o)
131
+
132
+ triton_grid = lambda meta: [o_b,
133
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
134
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
135
+
136
+ (_BlocksparseSoftmax.kernel_blocksparse_softmax_grad_x[triton_grid]
137
+ (grad_output,
138
+ o_b, o_b_s, o_r_s, o_c_s,
139
+ o,
140
+ o_b, o_b_s, o_r_s, o_c_s,
141
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
142
+ s,
143
+ s_b, s_b_s, s_r_s, s_c_s,
144
+ s_l_s_b, s_l_s_b_s, s_l_s_r_s,
145
+ sparsity_reverse_lut_s,
146
+ grad_x,
147
+ o_b, o_b_s, o_r_s, o_c_s,
148
+ triton_block_size
149
+ ))
150
+
151
+ return grad_x, None, None, None, None, None
152
+
153
+ @staticmethod
154
+ @triton.jit
155
+ def kernel_blocksparse_softmax(x,
156
+ x_b, x_b_s, x_r_s, x_c_s,
157
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
158
+ s, s_b, s_b_s, s_r_s, s_c_s,
159
+ s_l_s_b, s_l_s_b_s, s_l_s_r_s,
160
+ r_lut_s,
161
+ o,
162
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
163
+ # Get triton block indices
164
+ pid_blk = tl.program_id(axis=0)
165
+ pid_row = tl.program_id(axis=1)
166
+ pid_col = tl.program_id(axis=2)
167
+
168
+ # Get position of current sparsity block consisting of its batch and row index
169
+ spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
170
+ spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
171
+ spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
172
+
173
+ spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
174
+ spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
175
+ spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
176
+
177
+ # Get reverse sparsity indices for x
178
+ rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
179
+ spa_row * s_l_s_r_s)
180
+ rev_idx_spa_s_msk = (rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
181
+ rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
182
+
183
+ if rev_idx_spa_s == -1:
184
+ assert False, "Invalid sparsity block"
185
+
186
+ # Load x block
187
+ blk_x_idx = ((pid_blk * x_b_s) +
188
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
189
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
190
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
191
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
192
+
193
+ # Load sum block
194
+ blk_s_idx = (rev_idx_spa_s * s_b_s +
195
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
196
+ (tl.arange(0, 1) * s_c_s)[None, :])
197
+ blk_s_msk = (blk_s_idx < s_b * s_b_s)
198
+ blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
199
+
200
+ # Compute softmax
201
+ buf = tl.div_rn(blk_x, blk_s)
202
+
203
+ # Store output
204
+ tl.store(o + blk_x_idx, buf, mask=blk_x_msk)
205
+
206
+ @staticmethod
207
+ @triton.jit
208
+ def kernel_blocksparse_softmax_grad_x(g,
209
+ g_b, g_b_s, g_r_s, g_c_s,
210
+ x,
211
+ x_b, x_b_s, x_r_s, x_c_s,
212
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
213
+ s,
214
+ s_b, s_b_s, s_r_s, s_c_s,
215
+ s_l_s_b, s_l_s_b_s, s_l_s_r_s,
216
+ r_lut_s,
217
+ o,
218
+ o_b, o_b_s, o_r_s, o_c_s,
219
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
220
+ # Get triton block indices
221
+ pid_blk = tl.program_id(axis=0)
222
+ pid_row = tl.program_id(axis=1)
223
+ pid_col = tl.program_id(axis=2)
224
+
225
+ # Get position of current sparsity block consisting of its batch and row index
226
+ spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
227
+ spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
228
+ spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
229
+
230
+ spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
231
+ spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
232
+ spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
233
+
234
+ rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
235
+ spa_row * s_l_s_r_s)
236
+ rev_idx_spa_s_msk = (rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
237
+ rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
238
+
239
+ blk_s_idx = (rev_idx_spa_s * s_b_s +
240
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
241
+ (tl.arange(0, 1) * s_c_s)[None, :])
242
+ blk_s_msk = (blk_s_idx < s_b * s_b_s)
243
+ blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
244
+
245
+ blk_g_idx = ((pid_blk * g_b_s) +
246
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_r_s)[:, None] +
247
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_c_s)[None, :])
248
+ blk_g_msk = (blk_g_idx < g_b * g_b_s)
249
+ blk_g = tl.load(g + blk_g_idx, mask=blk_g_msk)
250
+
251
+ blk_x_idx = ((pid_blk * x_b_s) +
252
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
253
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
254
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
255
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
256
+
257
+ buf = blk_x * (blk_g - blk_s)
258
+
259
+ blk_o_idx = ((pid_blk * o_b_s) +
260
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
261
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
262
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
263
+ tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
@@ -0,0 +1,154 @@
1
+ import torch
2
+ import triton
3
+ from torch import Tensor
4
+ from triton import language as tl
5
+
6
+ from blksprs.utils.tools import get_triton_block_size
7
+ from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
8
+ validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
9
+
10
+
11
+ def transpose(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> (
12
+ Tensor, Tensor):
13
+ """Transposes a block-sparse tensor in compressed form.
14
+
15
+ Note:
16
+ Returns the transposed tensor and the sparsity layout of the transposed tensor.
17
+
18
+ Args:
19
+ x (Tensor): A block-sparse tensor in compressed form.
20
+ sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
21
+ sparsity_block_size (int): The size of the sparsity blocks.
22
+ triton_block_size (int): The block size to use for the triton kernel (default ``None``).
23
+
24
+ Returns:
25
+ Tensor: The transposed block-sparse tensor in compressed form.
26
+ Tensor: The sparsity layout of the transposed tensor.
27
+
28
+ """
29
+ validate_dimensions(x)
30
+ validate_contiguous(x)
31
+ validate_device(x)
32
+ validate_sparsity(sparsity_block_size, (x, sparsity_layout))
33
+ validate_sparsity_block_size(sparsity_block_size, x)
34
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
35
+
36
+ sparsity_layout_t = sparsity_layout.transpose(-1, -2).contiguous()
37
+
38
+ sparsity_lut = torch.nonzero(sparsity_layout_t).contiguous()
39
+
40
+ sparsity_layout_flat = sparsity_layout.reshape(-1)
41
+ sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
42
+ (sparsity_layout_flat == 1) -
43
+ (1 * (sparsity_layout_flat == 0)))
44
+ .reshape(sparsity_layout.size()).transpose(-1, -2).contiguous().reshape(-1))
45
+
46
+ n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
47
+
48
+ validate_contiguous(sparsity_layout_t, sparsity_lut, sparsity_reverse_lut)
49
+
50
+ return _BlocksparseTranspose.apply(x, sparsity_layout_t, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
51
+ n_sparse_blocks, triton_block_size), sparsity_layout_t
52
+
53
+
54
+ class _BlocksparseTranspose(torch.autograd.Function):
55
+
56
+ @staticmethod
57
+ def forward(ctx, x: Tensor,
58
+ sparsity_layout: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor, sparsity_block_size: int,
59
+ n_sparse_blocks: int, triton_block_size: int) -> (Tensor, Tensor):
60
+ output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size), device=x.device)
61
+
62
+ x_b, x_r, x_c = x.size()
63
+ x_b_s, x_r_s, x_c_s = x.stride()
64
+ s_l_b, s_l_r, s_l_c = sparsity_layout.size()
65
+ s_l_b_s, s_l_r_s, s_l_c_s = sparsity_layout.stride()
66
+ s_lut_r, s_lut_c = sparsity_lut.shape
67
+ s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
68
+ o_b, o_r, o_c = output.size()
69
+ o_b_s, o_r_s, o_c_s = output.stride()
70
+
71
+ if triton_block_size is None:
72
+ triton_block_size = get_triton_block_size(sparsity_block_size)
73
+
74
+ triton_grid = lambda meta: [o_b,
75
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
76
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
77
+
78
+ (_BlocksparseTranspose.kernel_blocksparse_transpose[triton_grid]
79
+ (x,
80
+ x_b, x_b_s, x_r_s, x_c_s,
81
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
82
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
83
+ sparsity_reverse_lut,
84
+ output,
85
+ o_b, o_b_s,
86
+ triton_block_size))
87
+
88
+ # Save for backward pass
89
+ ctx.save_for_backward(sparsity_layout)
90
+ ctx.sparsity_layout = sparsity_layout
91
+ ctx.sparsity_block_size = sparsity_block_size
92
+ ctx.triton_block_size = triton_block_size
93
+
94
+ return output
95
+
96
+ @staticmethod
97
+ def backward(ctx, grad_output):
98
+ sparsity_layout = ctx.saved_tensors[0]
99
+ sparsity_block_size = ctx.sparsity_block_size
100
+ triton_block_size = ctx.triton_block_size
101
+
102
+ return transpose(grad_output, sparsity_layout, sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None
103
+
104
+ @staticmethod
105
+ @triton.jit
106
+ def kernel_blocksparse_transpose(x,
107
+ x_b, x_b_s, x_r_s, x_c_s,
108
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
109
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
110
+ r_lut,
111
+ o,
112
+ o_b, o_b_s,
113
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
114
+ # Get triton block indices
115
+ pid_blk = tl.program_id(axis=0)
116
+ pid_row = tl.program_id(axis=1)
117
+ pid_col = tl.program_id(axis=2)
118
+
119
+ # Get sparsity index of current output block consisting of its batch, row, and column index
120
+ spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
121
+ spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
122
+ spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
123
+
124
+ spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
125
+ spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
126
+ spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
127
+
128
+ spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
129
+ spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
130
+ spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
131
+
132
+ # Get reverse sparsity indices
133
+ rev_idx_spa_idx = (spa_bat * s_l_b_s +
134
+ spa_row * s_l_r_s +
135
+ spa_col * s_l_c_s)
136
+ rev_idx_spa_msk = (rev_idx_spa_idx < s_l_b * s_l_b_s)
137
+ rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
138
+
139
+ if rev_idx_spa == -1:
140
+ assert False, "Invalid sparsity block"
141
+
142
+ blk_x_idx = (rev_idx_spa * x_b_s +
143
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
144
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
145
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
146
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
147
+
148
+ blk_x_t = tl.trans(blk_x)
149
+
150
+ blk_o_idx = (pid_blk * o_b_s +
151
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
152
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
153
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
154
+ tl.store(o + blk_o_idx, blk_x_t, mask=blk_o_msk)
blksprs/utils/tools.py ADDED
@@ -0,0 +1,20 @@
1
+ import torch
2
+ from torch import Tensor, Size
3
+
4
+
5
+ def do_shape_blocksparse(x: Tensor):
6
+ if x.dim() == 3:
7
+ return x, x.size()
8
+
9
+ return x.reshape(-1, x.size(-2), x.size(-1)), x.size()
10
+
11
+
12
+ def undo_shape_blocksparse(x: Tensor, shape: Size):
13
+ if x.shape[-2:] == shape[-2:]:
14
+ return x
15
+
16
+ return x.reshape((*shape[:-2], *x.shape[-2:]))
17
+
18
+
19
+ def get_triton_block_size(sparsity_block_size: int, limit: int = 128):
20
+ return min(sparsity_block_size, limit)