blksprs 1.10.2__py3-none-any.whl → 1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/__init__.py CHANGED
@@ -13,7 +13,6 @@ class ops:
13
13
  class misc:
14
14
  from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_add, row_wise_sub
15
15
  from blksprs.ops.misc.broadcast_ops import broadcast_add, broadcast_sub
16
- from blksprs.ops.misc.exp import exp
17
16
 
18
17
 
19
18
  class layouting:
blksprs/ops/conversion.py CHANGED
@@ -19,7 +19,7 @@ def from_blksprs(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size:
19
19
 
20
20
 
21
21
  def to_dense(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
22
- triton_block_size: int = None) -> Tensor:
22
+ triton_block_size: int = None, lut: dict = None) -> Tensor:
23
23
  """Converts a block-sparse tensor in compressed form to a block-sparse tensor in regular form based on the given
24
24
  sparsity layout.
25
25
 
@@ -30,6 +30,7 @@ def to_dense(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int
30
30
  fill_value (float): The value to fill the resulting dense tensor with where the block-sparse tensor is not
31
31
  present (default ``0``).
32
32
  triton_block_size (int): The block size to use for the triton kernel (default ``None``).
33
+ lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
33
34
 
34
35
  Returns:
35
36
  Tensor: The block-sparse tensor converted to regular form.
@@ -44,24 +45,35 @@ def to_dense(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int
44
45
  validate_sparsity_block_size(sparsity_block_size, x)
45
46
  validate_triton_block_size(triton_block_size, sparsity_block_size)
46
47
 
47
- sparsity_layout_flat = sparsity_layout.reshape(-1)
48
- sparsity_reverse_lut = ((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
49
- (sparsity_layout_flat == 1) -
50
- (1 * (sparsity_layout_flat == 0)))
51
-
52
- validate_contiguous(sparsity_reverse_lut)
48
+ lut = _BlocksparseToDense.build_lut(lut, sparsity_layout)
53
49
 
54
50
  if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
55
51
  return x
56
52
 
57
53
  return _BlocksparseToDense.apply(x,
58
- sparsity_layout, sparsity_reverse_lut,
54
+ sparsity_layout, lut["sparsity_reverse_lut"],
59
55
  sparsity_block_size, fill_value,
60
56
  triton_block_size)
61
57
 
62
58
 
63
59
  class _BlocksparseToDense(torch.autograd.Function):
64
60
 
61
+ @staticmethod
62
+ def build_lut(lut: dict, sparsity_layout: Tensor):
63
+ if lut is None:
64
+ lut = dict()
65
+
66
+ if "sparsity_reverse_lut" not in lut:
67
+ sparsity_layout_flat = sparsity_layout.reshape(-1)
68
+ sparsity_reverse_lut = ((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
69
+ (sparsity_layout_flat == 1) -
70
+ (1 * (sparsity_layout_flat == 0)))
71
+ lut["sparsity_reverse_lut"] = sparsity_reverse_lut
72
+
73
+ validate_contiguous(lut["sparsity_reverse_lut"])
74
+
75
+ return lut
76
+
65
77
  @staticmethod
66
78
  def forward(ctx, x: Tensor,
67
79
  sparsity_layout: Tensor, sparsity_reverse_lut: Tensor,
@@ -160,7 +172,7 @@ def to_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
160
172
 
161
173
 
162
174
  def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
163
- triton_block_size: int = None) -> BlksprsTensor:
175
+ triton_block_size: int = None, lut: dict = None) -> BlksprsTensor:
164
176
  """Converts a block-sparse tensor in regular form to a block-sparse tensor in compressed form based on the given
165
177
  sparsity layout.
166
178
 
@@ -169,6 +181,7 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
169
181
  sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
170
182
  sparsity_block_size (int): The size of the sparsity blocks.
171
183
  triton_block_size (int): The block size to use for the triton kernel (default ``None``).
184
+ lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
172
185
 
173
186
  Returns:
174
187
  BlksprsTensor: The block-sparse tensor converted to compressed form.
@@ -183,22 +196,36 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
183
196
  validate_sparsity_block_size(sparsity_block_size, x)
184
197
  validate_triton_block_size(triton_block_size, sparsity_block_size)
185
198
 
186
- sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
187
- n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
188
-
189
- validate_contiguous(sparsity_layout, sparsity_lut)
199
+ lut = _BlocksparseToSparse.build_lut(lut, sparsity_layout)
190
200
 
191
201
  if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
192
202
  return BlksprsTensor(x)
193
203
 
194
204
  return BlksprsTensor(_BlocksparseToSparse.apply(x,
195
- sparsity_layout, sparsity_lut,
196
- sparsity_block_size, n_sparse_blocks,
205
+ sparsity_layout, lut["sparsity_lut"],
206
+ sparsity_block_size, lut["n_sparse_blocks"],
197
207
  triton_block_size))
198
208
 
199
209
 
200
210
  class _BlocksparseToSparse(torch.autograd.Function):
201
211
 
212
+ @staticmethod
213
+ def build_lut(lut: dict, sparsity_layout: Tensor):
214
+ if lut is None:
215
+ lut = dict()
216
+
217
+ if "sparsity_lut" not in lut:
218
+ sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
219
+ lut["sparsity_lut"] = sparsity_lut
220
+
221
+ if "n_sparse_blocks" not in lut:
222
+ n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
223
+ lut["n_sparse_blocks"] = n_sparse_blocks
224
+
225
+ validate_contiguous(sparsity_layout, lut["sparsity_lut"])
226
+
227
+ return lut
228
+
202
229
  @staticmethod
203
230
  def forward(ctx, x: Tensor,
204
231
  sparsity_layout: Tensor, sparsity_lut: Tensor,
@@ -13,7 +13,7 @@ from blksprs.utils.validation import validate_contiguous, validate_dimensions, v
13
13
  def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
14
14
  dim: int,
15
15
  idx: BlksprsTensor, sparsity_layout_idx: Tensor,
16
- sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
16
+ sparsity_block_size: int, triton_block_size: int = None, lut: dict = None) -> BlksprsTensor:
17
17
  """Applies a gather operation on a block-sparse tensor in compressed form.
18
18
 
19
19
  Args:
@@ -24,6 +24,7 @@ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
24
24
  sparsity_layout_idx (Tensor): The sparsity layout of the indices block-sparse tensor.
25
25
  sparsity_block_size (int): The size of the sparsity blocks.
26
26
  triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
27
+ lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
27
28
 
28
29
  Returns:
29
30
  BlksprsTensor: The result of the gather operation as a block-sparse tensor in compressed form.
@@ -40,25 +41,38 @@ def gather(src: BlksprsTensor, sparsity_layout_src: Tensor,
40
41
  validate_sparsity_block_size(sparsity_block_size, src, idx)
41
42
  validate_triton_block_size(triton_block_size, sparsity_block_size)
42
43
 
43
- sparsity_layout_x_flat = sparsity_layout_src.reshape(-1)
44
- sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
45
- (sparsity_layout_x_flat == 1) -
46
- (1 * (sparsity_layout_x_flat == 0)))
47
-
48
- sparsity_lut_i = torch.nonzero(sparsity_layout_idx).contiguous()
49
-
50
- validate_contiguous(sparsity_layout_src, sparsity_reverse_lut_x,
51
- sparsity_layout_idx, sparsity_lut_i)
52
-
53
44
  adjusted_dim = dim % 3
54
45
 
55
- return BlksprsTensor(_BlocksparseGather.apply(src, sparsity_layout_src, sparsity_reverse_lut_x,
56
- adjusted_dim, idx, sparsity_layout_idx, sparsity_lut_i,
46
+ lut = _BlocksparseGather.build_lut(lut, sparsity_layout_src, sparsity_layout_idx)
47
+
48
+ return BlksprsTensor(_BlocksparseGather.apply(src, sparsity_layout_src, lut["sparsity_reverse_lut_x"],
49
+ adjusted_dim, idx, sparsity_layout_idx, lut["sparsity_lut_i"],
57
50
  sparsity_block_size, triton_block_size))
58
51
 
59
52
 
60
53
  class _BlocksparseGather(torch.autograd.Function):
61
54
 
55
+ @staticmethod
56
+ def build_lut(lut: dict, sparsity_layout_src: Tensor, sparsity_layout_idx: Tensor):
57
+ if lut is None:
58
+ lut = dict()
59
+
60
+ if "sparsity_reverse_lut_x" not in lut:
61
+ sparsity_layout_x_flat = sparsity_layout_src.reshape(-1)
62
+ sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
63
+ (sparsity_layout_x_flat == 1) -
64
+ (1 * (sparsity_layout_x_flat == 0)))
65
+ lut["sparsity_reverse_lut_x"] = sparsity_reverse_lut_x
66
+
67
+ if "sparsity_lut_i" not in lut:
68
+ sparsity_lut_i = torch.nonzero(sparsity_layout_idx).contiguous()
69
+ lut["sparsity_lut_i"] = sparsity_lut_i
70
+
71
+ validate_contiguous(sparsity_layout_src, lut["sparsity_reverse_lut_x"],
72
+ sparsity_layout_idx, lut["sparsity_lut_i"])
73
+
74
+ return lut
75
+
62
76
  @staticmethod
63
77
  def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
64
78
  dim: int, i: Tensor, sparsity_layout_i: Tensor, sparsity_lut_i: Tensor,
@@ -202,7 +216,7 @@ def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
202
216
  dim: int,
203
217
  idx: BlksprsTensor,
204
218
  sparsity_layout_tgt: Tensor,
205
- sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
219
+ sparsity_block_size: int, triton_block_size: int = None, lut: dict = None) -> BlksprsTensor:
206
220
  """Wrapper for ``scatter_reduce`` with ``reduce_op="none"``.
207
221
 
208
222
  """
@@ -219,7 +233,7 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
219
233
  idx: BlksprsTensor,
220
234
  sparsity_layout_tgt: Tensor,
221
235
  sparsity_block_size: int,
222
- reduce_op: str = "sum", triton_block_size: int = None) -> BlksprsTensor:
236
+ reduce_op: str = "sum", triton_block_size: int = None, lut: dict = None) -> BlksprsTensor:
223
237
  """Applies a scatter operation on a block-sparse tensor in compressed form.
224
238
 
225
239
  Args:
@@ -232,6 +246,7 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
232
246
  reduce_op (str, optional): The reduction operation to apply during the scatter operation (default ``"sum"``).
233
247
  Supported operations are ``"none"`` and ``"sum"``.
234
248
  triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
249
+ lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
235
250
 
236
251
  Returns:
237
252
  BlksprsTensor: The result of the scatter operation as a block-sparse tensor in compressed form.
@@ -251,29 +266,44 @@ def scatter_reduce(src: BlksprsTensor, sparsity_layout_src: Tensor,
251
266
  if reduce_op not in ["none", "sum"]:
252
267
  raise ValueError(f"Reduction operation '{reduce_op}' is not supported")
253
268
 
254
- sparsity_lut_x = torch.nonzero(sparsity_layout_src).contiguous()
255
-
256
- sparsity_layout_o_flat = sparsity_layout_tgt.reshape(-1)
257
- sparsity_reverse_lut_o = ((torch.cumsum(sparsity_layout_o_flat, dim=-1) - 1) *
258
- (sparsity_layout_o_flat == 1) -
259
- (1 * (sparsity_layout_o_flat == 0)))
260
-
261
- n_sparse_blocks = torch.sum(sparsity_layout_tgt.to(torch.int)).item()
262
-
263
- validate_contiguous(sparsity_layout_src, sparsity_lut_x,
264
- sparsity_layout_tgt, sparsity_reverse_lut_o)
265
-
266
269
  adjusted_dim = dim % 3
267
270
 
268
- return BlksprsTensor(_BlocksparseScatterReduce.apply(src, sparsity_layout_src, sparsity_lut_x,
271
+ lut = _BlocksparseScatterReduce.build_lut(lut, sparsity_layout_src, sparsity_layout_tgt)
272
+
273
+ return BlksprsTensor(_BlocksparseScatterReduce.apply(src, sparsity_layout_src, lut["sparsity_lut_x"],
269
274
  adjusted_dim, idx,
270
- sparsity_layout_tgt, sparsity_reverse_lut_o,
271
- sparsity_block_size, n_sparse_blocks,
275
+ sparsity_layout_tgt, lut["sparsity_reverse_lut_o"],
276
+ sparsity_block_size, lut["n_sparse_blocks"],
272
277
  reduce_op, triton_block_size))
273
278
 
274
279
 
275
280
  class _BlocksparseScatterReduce(torch.autograd.Function):
276
281
 
282
+ @staticmethod
283
+ def build_lut(lut: dict, sparsity_layout_src: Tensor, sparsity_layout_tgt: Tensor):
284
+ if lut is None:
285
+ lut = dict()
286
+
287
+ if "sparsity_lut_x" not in lut:
288
+ sparsity_lut_x = torch.nonzero(sparsity_layout_src).contiguous()
289
+ lut["sparsity_lut_x"] = sparsity_lut_x
290
+
291
+ if "sparsity_reverse_lut_o" not in lut:
292
+ sparsity_layout_o_flat = sparsity_layout_tgt.reshape(-1)
293
+ sparsity_reverse_lut_o = ((torch.cumsum(sparsity_layout_o_flat, dim=-1) - 1) *
294
+ (sparsity_layout_o_flat == 1) -
295
+ (1 * (sparsity_layout_o_flat == 0)))
296
+ lut["sparsity_reverse_lut_o"] = sparsity_reverse_lut_o
297
+
298
+ if "n_sparse_blocks" not in lut:
299
+ n_sparse_blocks = torch.sum(sparsity_layout_tgt.to(torch.int)).item()
300
+ lut["n_sparse_blocks"] = n_sparse_blocks
301
+
302
+ validate_contiguous(sparsity_layout_src, lut["sparsity_lut_x"],
303
+ sparsity_layout_tgt, lut["sparsity_reverse_lut_o"])
304
+
305
+ return lut
306
+
277
307
  @staticmethod
278
308
  def forward(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_lut_x: Tensor,
279
309
  dim: int, i: Tensor,
blksprs/ops/flow.py CHANGED
@@ -40,21 +40,18 @@ def kernel_blocksparse_flow_pull(x,
40
40
  rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
41
41
  rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
42
42
 
43
- if rev_idx_spa == -1:
44
- tl.device_assert(False)
45
- return
43
+ if rev_idx_spa >= 0:
44
+ blk_x_idx = (rev_idx_spa * x_b_s +
45
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
46
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
47
+ blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
48
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
46
49
 
47
- blk_x_idx = (rev_idx_spa * x_b_s +
48
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
49
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
50
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
51
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
52
-
53
- blk_o_idx = (pid_blk * o_b_s +
54
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
55
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
56
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
57
- tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
50
+ blk_o_idx = (pid_blk * o_b_s +
51
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
52
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
53
+ blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
54
+ tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
58
55
 
59
56
 
60
57
  @triton.jit
@@ -91,25 +88,22 @@ def kernel_blocksparse_flow_push(x,
91
88
  rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_x_b * s_l_x_b_s)
92
89
  rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
93
90
 
94
- if rev_idx_spa == -1:
95
- tl.device_assert(False)
96
- return
97
-
98
- blk_x_idx = (pid_blk * x_b_s +
99
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
100
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
101
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
102
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
91
+ if rev_idx_spa >= 0:
92
+ blk_x_idx = (pid_blk * x_b_s +
93
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
94
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
95
+ blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
96
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
103
97
 
104
- blk_o_idx = (rev_idx_spa * o_b_s +
105
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
106
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
107
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
108
- tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
98
+ blk_o_idx = (rev_idx_spa * o_b_s +
99
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
100
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
101
+ blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
102
+ tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
109
103
 
110
104
 
111
- def flow_forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
112
- sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
105
+ def flow_forward_pull(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
106
+ sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
113
107
  output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
114
108
  dtype=x.dtype, device=x.device)
115
109
 
@@ -144,3 +138,42 @@ def flow_forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor
144
138
  ctx.triton_block_size = triton_block_size
145
139
 
146
140
  return output
141
+
142
+
143
+ def flow_forward_push(ctx, x: Tensor, sparsity_layout_x: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
144
+ sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
145
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
146
+ dtype=x.dtype, device=x.device)
147
+
148
+ x_b, x_r, x_c = x.size()
149
+ x_b_s, x_r_s, x_c_s = stride(x)
150
+ s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
151
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
152
+ s_lut_r, s_lut_c = sparsity_lut.size()
153
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
154
+ o_b, o_r, o_c = output.size()
155
+ o_b_s, o_r_s, o_c_s = stride(output)
156
+
157
+ if triton_block_size is None:
158
+ triton_block_size = get_triton_block_size(sparsity_block_size)
159
+
160
+ triton_grid = lambda meta: [x_b,
161
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
162
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
163
+
164
+ (kernel_blocksparse_flow_push[triton_grid]
165
+ (x,
166
+ x_b, x_b_s, x_r_s, x_c_s,
167
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
168
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
169
+ sparsity_reverse_lut,
170
+ output,
171
+ o_b, o_b_s, o_r_s, o_c_s,
172
+ triton_block_size))
173
+
174
+ # Save for backward pass
175
+ if ctx is not None:
176
+ ctx.sparsity_block_size = sparsity_block_size
177
+ ctx.triton_block_size = triton_block_size
178
+
179
+ return output
blksprs/ops/matmul.py CHANGED
@@ -13,7 +13,7 @@ from blksprs.utils.validation import validate_contiguous, validate_dimensions, v
13
13
  def matmul(x: BlksprsTensor, sparsity_layout_x: Tensor,
14
14
  y: BlksprsTensor, sparsity_layout_y: Tensor,
15
15
  sparsity_layout_output: Tensor,
16
- sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
16
+ sparsity_block_size: int, triton_block_size: int = None, lut: dict = None) -> BlksprsTensor:
17
17
  """Performs matrix multiplication between two block-sparse tensors.
18
18
 
19
19
  The sparsity layout of the output tensor is used to only calculate blocks that will be present in the output.
@@ -26,6 +26,7 @@ def matmul(x: BlksprsTensor, sparsity_layout_x: Tensor,
26
26
  sparsity_layout_output (Tensor): The sparsity layout of the output tensor.
27
27
  sparsity_block_size (int): The size of the sparsity blocks.
28
28
  triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
29
+ lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
29
30
 
30
31
  Returns:
31
32
  BlksprsTensor: The result of the matrix multiplication as a block-sparse tensor in compressed form.
@@ -44,35 +45,52 @@ def matmul(x: BlksprsTensor, sparsity_layout_x: Tensor,
44
45
  validate_sparsity_block_size(sparsity_block_size, x, y)
45
46
  validate_triton_block_size(triton_block_size, sparsity_block_size)
46
47
 
47
- sparsity_layout_x_flat = sparsity_layout_x.reshape(-1)
48
- sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
49
- (sparsity_layout_x_flat == 1) -
50
- (1 * (sparsity_layout_x_flat == 0)))
51
-
52
- sparsity_layout_y_flat = sparsity_layout_y.reshape(-1)
53
- sparsity_reverse_lut_y = ((torch.cumsum(sparsity_layout_y_flat, dim=-1) - 1) *
54
- (sparsity_layout_y_flat == 1) -
55
- (1 * (sparsity_layout_y_flat == 0)))
56
-
57
- sparsity_lut_o = torch.nonzero(sparsity_layout_output).contiguous()
58
-
59
- n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
60
-
61
- validate_contiguous(sparsity_layout_x, sparsity_reverse_lut_x,
62
- sparsity_layout_y, sparsity_reverse_lut_y,
63
- sparsity_layout_output, sparsity_lut_o)
48
+ lut = _BlocksparseMatmulSSS.build_lut(lut, sparsity_layout_x, sparsity_layout_y, sparsity_layout_output)
64
49
 
65
50
  return BlksprsTensor(_BlocksparseMatmulSSS.apply(x, y,
66
- sparsity_layout_x, sparsity_reverse_lut_x,
67
- sparsity_layout_y, sparsity_reverse_lut_y,
68
- sparsity_layout_output, sparsity_lut_o,
51
+ sparsity_layout_x, lut["sparsity_reverse_lut_x"],
52
+ sparsity_layout_y, lut["sparsity_reverse_lut_y"],
53
+ sparsity_layout_output, lut["sparsity_lut_o"],
69
54
  sparsity_block_size,
70
- n_sparse_blocks,
55
+ lut["n_sparse_blocks"],
71
56
  triton_block_size))
72
57
 
73
58
 
74
59
  class _BlocksparseMatmulSSS(torch.autograd.Function):
75
60
 
61
+ @staticmethod
62
+ def build_lut(lut: dict, sparsity_layout_x: Tensor, sparsity_layout_y: Tensor, sparsity_layout_output: Tensor):
63
+ if lut is None:
64
+ lut = dict()
65
+
66
+ if "sparsity_reverse_lut_x" not in lut:
67
+ sparsity_layout_x_flat = sparsity_layout_x.reshape(-1)
68
+ sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
69
+ (sparsity_layout_x_flat == 1) -
70
+ (1 * (sparsity_layout_x_flat == 0)))
71
+ lut["sparsity_reverse_lut_x"] = sparsity_reverse_lut_x
72
+
73
+ if "sparsity_reverse_lut_y" not in lut:
74
+ sparsity_layout_y_flat = sparsity_layout_y.reshape(-1)
75
+ sparsity_reverse_lut_y = ((torch.cumsum(sparsity_layout_y_flat, dim=-1) - 1) *
76
+ (sparsity_layout_y_flat == 1) -
77
+ (1 * (sparsity_layout_y_flat == 0)))
78
+ lut["sparsity_reverse_lut_y"] = sparsity_reverse_lut_y
79
+
80
+ if "sparsity_lut_o" not in lut:
81
+ sparsity_lut_o = torch.nonzero(sparsity_layout_output).contiguous()
82
+ lut["sparsity_lut_o"] = sparsity_lut_o
83
+
84
+ if "n_sparse_blocks" not in lut:
85
+ n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
86
+ lut["n_sparse_blocks"] = n_sparse_blocks
87
+
88
+ validate_contiguous(sparsity_layout_x, lut["sparsity_reverse_lut_x"],
89
+ sparsity_layout_y, lut["sparsity_reverse_lut_y"],
90
+ sparsity_layout_output, lut["sparsity_lut_o"])
91
+
92
+ return lut
93
+
76
94
  @staticmethod
77
95
  def forward(ctx, x: Tensor, y: Tensor,
78
96
  sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,