blksprs 1.6.1__py3-none-any.whl → 1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/ops/repeat.py ADDED
@@ -0,0 +1,322 @@
1
+ import torch
2
+ import triton
3
+ from triton import language as tl
4
+ from torch import Tensor
5
+
6
+ from blksprs.utils.tools import get_triton_block_size, stride
7
+ from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
8
+ validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
9
+
10
+
11
+ def repeat(x: Tensor, sparsity_layout_x: Tensor, repeats: tuple[int, int, int],
12
+ sparsity_block_size: int, sparsity_layout_output: Tensor = None, triton_block_size: int = None) -> (
13
+ Tensor, Tensor):
14
+ """Repeats a block-spare tensor in compressed form according to the given repeats.
15
+
16
+ Repeats is a 3-tuple of integers, where each integer represents the number of times the tensor should be repeated in
17
+ the first, second and third dimension respectively.
18
+
19
+ Note:
20
+ An output sparsity layout can be provided, in which case only the indicated blocks are filled. This may result
21
+ in blocks not being present in the output that were present in the input if the output sparsity layout indicates
22
+ them to be sparse.
23
+
24
+ Args:
25
+ x (Tensor): A block-sparse tensor in compressed form.
26
+ sparsity_layout_x (Tensor): The sparsity layout of the block-sparse tensor.
27
+ repeats (tuple[int, int, int]): The number of times the tensor should be repeated in the first, second and
28
+ third dimension respectively.
29
+ sparsity_block_size (int): The size of the sparsity blocks.
30
+ sparsity_layout_output (Tensor): The desired sparsity layout of the output tensor (default ``None``).
31
+ triton_block_size (int): The block size to use for the triton kernel (default ``None``).
32
+
33
+ Returns:
34
+ Tensor: A block-sparse tensor in compressed form containing the repeated values.
35
+ Tensor: The sparsity layout of the resulting output tensor.
36
+
37
+ """
38
+ x = x.contiguous()
39
+
40
+ validate_dimensions(x)
41
+ validate_contiguous(x)
42
+ validate_device(x)
43
+ validate_sparsity(sparsity_block_size, (x, sparsity_layout_x))
44
+ validate_sparsity_block_size(sparsity_block_size, x)
45
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
46
+
47
+ sparsity_layout_o = sparsity_layout_x.repeat(repeats[0], repeats[1], repeats[2])
48
+
49
+ if sparsity_layout_output is not None:
50
+ sparsity_layout_o = torch.logical_and(sparsity_layout_o, sparsity_layout_output)
51
+
52
+ sparsity_lut = torch.nonzero(sparsity_layout_o).contiguous()
53
+
54
+ sparsity_layout_flat = sparsity_layout_x.reshape(-1)
55
+ sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
56
+ (sparsity_layout_flat == 1) -
57
+ (1 * (sparsity_layout_flat == 0)))
58
+ .reshape(sparsity_layout_x.size())
59
+ .repeat(repeats[0], repeats[1], repeats[2])
60
+ .reshape(-1).contiguous())
61
+
62
+ n_sparse_blocks = torch.sum(sparsity_layout_o.to(torch.int)).item()
63
+
64
+ validate_contiguous(sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
65
+
66
+ return _BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
67
+ sparsity_block_size, n_sparse_blocks, triton_block_size), sparsity_layout_o
68
+
69
+
70
+ def repeat_interleave(x: Tensor, sparsity_layout_x: Tensor, repeats: int,
71
+ sparsity_block_size: int, sparsity_layout_output: Tensor = None,
72
+ triton_block_size: int = None) -> (
73
+ Tensor, Tensor):
74
+ """Repeats and interleaves the block-sparse tensor in compressed form.
75
+
76
+ Repeats each matrix contained in the tensors by ``repeats`` amount and places them consecutively in the output
77
+ tensor.
78
+
79
+ Note:
80
+ In similar fashion to the regular ``repeat`` an output sparsity layout can be provided. In this case only
81
+ non-sparse blocks will be filled.
82
+
83
+ Args:
84
+ x (Tensor): A block-sparse tensor in compressed form.
85
+ sparsity_layout_x (Tensor): The sparsity layout of the block-sparse tensor.
86
+ repeats (int): The number of times to repeat the matrices.
87
+ sparsity_block_size (int): The size of the sparsity blocks.
88
+ sparsity_layout_output (Tensor): The desired sparsity layout of the output tensor (default ``None``).
89
+ triton_block_size (int): The block size to use for the triton kernel (default ``None``).
90
+
91
+ Returns:
92
+ Tensor: A block-sparse tensor in compressed form containing the repeated and interleaved matrices.
93
+ Tensor: The sparsity layout of the resulting output tensor.
94
+
95
+ """
96
+ x = x.contiguous()
97
+
98
+ validate_dimensions(x)
99
+ validate_contiguous(x)
100
+ validate_device(x)
101
+ validate_sparsity(sparsity_block_size, (x, sparsity_layout_x))
102
+ validate_sparsity_block_size(sparsity_block_size, x)
103
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
104
+
105
+ sparsity_layout_o = torch.repeat_interleave(sparsity_layout_x, repeats, dim=0).contiguous()
106
+
107
+ if sparsity_layout_output is not None:
108
+ sparsity_layout_o = torch.logical_and(sparsity_layout_o, sparsity_layout_output)
109
+
110
+ sparsity_lut = torch.nonzero(sparsity_layout_o).contiguous()
111
+
112
+ sparsity_layout_flat = sparsity_layout_x.reshape(-1)
113
+ sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
114
+ (sparsity_layout_flat == 1) -
115
+ (1 * (sparsity_layout_flat == 0)))
116
+ .reshape(sparsity_layout_x.size())
117
+ .repeat_interleave(repeats, dim=0)
118
+ .reshape(-1).contiguous())
119
+
120
+ n_sparse_blocks = torch.sum(sparsity_layout_o.to(torch.int)).item()
121
+
122
+ validate_contiguous(sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
123
+
124
+ return _BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
125
+ sparsity_block_size, n_sparse_blocks, triton_block_size), sparsity_layout_o
126
+
127
+
128
+ class _BlocksparseRepeat(torch.autograd.Function):
129
+
130
+ @staticmethod
131
+ def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor,
132
+ sparsity_reverse_lut: Tensor,
133
+ sparsity_block_size: int, n_sparse_blocks: int,
134
+ triton_block_size: int) -> Tensor:
135
+ ctx.save_for_backward(sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
136
+ ctx.x_size = x.size()
137
+ ctx.x_stride = stride(x)
138
+
139
+ return forward_flow(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
140
+ n_sparse_blocks, triton_block_size)
141
+
142
+ @staticmethod
143
+ def backward(ctx, grad_output):
144
+ sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut = ctx.saved_tensors
145
+ x_size = ctx.x_size
146
+ x_stride = ctx.x_stride
147
+ sparsity_block_size = ctx.sparsity_block_size
148
+ triton_block_size = ctx.triton_block_size
149
+
150
+ n_sparse_blocks = torch.sum(sparsity_layout_x.to(torch.int)).item()
151
+
152
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
153
+ dtype=grad_output.dtype, device=grad_output.device)
154
+
155
+ x_b, x_r, x_c = grad_output.size()
156
+ x_b_s, x_r_s, x_c_s = stride(grad_output)
157
+ s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_o.size()
158
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_o)
159
+ s_lut_r, s_lut_c = sparsity_lut.size()
160
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
161
+ o_b, o_r, o_c = x_size
162
+ o_b_s, o_r_s, o_c_s = x_stride
163
+
164
+ if triton_block_size is None:
165
+ triton_block_size = get_triton_block_size(sparsity_block_size)
166
+
167
+ triton_grid = lambda meta: [x_b,
168
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
169
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
170
+
171
+ (kernel_blocksparse_flow_push[triton_grid]
172
+ (grad_output,
173
+ x_b, x_b_s, x_r_s, x_c_s,
174
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
175
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
176
+ sparsity_reverse_lut,
177
+ output,
178
+ o_b, o_b_s, o_r_s, o_c_s,
179
+ triton_block_size))
180
+
181
+ return output, None, None, None, None, None, None, None
182
+
183
+
184
+ @triton.jit
185
+ def kernel_blocksparse_flow_pull(x,
186
+ x_b, x_b_s, x_r_s, x_c_s,
187
+ o,
188
+ o_b, o_b_s, o_r_s, o_c_s,
189
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
190
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
191
+ r_lut,
192
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
193
+ # Get triton block indices
194
+ pid_blk = tl.program_id(axis=0)
195
+ pid_row = tl.program_id(axis=1)
196
+ pid_col = tl.program_id(axis=2)
197
+
198
+ # Get sparsity index of current output block consisting of its batch, row, and column index
199
+ spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
200
+ spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
201
+ spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
202
+
203
+ spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
204
+ spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
205
+ spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
206
+
207
+ spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
208
+ spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
209
+ spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
210
+
211
+ # Get reverse sparsity index
212
+ rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
213
+ spa_row * s_l_o_r_s +
214
+ spa_col * s_l_o_c_s)
215
+ rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
216
+ rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
217
+
218
+ if rev_idx_spa == -1:
219
+ tl.device_assert(False)
220
+ return
221
+
222
+ blk_x_idx = (rev_idx_spa * x_b_s +
223
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
224
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
225
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
226
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
227
+
228
+ blk_o_idx = (pid_blk * o_b_s +
229
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
230
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
231
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
232
+ tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
233
+
234
+
235
+ @triton.jit
236
+ def kernel_blocksparse_flow_push(x,
237
+ x_b, x_b_s, x_r_s, x_c_s,
238
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
239
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
240
+ r_lut,
241
+ o,
242
+ o_b, o_b_s, o_r_s, o_c_s,
243
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
244
+ # Get triton block indices
245
+ pid_blk = tl.program_id(axis=0)
246
+ pid_row = tl.program_id(axis=1)
247
+ pid_col = tl.program_id(axis=2)
248
+
249
+ # Get sparsity index of current input block consisting of its batch, row, and column index
250
+ spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
251
+ spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
252
+ spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
253
+
254
+ spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
255
+ spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
256
+ spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
257
+
258
+ spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
259
+ spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
260
+ spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
261
+
262
+ # Get reverse sparsity index
263
+ rev_idx_spa_idx = (spa_bat * s_l_x_b_s +
264
+ spa_row * s_l_x_r_s +
265
+ spa_col * s_l_x_c_s)
266
+ rev_idx_spa_msk = (rev_idx_spa_idx < s_l_x_b * s_l_x_b_s)
267
+ rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
268
+
269
+ if rev_idx_spa == -1:
270
+ tl.device_assert(False)
271
+ return
272
+
273
+ blk_x_idx = (pid_blk * x_b_s +
274
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
275
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
276
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
277
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
278
+
279
+ blk_o_idx = (rev_idx_spa * o_b_s +
280
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
281
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
282
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
283
+ tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
284
+
285
+
286
+ def forward_flow(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
287
+ sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
288
+ output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
289
+ dtype=x.dtype, device=x.device)
290
+ output = torch.zeros_like(output)
291
+
292
+ x_b, x_r, x_c = x.size()
293
+ x_b_s, x_r_s, x_c_s = stride(x)
294
+ o_b, o_r, o_c = output.size()
295
+ o_b_s, o_r_s, o_c_s = stride(output)
296
+ s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
297
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
298
+ s_lut_r, s_lut_c = sparsity_lut.size()
299
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
300
+
301
+ if triton_block_size is None:
302
+ triton_block_size = get_triton_block_size(sparsity_block_size)
303
+
304
+ triton_grid = lambda meta: [o_b,
305
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
306
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
307
+
308
+ (kernel_blocksparse_flow_pull[triton_grid]
309
+ (x,
310
+ x_b, x_b_s, x_r_s, x_c_s,
311
+ output,
312
+ o_b, o_b_s, o_r_s, o_c_s,
313
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
314
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
315
+ sparsity_reverse_lut,
316
+ triton_block_size))
317
+
318
+ # Save for backward pass
319
+ ctx.sparsity_block_size = sparsity_block_size
320
+ ctx.triton_block_size = triton_block_size
321
+
322
+ return output
blksprs/ops/softmax.py CHANGED
@@ -3,9 +3,9 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
- from blksprs.ops.exp import exp
6
+ from blksprs.misc.exp import exp
7
7
  from blksprs.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
8
- from blksprs.utils.tools import get_triton_block_size
8
+ from blksprs.utils.tools import get_triton_block_size, stride
9
9
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
10
10
  validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
11
11
 
@@ -61,9 +61,9 @@ class _BlocksparseSoftmax(torch.autograd.Function):
61
61
  output = torch.empty_like(x)
62
62
 
63
63
  x_b, x_r, x_c = x.size()
64
- x_b_s, x_r_s, x_c_s = x.stride()
64
+ x_b_s, x_r_s, x_c_s = stride(x)
65
65
  s_lut_r, s_lut_c = sparsity_lut.size()
66
- s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
66
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
67
67
  o_b, o_r, o_c = output.size()
68
68
 
69
69
  x_row_wise_max, sparsity_layout_rwm = row_wise_max(x, sparsity_layout, sparsity_block_size,
@@ -76,9 +76,9 @@ class _BlocksparseSoftmax(torch.autograd.Function):
76
76
  triton_block_size=triton_block_size)
77
77
 
78
78
  s_b, s_r, s_c = x_exp_row_wise_sum.shape
79
- s_b_s, s_r_s, s_c_s = x_exp_row_wise_sum.stride()
79
+ s_b_s, s_r_s, s_c_s = stride(x_exp_row_wise_sum)
80
80
  s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_rws.shape
81
- s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = sparsity_layout_rws.stride()
81
+ s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_rws)
82
82
 
83
83
  if triton_block_size is None:
84
84
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -119,13 +119,13 @@ class _BlocksparseSoftmax(torch.autograd.Function):
119
119
  (1 * (sparsity_layout_s_flat == 0)))
120
120
 
121
121
  o_b, o_r, o_c = o.size()
122
- o_b_s, o_r_s, o_c_s = o.stride()
122
+ o_b_s, o_r_s, o_c_s = stride(o)
123
123
  s_lut_r, s_lut_c = sparsity_lut.size()
124
- s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
124
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
125
125
  s_b, s_r, s_c = s.size()
126
- s_b_s, s_r_s, s_c_s = s.stride()
126
+ s_b_s, s_r_s, s_c_s = stride(s)
127
127
  s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
128
- s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = sparsity_layout_s.stride()
128
+ s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_s)
129
129
 
130
130
  grad_x = torch.empty_like(o, dtype=torch.float)
131
131
 
@@ -181,7 +181,8 @@ class _BlocksparseSoftmax(torch.autograd.Function):
181
181
  rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
182
182
 
183
183
  if rev_idx_spa_s == -1:
184
- assert False, "Invalid sparsity block"
184
+ tl.device_assert(False)
185
+ return
185
186
 
186
187
  # Load x block
187
188
  blk_x_idx = ((pid_blk * x_b_s) +
blksprs/ops/transpose.py CHANGED
@@ -3,7 +3,7 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
- from blksprs.utils.tools import get_triton_block_size
6
+ from blksprs.utils.tools import get_triton_block_size, stride
7
7
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
8
8
  validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
9
9
 
@@ -63,13 +63,13 @@ class _BlocksparseTranspose(torch.autograd.Function):
63
63
  dtype=x.dtype, device=x.device)
64
64
 
65
65
  x_b, x_r, x_c = x.size()
66
- x_b_s, x_r_s, x_c_s = x.stride()
66
+ x_b_s, x_r_s, x_c_s = stride(x)
67
67
  s_l_b, s_l_r, s_l_c = sparsity_layout_o.size()
68
- s_l_b_s, s_l_r_s, s_l_c_s = sparsity_layout_o.stride()
68
+ s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout_o)
69
69
  s_lut_r, s_lut_c = sparsity_lut.shape
70
- s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
70
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
71
71
  o_b, o_r, o_c = output.size()
72
- o_b_s, o_r_s, o_c_s = output.stride()
72
+ o_b_s, o_r_s, o_c_s = stride(output)
73
73
 
74
74
  if triton_block_size is None:
75
75
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -140,7 +140,8 @@ class _BlocksparseTranspose(torch.autograd.Function):
140
140
  rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
141
141
 
142
142
  if rev_idx_spa == -1:
143
- assert False, "Invalid sparsity block"
143
+ tl.device_assert(False)
144
+ return
144
145
 
145
146
  blk_x_idx = (rev_idx_spa * x_b_s +
146
147
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
blksprs/utils/tools.py CHANGED
@@ -23,3 +23,6 @@ def get_triton_block_size(sparsity_block_size: int, limit: int = 128):
23
23
 
24
24
  def disable_validation():
25
25
  _set_skip_validation(True)
26
+
27
+ def stride(x: Tensor):
28
+ return x.view(x.shape).stride()
@@ -3,6 +3,7 @@ from torch import Tensor
3
3
 
4
4
  VALIDATION = True
5
5
 
6
+
6
7
  def validate_dimensions(*tensors: Tensor, dims=3) -> None:
7
8
  if _check_skip_validation():
8
9
  return
@@ -71,10 +72,25 @@ def validate_sparsity(sparsity_block_size: int, *tensor_sparsity_layout_tuples:
71
72
  raise ValueError("Mismatch between sparsity layout and blocks")
72
73
 
73
74
 
75
+ def validate_sparsity_dense(sparsity_block_size: int, *tensor_sparsity_layout_tuples: tuple[Tensor, Tensor]) -> None:
76
+ if _check_skip_validation():
77
+ return
78
+
79
+ for (tensor, sparsity_layout) in tensor_sparsity_layout_tuples:
80
+ _validate_sparsity_layout_values(sparsity_layout)
81
+
82
+ if not sparsity_layout.dim() == 3:
83
+ raise ValueError("Sparsity layout must have exactly 3 dimensions")
84
+ if not (tensor.size(-1) // sparsity_block_size == sparsity_layout.size(-1) and
85
+ tensor.size(-2) // sparsity_block_size == sparsity_layout.size(-2)):
86
+ raise ValueError("Tensor not conforming to sparsity layout")
87
+
88
+
74
89
  def _validate_sparsity_layout_values(sparsity_layout: Tensor):
75
90
  if not torch.all(torch.logical_or(sparsity_layout == 0, sparsity_layout == 1)):
76
91
  raise ValueError("Sparsity layout values must be either 0 or 1")
77
92
 
93
+
78
94
  def validate_sparsity_block_size(sparsity_block_size: int, *tensors):
79
95
  if _check_skip_validation():
80
96
  return
@@ -86,6 +102,7 @@ def validate_sparsity_block_size(sparsity_block_size: int, *tensors):
86
102
  if not (tensor.size(-1) % sparsity_block_size == 0 and tensor.size(-2) % sparsity_block_size == 0):
87
103
  raise ValueError("Tensor sizes must be divisible by sparsity block size")
88
104
 
105
+
89
106
  def validate_triton_block_size(triton_block_size: int, sparsity_block_size: int):
90
107
  if _check_skip_validation():
91
108
  return
@@ -99,9 +116,11 @@ def validate_triton_block_size(triton_block_size: int, sparsity_block_size: int)
99
116
  if triton_block_size > sparsity_block_size:
100
117
  raise ValueError("Triton block size cannot be larger than sparsity block size")
101
118
 
119
+
102
120
  def _check_skip_validation():
103
121
  return not VALIDATION
104
122
 
123
+
105
124
  def _set_skip_validation(skip_validation: bool):
106
125
  global VALIDATION
107
- VALIDATION = not skip_validation
126
+ VALIDATION = not skip_validation
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: blksprs
3
- Version: 1.6.1
3
+ Version: 1.8
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -28,12 +28,13 @@ A lightweight and efficient library for operations on block-sparse matrices in P
28
28
 
29
29
  Currently supported operations (includes gradient calculation):
30
30
 
31
- - Sparse matrix multiplication (_supports any combination of sparse and dense matrices due to support
32
- for `sparse = sparse @ sparse` matmul_)
31
+ - Matrix multiplication
33
32
  - Softmax
34
33
  - Transpose
35
34
  - Gather
36
35
  - Scatter (_supports either no reduction or summation, gradients are only available for summation_)
36
+ - Repeat (_supports target sparsity layout_)
37
+ - Repeat Interleave (_supports target sparsity layout_)
37
38
  - Splitting and merging of matrices along the last dimension
38
39
  - Conversion to and from sparse form
39
40
  - Conversion to different sparsity layouts and different sparsity block sizes
@@ -50,8 +51,14 @@ These include, e.g.,
50
51
  Note that in order to correctly apply element-wise operations between two sparse tensors their sparsity layouts have to
51
52
  match.
52
53
 
54
+ Further helpful operations (included in the ``bs.misc`` module) that do **not** support gradient calculation include:
55
+
56
+ - Row-wise sum, max, addition, and subtraction
57
+ - Broadcast addition and subtraction between slices
58
+
53
59
  Furthermore, the library provides a set of utility functions for the creation of sparsity layouts based on existing
54
- dense tensors.
60
+ dense tensors and for the scatter operation (module ``bs.layout``), as well as utility functions to ensure correct input
61
+ dimensionality (module ``bs.util``).
55
62
 
56
63
  ## Installation
57
64
 
@@ -64,7 +71,7 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
64
71
 
65
72
  ### Dependencies
66
73
 
67
- - [PyTorch](https://pytorch.org/) (built with v2.4.0)
74
+ - [PyTorch](https://pytorch.org/) (built with v2.5.0)
68
75
  - _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
69
76
 
70
77
  ## Changelog
@@ -0,0 +1,21 @@
1
+ blksprs/__init__.py,sha256=qDqoB-X5vo5_3PlrN54sp59XR5hg6EanIsADS67QnH0,1058
2
+ blksprs/experimental/distribution_mdi.py,sha256=jE_SbB0SMGPcCoeM0699cceWAikBkBpGM_8Oo5A7Ets,20083
3
+ blksprs/layouting/distribution_layout.py,sha256=nCR3fCF6eNpi97DI6MMLF2hq_9Lwyo6_aUCIZiJfiX4,4170
4
+ blksprs/layouting/sparsity_layout.py,sha256=JNwbJ4L-418zCiCbt-vLfZ0xU7ReP0zr_tLHs_dytRA,9686
5
+ blksprs/misc/broadcast_ops.py,sha256=0RLnLMYV7GAPI2YL8RotcxjIUSBZKGxdVcsGaJFeL_I,5327
6
+ blksprs/misc/exp.py,sha256=cdF0s93Q9iucIXuEE3howsB0N6D60xgvem7C-a-yiGI,3704
7
+ blksprs/misc/partitioning.py,sha256=nBRZzfi3XYAhDLEBzYflQkvGa3MIZ-qNeIlrZ16k44g,7533
8
+ blksprs/misc/row_wise.py,sha256=0vDJA8uCocmebSIPIbFeND5_PQIE10pUj3DBOQXlTvE,16888
9
+ blksprs/ops/conversion.py,sha256=9xVdCrj38m1cMh43LQs-GrXZ5pNRjhQyKx6paaw3C6A,21898
10
+ blksprs/ops/distribution.py,sha256=V3TK5SlNT_JdGHNaDNl-U4U5vwAYsgkAOg4eTmYxbuA,16877
11
+ blksprs/ops/matmul.py,sha256=uqVe6Dz2aaCbCglM1uS2eRHVKh7PQcuecaIBWFubPEw,11256
12
+ blksprs/ops/repeat.py,sha256=OSsa2rj6BHL3Kedfu3wr0D82mn4HmbJ1l7XEmT-6ehg,14423
13
+ blksprs/ops/softmax.py,sha256=5nAgeT68nucgOugjtCy1aBIMa7Kyk1KNN-j8fgmeVuk,11996
14
+ blksprs/ops/transpose.py,sha256=67pDdCEb7r-Xifupl82fBKAYsxKcCUDy--cPPfduRvU,6761
15
+ blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
16
+ blksprs/utils/tools.py,sha256=S3836Zuc-BMigv-5mLTjRznCzuaF6oYW-Ir9zzUnr3o,655
17
+ blksprs/utils/validation.py,sha256=WzihRPibXYzss3PMkhDt5_d3Q3NHA_d1TzTz3CoGPGg,4136
18
+ blksprs-1.8.dist-info/METADATA,sha256=koey4w8ynY84Z0dM5u9y_P831rtR0w-Z-dBcje4O6ko,8007
19
+ blksprs-1.8.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
20
+ blksprs-1.8.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
21
+ blksprs-1.8.dist-info/RECORD,,