blksprs 1.0__py3-none-any.whl → 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,362 @@
1
+ import torch
2
+ import triton
3
+ from torch import Tensor
4
+ from triton import language as tl
5
+
6
+ from blksprs.utils.tools import get_triton_block_size
7
+ from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
8
+ validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
9
+
10
+
11
+ def gather(src: Tensor, sparsity_layout_src: Tensor, idx: Tensor, sparsity_layout_idx: Tensor,
12
+ sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
13
+ """Applies a gather operation on a block-sparse tensor in compressed form.
14
+
15
+ Args:
16
+ src (Tensor): The source block-sparse tensor in compressed form to gather from.
17
+ sparsity_layout_src (Tensor): The sparsity layout of the source block-sparse tensor.
18
+ idx (Tensor): The block-sparse indices tensor in compressed form specifying how to gather from the source tensor.
19
+ sparsity_layout_idx (Tensor): The sparsity layout of the indices block-sparse tensor.
20
+ sparsity_block_size (int): The size of the sparsity blocks.
21
+ triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
22
+
23
+ Returns:
24
+ Tensor: The result of the gather operation as a block-sparse tensor in compressed form.
25
+
26
+ """
27
+ validate_dimensions(src, idx)
28
+ validate_contiguous(src, idx)
29
+ validate_dtype_int(idx)
30
+ validate_device(src, idx)
31
+ validate_sparsity(sparsity_block_size, (src, sparsity_layout_src), (idx, sparsity_layout_idx))
32
+ validate_sparsity_block_size(sparsity_block_size, src, idx)
33
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
34
+
35
+ sparsity_layout_x_flat = sparsity_layout_src.reshape(-1)
36
+ sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
37
+ (sparsity_layout_x_flat == 1) -
38
+ (1 * (sparsity_layout_x_flat == 0)))
39
+
40
+ sparsity_lut_i = torch.nonzero(sparsity_layout_idx).contiguous()
41
+
42
+ validate_contiguous(sparsity_layout_src, sparsity_reverse_lut_x,
43
+ sparsity_layout_idx, sparsity_lut_i)
44
+
45
+ return _BlocksparseGather.apply(src, sparsity_layout_src, sparsity_reverse_lut_x,
46
+ idx, sparsity_layout_idx, sparsity_lut_i,
47
+ sparsity_block_size, triton_block_size)
48
+
49
+
50
+ class _BlocksparseGather(torch.autograd.Function):
51
+
52
+ @staticmethod
53
+ def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
54
+ i: Tensor, sparsity_layout_i: Tensor, sparsity_lut_i: Tensor,
55
+ sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
56
+ output = torch.empty_like(i, dtype=x.dtype)
57
+
58
+ x_b, x_r, x_c = x.size()
59
+ x_b_s, x_r_s, x_c_s = x.stride()
60
+ s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
61
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout_x.stride()
62
+ i_b, i_r, i_c = i.size()
63
+ i_b_s, i_r_s, i_c_s = i.stride()
64
+ s_lut_i_r, s_lut_i_c = sparsity_lut_i.size()
65
+ s_lut_i_r_s, s_lut_i_c_s = sparsity_lut_i.stride()
66
+ o_b, o_r, o_c = output.size()
67
+ o_b_s, o_r_s, o_c_s = output.stride()
68
+
69
+ if triton_block_size is None:
70
+ triton_block_size = get_triton_block_size(sparsity_block_size)
71
+
72
+ triton_grid = lambda meta: [o_b,
73
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
74
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
75
+
76
+ (_BlocksparseGather.kernel_blocksparse_gather[triton_grid]
77
+ (x,
78
+ x_b, x_b_s, x_r_s, x_c_s,
79
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
80
+ sparsity_reverse_lut_x,
81
+ i,
82
+ i_b, i_b_s, i_r_s, i_c_s,
83
+ output,
84
+ o_b, o_b_s, o_r_s, o_c_s,
85
+ sparsity_lut_i, s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
86
+ sparsity_block_size,
87
+ triton_block_size))
88
+
89
+ ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_i)
90
+ ctx.sparsity_block_size = sparsity_block_size
91
+ ctx.triton_block_size = triton_block_size
92
+
93
+ return output
94
+
95
+ @staticmethod
96
+ def backward(ctx, grad_output):
97
+ sparsity_layout_x, i, sparsity_layout_i = ctx.saved_tensors
98
+ sparsity_block_size = ctx.sparsity_block_size
99
+ triton_block_size = ctx.triton_block_size
100
+
101
+ return scatter_reduce(grad_output, sparsity_layout_i,
102
+ i,
103
+ sparsity_layout_x,
104
+ sparsity_block_size,
105
+ reduce_op="sum",
106
+ triton_block_size=triton_block_size), None, None, None, None, None, None, None
107
+
108
+ @staticmethod
109
+ @triton.jit
110
+ def kernel_blocksparse_gather(x,
111
+ x_b, x_b_s, x_r_s, x_c_s,
112
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
113
+ r_lut_x,
114
+ i,
115
+ i_b, i_b_s, i_r_s, i_c_s,
116
+ o,
117
+ o_b, o_b_s, o_r_s, o_c_s,
118
+ s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
119
+ sparsity_block_size,
120
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
121
+ # Get triton block indices
122
+ pid_blk = tl.program_id(axis=0)
123
+ pid_row = tl.program_id(axis=1)
124
+ pid_col = tl.program_id(axis=2)
125
+
126
+ # Get position of current sparsity block consisting of its batch, row, and column index
127
+ spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
128
+ spa_bat_o_msk = (spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
129
+ spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
130
+
131
+ spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
132
+ spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
133
+ spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
134
+
135
+ # Load index values
136
+ blk_i_idx = ((pid_blk * i_b_s) +
137
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
138
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
139
+ blk_i_msk = (blk_i_idx < i_b * i_b_s)
140
+ blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
141
+
142
+ # Get positions of sparsity blocks
143
+ pos_spa_blk_x = blk_i // sparsity_block_size
144
+ pos_spa_col_x = blk_i % sparsity_block_size
145
+
146
+ # Load reverse sparsity indices for x
147
+ rev_idx_spa_x_idx = ((spa_bat_o * s_l_x_b_s) +
148
+ (spa_row_o * s_l_x_r_s) +
149
+ (pos_spa_blk_x * s_l_x_c_s))
150
+ rev_idx_spa_x_msk = (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
151
+ rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
152
+
153
+ # Load x values
154
+ blk_x_idx = ((rev_idx_spa_x * x_b_s) +
155
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
156
+ (pos_spa_col_x * x_c_s))
157
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
158
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
159
+
160
+ # Store output
161
+ blk_o_idx = ((pid_blk * o_b_s) +
162
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
163
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
164
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
165
+ tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
166
+
167
+
168
+ def scatter(src: Tensor, sparsity_layout_src: Tensor,
169
+ idx: Tensor,
170
+ sparsity_layout_tgt: Tensor,
171
+ sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
172
+ """Wrapper for ``scatter_reduce`` with ``reduce_op="none"``.
173
+
174
+ """
175
+ return scatter_reduce(src, sparsity_layout_src,
176
+ idx,
177
+ sparsity_layout_tgt,
178
+ sparsity_block_size,
179
+ reduce_op="none", triton_block_size=triton_block_size)
180
+
181
+
182
+ def scatter_reduce(src: Tensor, sparsity_layout_src: Tensor,
183
+ idx: Tensor,
184
+ sparsity_layout_tgt: Tensor,
185
+ sparsity_block_size: int,
186
+ reduce_op: str = "sum", triton_block_size: int = None) -> Tensor:
187
+ """Applies a scatter operation on a block-sparse tensor in compressed form.
188
+
189
+ Args:
190
+ src (Tensor): The source block-sparse tensor in compressed form to scatter from.
191
+ sparsity_layout_src (Tensor): The sparsity layout of the source block-sparse tensor.
192
+ idx (Tensor): The block-sparse indices tensor in compressed form specifying how to scatter to the target tensor.
193
+ sparsity_layout_tgt (Tensor): The sparsity layout of the target block-sparse tensor.
194
+ sparsity_block_size (int): The size of the sparsity blocks.
195
+ reduce_op (str, optional): The reduction operation to apply during the scatter operation (default ``"sum"``).
196
+ Supported operations are ``"none"`` and ``"sum"``.
197
+ triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
198
+
199
+ Returns:
200
+ Tensor: The result of the scatter operation as a block-sparse tensor in compressed form.
201
+
202
+ """
203
+ validate_dimensions(src, idx)
204
+ validate_contiguous(src, idx)
205
+ validate_dtype_int(idx)
206
+ validate_device(src, idx)
207
+ validate_sparsity(sparsity_block_size, (src, sparsity_layout_src), (idx, sparsity_layout_src))
208
+ validate_sparsity_block_size(sparsity_block_size, src, idx)
209
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
210
+
211
+ if reduce_op not in ["none", "sum"]:
212
+ raise ValueError(f"Reduction operation '{reduce_op}' is not supported")
213
+
214
+ sparsity_lut_x = torch.nonzero(sparsity_layout_src).contiguous()
215
+
216
+ sparsity_layout_o_flat = sparsity_layout_tgt.reshape(-1)
217
+ sparsity_reverse_lut_o = ((torch.cumsum(sparsity_layout_o_flat, dim=-1) - 1) *
218
+ (sparsity_layout_o_flat == 1) -
219
+ (1 * (sparsity_layout_o_flat == 0)))
220
+
221
+ n_sparse_blocks = torch.sum(sparsity_layout_tgt.to(torch.int)).item()
222
+
223
+ validate_contiguous(sparsity_layout_src, sparsity_lut_x,
224
+ sparsity_layout_tgt, sparsity_reverse_lut_o)
225
+
226
+ return _BlocksparseScatterReduce.apply(src, sparsity_layout_src, sparsity_lut_x,
227
+ idx,
228
+ sparsity_layout_tgt, sparsity_reverse_lut_o,
229
+ sparsity_block_size, n_sparse_blocks,
230
+ reduce_op, triton_block_size)
231
+
232
+
233
+ class _BlocksparseScatterReduce(torch.autograd.Function):
234
+
235
+ @staticmethod
236
+ def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_lut_x: Tensor,
237
+ i: Tensor,
238
+ sparsity_layout_o: Tensor, sparsity_reverse_lut_o: Tensor,
239
+ sparsity_block_size: int, n_sparse_blocks: int,
240
+ reduce_op: str, triton_block_size: int) -> Tensor:
241
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
242
+ dtype=x.dtype, device=x.device)
243
+
244
+ x_b, x_r, x_c = x.size()
245
+ x_b_s, x_r_s, x_c_s = x.stride()
246
+ s_lut_x_r, s_lut_x_c = sparsity_lut_x.size()
247
+ s_lut_x_r_s, s_lut_x_c_s = sparsity_lut_x.stride()
248
+ i_b, i_r, i_c = i.size()
249
+ i_b_s, i_r_s, i_c_s = i.stride()
250
+ o_b, o_r, o_c = output.size()
251
+ o_b_s, o_r_s, o_c_s = output.stride()
252
+ s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
253
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_o.stride()
254
+
255
+ if triton_block_size is None:
256
+ triton_block_size = get_triton_block_size(sparsity_block_size)
257
+
258
+ triton_grid = lambda meta: [x_b,
259
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
260
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
261
+
262
+ reduce_op_ind = 0
263
+ if reduce_op == "sum":
264
+ reduce_op_ind = 1
265
+
266
+ (_BlocksparseScatterReduce.kernel_blocksparse_scatter[triton_grid]
267
+ (x,
268
+ x_b, x_b_s, x_r_s, x_c_s,
269
+ sparsity_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
270
+ i,
271
+ i_b, i_b_s, i_r_s, i_c_s,
272
+ output,
273
+ o_b, o_b_s, o_r_s, o_c_s,
274
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
275
+ sparsity_reverse_lut_o,
276
+ reduce_op_ind,
277
+ sparsity_block_size,
278
+ triton_block_size))
279
+
280
+ ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_o)
281
+ ctx.sparsity_block_size = sparsity_block_size
282
+ ctx.reduce_op = reduce_op
283
+ ctx.triton_block_size = triton_block_size
284
+
285
+ return output
286
+
287
+ @staticmethod
288
+ def backward(ctx, grad_output):
289
+ sparsity_layout_x, i, sparsity_layout_o = ctx.saved_tensors
290
+ sparsity_block_size = ctx.sparsity_block_size
291
+ reduce_op = ctx.reduce_op
292
+ triton_block_size = ctx.triton_block_size
293
+
294
+ if reduce_op == "sum":
295
+ return gather(grad_output, sparsity_layout_o, i, sparsity_layout_x, sparsity_block_size,
296
+ triton_block_size=triton_block_size), None, None, None, None, None, None, None, None, None
297
+ else:
298
+ raise ValueError(f"Reduction operation '{reduce_op}' does not support backward pass")
299
+
300
+ @staticmethod
301
+ @triton.jit
302
+ def kernel_blocksparse_scatter(x,
303
+ x_b, x_b_s, x_r_s, x_c_s,
304
+ s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
305
+ i,
306
+ i_b, i_b_s, i_r_s, i_c_s,
307
+ o,
308
+ o_b, o_b_s, o_r_s, o_c_s,
309
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
310
+ r_lut_o,
311
+ reduce_op_ind,
312
+ sparsity_block_size,
313
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
314
+ # Get triton block indices
315
+ pid_blk = tl.program_id(axis=0)
316
+ pid_row = tl.program_id(axis=1)
317
+ pid_col = tl.program_id(axis=2)
318
+
319
+ # Get position of current sparsity block consisting of its batch, row, and column index
320
+ spa_bat_x_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
321
+ spa_bat_x_msk = (spa_bat_x_idx < s_lut_x_r * s_lut_x_r_s)
322
+ spa_bat_x = tl.load(s_lut_x + spa_bat_x_idx, mask=spa_bat_x_msk)
323
+
324
+ spa_row_x_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
325
+ spa_row_x_msk = (spa_row_x_idx < s_lut_x_r * s_lut_x_r_s)
326
+ spa_row_x = tl.load(s_lut_x + spa_row_x_idx, mask=spa_row_x_msk)
327
+
328
+ # Load x values
329
+ blk_x_idx = ((pid_blk * x_b_s) +
330
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
331
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
332
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
333
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
334
+
335
+ # Load index values
336
+ blk_i_idx = ((pid_blk * i_b_s) +
337
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
338
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
339
+ blk_i_msk = (blk_i_idx < i_b * i_b_s)
340
+ blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
341
+
342
+ # Get positions of sparsity blocks
343
+ pos_spa_blk_o = blk_i // sparsity_block_size
344
+ pos_spa_col_o = blk_i % sparsity_block_size
345
+
346
+ # Load reverse sparsity indices for o
347
+ rev_idx_spa_o_idx = ((spa_bat_x * s_l_o_b_s) +
348
+ (spa_row_x * s_l_o_r_s) +
349
+ (pos_spa_blk_o * s_l_o_c_s))
350
+ rev_idx_spa_o_msk = (rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
351
+ rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
352
+
353
+ # Store output
354
+ blk_o_idx = ((rev_idx_spa_o * o_b_s) +
355
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
356
+ (pos_spa_col_o * o_c_s))
357
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
358
+
359
+ if reduce_op_ind == 0:
360
+ tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
361
+ elif reduce_op_ind == 1:
362
+ tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
blksprs/ops/exp.py CHANGED
@@ -1,25 +1,35 @@
1
1
  import torch
2
2
  import triton
3
- from triton import language as tl
4
3
  from torch import Tensor
4
+ from triton import language as tl
5
5
 
6
6
  from blksprs.utils.tools import get_triton_block_size
7
- from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_dtype_float, validate_device
7
+ from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
8
+ validate_sparsity_block_size, validate_triton_block_size
8
9
 
9
10
 
10
11
  def exp(x: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
11
- """Applies the element-wise exponential function to the input tensor.
12
+ """Applies the element-wise exponential function to a block-sparse tensor.
13
+
14
+ Note:
15
+ This operation does not consider sparse blocks, i.e., these will not be set to ``e^0``.
16
+ Consider this when converting back to tensors in regular form.
17
+
18
+ Args:
19
+ x (Tensor): A block-sparse tensor in compressed form.
20
+ sparsity_block_size (int): The size of the sparsity blocks.
21
+ triton_block_size (int): The block size to use for the triton kernel (default ``None``).
12
22
 
13
- Returns a new tensor with the exponential of the elements of the input tensor.
23
+ Returns:
24
+ Tensor: The exponential function applied to all elements of the input tensor as a block-sparse tensor in
25
+ compressed form.
14
26
 
15
- Note:
16
- This operation does not consider sparse blocks, i.e., these will not be set to ``e^0``.
17
- Consider this when converting back to dense tensors.
18
27
  """
19
28
  validate_dimensions(x)
20
29
  validate_contiguous(x)
21
- validate_dtype_float(x)
22
30
  validate_device(x)
31
+ validate_sparsity_block_size(sparsity_block_size, x)
32
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
23
33
 
24
34
  return _BlocksparseExp.apply(x, sparsity_block_size, triton_block_size)
25
35
 
@@ -5,25 +5,39 @@ from triton import language as tl
5
5
 
6
6
  from blksprs.ops.transpose import transpose
7
7
  from blksprs.utils.tools import get_triton_block_size
8
- from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_dtype_float, validate_device, \
9
- validate_sparsity
8
+ from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
9
+ validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
10
10
 
11
11
 
12
- def matmul_sss(x: Tensor, y: Tensor,
13
- sparsity_layout_x: Tensor, sparsity_layout_y: Tensor, sparsity_layout_output: Tensor,
14
- sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
15
- """Performs matrix multiplication between two blocksparse tensors.
12
+ def matmul(x: Tensor, sparsity_layout_x: Tensor,
13
+ y: Tensor, sparsity_layout_y: Tensor,
14
+ sparsity_layout_output: Tensor,
15
+ sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
16
+ """Performs matrix multiplication between two block-sparse tensors.
16
17
 
17
- The desired sparsity layout of the output tensor is used to only calculate blocks that will be present in the output.
18
+ The sparsity layout of the output tensor is used to only calculate blocks that will be present in the output.
19
+
20
+ Args:
21
+ x (Tensor): A block-sparse tensor in compressed form.
22
+ y (Tensor): A block-sparse tensor in compressed form.
23
+ sparsity_layout_x (Tensor): The sparsity layout of the first block-sparse tensor.
24
+ sparsity_layout_y (Tensor): The sparsity layout of the second block-sparse tensor.
25
+ sparsity_layout_output (Tensor): The sparsity layout of the output tensor.
26
+ sparsity_block_size (int): The size of the sparsity blocks.
27
+ triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
28
+
29
+ Returns:
30
+ Tensor: The result of the matrix multiplication as a block-sparse tensor in compressed form.
18
31
 
19
32
  """
20
33
  validate_dimensions(x, y)
21
34
  validate_contiguous(x, y)
22
- validate_dtype_float(x, y)
23
35
  validate_device(x, y)
24
36
  validate_sparsity(sparsity_block_size, (x, sparsity_layout_x), (y, sparsity_layout_y))
25
37
  if sparsity_layout_x.size(-1) != sparsity_layout_y.size(-2):
26
38
  raise ValueError("Inner dimensions of tensors must match")
39
+ validate_sparsity_block_size(sparsity_block_size, x, y)
40
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
27
41
 
28
42
  sparsity_layout_x_flat = sparsity_layout_x.reshape(-1)
29
43
  sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
@@ -98,10 +112,7 @@ class _BlocksparseMatmulSSS(torch.autograd.Function):
98
112
  sparsity_block_size,
99
113
  triton_block_size))
100
114
 
101
- ctx.save_for_backward(x, y)
102
- ctx.sparsity_layout_x = sparsity_layout_x
103
- ctx.sparsity_layout_y = sparsity_layout_y
104
- ctx.sparsity_layout_o = sparsity_layout_o
115
+ ctx.save_for_backward(x, sparsity_layout_x, y, sparsity_layout_y, sparsity_layout_o)
105
116
  ctx.sparsity_block_size = sparsity_block_size
106
117
  ctx.triton_block_size = triton_block_size
107
118
 
@@ -109,26 +120,17 @@ class _BlocksparseMatmulSSS(torch.autograd.Function):
109
120
 
110
121
  @staticmethod
111
122
  def backward(ctx, grad_output):
112
- x, y = ctx.saved_tensors
113
- sparsity_layout_x = ctx.sparsity_layout_x
114
- sparsity_layout_y = ctx.sparsity_layout_y
115
- sparsity_layout_o = ctx.sparsity_layout_o
123
+ x, sparsity_layout_x, y, sparsity_layout_y, sparsity_layout_o = ctx.saved_tensors
116
124
  sparsity_block_size = ctx.sparsity_block_size
117
125
  triton_block_size = ctx.triton_block_size
118
126
 
119
127
  x_t, sparsity_layout_x_t = transpose(x, sparsity_layout_x, sparsity_block_size, triton_block_size)
120
128
  y_t, sparsity_layout_y_t = transpose(y, sparsity_layout_y, sparsity_block_size, triton_block_size)
121
129
 
122
- grad_x = matmul_sss(grad_output, y_t,
123
- sparsity_layout_o,
124
- sparsity_layout_y_t,
125
- sparsity_layout_x,
126
- sparsity_block_size, triton_block_size)
127
- grad_y = matmul_sss(x_t, grad_output,
128
- sparsity_layout_x_t,
129
- sparsity_layout_o,
130
- sparsity_layout_y,
131
- sparsity_block_size, triton_block_size)
130
+ grad_x = matmul(grad_output, sparsity_layout_o, y_t, sparsity_layout_y_t, sparsity_layout_x,
131
+ sparsity_block_size, triton_block_size)
132
+ grad_y = matmul(x_t, sparsity_layout_x_t, grad_output, sparsity_layout_o, sparsity_layout_y,
133
+ sparsity_block_size, triton_block_size)
132
134
 
133
135
  return grad_x, grad_y, None, None, None, None, None, None, None, None, None
134
136
 
@@ -4,23 +4,39 @@ from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
6
  from blksprs.utils.tools import get_triton_block_size
7
- from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_dtype_float, validate_device
7
+ from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
8
+ validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
8
9
 
9
10
 
10
11
  def row_wise_sum(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
11
12
  flag_slice_only: bool = False, triton_block_size: int = None) -> tuple[Tensor, Tensor]:
12
- """Computes the row-wise sum of a blocksparse tensor.
13
+ """Computes the row-wise sum of a block-sparse tensor.
13
14
 
14
- Returns a blocksparse tensor with only one block per row, where the first entry is the sum of the corresponding row.
15
+ Returns a block-sparse tensor in compressed form with only one block per row, where the first entry contains the sum
16
+ of the corresponding row.
15
17
 
16
18
  Note:
17
- If ``flag_slice_only`` is set the output will be of shape ``[batch_size, row_size, 1]``.
19
+ If ``flag_slice_only`` is set the output will be of shape ``[x.size(0), x.size(1), 1]``.
20
+
21
+ Args:
22
+ x (Tensor): A block-sparse tensor in compressed form.
23
+ sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
24
+ sparsity_block_size (int): The size of the sparsity blocks.
25
+ flag_slice_only (bool, optional): If set the output will be of shape ``[x.size(0), x.size(1), 1]``
26
+ (default ``False``).
27
+ triton_block_size (int): The block size to use for the triton kernel (default ``None``).
28
+
29
+ Returns:
30
+ tuple[Tensor, Tensor]: A tuple containing a block-sparse tensor in compressed form containing the row-wise sum
31
+ of the input and the sparsity layout of the output tensor.
18
32
 
19
33
  """
20
34
  validate_dimensions(x)
21
35
  validate_contiguous(x)
22
- validate_dtype_float(x)
23
36
  validate_device(x)
37
+ validate_sparsity(sparsity_block_size, (x, sparsity_layout))
38
+ validate_sparsity_block_size(sparsity_block_size, x)
39
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
24
40
 
25
41
  sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
26
42
  sparsity_layout_flat = sparsity_layout.reshape(-1)
blksprs/ops/softmax.py CHANGED
@@ -6,22 +6,37 @@ from triton import language as tl
6
6
  from blksprs.ops.exp import exp
7
7
  from blksprs.ops.row_wise_sum import row_wise_sum
8
8
  from blksprs.utils.tools import get_triton_block_size
9
- from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_dtype_float, validate_device
9
+ from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
10
+ validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
10
11
 
11
12
 
12
13
  def softmax(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
13
- """Computes the softmax of a blocksparse tensor.
14
+ """Computes the softmax of a block-sparse tensor in compressed form.
14
15
 
15
16
  Note:
16
- Sparse blocks are not considered for the calculation of the softmax, i.e., assumed to be ``-inf``.
17
+ Sparse blocks are not considered for the calculation of the softmax, i.e., all values are assumed to be ``-inf``.
18
+
19
+ Args:
20
+ x (Tensor): A block-sparse tensor in compressed form.
21
+ sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
22
+ sparsity_block_size (int): The size of the sparsity blocks.
23
+ triton_block_size (int): The block size to use for the triton kernel (default ``None``).
24
+
25
+ Returns:
26
+ Tensor: The result of the softmax operation as a block-sparse tensor in compressed form.
17
27
 
18
28
  """
19
29
  validate_dimensions(x)
20
30
  validate_contiguous(x)
21
- validate_dtype_float(x)
22
31
  validate_device(x)
23
-
24
- max_val = torch.max(x).item()
32
+ validate_sparsity(sparsity_block_size, (x, sparsity_layout))
33
+ validate_sparsity_block_size(sparsity_block_size, x)
34
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
35
+
36
+ if x.size(0) != 0:
37
+ max_val = torch.max(x).item()
38
+ else:
39
+ max_val = 0
25
40
  x_scaled = x - max_val
26
41
 
27
42
  sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
@@ -83,9 +98,7 @@ class _BlocksparseSoftmax(torch.autograd.Function):
83
98
  triton_block_size))
84
99
 
85
100
  # Save for backward pass
86
- ctx.save_for_backward(output)
87
- ctx.sparsity_layout = sparsity_layout
88
- ctx.sparsity_lut = sparsity_lut
101
+ ctx.save_for_backward(output, sparsity_layout, sparsity_lut)
89
102
  ctx.sparsity_block_size = sparsity_block_size
90
103
  ctx.triton_block_size = triton_block_size
91
104
 
@@ -93,9 +106,7 @@ class _BlocksparseSoftmax(torch.autograd.Function):
93
106
 
94
107
  @staticmethod
95
108
  def backward(ctx, grad_output):
96
- o = ctx.saved_tensors[0]
97
- sparsity_layout = ctx.sparsity_layout
98
- sparsity_lut = ctx.sparsity_lut
109
+ o, sparsity_layout, sparsity_lut = ctx.saved_tensors
99
110
  sparsity_block_size = ctx.sparsity_block_size
100
111
  triton_block_size = ctx.triton_block_size
101
112
 
blksprs/ops/transpose.py CHANGED
@@ -1,26 +1,37 @@
1
- from typing import Any
2
-
3
1
  import torch
4
2
  import triton
5
- from triton import language as tl
6
3
  from torch import Tensor
4
+ from triton import language as tl
7
5
 
8
6
  from blksprs.utils.tools import get_triton_block_size
9
- from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_dtype_float, validate_device
7
+ from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
8
+ validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
10
9
 
11
10
 
12
11
  def transpose(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> (
13
12
  Tensor, Tensor):
14
- """Transposes a blocksparse tensor.
13
+ """Transposes a block-sparse tensor in compressed form.
15
14
 
16
15
  Note:
17
16
  Returns the transposed tensor and the sparsity layout of the transposed tensor.
18
17
 
18
+ Args:
19
+ x (Tensor): A block-sparse tensor in compressed form.
20
+ sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
21
+ sparsity_block_size (int): The size of the sparsity blocks.
22
+ triton_block_size (int): The block size to use for the triton kernel (default ``None``).
23
+
24
+ Returns:
25
+ Tensor: The transposed block-sparse tensor in compressed form.
26
+ Tensor: The sparsity layout of the transposed tensor.
27
+
19
28
  """
20
29
  validate_dimensions(x)
21
30
  validate_contiguous(x)
22
- validate_dtype_float(x)
23
31
  validate_device(x)
32
+ validate_sparsity(sparsity_block_size, (x, sparsity_layout))
33
+ validate_sparsity_block_size(sparsity_block_size, x)
34
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
24
35
 
25
36
  sparsity_layout_t = sparsity_layout.transpose(-1, -2).contiguous()
26
37
 
@@ -75,6 +86,7 @@ class _BlocksparseTranspose(torch.autograd.Function):
75
86
  triton_block_size))
76
87
 
77
88
  # Save for backward pass
89
+ ctx.save_for_backward(sparsity_layout)
78
90
  ctx.sparsity_layout = sparsity_layout
79
91
  ctx.sparsity_block_size = sparsity_block_size
80
92
  ctx.triton_block_size = triton_block_size
@@ -83,7 +95,7 @@ class _BlocksparseTranspose(torch.autograd.Function):
83
95
 
84
96
  @staticmethod
85
97
  def backward(ctx, grad_output):
86
- sparsity_layout = ctx.sparsity_layout
98
+ sparsity_layout = ctx.saved_tensors[0]
87
99
  sparsity_block_size = ctx.sparsity_block_size
88
100
  triton_block_size = ctx.triton_block_size
89
101