blksprs 0.2b4__py3-none-any.whl → 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,362 @@
1
+ import torch
2
+ import triton
3
+ from torch import Tensor
4
+ from triton import language as tl
5
+
6
+ from blksprs.utils.tools import get_triton_block_size
7
+ from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
8
+ validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
9
+
10
+
11
+ def gather(src: Tensor, sparsity_layout_src: Tensor, idx: Tensor, sparsity_layout_idx: Tensor,
12
+ sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
13
+ """Applies a gather operation on a block-sparse tensor in compressed form.
14
+
15
+ Args:
16
+ src (Tensor): The source block-sparse tensor in compressed form to gather from.
17
+ sparsity_layout_src (Tensor): The sparsity layout of the source block-sparse tensor.
18
+ idx (Tensor): The block-sparse indices tensor in compressed form specifying how to gather from the source tensor.
19
+ sparsity_layout_idx (Tensor): The sparsity layout of the indices block-sparse tensor.
20
+ sparsity_block_size (int): The size of the sparsity blocks.
21
+ triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
22
+
23
+ Returns:
24
+ Tensor: The result of the gather operation as a block-sparse tensor in compressed form.
25
+
26
+ """
27
+ validate_dimensions(src, idx)
28
+ validate_contiguous(src, idx)
29
+ validate_dtype_int(idx)
30
+ validate_device(src, idx)
31
+ validate_sparsity(sparsity_block_size, (src, sparsity_layout_src), (idx, sparsity_layout_idx))
32
+ validate_sparsity_block_size(sparsity_block_size, src, idx)
33
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
34
+
35
+ sparsity_layout_x_flat = sparsity_layout_src.reshape(-1)
36
+ sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
37
+ (sparsity_layout_x_flat == 1) -
38
+ (1 * (sparsity_layout_x_flat == 0)))
39
+
40
+ sparsity_lut_i = torch.nonzero(sparsity_layout_idx).contiguous()
41
+
42
+ validate_contiguous(sparsity_layout_src, sparsity_reverse_lut_x,
43
+ sparsity_layout_idx, sparsity_lut_i)
44
+
45
+ return _BlocksparseGather.apply(src, sparsity_layout_src, sparsity_reverse_lut_x,
46
+ idx, sparsity_layout_idx, sparsity_lut_i,
47
+ sparsity_block_size, triton_block_size)
48
+
49
+
50
+ class _BlocksparseGather(torch.autograd.Function):
51
+
52
+ @staticmethod
53
+ def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
54
+ i: Tensor, sparsity_layout_i: Tensor, sparsity_lut_i: Tensor,
55
+ sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
56
+ output = torch.empty_like(i, dtype=x.dtype)
57
+
58
+ x_b, x_r, x_c = x.size()
59
+ x_b_s, x_r_s, x_c_s = x.stride()
60
+ s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
61
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout_x.stride()
62
+ i_b, i_r, i_c = i.size()
63
+ i_b_s, i_r_s, i_c_s = i.stride()
64
+ s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
65
+ s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
66
+ o_b, o_r, o_c = output.size()
67
+ o_b_s, o_r_s, o_c_s = output.stride()
68
+
69
+ if triton_block_size is None:
70
+ triton_block_size = get_triton_block_size(sparsity_block_size)
71
+
72
+ triton_grid = lambda meta: [o_b,
73
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
74
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
75
+
76
+ (_BlocksparseGather.kernel_blocksparse_gather[triton_grid]
77
+ (x,
78
+ x_b, x_b_s, x_r_s, x_c_s,
79
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
80
+ sparsity_reverse_lut_x,
81
+ i,
82
+ i_b, i_b_s, i_r_s, i_c_s,
83
+ output,
84
+ o_b, o_b_s, o_r_s, o_c_s,
85
+ sparsity_lut_i, s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
86
+ sparsity_block_size,
87
+ triton_block_size))
88
+
89
+ ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_i)
90
+ ctx.sparsity_block_size = sparsity_block_size
91
+ ctx.triton_block_size = triton_block_size
92
+
93
+ return output
94
+
95
+ @staticmethod
96
+ def backward(ctx, grad_output):
97
+ sparsity_layout_x, i, sparsity_layout_i = ctx.saved_tensors
98
+ sparsity_block_size = ctx.sparsity_block_size
99
+ triton_block_size = ctx.triton_block_size
100
+
101
+ return scatter_reduce(grad_output, sparsity_layout_i,
102
+ i,
103
+ sparsity_layout_x,
104
+ sparsity_block_size,
105
+ reduce_op="sum",
106
+ triton_block_size=triton_block_size), None, None, None, None, None, None, None
107
+
108
+ @staticmethod
109
+ @triton.jit
110
+ def kernel_blocksparse_gather(x,
111
+ x_b, x_b_s, x_r_s, x_c_s,
112
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
113
+ r_lut_x,
114
+ i,
115
+ i_b, i_b_s, i_r_s, i_c_s,
116
+ o,
117
+ o_b, o_b_s, o_r_s, o_c_s,
118
+ s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
119
+ sparsity_block_size,
120
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
121
+ # Get triton block indices
122
+ pid_blk = tl.program_id(axis=0)
123
+ pid_row = tl.program_id(axis=1)
124
+ pid_col = tl.program_id(axis=2)
125
+
126
+ # Get position of current sparsity block consisting of its batch, row, and column index
127
+ spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
128
+ spa_bat_o_msk = (spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
129
+ spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
130
+
131
+ spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
132
+ spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
133
+ spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
134
+
135
+ # Load index values
136
+ blk_i_idx = ((pid_blk * i_b_s) +
137
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
138
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
139
+ blk_i_msk = (blk_i_idx < i_b * i_b_s)
140
+ blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
141
+
142
+ # Get positions of sparsity blocks
143
+ pos_spa_blk_x = blk_i // sparsity_block_size
144
+ pos_spa_col_x = blk_i % sparsity_block_size
145
+
146
+ # Load reverse sparsity indices for x
147
+ rev_idx_spa_x_idx = ((spa_bat_o * s_l_x_b_s) +
148
+ (spa_row_o * s_l_x_r_s) +
149
+ (pos_spa_blk_x * s_l_x_c_s))
150
+ rev_idx_spa_x_msk = (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
151
+ rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
152
+
153
+ # Load x values
154
+ blk_x_idx = ((rev_idx_spa_x * x_b_s) +
155
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
156
+ (pos_spa_col_x * x_c_s))
157
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
158
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
159
+
160
+ # Store output
161
+ blk_o_idx = ((pid_blk * o_b_s) +
162
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
163
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
164
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
165
+ tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
166
+
167
+
168
+ def scatter(src: Tensor, sparsity_layout_src: Tensor,
169
+ idx: Tensor,
170
+ sparsity_layout_tgt: Tensor,
171
+ sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
172
+ """Wrapper for ``scatter_reduce`` with ``reduce_op="none"``.
173
+
174
+ """
175
+ return scatter_reduce(src, sparsity_layout_src,
176
+ idx,
177
+ sparsity_layout_tgt,
178
+ sparsity_block_size,
179
+ reduce_op="none", triton_block_size=triton_block_size)
180
+
181
+
182
+ def scatter_reduce(src: Tensor, sparsity_layout_src: Tensor,
183
+ idx: Tensor,
184
+ sparsity_layout_tgt: Tensor,
185
+ sparsity_block_size: int,
186
+ reduce_op: str = "sum", triton_block_size: int = None) -> Tensor:
187
+ """Applies a scatter operation on a block-sparse tensor in compressed form.
188
+
189
+ Args:
190
+ src (Tensor): The source block-sparse tensor in compressed form to scatter from.
191
+ sparsity_layout_src (Tensor): The sparsity layout of the source block-sparse tensor.
192
+ idx (Tensor): The block-sparse indices tensor in compressed form specifying how to scatter to the target tensor.
193
+ sparsity_layout_tgt (Tensor): The sparsity layout of the target block-sparse tensor.
194
+ sparsity_block_size (int): The size of the sparsity blocks.
195
+ reduce_op (str, optional): The reduction operation to apply during the scatter operation (default ``"sum"``).
196
+ Supported operations are ``"none"`` and ``"sum"``.
197
+ triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
198
+
199
+ Returns:
200
+ Tensor: The result of the scatter operation as a block-sparse tensor in compressed form.
201
+
202
+ """
203
+ validate_dimensions(src, idx)
204
+ validate_contiguous(src, idx)
205
+ validate_dtype_int(idx)
206
+ validate_device(src, idx)
207
+ validate_sparsity(sparsity_block_size, (src, sparsity_layout_src), (idx, sparsity_layout_src))
208
+ validate_sparsity_block_size(sparsity_block_size, src, idx)
209
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
210
+
211
+ if reduce_op not in ["none", "sum"]:
212
+ raise ValueError(f"Reduction operation '{reduce_op}' is not supported")
213
+
214
+ sparsity_lut_x = torch.nonzero(sparsity_layout_src).contiguous()
215
+
216
+ sparsity_layout_o_flat = sparsity_layout_tgt.reshape(-1)
217
+ sparsity_reverse_lut_o = ((torch.cumsum(sparsity_layout_o_flat, dim=-1) - 1) *
218
+ (sparsity_layout_o_flat == 1) -
219
+ (1 * (sparsity_layout_o_flat == 0)))
220
+
221
+ n_sparse_blocks = torch.sum(sparsity_layout_tgt.to(torch.int)).item()
222
+
223
+ validate_contiguous(sparsity_layout_src, sparsity_lut_x,
224
+ sparsity_layout_tgt, sparsity_reverse_lut_o)
225
+
226
+ return _BlocksparseScatterReduce.apply(src, sparsity_layout_src, sparsity_lut_x,
227
+ idx,
228
+ sparsity_layout_tgt, sparsity_reverse_lut_o,
229
+ sparsity_block_size, n_sparse_blocks,
230
+ reduce_op, triton_block_size)
231
+
232
+
233
+ class _BlocksparseScatterReduce(torch.autograd.Function):
234
+
235
+ @staticmethod
236
+ def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_lut_x: Tensor,
237
+ i: Tensor,
238
+ sparsity_layout_o: Tensor, sparsity_reverse_lut_o: Tensor,
239
+ sparsity_block_size: int, n_sparse_blocks: int,
240
+ reduce_op: str, triton_block_size: int) -> Tensor:
241
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
242
+ dtype=x.dtype, device=x.device)
243
+
244
+ x_b, x_r, x_c = x.size()
245
+ x_b_s, x_r_s, x_c_s = x.stride()
246
+ s_lut_x_r, s_lut_x_c = sparsity_lut_x.size()
247
+ s_lut_x_r_s, s_lut_x_c_s = sparsity_lut_x.stride()
248
+ i_b, i_r, i_c = i.size()
249
+ i_b_s, i_r_s, i_c_s = i.stride()
250
+ o_b, o_r, o_c = output.size()
251
+ o_b_s, o_r_s, o_c_s = output.stride()
252
+ s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
253
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_o.stride()
254
+
255
+ if triton_block_size is None:
256
+ triton_block_size = get_triton_block_size(sparsity_block_size)
257
+
258
+ triton_grid = lambda meta: [x_b,
259
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
260
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
261
+
262
+ reduce_op_ind = 0
263
+ if reduce_op == "sum":
264
+ reduce_op_ind = 1
265
+
266
+ (_BlocksparseScatterReduce.kernel_blocksparse_scatter[triton_grid]
267
+ (x,
268
+ x_b, x_b_s, x_r_s, x_c_s,
269
+ sparsity_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
270
+ i,
271
+ i_b, i_b_s, i_r_s, i_c_s,
272
+ output,
273
+ o_b, o_b_s, o_r_s, o_c_s,
274
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
275
+ sparsity_reverse_lut_o,
276
+ reduce_op_ind,
277
+ sparsity_block_size,
278
+ triton_block_size))
279
+
280
+ ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_o)
281
+ ctx.sparsity_block_size = sparsity_block_size
282
+ ctx.reduce_op = reduce_op
283
+ ctx.triton_block_size = triton_block_size
284
+
285
+ return output
286
+
287
+ @staticmethod
288
+ def backward(ctx, grad_output):
289
+ sparsity_layout_x, i, sparsity_layout_o = ctx.saved_tensors
290
+ sparsity_block_size = ctx.sparsity_block_size
291
+ reduce_op = ctx.reduce_op
292
+ triton_block_size = ctx.triton_block_size
293
+
294
+ if reduce_op == "sum":
295
+ return gather(grad_output, sparsity_layout_o, i, sparsity_layout_x, sparsity_block_size,
296
+ triton_block_size=triton_block_size), None, None, None, None, None, None, None, None, None
297
+ else:
298
+ raise ValueError(f"Reduction operation '{reduce_op}' does not support backward pass")
299
+
300
+ @staticmethod
301
+ @triton.jit
302
+ def kernel_blocksparse_scatter(x,
303
+ x_b, x_b_s, x_r_s, x_c_s,
304
+ s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
305
+ i,
306
+ i_b, i_b_s, i_r_s, i_c_s,
307
+ o,
308
+ o_b, o_b_s, o_r_s, o_c_s,
309
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
310
+ r_lut_o,
311
+ reduce_op_ind,
312
+ sparsity_block_size,
313
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
314
+ # Get triton block indices
315
+ pid_blk = tl.program_id(axis=0)
316
+ pid_row = tl.program_id(axis=1)
317
+ pid_col = tl.program_id(axis=2)
318
+
319
+ # Get position of current sparsity block consisting of its batch, row, and column index
320
+ spa_bat_x_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
321
+ spa_bat_x_msk = (spa_bat_x_idx < s_lut_x_r * s_lut_x_r_s)
322
+ spa_bat_x = tl.load(s_lut_x + spa_bat_x_idx, mask=spa_bat_x_msk)
323
+
324
+ spa_row_x_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
325
+ spa_row_x_msk = (spa_row_x_idx < s_lut_x_r * s_lut_x_r_s)
326
+ spa_row_x = tl.load(s_lut_x + spa_row_x_idx, mask=spa_row_x_msk)
327
+
328
+ # Load x values
329
+ blk_x_idx = ((pid_blk * x_b_s) +
330
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
331
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
332
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
333
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
334
+
335
+ # Load index values
336
+ blk_i_idx = ((pid_blk * i_b_s) +
337
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
338
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
339
+ blk_i_msk = (blk_i_idx < i_b * i_b_s)
340
+ blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
341
+
342
+ # Get positions of sparsity blocks
343
+ pos_spa_blk_o = blk_i // sparsity_block_size
344
+ pos_spa_col_o = blk_i % sparsity_block_size
345
+
346
+ # Load reverse sparsity indices for o
347
+ rev_idx_spa_o_idx = ((spa_bat_x * s_l_o_b_s) +
348
+ (spa_row_x * s_l_o_r_s) +
349
+ (pos_spa_blk_o * s_l_o_c_s))
350
+ rev_idx_spa_o_msk = (rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
351
+ rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
352
+
353
+ # Store output
354
+ blk_o_idx = ((rev_idx_spa_o * o_b_s) +
355
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
356
+ (pos_spa_col_o * o_c_s))
357
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
358
+
359
+ if reduce_op_ind == 0:
360
+ tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
361
+ elif reduce_op_ind == 1:
362
+ tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
blksprs/ops/exp.py ADDED
@@ -0,0 +1,101 @@
1
+ import torch
2
+ import triton
3
+ from torch import Tensor
4
+ from triton import language as tl
5
+
6
+ from blksprs.utils.tools import get_triton_block_size
7
+ from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
8
+ validate_sparsity_block_size, validate_triton_block_size
9
+
10
+
11
+ def exp(x: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
12
+ """Applies the element-wise exponential function to a block-sparse tensor.
13
+
14
+ Note:
15
+ This operation does not consider sparse blocks, i.e., these will not be set to ``e^0``.
16
+ Consider this when converting back to tensors in regular form.
17
+
18
+ Args:
19
+ x (Tensor): A block-sparse tensor in compressed form.
20
+ sparsity_block_size (int): The size of the sparsity blocks.
21
+ triton_block_size (int): The block size to use for the triton kernel (default ``None``).
22
+
23
+ Returns:
24
+ Tensor: The exponential function applied to all elements of the input tensor as a block-sparse tensor in
25
+ compressed form.
26
+
27
+ """
28
+ validate_dimensions(x)
29
+ validate_contiguous(x)
30
+ validate_device(x)
31
+ validate_sparsity_block_size(sparsity_block_size, x)
32
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
33
+
34
+ return _BlocksparseExp.apply(x, sparsity_block_size, triton_block_size)
35
+
36
+
37
+ class _BlocksparseExp(torch.autograd.Function):
38
+
39
+ @staticmethod
40
+ def forward(ctx, x: Tensor, sparsity_block_size: int, triton_block_size: int) -> Tensor:
41
+ output = torch.empty_like(x)
42
+
43
+ x_b, x_r, x_c = x.shape
44
+ x_b_s, x_r_s, x_c_s = x.stride()
45
+ o_b, o_r, o_c = output.shape
46
+ o_b_s, o_r_s, o_c_s = output.stride()
47
+
48
+ if triton_block_size is None:
49
+ triton_block_size = get_triton_block_size(sparsity_block_size)
50
+
51
+ triton_grid = lambda meta: [o_b,
52
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
53
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
54
+
55
+ (_BlocksparseExp.kernel_blocksparse_exp[triton_grid]
56
+ (x,
57
+ x_b, x_b_s, x_r_s, x_c_s,
58
+ output,
59
+ o_b, o_b_s, o_r_s, o_c_s,
60
+ triton_block_size))
61
+
62
+ ctx.save_for_backward(output)
63
+
64
+ return output
65
+
66
+ @staticmethod
67
+ def backward(ctx, grad_output):
68
+ o = ctx.saved_tensors[0]
69
+
70
+ grad_x = torch.mul(grad_output, o)
71
+
72
+ return grad_x, None, None
73
+
74
+ @staticmethod
75
+ @triton.jit
76
+ def kernel_blocksparse_exp(x,
77
+ x_b, x_b_s, x_r_s, x_c_s,
78
+ o,
79
+ o_b, o_b_s, o_r_s, o_c_s,
80
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
81
+ # Get triton block indices
82
+ pid_blk = tl.program_id(axis=0)
83
+ pid_row = tl.program_id(axis=1)
84
+ pid_col = tl.program_id(axis=2)
85
+
86
+ # Load block
87
+ blk_x_idx = ((pid_blk * x_b_s) +
88
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
89
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
90
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
91
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
92
+
93
+ # Compute exp
94
+ buf = tl.exp(blk_x)
95
+
96
+ # Store block
97
+ blk_o_idx = ((pid_blk * o_b_s) +
98
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
99
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
100
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
101
+ tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
blksprs/ops/matmul.py ADDED
@@ -0,0 +1,221 @@
1
+ import torch
2
+ import triton
3
+ from torch import Tensor
4
+ from triton import language as tl
5
+
6
+ from blksprs.ops.transpose import transpose
7
+ from blksprs.utils.tools import get_triton_block_size
8
+ from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
9
+ validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
10
+
11
+
12
+ def matmul(x: Tensor, sparsity_layout_x: Tensor,
13
+ y: Tensor, sparsity_layout_y: Tensor,
14
+ sparsity_layout_output: Tensor,
15
+ sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
16
+ """Performs matrix multiplication between two block-sparse tensors.
17
+
18
+ The sparsity layout of the output tensor is used to only calculate blocks that will be present in the output.
19
+
20
+ Args:
21
+ x (Tensor): A block-sparse tensor in compressed form.
22
+ y (Tensor): A block-sparse tensor in compressed form.
23
+ sparsity_layout_x (Tensor): The sparsity layout of the first block-sparse tensor.
24
+ sparsity_layout_y (Tensor): The sparsity layout of the second block-sparse tensor.
25
+ sparsity_layout_output (Tensor): The sparsity layout of the output tensor.
26
+ sparsity_block_size (int): The size of the sparsity blocks.
27
+ triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
28
+
29
+ Returns:
30
+ Tensor: The result of the matrix multiplication as a block-sparse tensor in compressed form.
31
+
32
+ """
33
+ validate_dimensions(x, y)
34
+ validate_contiguous(x, y)
35
+ validate_device(x, y)
36
+ validate_sparsity(sparsity_block_size, (x, sparsity_layout_x), (y, sparsity_layout_y))
37
+ if sparsity_layout_x.size(-1) != sparsity_layout_y.size(-2):
38
+ raise ValueError("Inner dimensions of tensors must match")
39
+ validate_sparsity_block_size(sparsity_block_size, x, y)
40
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
41
+
42
+ sparsity_layout_x_flat = sparsity_layout_x.reshape(-1)
43
+ sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
44
+ (sparsity_layout_x_flat == 1) -
45
+ (1 * (sparsity_layout_x_flat == 0)))
46
+
47
+ sparsity_layout_y_flat = sparsity_layout_y.reshape(-1)
48
+ sparsity_reverse_lut_y = ((torch.cumsum(sparsity_layout_y_flat, dim=-1) - 1) *
49
+ (sparsity_layout_y_flat == 1) -
50
+ (1 * (sparsity_layout_y_flat == 0)))
51
+
52
+ sparsity_lut_o = torch.nonzero(sparsity_layout_output).contiguous()
53
+
54
+ n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
55
+
56
+ validate_contiguous(sparsity_layout_x, sparsity_reverse_lut_x,
57
+ sparsity_layout_y, sparsity_reverse_lut_y,
58
+ sparsity_layout_output, sparsity_lut_o)
59
+
60
+ return _BlocksparseMatmulSSS.apply(x, y,
61
+ sparsity_layout_x, sparsity_reverse_lut_x,
62
+ sparsity_layout_y, sparsity_reverse_lut_y,
63
+ sparsity_layout_output, sparsity_lut_o,
64
+ sparsity_block_size,
65
+ n_sparse_blocks,
66
+ triton_block_size)
67
+
68
+
69
+ class _BlocksparseMatmulSSS(torch.autograd.Function):
70
+
71
+ @staticmethod
72
+ def forward(ctx, x: Tensor, y: Tensor,
73
+ sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
74
+ sparsity_layout_y: Tensor, sparsity_reverse_lut_y: Tensor,
75
+ sparsity_layout_o: Tensor, sparsity_lut_o: Tensor,
76
+ sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
77
+ output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size), device=x.device)
78
+
79
+ x_b, x_r, x_c = x.size()
80
+ x_b_s, x_r_s, x_c_s = x.stride()
81
+ s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
82
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout_x.stride()
83
+ y_b, y_r, y_c = y.size()
84
+ y_b_s, y_r_s, y_c_s = y.stride()
85
+ s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_y.size()
86
+ s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = sparsity_layout_y.stride()
87
+ o_b, o_r, o_c = output.size()
88
+ o_b_s, o_r_s, o_c_s = output.stride()
89
+ s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
90
+ s_lut_o_r_s, s_lut_o_c_s = sparsity_lut_o.stride()
91
+
92
+ if triton_block_size is None:
93
+ triton_block_size = get_triton_block_size(sparsity_block_size)
94
+
95
+ triton_grid = lambda meta: [o_b,
96
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
97
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
98
+
99
+ (_BlocksparseMatmulSSS.kernel_blocksparse_matmul_sss[triton_grid]
100
+ (x,
101
+ x_b, x_b_s, x_r_s, x_c_s,
102
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c, s_l_x_c_s,
103
+ sparsity_reverse_lut_x,
104
+ y,
105
+ y_b, y_b_s, y_r_s, y_c_s,
106
+ s_l_y_b, s_l_y_b_s, s_l_y_r_s, s_l_y_c_s,
107
+ sparsity_reverse_lut_y,
108
+ output,
109
+ o_b, o_b_s, o_r_s, o_c_s,
110
+ sparsity_lut_o,
111
+ s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
112
+ sparsity_block_size,
113
+ triton_block_size))
114
+
115
+ ctx.save_for_backward(x, sparsity_layout_x, y, sparsity_layout_y, sparsity_layout_o)
116
+ ctx.sparsity_block_size = sparsity_block_size
117
+ ctx.triton_block_size = triton_block_size
118
+
119
+ return output
120
+
121
+ @staticmethod
122
+ def backward(ctx, grad_output):
123
+ x, sparsity_layout_x, y, sparsity_layout_y, sparsity_layout_o = ctx.saved_tensors
124
+ sparsity_block_size = ctx.sparsity_block_size
125
+ triton_block_size = ctx.triton_block_size
126
+
127
+ x_t, sparsity_layout_x_t = transpose(x, sparsity_layout_x, sparsity_block_size, triton_block_size)
128
+ y_t, sparsity_layout_y_t = transpose(y, sparsity_layout_y, sparsity_block_size, triton_block_size)
129
+
130
+ grad_x = matmul(grad_output, sparsity_layout_o, y_t, sparsity_layout_y_t, sparsity_layout_x,
131
+ sparsity_block_size, triton_block_size)
132
+ grad_y = matmul(x_t, sparsity_layout_x_t, grad_output, sparsity_layout_o, sparsity_layout_y,
133
+ sparsity_block_size, triton_block_size)
134
+
135
+ return grad_x, grad_y, None, None, None, None, None, None, None, None, None
136
+
137
+ @staticmethod
138
+ @triton.jit
139
+ def kernel_blocksparse_matmul_sss(x,
140
+ x_b, x_b_s, x_r_s, x_c_s,
141
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c, s_l_x_c_s,
142
+ r_lut_x,
143
+ y,
144
+ y_b, y_b_s, y_r_s, y_c_s,
145
+ s_l_y_b, s_l_y_b_s, s_l_y_r_s, s_l_y_c_s,
146
+ r_lut_y,
147
+ o,
148
+ o_b, o_b_s, o_r_s, o_c_s,
149
+ s_lut_o,
150
+ s_lut_o_r, s_lut_o_r_s,
151
+ s_lut_o_c_s,
152
+ sparsity_block_size,
153
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
154
+ # Get triton block indices
155
+ pid_blk = tl.program_id(axis=0)
156
+ pid_row = tl.program_id(axis=1)
157
+ pid_col = tl.program_id(axis=2)
158
+
159
+ # Get position of current sparsity block consisting of its batch, row, and column index
160
+ spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
161
+ spa_bat_o_msk = (spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
162
+ spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
163
+
164
+ spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
165
+ spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
166
+ spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
167
+
168
+ spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
169
+ spa_col_o_msk = (spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
170
+ spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
171
+
172
+ # Setup buffer
173
+ buf = tl.zeros(shape=(TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), dtype=tl.float32)
174
+
175
+ # Slide over triton block sized segments of input tensors
176
+ for i_seg_tri in range(0, tl.cdiv(s_l_x_c * sparsity_block_size, TRITON_BLOCK_SIZE)):
177
+ # Convert to segment index of sparsity layout
178
+ i_seg_spa = (i_seg_tri * TRITON_BLOCK_SIZE) // sparsity_block_size
179
+ # Calculate the triton segment index within a block
180
+ i_seg_tri_mod = i_seg_tri % (sparsity_block_size // TRITON_BLOCK_SIZE)
181
+
182
+ # Get reverse sparsity indices for input tensors x and y
183
+ # These are either -1 if the block is empty or equal to the index of the block in the sparse tensor
184
+
185
+ # Get reverse sparsity indices for x
186
+ rev_idx_spa_x_idx = (spa_bat_o * s_l_x_b_s +
187
+ spa_row_o * s_l_x_r_s +
188
+ i_seg_spa * s_l_x_c_s)
189
+ rev_idx_spa_x_msk = (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
190
+ rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
191
+
192
+ # Get reverse sparsity indices for y
193
+ rev_idx_spa_y_idx = (spa_bat_o * s_l_y_b_s + i_seg_spa * s_l_y_r_s + spa_col_o * s_l_y_c_s)
194
+ rev_idx_spa_y_msk = (rev_idx_spa_y_idx < s_l_y_b * s_l_y_b_s)
195
+ rev_idx_spa_y = tl.load(r_lut_y + rev_idx_spa_y_idx, mask=rev_idx_spa_y_msk).to(tl.int32)
196
+
197
+ # If both blocks are present commence calculation
198
+ if rev_idx_spa_x >= 0 and rev_idx_spa_y >= 0:
199
+ blk_x_idx = ((rev_idx_spa_x * x_b_s) +
200
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
201
+ ((i_seg_tri_mod * TRITON_BLOCK_SIZE +
202
+ tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
203
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
204
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
205
+
206
+ blk_y_idx = ((rev_idx_spa_y * y_b_s) +
207
+ ((i_seg_tri_mod * TRITON_BLOCK_SIZE +
208
+ tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
209
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
210
+ blk_y_msk = (blk_y_idx < y_b * y_b_s)
211
+ blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
212
+
213
+ # Perform matrix multiplication
214
+ buf += tl.dot(blk_x, blk_y)
215
+
216
+ # Store output
217
+ blk_o_idx = ((pid_blk * o_b_s) +
218
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
219
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
220
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
221
+ tl.store(o + blk_o_idx, buf, mask=blk_o_msk)