blksprs 1.8.3__py3-none-any.whl → 1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -10,13 +10,14 @@ from blksprs.utils.validation import validate_triton_block_size, validate_dimens
10
10
 
11
11
 
12
12
  def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: Tensor,
13
- size_target: torch.Size,
13
+ dim: int, size_target: torch.Size,
14
14
  sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
15
15
  """Builds the sparsity layout of either the source of a gather or the target of a scatter operation.
16
16
 
17
17
  Args:
18
18
  indices (BlksprsTensor): The block-sparse indices tensor in compressed form used for the gather or scatter operation.
19
19
  sparsity_layout_indices (Tensor): The sparsity layout of the indices block-sparse tensor.
20
+ dim (int): The dimension along which the operation is conducted.
20
21
  size_target (torch.Size): The size of the block-sparse target tensor in regular form.
21
22
  sparsity_block_size (int): The size of the sparsity blocks.
22
23
  triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
@@ -31,6 +32,8 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
31
32
 
32
33
  sparsity_lut_i = torch.nonzero(sparsity_layout_indices).contiguous()
33
34
 
35
+ adjusted_dim = dim % 3
36
+
34
37
  output = torch.zeros(size_target[0], size_target[1] // sparsity_block_size, size_target[2] // sparsity_block_size,
35
38
  dtype=torch.bool, device=indices.device)
36
39
 
@@ -55,6 +58,7 @@ def build_distribution_layout(indices: BlksprsTensor, sparsity_layout_indices: T
55
58
  i_b, i_b_s, i_r_s, i_c_s,
56
59
  sparsity_lut_i,
57
60
  s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
61
+ adjusted_dim,
58
62
  output,
59
63
  o_b, o_b_s, o_r_s, o_c_s,
60
64
  sparsity_block_size,
@@ -68,6 +72,7 @@ def kernel_distribution_layout(i,
68
72
  i_b, i_b_s, i_r_s, i_c_s,
69
73
  s_lut_i,
70
74
  s_lut_i_r, s_lut_i_r_s, s_lut_i_c_s,
75
+ dim,
71
76
  o,
72
77
  o_b, o_b_s, o_r_s, o_c_s,
73
78
  sparsity_block_size,
@@ -86,17 +91,30 @@ def kernel_distribution_layout(i,
86
91
  spa_row_i_msk = (spa_row_i_idx < s_lut_i_r * s_lut_i_r_s)
87
92
  spa_row_i = tl.load(s_lut_i + spa_row_i_idx, mask=spa_row_i_msk)
88
93
 
94
+ spa_col_i_idx = (pid_blk * s_lut_i_r_s + 2 * s_lut_i_c_s)
95
+ spa_col_i_msk = (spa_col_i_idx < s_lut_i_r * s_lut_i_r_s)
96
+ spa_col_i = tl.load(s_lut_i + spa_col_i_idx, mask=spa_col_i_msk)
97
+
89
98
  blk_i_idx = (pid_blk * i_b_s +
90
99
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
91
100
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_c_s)[None, :])
92
101
  blk_i_msk = (blk_i_idx < i_b * i_b_s)
93
102
  blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk)
94
103
 
95
- blk_i = blk_i // sparsity_block_size
104
+ dst_bat_idx = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_i, dtype=tl.int32)
105
+ dst_row_idx = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_i, dtype=tl.int32)
106
+ dst_col_idx = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_i, dtype=tl.int32)
107
+ if dim == 0:
108
+ dst_bat_idx = blk_i
109
+ elif dim == 1:
110
+ dst_row_idx = blk_i // sparsity_block_size
111
+ elif dim == 2:
112
+ dst_col_idx = blk_i // sparsity_block_size
113
+
96
114
  blk_v = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), 1, dtype=tl.int32)
97
115
 
98
- blk_o_idx = ((spa_bat_i * o_b_s) +
99
- (spa_row_i * o_r_s) +
100
- (blk_i * o_c_s))
116
+ blk_o_idx = ((dst_bat_idx * o_b_s) +
117
+ (dst_row_idx * o_r_s) +
118
+ (dst_col_idx * o_c_s))
101
119
  blk_o_msk = (blk_o_idx < o_b * o_b_s)
102
120
  tl.store(o + blk_o_idx, blk_v, mask=blk_o_msk)
@@ -3,19 +3,23 @@ import triton
3
3
  from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
+ from blksprs.ops.conversion import to_dense
6
7
  from blksprs.utils.blksprs_tensor import BlksprsTensor
7
8
  from blksprs.utils.tools import get_triton_block_size, stride
8
9
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
9
10
  validate_sparsity, validate_dtype_int, validate_sparsity_block_size, validate_triton_block_size
10
11
 
11
12
 
12
- def gather(src: BlksprsTensor, sparsity_layout_src: Tensor, idx: BlksprsTensor, sparsity_layout_idx: Tensor,
13
+ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
14
+ dim: int,
15
+ idx: BlksprsTensor, sparsity_layout_idx: Tensor,
13
16
  sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
14
17
  """Applies a gather operation on a block-sparse tensor in compressed form.
15
18
 
16
19
  Args:
17
20
  src (BlksprsTensor): The source block-sparse tensor in compressed form to gather from.
18
21
  sparsity_layout_src (Tensor): The sparsity layout of the source block-sparse tensor.
22
+ dim (int): The dimension along which to gather.
19
23
  idx (BlksprsTensor): The block-sparse indices tensor in compressed form specifying how to gather from the source tensor.
20
24
  sparsity_layout_idx (Tensor): The sparsity layout of the indices block-sparse tensor.
21
25
  sparsity_block_size (int): The size of the sparsity blocks.
@@ -46,16 +50,18 @@ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor, idx: BlksprsTensor,
46
50
  validate_contiguous(sparsity_layout_src, sparsity_reverse_lut_x,
47
51
  sparsity_layout_idx, sparsity_lut_i)
48
52
 
53
+ adjusted_dim = dim % 3
54
+
49
55
  return BlksprsTensor(_BlocksparseGather.apply(src, sparsity_layout_src, sparsity_reverse_lut_x,
50
- idx, sparsity_layout_idx, sparsity_lut_i,
51
- sparsity_block_size, triton_block_size))
56
+ adjusted_dim, idx, sparsity_layout_idx, sparsity_lut_i,
57
+ sparsity_block_size, triton_block_size))
52
58
 
53
59
 
54
60
  class _BlocksparseGather(torch.autograd.Function):
55
61
 
56
62
  @staticmethod
57
63
  def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
58
- i: Tensor, sparsity_layout_i: Tensor, sparsity_lut_i: Tensor,
64
+ dim: int, i: Tensor, sparsity_layout_i: Tensor, sparsity_lut_i: Tensor,
59
65
  sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
60
66
  output = torch.empty_like(i, dtype=x.dtype)
61
67
 
@@ -82,6 +88,7 @@ class _BlocksparseGather(torch.autograd.Function):
82
88
  x_b, x_b_s, x_r_s, x_c_s,
83
89
  s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
84
90
  sparsity_reverse_lut_x,
91
+ dim,
85
92
  i,
86
93
  i_b, i_b_s, i_r_s, i_c_s,
87
94
  output,
@@ -91,6 +98,7 @@ class _BlocksparseGather(torch.autograd.Function):
91
98
  triton_block_size))
92
99
 
93
100
  ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_i)
101
+ ctx.dim = dim
94
102
  ctx.sparsity_block_size = sparsity_block_size
95
103
  ctx.triton_block_size = triton_block_size
96
104
 
@@ -99,15 +107,15 @@ class _BlocksparseGather(torch.autograd.Function):
99
107
  @staticmethod
100
108
  def backward(ctx, grad_output):
101
109
  sparsity_layout_x, i, sparsity_layout_i = ctx.saved_tensors
110
+ dim = ctx.dim
102
111
  sparsity_block_size = ctx.sparsity_block_size
103
112
  triton_block_size = ctx.triton_block_size
104
113
 
105
114
  return scatter_reduce(grad_output, sparsity_layout_i,
106
- i,
107
- sparsity_layout_x,
108
- sparsity_block_size,
115
+ dim, i,
116
+ sparsity_layout_x, sparsity_block_size,
109
117
  reduce_op="sum",
110
- triton_block_size=triton_block_size), None, None, None, None, None, None, None
118
+ triton_block_size=triton_block_size), None, None, None, None, None, None, None, None
111
119
 
112
120
  @staticmethod
113
121
  @triton.jit
@@ -115,6 +123,7 @@ class _BlocksparseGather(torch.autograd.Function):
115
123
  x_b, x_b_s, x_r_s, x_c_s,
116
124
  s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
117
125
  r_lut_x,
126
+ dim,
118
127
  i,
119
128
  i_b, i_b_s, i_r_s, i_c_s,
120
129
  o,
@@ -136,6 +145,10 @@ class _BlocksparseGather(torch.autograd.Function):
136
145
  spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
137
146
  spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
138
147
 
148
+ spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
149
+ spa_col_o_msk = (spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
150
+ spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
151
+
139
152
  # Load index values
140
153
  blk_i_idx = ((pid_blk * i_b_s) +
141
154
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * i_r_s)[:, None] +
@@ -143,33 +156,50 @@ class _BlocksparseGather(torch.autograd.Function):
143
156
  blk_i_msk = (blk_i_idx < i_b * i_b_s)
144
157
  blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
145
158
 
146
- # Get positions of sparsity blocks
159
+ # Get indices of sparsity blocks and positions within the blocks
147
160
  pos_spa_blk_x = blk_i // sparsity_block_size
148
- pos_spa_col_x = blk_i % sparsity_block_size
161
+ pos_spa_int_x = blk_i % sparsity_block_size
162
+
163
+ rev_dst_bat_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_o, dtype=tl.int32)
164
+ rev_dst_row_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_o, dtype=tl.int32)
165
+ rev_dst_col_x = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_o, dtype=tl.int32)
166
+ dst_row_x = (((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
167
+ .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
168
+ dst_col_x = (((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
169
+ .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
170
+ if dim == 0:
171
+ rev_dst_bat_x = blk_i
172
+ elif dim == 1:
173
+ rev_dst_row_x = pos_spa_blk_x
174
+ dst_row_x = pos_spa_int_x * x_r_s
175
+ elif dim == 2:
176
+ rev_dst_col_x = pos_spa_blk_x
177
+ dst_col_x = pos_spa_int_x * x_c_s
149
178
 
150
179
  # Load reverse sparsity indices for x
151
- rev_idx_spa_x_idx = ((spa_bat_o * s_l_x_b_s) +
152
- (spa_row_o * s_l_x_r_s) +
153
- (pos_spa_blk_x * s_l_x_c_s))
180
+ rev_idx_spa_x_idx = ((rev_dst_bat_x * s_l_x_b_s) +
181
+ (rev_dst_row_x * s_l_x_r_s) +
182
+ (rev_dst_col_x * s_l_x_c_s))
154
183
  rev_idx_spa_x_msk = (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
155
184
  rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
156
185
 
157
186
  # Load x values
158
187
  blk_x_idx = ((rev_idx_spa_x * x_b_s) +
159
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
160
- (pos_spa_col_x * x_c_s))
161
- blk_x_msk = (blk_x_idx < x_b * x_b_s)
188
+ dst_row_x +
189
+ dst_col_x)
190
+ blk_x_msk = ((blk_x_idx < x_b * x_b_s) & rev_idx_spa_x_msk != -1)
162
191
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
163
192
 
164
193
  # Store output
165
194
  blk_o_idx = ((pid_blk * o_b_s) +
166
195
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
167
196
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
168
- blk_o_msk = (blk_o_idx < o_b * o_b_s)
197
+ blk_o_msk = ((blk_o_idx < o_b * o_b_s) & rev_idx_spa_x_msk != -1)
169
198
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
170
199
 
171
200
 
172
201
  def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
202
+ dim: int,
173
203
  idx: BlksprsTensor,
174
204
  sparsity_layout_tgt: Tensor,
175
205
  sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
@@ -184,6 +214,7 @@ def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
184
214
 
185
215
 
186
216
  def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
217
+ dim: int,
187
218
  idx: BlksprsTensor,
188
219
  sparsity_layout_tgt: Tensor,
189
220
  sparsity_block_size: int,
@@ -193,6 +224,7 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
193
224
  Args:
194
225
  src (BlksprsTensor): The source block-sparse tensor in compressed form to scatter from.
195
226
  sparsity_layout_src (Tensor): The sparsity layout of the source block-sparse tensor.
227
+ dim (int): The dimension along which to scatter.
196
228
  idx (BlksprsTensor): The block-sparse indices tensor in compressed form specifying how to scatter to the target tensor.
197
229
  sparsity_layout_tgt (Tensor): The sparsity layout of the target block-sparse tensor.
198
230
  sparsity_block_size (int): The size of the sparsity blocks.
@@ -230,18 +262,20 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
230
262
  validate_contiguous(sparsity_layout_src, sparsity_lut_x,
231
263
  sparsity_layout_tgt, sparsity_reverse_lut_o)
232
264
 
265
+ adjusted_dim = dim % 3
266
+
233
267
  return BlksprsTensor(_BlocksparseScatterReduce.apply(src, sparsity_layout_src, sparsity_lut_x,
234
- idx,
235
- sparsity_layout_tgt, sparsity_reverse_lut_o,
236
- sparsity_block_size, n_sparse_blocks,
237
- reduce_op, triton_block_size))
268
+ adjusted_dim, idx,
269
+ sparsity_layout_tgt, sparsity_reverse_lut_o,
270
+ sparsity_block_size, n_sparse_blocks,
271
+ reduce_op, triton_block_size))
238
272
 
239
273
 
240
274
  class _BlocksparseScatterReduce(torch.autograd.Function):
241
275
 
242
276
  @staticmethod
243
277
  def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_lut_x: Tensor,
244
- i: Tensor,
278
+ dim: int, i: Tensor,
245
279
  sparsity_layout_o: Tensor, sparsity_reverse_lut_o: Tensor,
246
280
  sparsity_block_size: int, n_sparse_blocks: int,
247
281
  reduce_op: str, triton_block_size: int) -> Tensor:
@@ -274,10 +308,11 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
274
308
  (x,
275
309
  x_b, x_b_s, x_r_s, x_c_s,
276
310
  sparsity_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
311
+ dim,
277
312
  i,
278
313
  i_b, i_b_s, i_r_s, i_c_s,
279
314
  output,
280
- o_b, o_b_s, o_r_s, o_c_s,
315
+ o_b, o_b_s,
281
316
  s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
282
317
  sparsity_reverse_lut_o,
283
318
  reduce_op_ind,
@@ -285,6 +320,7 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
285
320
  triton_block_size))
286
321
 
287
322
  ctx.save_for_backward(sparsity_layout_x, i, sparsity_layout_o)
323
+ ctx.dim = dim
288
324
  ctx.sparsity_block_size = sparsity_block_size
289
325
  ctx.reduce_op = reduce_op
290
326
  ctx.triton_block_size = triton_block_size
@@ -294,13 +330,14 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
294
330
  @staticmethod
295
331
  def backward(ctx, grad_output):
296
332
  sparsity_layout_x, i, sparsity_layout_o = ctx.saved_tensors
333
+ dim = ctx.dim
297
334
  sparsity_block_size = ctx.sparsity_block_size
298
335
  reduce_op = ctx.reduce_op
299
336
  triton_block_size = ctx.triton_block_size
300
337
 
301
338
  if reduce_op == "sum":
302
- return gather(grad_output, sparsity_layout_o, i, sparsity_layout_x, sparsity_block_size,
303
- triton_block_size=triton_block_size), None, None, None, None, None, None, None, None, None
339
+ return gather(grad_output, sparsity_layout_o, dim, i, sparsity_layout_x, sparsity_block_size,
340
+ triton_block_size=triton_block_size), None, None, None, None, None, None, None, None, None, None
304
341
  else:
305
342
  raise ValueError(f"Reduction operation '{reduce_op}' does not support backward pass")
306
343
 
@@ -309,10 +346,11 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
309
346
  def kernel_blocksparse_scatter(x,
310
347
  x_b, x_b_s, x_r_s, x_c_s,
311
348
  s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
349
+ dim,
312
350
  i,
313
351
  i_b, i_b_s, i_r_s, i_c_s,
314
352
  o,
315
- o_b, o_b_s, o_r_s, o_c_s,
353
+ o_b, o_b_s,
316
354
  s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
317
355
  r_lut_o,
318
356
  reduce_op_ind,
@@ -332,6 +370,10 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
332
370
  spa_row_x_msk = (spa_row_x_idx < s_lut_x_r * s_lut_x_r_s)
333
371
  spa_row_x = tl.load(s_lut_x + spa_row_x_idx, mask=spa_row_x_msk)
334
372
 
373
+ spa_col_x_idx = (pid_blk * s_lut_x_r_s + 2 * s_lut_x_c_s)
374
+ spa_col_x_msk = (spa_col_x_idx < s_lut_x_r * s_lut_x_r_s)
375
+ spa_col_x = tl.load(s_lut_x + spa_col_x_idx, mask=spa_col_x_msk)
376
+
335
377
  # Load x values
336
378
  blk_x_idx = ((pid_blk * x_b_s) +
337
379
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
@@ -346,22 +388,38 @@ class _BlocksparseScatterReduce(torch.autograd.Function):
346
388
  blk_i_msk = (blk_i_idx < i_b * i_b_s)
347
389
  blk_i = tl.load(i + blk_i_idx, mask=blk_i_msk).to(tl.int32)
348
390
 
349
- # Get positions of sparsity blocks
350
- pos_spa_blk_o = blk_i // sparsity_block_size
351
- pos_spa_col_o = blk_i % sparsity_block_size
391
+ # Get indices of sparsity blocks and positions within the blocks
392
+ pos_spa_blk_x = blk_i // sparsity_block_size
393
+ pos_spa_int_x = blk_i % sparsity_block_size
394
+
395
+ rev_dst_bat_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_bat_x, dtype=tl.int32)
396
+ rev_dst_row_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_row_x, dtype=tl.int32)
397
+ rev_dst_col_o = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), spa_col_x, dtype=tl.int32)
398
+ dst_row_o = (((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None]
399
+ .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
400
+ dst_col_o = (((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :]
401
+ .broadcast_to((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE)))
402
+ if dim == 0:
403
+ rev_dst_bat_o = blk_i
404
+ elif dim == 1:
405
+ rev_dst_row_o = pos_spa_blk_x
406
+ dst_row_o = pos_spa_int_x * x_r_s
407
+ elif dim == 2:
408
+ rev_dst_col_o = pos_spa_blk_x
409
+ dst_col_o = pos_spa_int_x * x_c_s
352
410
 
353
411
  # Load reverse sparsity indices for o
354
- rev_idx_spa_o_idx = ((spa_bat_x * s_l_o_b_s) +
355
- (spa_row_x * s_l_o_r_s) +
356
- (pos_spa_blk_o * s_l_o_c_s))
412
+ rev_idx_spa_o_idx = ((rev_dst_bat_o * s_l_o_b_s) +
413
+ (rev_dst_row_o * s_l_o_r_s) +
414
+ (rev_dst_col_o * s_l_o_c_s))
357
415
  rev_idx_spa_o_msk = (rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
358
416
  rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
359
417
 
360
418
  # Store output
361
419
  blk_o_idx = ((rev_idx_spa_o * o_b_s) +
362
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
363
- (pos_spa_col_o * o_c_s))
364
- blk_o_msk = (blk_o_idx < o_b * o_b_s)
420
+ dst_row_o +
421
+ dst_col_o)
422
+ blk_o_msk = ((blk_o_idx < o_b * o_b_s) & rev_idx_spa_o_msk != -1)
365
423
 
366
424
  if reduce_op_ind == 0:
367
425
  tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
@@ -153,6 +153,10 @@ class _BlocksparseGatherMDI(torch.autograd.Function):
153
153
  rev_idx_spa_x_msk = (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
154
154
  rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
155
155
 
156
+ if rev_idx_spa_x == -1:
157
+ tl.device_assert(False)
158
+ return
159
+
156
160
  # Load x values
157
161
  blk_x_idx = ((rev_idx_spa_x * x_b_s) +
158
162
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
@@ -342,6 +346,10 @@ class _BlocksparseScatterReduceMDI(torch.autograd.Function):
342
346
  rev_idx_spa_o_msk = (rev_idx_spa_o_idx < s_l_o_b * s_l_o_b_s)
343
347
  rev_idx_spa_o = tl.load(r_lut_o + rev_idx_spa_o_idx, mask=rev_idx_spa_o_msk).to(tl.int32)
344
348
 
349
+ if rev_idx_spa_o == -1:
350
+ tl.device_assert(False)
351
+ return
352
+
345
353
  # Store output
346
354
  blk_o_idx = ((rev_idx_spa_o * o_b_s) +
347
355
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
@@ -117,6 +117,10 @@ def kernel_blocksparse_row_wise_sum(x,
117
117
  rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
118
118
  rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
119
119
 
120
+ if rev_idx_spa == -1:
121
+ tl.device_assert(False)
122
+ return
123
+
120
124
  blk_idx = ((pid_blk * x_b_s) +
121
125
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
122
126
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
@@ -240,6 +244,10 @@ def kernel_blocksparse_row_wise_max(x,
240
244
  rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
241
245
  rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
242
246
 
247
+ if rev_idx_spa == -1:
248
+ tl.device_assert(False)
249
+ return
250
+
243
251
  blk_idx = ((pid_blk * x_b_s) +
244
252
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
245
253
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
blksprs/ops/softmax.py CHANGED
@@ -238,6 +238,10 @@ class _BlocksparseSoftmax(torch.autograd.Function):
238
238
  rev_idx_spa_s_msk = (rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
239
239
  rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
240
240
 
241
+ if rev_idx_spa_s == -1:
242
+ tl.device_assert(False)
243
+ return
244
+
241
245
  blk_s_idx = (rev_idx_spa_s * s_b_s +
242
246
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
243
247
  (tl.arange(0, 1) * s_c_s)[None, :])
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: blksprs
3
- Version: 1.8.3
3
+ Version: 1.9
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -1,23 +1,23 @@
1
1
  blksprs/__init__.py,sha256=YMrERuEf1hTv5vVdOvPEzh9rESn4uqOB7WHB12Qs5lU,1836
2
- blksprs/layouting/distribution_layout.py,sha256=wmj1SwWyY_fhbvMmh6AXrR77LoSp6xLwUWCCyO9i5lk,4239
2
+ blksprs/layouting/distribution_layout.py,sha256=9f_Bx2YQF4LTH95C0S7OuB9eeOuh73NcE0Z7Wrtug38,5034
3
3
  blksprs/layouting/sparsity_layout.py,sha256=-sScIn4hhG35j9BXytrojEzp8jnFkMargJjtivPV1fc,9755
4
4
  blksprs/ops/conversion.py,sha256=ol-iV45wDzp9G1dJEkY53EdrvnmHzcl7QQmPJ-xqQTs,22410
5
- blksprs/ops/distribution.py,sha256=fXZV6UegCVpIwzh-A825OSYClHWu5k0UMYdO2UGDUpM,17067
5
+ blksprs/ops/distribution.py,sha256=OWTH_dfO43uIMY6S44wpvRoIBuKzaTy1f57BOEf7EYA,19925
6
6
  blksprs/ops/matmul.py,sha256=yh2ZnO0ZltT1AgadiFP0vX28YJ4n74xO-I_5vFUmOmA,11452
7
7
  blksprs/ops/partitioning.py,sha256=K0ExR2a3W62d_9xxCJzsdJDLgtbxTI6P8loOOBdhPzE,7674
8
8
  blksprs/ops/repeat.py,sha256=IvSIRbuyFn0b57LObymLgup0LqlWQ3ndIw-QuiYQcaU,14564
9
- blksprs/ops/softmax.py,sha256=CDQT2KnwkJ4hGIgT0EUp6P92uiYpCdJQ9zxcdgSAAJA,12102
9
+ blksprs/ops/softmax.py,sha256=V-1vqRefjjwSp6JPwKxVxh5pTng9gOdtgGlXHDPbpYM,12190
10
10
  blksprs/ops/transpose.py,sha256=jxzFFffrj4S_9tiCrwwUMdz6EA98o1dziWXjlqb64a4,6859
11
- blksprs/ops/experimental/distribution_mdi.py,sha256=HaRUu6LTWATzjuHWgddIUE-0fgY-O87STpJO4JY7k_8,20357
11
+ blksprs/ops/experimental/distribution_mdi.py,sha256=F_0tl4Gn-9JZs_TZfDtZqO_RPFl7sejqQNF8UNIoCbs,20533
12
12
  blksprs/ops/misc/broadcast_ops.py,sha256=cPtRJa3pkZfY1QG51CJ-zDn4SK-CRpX5LEXoKGGMvRU,5418
13
13
  blksprs/ops/misc/exp.py,sha256=FnSFosBfJHuiEbD0MD-i4axLghRn4a0f8KvHXrKBB6M,3802
14
- blksprs/ops/misc/row_wise.py,sha256=SvJuNww-_QoVKTyTjMvjmzHlBuUlTKamkuq_rKzwAqs,17081
14
+ blksprs/ops/misc/row_wise.py,sha256=U4Kk0-P4oOuMNjMHXxP2gP9njMIeMfz8RZrzItNIF94,17229
15
15
  blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
16
16
  blksprs/utils/blksprs_tensor.py,sha256=VjplBgDhnf9sxf-1R5feA0xp5FDCDdaeZmCeoIRdCnc,151
17
17
  blksprs/utils/processing.py,sha256=hYsFxEbQKcbqU4WtZWusPnWMHg8ZAZF1SKZJYjez9aU,2060
18
18
  blksprs/utils/tools.py,sha256=r7Y4C37vfSWUyQTGwa8NyRqgovmsq9hMufkenqYHOxo,539
19
19
  blksprs/utils/validation.py,sha256=IZxH2HZpePmv7lRqLsSwV_6FwsdnTXv9q4j98vCMSsQ,4195
20
- blksprs-1.8.3.dist-info/METADATA,sha256=DZkJ_HeetF1V6-_F6GeG0uXT-QmttMFOq4ao8fiSMgQ,8458
21
- blksprs-1.8.3.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
22
- blksprs-1.8.3.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
23
- blksprs-1.8.3.dist-info/RECORD,,
20
+ blksprs-1.9.dist-info/METADATA,sha256=9mMjmvJ2_Rz0uyiY9S8SKTRcs6YW5Jk1w6PRobh6Q3c,8456
21
+ blksprs-1.9.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
22
+ blksprs-1.9.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
23
+ blksprs-1.9.dist-info/RECORD,,
File without changes