blksprs 1.9.3__py3-none-any.whl → 1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -104,17 +104,17 @@ def kernel_blocksparse_row_wise_sum(x,
104
104
 
105
105
  # Get position of current sparsity block consisting of its batch and row index
106
106
  spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
107
- spa_bat_msk = (spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
107
+ spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
108
108
  spa_bat = tl.load(s_lut_x + spa_bat_idx, mask=spa_bat_msk)
109
109
 
110
110
  spa_row_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
111
- spa_row_msk = (spa_row_idx < s_lut_x_r * s_lut_x_r_s)
111
+ spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_x_r * s_lut_x_r_s)
112
112
  spa_row = tl.load(s_lut_x + spa_row_idx, mask=spa_row_msk)
113
113
 
114
114
  # Load reverse sparsity index for current block
115
115
  rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
116
116
  spa_row * s_l_o_r_s)
117
- rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
117
+ rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
118
118
  rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
119
119
 
120
120
  if rev_idx_spa == -1:
@@ -124,7 +124,7 @@ def kernel_blocksparse_row_wise_sum(x,
124
124
  blk_idx = ((pid_blk * x_b_s) +
125
125
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
126
126
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
127
- blk_msk = (blk_idx < x_b * x_b_s)
127
+ blk_msk = (blk_idx >= 0 and blk_idx < x_b * x_b_s)
128
128
  blk = tl.load(x + blk_idx, mask=blk_msk)
129
129
 
130
130
  buf = tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
@@ -132,7 +132,7 @@ def kernel_blocksparse_row_wise_sum(x,
132
132
  o_idx = (rev_idx_spa * o_b_s +
133
133
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
134
134
  (tl.arange(0, 1))[None, :])
135
- o_msk = (o_idx < o_b * o_b_s)
135
+ o_msk = (o_idx >= 0 and o_idx < o_b * o_b_s)
136
136
  tl.atomic_add(o + o_idx, buf, o_msk)
137
137
 
138
138
 
@@ -231,17 +231,17 @@ def kernel_blocksparse_row_wise_max(x,
231
231
 
232
232
  # Get position of current sparsity block consisting of its batch and row index
233
233
  spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
234
- spa_bat_msk = (spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
234
+ spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
235
235
  spa_bat = tl.load(s_lut_x + spa_bat_idx, mask=spa_bat_msk)
236
236
 
237
237
  spa_row_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
238
- spa_row_msk = (spa_row_idx < s_lut_x_r * s_lut_x_r_s)
238
+ spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_x_r * s_lut_x_r_s)
239
239
  spa_row = tl.load(s_lut_x + spa_row_idx, mask=spa_row_msk)
240
240
 
241
241
  # Load reverse sparsity index for current block
242
242
  rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
243
243
  spa_row * s_l_o_r_s)
244
- rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
244
+ rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
245
245
  rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
246
246
 
247
247
  if rev_idx_spa == -1:
@@ -251,7 +251,7 @@ def kernel_blocksparse_row_wise_max(x,
251
251
  blk_idx = ((pid_blk * x_b_s) +
252
252
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
253
253
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
254
- blk_msk = (blk_idx < x_b * x_b_s)
254
+ blk_msk = (blk_idx >= 0 and blk_idx < x_b * x_b_s)
255
255
  blk = tl.load(x + blk_idx, mask=blk_msk)
256
256
 
257
257
  buf = tl.reshape(tl.max(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
@@ -259,7 +259,7 @@ def kernel_blocksparse_row_wise_max(x,
259
259
  o_idx = (rev_idx_spa * o_b_s +
260
260
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
261
261
  (tl.arange(0, 1))[None, :])
262
- o_msk = (o_idx < o_b * o_b_s)
262
+ o_msk = (o_idx >= 0 and o_idx < o_b * o_b_s)
263
263
  tl.atomic_max(o + o_idx, buf, o_msk)
264
264
 
265
265
 
@@ -356,17 +356,17 @@ def kernel_blocksparse_row_wise_add(x,
356
356
 
357
357
  # Get position of current sparsity block consisting of its batch and row index
358
358
  spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
359
- spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
359
+ spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
360
360
  spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
361
361
 
362
362
  spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
363
- spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
363
+ spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
364
364
  spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
365
365
 
366
366
  # Get reverse sparsity indices for s
367
367
  rev_idx_spa_s_idx = (spa_bat * s_l_y_b_s +
368
368
  spa_row * s_l_y_r_s)
369
- rev_idx_spa_s_msk = (rev_idx_spa_s_idx < s_l_y_b * s_l_y_b_s)
369
+ rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_y_b * s_l_y_b_s)
370
370
  rev_idx_spa_s = tl.load(r_lut_y + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
371
371
 
372
372
  if rev_idx_spa_s == -1:
@@ -377,25 +377,22 @@ def kernel_blocksparse_row_wise_add(x,
377
377
  blk_x_idx = ((pid_blk * x_b_s) +
378
378
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
379
379
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
380
- blk_x_msk = (blk_x_idx < x_b * x_b_s)
380
+ blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
381
381
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
382
382
 
383
383
  # Load sum block
384
384
  blk_s_idx = (rev_idx_spa_s * y_b_s +
385
385
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
386
386
  (tl.arange(0, 1) * y_c_s)[None, :])
387
- blk_s_msk = (blk_s_idx < y_b * y_b_s)
387
+ blk_s_msk = (blk_s_idx >= 0 and blk_s_idx < y_b * y_b_s)
388
388
  blk_s = tl.load(y + blk_s_idx, mask=blk_s_msk)
389
389
 
390
390
  # Compute exp
391
391
  buf = blk_x + tl.broadcast_to(blk_s, (TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE))
392
392
 
393
- # debug
394
- asdf = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), 1.0, dtype=tl.float32)
395
-
396
393
  # Store block
397
394
  blk_o_idx = ((pid_blk * o_b_s) +
398
395
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
399
396
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
400
- blk_o_msk = (blk_o_idx < o_b * o_b_s)
397
+ blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
401
398
  tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
@@ -9,13 +9,14 @@ from blksprs.utils.validation import validate_dimensions, validate_contiguous, v
9
9
 
10
10
 
11
11
  def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
12
- sparsity_block_size: int, triton_block_size: int = None) -> (BlksprsTensor, Tensor):
12
+ dim: int, sparsity_block_size: int, triton_block_size: int = None) -> (BlksprsTensor, Tensor):
13
13
  """Splits a block-sparse tensor in compressed form along the last dimension into partitions.
14
14
 
15
15
  Args:
16
16
  x (BlksprsTensor): A block-sparse tensor in compressed form.
17
17
  sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
18
18
  partitions (int): The number of partitions to split the block-sparse tensor into.
19
+ dim (int): The dimension along which to split the tensor. Currently only supports dim=2.
19
20
  sparsity_block_size (int): The size of the sparsity blocks.
20
21
  triton_block_size (int): The block size to use for the triton kernel (default ``None``).
21
22
 
@@ -54,17 +55,22 @@ def split(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
54
55
 
55
56
  validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
56
57
 
58
+ adjusted_dim = dim % 3
59
+ if adjusted_dim != 2:
60
+ raise NotImplementedError("Currently only supports dim=2")
61
+
57
62
  return BlksprsTensor(_BlocksparseSplit.apply(x, sparsity_layout_output, sparsity_lut, sparsity_reverse_lut, partitions,
58
- sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_output
63
+ adjusted_dim, sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_output
59
64
 
60
65
 
61
66
  class _BlocksparseSplit(torch.autograd.Function):
62
67
 
63
68
  @staticmethod
64
69
  def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
65
- num_partitions: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
70
+ num_partitions: int, dim: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
66
71
  ctx.save_for_backward(sparsity_layout_o)
67
72
  ctx.num_partitions = num_partitions
73
+ ctx.dim = dim
68
74
 
69
75
  return flow_forward(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
70
76
  n_sparse_blocks, triton_block_size)
@@ -73,21 +79,23 @@ class _BlocksparseSplit(torch.autograd.Function):
73
79
  def backward(ctx, grad_output):
74
80
  sparsity_layout = ctx.saved_tensors[0]
75
81
  num_partitions = ctx.num_partitions
82
+ dim = ctx.dim
76
83
  sparsity_block_size = ctx.sparsity_block_size
77
84
  triton_block_size = ctx.triton_block_size
78
85
 
79
- return merge(grad_output, sparsity_layout, num_partitions,
80
- sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None
86
+ return merge(grad_output, sparsity_layout, num_partitions, dim,
87
+ sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None, None
81
88
 
82
89
 
83
90
  def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
84
- sparsity_block_size: int, triton_block_size: int = None) -> (BlksprsTensor, Tensor):
91
+ dim: int, sparsity_block_size: int, triton_block_size: int = None) -> (BlksprsTensor, Tensor):
85
92
  """Merges the specified partitions of a block-sparse tensor in compressed form along the last dimension.
86
93
 
87
94
  Args:
88
95
  x (BlksprsTensor): A block-sparse tensor in compressed form.
89
96
  sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
90
97
  partitions (int): The number of partitions to be merged.
98
+ dim (int): The dimension along which to merge the tensor. Currently only supports dim=2.
91
99
  sparsity_block_size (int): The size of the sparsity blocks.
92
100
  triton_block_size (int): The block size to use for the triton kernel (default ``None``).
93
101
 
@@ -128,17 +136,22 @@ def merge(x: BlksprsTensor, sparsity_layout: Tensor, partitions: int,
128
136
 
129
137
  validate_contiguous(sparsity_layout_output, sparsity_lut, sparsity_reverse_lut)
130
138
 
139
+ adjusted_dim = dim % 3
140
+ if adjusted_dim != 2:
141
+ raise NotImplementedError("Currently only supports dim=2")
142
+
131
143
  return BlksprsTensor(_BlocksparseMerge.apply(x, sparsity_layout_output, sparsity_lut, sparsity_reverse_lut, partitions,
132
- sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_output
144
+ adjusted_dim, sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_output
133
145
 
134
146
 
135
147
  class _BlocksparseMerge(torch.autograd.Function):
136
148
 
137
149
  @staticmethod
138
150
  def forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
139
- num_partitions: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
151
+ num_partitions: int, dim: int, sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
140
152
  ctx.save_for_backward(sparsity_layout_o)
141
153
  ctx.num_partitions = num_partitions
154
+ ctx.dim = dim
142
155
 
143
156
  return flow_forward(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
144
157
  n_sparse_blocks, triton_block_size)
@@ -147,10 +160,11 @@ class _BlocksparseMerge(torch.autograd.Function):
147
160
  def backward(ctx, grad_output):
148
161
  sparsity_layout = ctx.saved_tensors[0]
149
162
  num_partitions = ctx.num_partitions
163
+ dim = ctx.dim
150
164
  sparsity_block_size = ctx.sparsity_block_size
151
165
  triton_block_size = ctx.triton_block_size
152
166
 
153
- return split(grad_output, sparsity_layout, num_partitions,
154
- sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None
167
+ return split(grad_output, sparsity_layout, num_partitions, dim,
168
+ sparsity_block_size, triton_block_size)[0], None, None, None, None, None, None, None, None
155
169
 
156
170
 
blksprs/ops/softmax.py CHANGED
@@ -11,7 +11,8 @@ from blksprs.utils.validation import validate_contiguous, validate_dimensions, v
11
11
  validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
12
12
 
13
13
 
14
- def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
14
+ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
15
+ triton_block_size: int = None) -> BlksprsTensor:
15
16
  """Computes the softmax of a block-sparse tensor in compressed form.
16
17
 
17
18
  Note:
@@ -47,9 +48,9 @@ def softmax(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
47
48
  validate_contiguous(sparsity_layout, sparsity_lut, sparsity_reverse_lut_rws)
48
49
 
49
50
  return BlksprsTensor(_BlocksparseSoftmax.apply(x, sparsity_layout,
50
- sparsity_lut,
51
- sparsity_reverse_lut_rws,
52
- sparsity_block_size, triton_block_size))
51
+ sparsity_lut,
52
+ sparsity_reverse_lut_rws,
53
+ sparsity_block_size, triton_block_size))
53
54
 
54
55
 
55
56
  class _BlocksparseSoftmax(torch.autograd.Function):
@@ -168,17 +169,17 @@ class _BlocksparseSoftmax(torch.autograd.Function):
168
169
 
169
170
  # Get position of current sparsity block consisting of its batch and row index
170
171
  spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
171
- spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
172
+ spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
172
173
  spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
173
174
 
174
175
  spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
175
- spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
176
+ spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
176
177
  spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
177
178
 
178
179
  # Get reverse sparsity indices for s
179
180
  rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
180
181
  spa_row * s_l_s_r_s)
181
- rev_idx_spa_s_msk = (rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
182
+ rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
182
183
  rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
183
184
 
184
185
  if rev_idx_spa_s == -1:
@@ -189,14 +190,14 @@ class _BlocksparseSoftmax(torch.autograd.Function):
189
190
  blk_x_idx = ((pid_blk * x_b_s) +
190
191
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
191
192
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
192
- blk_x_msk = (blk_x_idx < x_b * x_b_s)
193
+ blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
193
194
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
194
195
 
195
196
  # Load sum block
196
197
  blk_s_idx = (rev_idx_spa_s * s_b_s +
197
198
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
198
199
  (tl.arange(0, 1) * s_c_s)[None, :])
199
- blk_s_msk = (blk_s_idx < s_b * s_b_s)
200
+ blk_s_msk = (blk_s_idx >= 0 and blk_s_idx < s_b * s_b_s)
200
201
  blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
201
202
 
202
203
  # Compute softmax
@@ -226,16 +227,16 @@ class _BlocksparseSoftmax(torch.autograd.Function):
226
227
 
227
228
  # Get position of current sparsity block consisting of its batch and row index
228
229
  spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
229
- spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
230
+ spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
230
231
  spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
231
232
 
232
233
  spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
233
- spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
234
+ spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
234
235
  spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
235
236
 
236
237
  rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
237
238
  spa_row * s_l_s_r_s)
238
- rev_idx_spa_s_msk = (rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
239
+ rev_idx_spa_s_msk = (rev_idx_spa_s_idx >= 0 and rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
239
240
  rev_idx_spa_s = tl.load(r_lut_s + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
240
241
 
241
242
  if rev_idx_spa_s == -1:
@@ -245,19 +246,19 @@ class _BlocksparseSoftmax(torch.autograd.Function):
245
246
  blk_s_idx = (rev_idx_spa_s * s_b_s +
246
247
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * s_r_s)[:, None] +
247
248
  (tl.arange(0, 1) * s_c_s)[None, :])
248
- blk_s_msk = (blk_s_idx < s_b * s_b_s)
249
+ blk_s_msk = (blk_s_idx >= 0 and blk_s_idx < s_b * s_b_s)
249
250
  blk_s = tl.load(s + blk_s_idx, mask=blk_s_msk)
250
251
 
251
252
  blk_g_idx = ((pid_blk * g_b_s) +
252
253
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_r_s)[:, None] +
253
254
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * g_c_s)[None, :])
254
- blk_g_msk = (blk_g_idx < g_b * g_b_s)
255
+ blk_g_msk = (blk_g_idx >= 0 and blk_g_idx < g_b * g_b_s)
255
256
  blk_g = tl.load(g + blk_g_idx, mask=blk_g_msk)
256
257
 
257
258
  blk_x_idx = ((pid_blk * x_b_s) +
258
259
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
259
260
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
260
- blk_x_msk = (blk_x_idx < x_b * x_b_s)
261
+ blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
261
262
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
262
263
 
263
264
  buf = blk_x * (blk_g - blk_s)
@@ -265,5 +266,5 @@ class _BlocksparseSoftmax(torch.autograd.Function):
265
266
  blk_o_idx = ((pid_blk * o_b_s) +
266
267
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
267
268
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
268
- blk_o_msk = (blk_o_idx < o_b * o_b_s)
269
+ blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
269
270
  tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
blksprs/ops/transpose.py CHANGED
@@ -50,8 +50,9 @@ def transpose(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: in
50
50
 
51
51
  validate_contiguous(sparsity_layout_t, sparsity_lut, sparsity_reverse_lut)
52
52
 
53
- return BlksprsTensor(_BlocksparseTranspose.apply(x, sparsity_layout_t, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
54
- n_sparse_blocks, triton_block_size)), sparsity_layout_t
53
+ return BlksprsTensor(
54
+ _BlocksparseTranspose.apply(x, sparsity_layout_t, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
55
+ n_sparse_blocks, triton_block_size)), sparsity_layout_t
55
56
 
56
57
 
57
58
  class _BlocksparseTranspose(torch.autograd.Function):
@@ -122,22 +123,22 @@ class _BlocksparseTranspose(torch.autograd.Function):
122
123
 
123
124
  # Get sparsity index of current output block consisting of its batch, row, and column index
124
125
  spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
125
- spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
126
+ spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
126
127
  spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
127
128
 
128
129
  spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
129
- spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
130
+ spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
130
131
  spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
131
132
 
132
133
  spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
133
- spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
134
+ spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
134
135
  spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
135
136
 
136
137
  # Get reverse sparsity index
137
138
  rev_idx_spa_idx = (spa_bat * s_l_b_s +
138
139
  spa_row * s_l_r_s +
139
140
  spa_col * s_l_c_s)
140
- rev_idx_spa_msk = (rev_idx_spa_idx < s_l_b * s_l_b_s)
141
+ rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_b * s_l_b_s)
141
142
  rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
142
143
 
143
144
  if rev_idx_spa == -1:
@@ -147,7 +148,7 @@ class _BlocksparseTranspose(torch.autograd.Function):
147
148
  blk_x_idx = (rev_idx_spa * x_b_s +
148
149
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
149
150
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
150
- blk_x_msk = (blk_x_idx < x_b * x_b_s)
151
+ blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
151
152
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
152
153
 
153
154
  blk_x_t = tl.trans(blk_x)
@@ -155,5 +156,5 @@ class _BlocksparseTranspose(torch.autograd.Function):
155
156
  blk_o_idx = (pid_blk * o_b_s +
156
157
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
157
158
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
158
- blk_o_msk = (blk_o_idx < o_b * o_b_s)
159
+ blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
159
160
  tl.store(o + blk_o_idx, blk_x_t, mask=blk_o_msk)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: blksprs
3
- Version: 1.9.3
3
+ Version: 1.10
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -23,14 +23,6 @@ Requires-Dist: build; extra == "build"
23
23
  [![GitHub Release](https://img.shields.io/github/v/release/FelixSchoen/blksprs?include_prereleases&label=Latest%20Release)](https://github.com/FelixSchoen/blksprs/releases)
24
24
  [![Python Version](https://img.shields.io/badge/Python%20Version-3.11-blue)](https://www.python.org/downloads/release/python-3119/)
25
25
 
26
- ## Important Notice
27
-
28
- 🚨 **Non-Final API** 🚨
29
-
30
- Although it already supports a wide variety of functions, this library is still under active development and the API is
31
- subject to change. For feature requests or bug reports, please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
32
- We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
33
-
34
26
  ## Overview
35
27
 
36
28
  A lightweight and efficient library for operations on block-sparse matrices in PyTorch using Triton.
@@ -44,7 +36,7 @@ Currently supported operations (includes gradient calculation):
44
36
  - Scatter (_supports either no reduction or summation, gradients are only available for summation_)
45
37
  - Repeat (_supports target sparsity layout_)
46
38
  - Repeat Interleave (_supports target sparsity layout_)
47
- - Splitting and merging of matrices along the last dimension
39
+ - Splitting and merging of matrices (_currently* only supports splitting and merging along the last dimension_)
48
40
  - Conversion to and from sparse form
49
41
  - Conversion to different sparsity layouts and different sparsity block sizes
50
42
 
@@ -70,13 +62,15 @@ Furthermore, the library provides a set of utility functions
70
62
  - for the creation of sparsity layouts based on existing
71
63
  dense tensors and for the scatter operation (module ``bs.layouting``),
72
64
  - for the application of ``nn.Linear``, ``nn.Dropout``, and ``nn.LayerNorm`` layers to block-sparse tensors,
73
- - as well as utility functions to apply linear layers,
74
- ensure correct input dimensionality, and validate input (module ``bs.utils``).
65
+ - as well as utility functions to ensure correct input dimensionality, and validate input (module ``bs.utils``).
66
+
67
+ _* see the [Roadmap](#roadmap) section for more information_
75
68
 
76
69
  ## Installation
77
70
 
78
- Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is only compatible with
79
- the Linux platform.
71
+ Note that due to the dependency on [Triton](https://github.com/triton-lang/triton) this library is **only compatible with
72
+ the Linux platform**.
73
+ Keep track of this [issue](https://github.com/triton-lang/triton/issues/1640) for updates.
80
74
 
81
75
  We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) using pip:
82
76
 
@@ -92,6 +86,16 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
92
86
 
93
87
  See [`CHANGELOG.md`](https://github.com/FelixSchoen/blksprs/blob/main/CHANGELOG.md) for a detailed changelog.
94
88
 
89
+ ## Roadmap
90
+
91
+ Note that since this library covers all our current needs it is in a **bugfix-only** state.
92
+ This means that there are no plans to add new features, e.g., support for dimension specification of the ``split`` and ``merge`` operations.
93
+ We will continue to maintain the library and fix any issues that arise.
94
+ Should you find any bugs please open an [issue](https://github.com/FelixSchoen/blksprs/issues).
95
+ We also encourage [pull requests](https://github.com/FelixSchoen/blksprs/pulls).
96
+
97
+ It might be that this changes with future projects, but as of December 2024, we are content with the current state of the library.
98
+
95
99
  ## Usage
96
100
 
97
101
  We provide an example below to demonstrate the usage of the library.
@@ -0,0 +1,24 @@
1
+ blksprs/__init__.py,sha256=wnpk-20jXq7xV0xa-WpHfPQuauI2gEZz9sH-0blKxP0,1766
2
+ blksprs/layouting/distribution_layout.py,sha256=xDGY5-J7uSD8oenlf8bEJ2amMiQG3NBf2klTTydbTJE,5140
3
+ blksprs/layouting/sparsity_layout.py,sha256=IVtHc_nN3ZM2y4GFcys70PqDWmWc7tkHlVGlToErANk,9894
4
+ blksprs/ops/conversion.py,sha256=-KeVaOUdMB0aAj68XZjyzZgf0Dfg5Tt5AnWgx4AZVCY,22320
5
+ blksprs/ops/distribution.py,sha256=qK5t5XgQSJxXPced8RohprqCtUMMTaEP2pFm3KU1c8o,20267
6
+ blksprs/ops/flow.py,sha256=SWHDQ5zx0cjnPR0CcAcRNZdSusSAHSU840SwDNUr24g,6437
7
+ blksprs/ops/matmul.py,sha256=LAQyPNwWVmBMRnAex3msLSPD_aG5SblLCMiutJWqvus,11632
8
+ blksprs/ops/partitioning.py,sha256=ugKnpvH36ND7qeJQp56M74qqfACkzcTVuXebzw__28Y,8286
9
+ blksprs/ops/repeat.py,sha256=RCa-dITomA5v12K5Oxa5_ReA361zS7WHPNNHxSp9PGw,8578
10
+ blksprs/ops/softmax.py,sha256=i8NJhvPRYya94AzpN6qiki6_G9KfDrtPifhWd7wbYzk,12496
11
+ blksprs/ops/transpose.py,sha256=oAtUu7QzQnNAH3lvRs_MIvIKpBu9h74f9Sk07AxKnDM,6991
12
+ blksprs/ops/misc/broadcast_ops.py,sha256=pv0nssSDOdDbQFttpqUIs2ZXShqfm2RYCfJH-C5x3H0,5544
13
+ blksprs/ops/misc/exp.py,sha256=ygfw7oD6ALdPwNQX_HelKgO8I3-LCgIXH_x0gWzkUN8,3840
14
+ blksprs/ops/misc/row_wise.py,sha256=DnV5-xEJUbqZlK2fETwHiPQDUMwT-lkc0VUhBlnJ5Y0,17458
15
+ blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
16
+ blksprs/utils/blksprs_tensor.py,sha256=pfoz59aJixj_fIoFx76ySiygwRQUemmgjMKepZ2c4j0,244
17
+ blksprs/utils/layout_utils.py,sha256=49ZdPS_gMn_IrWty3FARbi2rda5a8g5DmAEL8LOrC30,670
18
+ blksprs/utils/processing.py,sha256=WLuMJQ8v-YovXwcDjhlDn3N31WMZXrtyeeyKSgq_zn4,3642
19
+ blksprs/utils/tools.py,sha256=r7Y4C37vfSWUyQTGwa8NyRqgovmsq9hMufkenqYHOxo,539
20
+ blksprs/utils/validation.py,sha256=CbxBbeQWJo8wox5eMoVzaTlP9FVBwt3-gxUOmi3EUgw,4213
21
+ blksprs-1.10.dist-info/METADATA,sha256=eiTG-EDaZAlRVbof5WsZQpzGmmL5nPoBVDok4VdatJI,9105
22
+ blksprs-1.10.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
23
+ blksprs-1.10.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
24
+ blksprs-1.10.dist-info/RECORD,,