blksprs 1.9__py3-none-any.whl → 1.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from blksprs.utils.blksprs_tensor import BlksprsTensor
2
2
 
3
+
3
4
  class ops:
4
5
  from blksprs.ops.conversion import to_dense, to_sparse, from_blksprs, to_blksprs, adapt_layout
5
6
  from blksprs.ops.distribution import gather, scatter, scatter_reduce
@@ -22,13 +23,15 @@ class layouting:
22
23
  from blksprs.layouting.distribution_layout import build_distribution_layout
23
24
  from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption, \
24
25
  build_sparsity_layout_matmul, build_sparsity_layout_matmul_fast
26
+ from blksprs.utils.layout_utils import build_full_sparsity_layout
25
27
 
26
28
  class experimental:
27
29
  from blksprs.ops.experimental.distribution_mdi import build_distribution_layout_mdi
28
30
 
29
31
 
30
32
  class utils:
31
- from blksprs.utils.processing import apply_torch_linear
33
+ from blksprs.utils.processing import apply_torch_linear, apply_torch_normalisation, apply_torch_dropout, \
34
+ apply_function_applicable_row_wise
32
35
  from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse
33
36
  from blksprs.utils.validation import disable_validation
34
37
 
blksprs/ops/conversion.py CHANGED
@@ -289,8 +289,8 @@ class _BlocksparseToSparse(torch.autograd.Function):
289
289
 
290
290
 
291
291
  def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_size_from: int,
292
- sparsity_block_size_to: int,
293
- preprocess_data: dict = None, triton_block_size: int = None) -> BlksprsTensor:
292
+ sparsity_block_size_to: int, sparsity_layout_to: Tensor = None,
293
+ triton_block_size: int = None) -> (BlksprsTensor, Tensor):
294
294
  """Adapts the sparsity layout of a block-sparse tensor, resulting in a new block-sparse tensor in compressed form
295
295
  conforming to the new sparsity layout (and sparsity block size) definition.
296
296
 
@@ -299,11 +299,12 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
299
299
  sparsity_layout_from (Tensor): The sparsity layout of the input block-sparse tensor.
300
300
  sparsity_block_size_from (int): The size of the sparsity blocks of the input sparsity layout.
301
301
  sparsity_block_size_to (int): The size of the sparsity blocks of the output sparsity layout.
302
- preprocess_data (dict): A dictionary containing data otherwise computed by the function (default ``None``).
302
+ sparsity_layout_to (Tensor): The sparsity layout of the output block-sparse tensor (default ``None``).
303
303
  triton_block_size (int): The block size to use for the triton kernel (default ``None``).
304
304
 
305
305
  Returns:
306
306
  BlksprsTensor: The block-sparse tensor in compressed form with the adapted sparsity layout and sparsity block size.
307
+ Tensor: The sparsity layout of the resulting output tensor.
307
308
 
308
309
  """
309
310
  x = x.contiguous()
@@ -317,52 +318,42 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
317
318
  min_sparsity_block_size = min(sparsity_block_size_from, sparsity_block_size_to)
318
319
  validate_triton_block_size(triton_block_size, min_sparsity_block_size)
319
320
 
320
- if preprocess_data is None:
321
- preprocess_data = {}
321
+ sparsity_layout_from_flat = sparsity_layout_from.reshape(-1)
322
+ sparsity_reverse_lut_from = ((torch.cumsum(sparsity_layout_from_flat, dim=-1) - 1) *
323
+ (sparsity_layout_from_flat == 1) -
324
+ (1 * (sparsity_layout_from_flat == 0)))
322
325
 
323
- if "sparsity_reverse_lut_from" not in preprocess_data:
324
- sparsity_layout_from_flat = sparsity_layout_from.reshape(-1)
325
- sparsity_reverse_lut_from = ((torch.cumsum(sparsity_layout_from_flat, dim=-1) - 1) *
326
- (sparsity_layout_from_flat == 1) -
327
- (1 * (sparsity_layout_from_flat == 0)))
328
- else:
329
- sparsity_reverse_lut_from = preprocess_data["sparsity_reverse_lut_from"]
330
-
331
- if "sparsity_layout_to" not in preprocess_data:
326
+ if sparsity_layout_to is None:
332
327
  sparsity_layout_to = build_sparsity_layout_adaption(x, sparsity_layout_from,
333
328
  sparsity_block_size_from, sparsity_block_size_to,
334
329
  triton_block_size)
335
- else:
336
- sparsity_layout_to = preprocess_data["sparsity_layout_to"]
337
330
 
338
- if "sparsity_lut_to" not in preprocess_data:
339
- sparsity_lut_to = torch.nonzero(sparsity_layout_to).contiguous()
340
- else:
341
- sparsity_lut_to = preprocess_data["sparsity_lut_to"]
331
+ sparsity_lut_to = torch.nonzero(sparsity_layout_to).contiguous()
342
332
 
343
- if "n_sparse_blocks_to" not in preprocess_data:
344
- n_sparse_blocks_to = torch.sum(sparsity_layout_to.to(torch.int)).item()
345
- else:
346
- n_sparse_blocks_to = preprocess_data["n_sparse_blocks_to"]
333
+ n_sparse_blocks_to = torch.sum(sparsity_layout_to.to(torch.int)).item()
347
334
 
348
- validate_contiguous(sparsity_layout_to, sparsity_reverse_lut_from, sparsity_lut_to)
335
+ validate_contiguous(sparsity_reverse_lut_from, sparsity_layout_to, sparsity_lut_to)
349
336
 
350
337
  if (sparsity_block_size_from == sparsity_block_size_to) and torch.equal(sparsity_layout_from, sparsity_layout_to):
351
- return BlksprsTensor(x)
338
+ return BlksprsTensor(x), sparsity_layout_to
352
339
 
353
340
  return BlksprsTensor(_BlocksparseAdaptLayout.apply(x,
354
341
  sparsity_layout_from, sparsity_reverse_lut_from,
355
342
  sparsity_block_size_from,
356
- sparsity_layout_to, sparsity_lut_to, sparsity_block_size_to,
357
- n_sparse_blocks_to, min_sparsity_block_size, triton_block_size))
343
+ sparsity_layout_to, sparsity_lut_to,
344
+ sparsity_block_size_to,
345
+ n_sparse_blocks_to, min_sparsity_block_size,
346
+ triton_block_size)), sparsity_layout_to
358
347
 
359
348
 
360
349
  class _BlocksparseAdaptLayout(torch.autograd.Function):
361
350
 
362
351
  @staticmethod
363
352
  def forward(ctx, x: Tensor,
364
- sparsity_layout_from: Tensor, sparsity_reverse_lut_from: Tensor, sparsity_block_size_from: int,
365
- sparsity_layout_to: Tensor, sparsity_lut_to: Tensor, sparsity_block_size_to: int,
353
+ sparsity_layout_from: Tensor, sparsity_reverse_lut_from: Tensor,
354
+ sparsity_block_size_from: int,
355
+ sparsity_layout_to: Tensor, sparsity_lut_to: Tensor,
356
+ sparsity_block_size_to: int,
366
357
  n_sparse_blocks_to: int, min_sparsity_block_size: int, triton_block_size: int) -> Tensor:
367
358
  output = torch.zeros(size=(n_sparse_blocks_to, sparsity_block_size_to, sparsity_block_size_to),
368
359
  dtype=x.dtype, device=x.device)
@@ -409,9 +400,10 @@ class _BlocksparseAdaptLayout(torch.autograd.Function):
409
400
  sparsity_block_size_to = ctx.sparsity_block_size_to
410
401
  triton_block_size = ctx.triton_block_size
411
402
 
412
- return adapt_layout(grad_output, sparsity_layout_to, sparsity_block_size_to, sparsity_block_size_from,
413
- preprocess_data={"sparsity_layout_to": sparsity_layout_from},
414
- triton_block_size=triton_block_size), None, None, None, None, None, None, None, None, None
403
+ return adapt_layout(
404
+ grad_output, sparsity_layout_to, sparsity_block_size_to, sparsity_block_size_from,
405
+ sparsity_layout_to=sparsity_layout_from,
406
+ triton_block_size=triton_block_size)[0], None, None, None, None, None, None, None, None, None
415
407
 
416
408
  @staticmethod
417
409
  @triton.jit
@@ -448,7 +440,7 @@ class _BlocksparseAdaptLayout(torch.autograd.Function):
448
440
  spa_row_x = (spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size_from
449
441
  spa_col_x = (spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size_from
450
442
 
451
- # # Get reverse sparsity indices for x
443
+ # Get reverse sparsity indices for x
452
444
  rev_idx_spa_x_idx = (spa_bat_x * s_l_x_b_s +
453
445
  spa_row_x * s_l_x_r_s +
454
446
  spa_col_x * s_l_x_c_s)
@@ -207,6 +207,7 @@ def scatter(src: BlksprsTensor, sparsity_layout_src: Tensor,
207
207
 
208
208
  """
209
209
  return scatter_reduce(src, sparsity_layout_src,
210
+ dim,
210
211
  idx,
211
212
  sparsity_layout_tgt,
212
213
  sparsity_block_size,
blksprs/ops/flow.py ADDED
@@ -0,0 +1,147 @@
1
+ import torch
2
+ import triton
3
+ from torch import Tensor
4
+ from triton import language as tl
5
+
6
+ from blksprs.utils.tools import stride, get_triton_block_size
7
+
8
+
9
+ @triton.jit
10
+ def kernel_blocksparse_flow_pull(x,
11
+ x_b, x_b_s, x_r_s, x_c_s,
12
+ o,
13
+ o_b, o_b_s, o_r_s, o_c_s,
14
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
15
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
16
+ r_lut,
17
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
18
+ # Get triton block indices
19
+ pid_blk = tl.program_id(axis=0)
20
+ pid_row = tl.program_id(axis=1)
21
+ pid_col = tl.program_id(axis=2)
22
+
23
+ # Get sparsity index of current output block consisting of its batch, row, and column index
24
+ spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
25
+ spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
26
+ spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
27
+
28
+ spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
29
+ spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
30
+ spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
31
+
32
+ spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
33
+ spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
34
+ spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
35
+
36
+ # Get reverse sparsity index
37
+ rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
38
+ spa_row * s_l_o_r_s +
39
+ spa_col * s_l_o_c_s)
40
+ rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
41
+ rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
42
+
43
+ if rev_idx_spa == -1:
44
+ tl.device_assert(False)
45
+ return
46
+
47
+ blk_x_idx = (rev_idx_spa * x_b_s +
48
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
49
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
50
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
51
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
52
+
53
+ blk_o_idx = (pid_blk * o_b_s +
54
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
55
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
56
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
57
+ tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
58
+
59
+
60
+ @triton.jit
61
+ def kernel_blocksparse_flow_push(x,
62
+ x_b, x_b_s, x_r_s, x_c_s,
63
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
64
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
65
+ r_lut,
66
+ o,
67
+ o_b, o_b_s, o_r_s, o_c_s,
68
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
69
+ # Get triton block indices
70
+ pid_blk = tl.program_id(axis=0)
71
+ pid_row = tl.program_id(axis=1)
72
+ pid_col = tl.program_id(axis=2)
73
+
74
+ # Get sparsity index of current input block consisting of its batch, row, and column index
75
+ spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
76
+ spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
77
+ spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
78
+
79
+ spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
80
+ spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
81
+ spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
82
+
83
+ spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
84
+ spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
85
+ spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
86
+
87
+ # Get reverse sparsity index
88
+ rev_idx_spa_idx = (spa_bat * s_l_x_b_s +
89
+ spa_row * s_l_x_r_s +
90
+ spa_col * s_l_x_c_s)
91
+ rev_idx_spa_msk = (rev_idx_spa_idx < s_l_x_b * s_l_x_b_s)
92
+ rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
93
+
94
+ if rev_idx_spa == -1:
95
+ tl.device_assert(False)
96
+ return
97
+
98
+ blk_x_idx = (pid_blk * x_b_s +
99
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
100
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
101
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
102
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
103
+
104
+ blk_o_idx = (rev_idx_spa * o_b_s +
105
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
106
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
107
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
108
+ tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
109
+
110
+
111
+ def flow_forward(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
112
+ sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
113
+ output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
114
+ dtype=x.dtype, device=x.device)
115
+ output = torch.zeros_like(output)
116
+
117
+ x_b, x_r, x_c = x.size()
118
+ x_b_s, x_r_s, x_c_s = stride(x)
119
+ o_b, o_r, o_c = output.size()
120
+ o_b_s, o_r_s, o_c_s = stride(output)
121
+ s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
122
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
123
+ s_lut_r, s_lut_c = sparsity_lut.size()
124
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
125
+
126
+ if triton_block_size is None:
127
+ triton_block_size = get_triton_block_size(sparsity_block_size)
128
+
129
+ triton_grid = lambda meta: [o_b,
130
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
131
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
132
+
133
+ (kernel_blocksparse_flow_pull[triton_grid]
134
+ (x,
135
+ x_b, x_b_s, x_r_s, x_c_s,
136
+ output,
137
+ o_b, o_b_s, o_r_s, o_c_s,
138
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
139
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
140
+ sparsity_reverse_lut,
141
+ triton_block_size))
142
+
143
+ # Save for backward pass
144
+ ctx.sparsity_block_size = sparsity_block_size
145
+ ctx.triton_block_size = triton_block_size
146
+
147
+ return output
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
  from torch import Tensor
3
3
 
4
- from blksprs.ops.repeat import forward_flow
4
+ from blksprs.ops.flow import flow_forward
5
5
  from blksprs.utils.blksprs_tensor import BlksprsTensor
6
6
 
7
7
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
@@ -66,7 +66,7 @@ class _BlocksparseSplit(torch.autograd.Function):
66
66
  ctx.save_for_backward(sparsity_layout_o)
67
67
  ctx.num_partitions = num_partitions
68
68
 
69
- return forward_flow(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
69
+ return flow_forward(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
70
70
  n_sparse_blocks, triton_block_size)
71
71
 
72
72
  @staticmethod
@@ -140,7 +140,7 @@ class _BlocksparseMerge(torch.autograd.Function):
140
140
  ctx.save_for_backward(sparsity_layout_o)
141
141
  ctx.num_partitions = num_partitions
142
142
 
143
- return forward_flow(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
143
+ return flow_forward(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
144
144
  n_sparse_blocks, triton_block_size)
145
145
 
146
146
  @staticmethod
blksprs/ops/repeat.py CHANGED
@@ -1,8 +1,8 @@
1
1
  import torch
2
2
  import triton
3
- from triton import language as tl
4
3
  from torch import Tensor
5
4
 
5
+ from blksprs.ops.flow import kernel_blocksparse_flow_push, flow_forward
6
6
  from blksprs.utils.blksprs_tensor import BlksprsTensor
7
7
  from blksprs.utils.tools import get_triton_block_size, stride
8
8
  from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, \
@@ -64,8 +64,9 @@ def repeat(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: tuple[int, int,
64
64
 
65
65
  validate_contiguous(sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
66
66
 
67
- return BlksprsTensor(_BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
68
- sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_o
67
+ return BlksprsTensor(
68
+ _BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
69
+ sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_o
69
70
 
70
71
 
71
72
  def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
@@ -122,8 +123,9 @@ def repeat_interleave(x: BlksprsTensor, sparsity_layout_x: Tensor, repeats: int,
122
123
 
123
124
  validate_contiguous(sparsity_layout_o, sparsity_lut, sparsity_reverse_lut)
124
125
 
125
- return BlksprsTensor(_BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
126
- sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_o
126
+ return BlksprsTensor(
127
+ _BlocksparseRepeat.apply(x, sparsity_layout_x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut,
128
+ sparsity_block_size, n_sparse_blocks, triton_block_size)), sparsity_layout_o
127
129
 
128
130
 
129
131
  class _BlocksparseRepeat(torch.autograd.Function):
@@ -137,7 +139,7 @@ class _BlocksparseRepeat(torch.autograd.Function):
137
139
  ctx.x_size = x.size()
138
140
  ctx.x_stride = stride(x)
139
141
 
140
- return forward_flow(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
142
+ return flow_forward(ctx, x, sparsity_layout_o, sparsity_lut, sparsity_reverse_lut, sparsity_block_size,
141
143
  n_sparse_blocks, triton_block_size)
142
144
 
143
145
  @staticmethod
@@ -180,144 +182,3 @@ class _BlocksparseRepeat(torch.autograd.Function):
180
182
  triton_block_size))
181
183
 
182
184
  return output, None, None, None, None, None, None, None
183
-
184
-
185
- @triton.jit
186
- def kernel_blocksparse_flow_pull(x,
187
- x_b, x_b_s, x_r_s, x_c_s,
188
- o,
189
- o_b, o_b_s, o_r_s, o_c_s,
190
- s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
191
- s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
192
- r_lut,
193
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
194
- # Get triton block indices
195
- pid_blk = tl.program_id(axis=0)
196
- pid_row = tl.program_id(axis=1)
197
- pid_col = tl.program_id(axis=2)
198
-
199
- # Get sparsity index of current output block consisting of its batch, row, and column index
200
- spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
201
- spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
202
- spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
203
-
204
- spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
205
- spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
206
- spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
207
-
208
- spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
209
- spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
210
- spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
211
-
212
- # Get reverse sparsity index
213
- rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
214
- spa_row * s_l_o_r_s +
215
- spa_col * s_l_o_c_s)
216
- rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
217
- rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
218
-
219
- if rev_idx_spa == -1:
220
- tl.device_assert(False)
221
- return
222
-
223
- blk_x_idx = (rev_idx_spa * x_b_s +
224
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
225
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
226
- blk_x_msk = (blk_x_idx < x_b * x_b_s)
227
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
228
-
229
- blk_o_idx = (pid_blk * o_b_s +
230
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
231
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
232
- blk_o_msk = (blk_o_idx < o_b * o_b_s)
233
- tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
234
-
235
-
236
- @triton.jit
237
- def kernel_blocksparse_flow_push(x,
238
- x_b, x_b_s, x_r_s, x_c_s,
239
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
240
- s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
241
- r_lut,
242
- o,
243
- o_b, o_b_s, o_r_s, o_c_s,
244
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
245
- # Get triton block indices
246
- pid_blk = tl.program_id(axis=0)
247
- pid_row = tl.program_id(axis=1)
248
- pid_col = tl.program_id(axis=2)
249
-
250
- # Get sparsity index of current input block consisting of its batch, row, and column index
251
- spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
252
- spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
253
- spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
254
-
255
- spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
256
- spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
257
- spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
258
-
259
- spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
260
- spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
261
- spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
262
-
263
- # Get reverse sparsity index
264
- rev_idx_spa_idx = (spa_bat * s_l_x_b_s +
265
- spa_row * s_l_x_r_s +
266
- spa_col * s_l_x_c_s)
267
- rev_idx_spa_msk = (rev_idx_spa_idx < s_l_x_b * s_l_x_b_s)
268
- rev_idx_spa = tl.load(r_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
269
-
270
- if rev_idx_spa == -1:
271
- tl.device_assert(False)
272
- return
273
-
274
- blk_x_idx = (pid_blk * x_b_s +
275
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
276
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
277
- blk_x_msk = (blk_x_idx < x_b * x_b_s)
278
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
279
-
280
- blk_o_idx = (rev_idx_spa * o_b_s +
281
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
282
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
283
- blk_o_msk = (blk_o_idx < o_b * o_b_s)
284
- tl.atomic_add(o + blk_o_idx, blk_x, mask=blk_o_msk)
285
-
286
-
287
- def forward_flow(ctx, x: Tensor, sparsity_layout_o: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
288
- sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
289
- output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
290
- dtype=x.dtype, device=x.device)
291
- output = torch.zeros_like(output)
292
-
293
- x_b, x_r, x_c = x.size()
294
- x_b_s, x_r_s, x_c_s = stride(x)
295
- o_b, o_r, o_c = output.size()
296
- o_b_s, o_r_s, o_c_s = stride(output)
297
- s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_o.size()
298
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = stride(sparsity_layout_o)
299
- s_lut_r, s_lut_c = sparsity_lut.size()
300
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
301
-
302
- if triton_block_size is None:
303
- triton_block_size = get_triton_block_size(sparsity_block_size)
304
-
305
- triton_grid = lambda meta: [o_b,
306
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
307
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
308
-
309
- (kernel_blocksparse_flow_pull[triton_grid]
310
- (x,
311
- x_b, x_b_s, x_r_s, x_c_s,
312
- output,
313
- o_b, o_b_s, o_r_s, o_c_s,
314
- s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
315
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
316
- sparsity_reverse_lut,
317
- triton_block_size))
318
-
319
- # Save for backward pass
320
- ctx.sparsity_block_size = sparsity_block_size
321
- ctx.triton_block_size = triton_block_size
322
-
323
- return output
@@ -0,0 +1,17 @@
1
+ import math
2
+
3
+ import torch
4
+ import triton
5
+ from torch import Tensor
6
+ from torch.xpu import device
7
+ from triton import language as tl
8
+
9
+ from blksprs.utils.blksprs_tensor import BlksprsTensor
10
+ from blksprs.utils.tools import get_triton_block_size, stride
11
+ from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
12
+ validate_contiguous, validate_sparsity, validate_sparsity_block_size
13
+
14
+
15
+ def build_full_sparsity_layout(x: Tensor, sparsity_block_size: int) -> Tensor:
16
+ return torch.ones(size=(x.size(0), x.size(1) // sparsity_block_size, x.size(2) // sparsity_block_size),
17
+ dtype=torch.bool, device=x.device)
@@ -1,7 +1,9 @@
1
+ from collections.abc import Callable
2
+
1
3
  import torch
2
4
  from torch import Tensor, nn
3
- from triton.language import dtype
4
5
 
6
+ import blksprs as bs
5
7
  from blksprs.layouting.sparsity_layout import build_sparsity_layout_matmul_fast
6
8
  from blksprs.ops.conversion import to_sparse
7
9
  from blksprs.ops.matmul import matmul
@@ -10,7 +12,7 @@ from blksprs.utils.blksprs_tensor import BlksprsTensor
10
12
 
11
13
 
12
14
  def apply_torch_linear(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
13
- linear: nn.Linear) -> (BlksprsTensor, Tensor):
15
+ linear: nn.Linear, bias: nn.Parameter = None) -> (BlksprsTensor, Tensor):
14
16
  # Extract weight and bias
15
17
  w = linear.weight
16
18
  b = linear.bias
@@ -27,6 +29,8 @@ def apply_torch_linear(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block
27
29
  interim = xw
28
30
 
29
31
  # Apply bias
32
+ if bias is not None:
33
+ b = bias
30
34
  if b is not None:
31
35
  b_slice = b.unsqueeze(0).unsqueeze(0).repeat(1, sparsity_block_size, 1)
32
36
  sparsity_layout_b_slice = torch.ones(size=(1, b_slice.size(1) // sparsity_block_size,
@@ -39,3 +43,32 @@ def apply_torch_linear(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block
39
43
  interim = interim + b_bs
40
44
 
41
45
  return interim, sparsity_layout_xw
46
+
47
+
48
+ def apply_torch_normalisation(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
49
+ normalisation: nn.Module) -> BlksprsTensor:
50
+ return apply_function_applicable_row_wise(x, sparsity_layout, sparsity_block_size, normalisation)
51
+
52
+
53
+ def apply_torch_dropout(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
54
+ dropout: nn.Dropout) -> BlksprsTensor:
55
+ return apply_function_applicable_row_wise(x, sparsity_layout, sparsity_block_size, dropout)
56
+
57
+
58
+ def apply_function_applicable_row_wise(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
59
+ function: Callable) -> BlksprsTensor:
60
+ sparsity_layout_packed = _pack_layout(sparsity_layout)
61
+ blksprs_pseudo_dense = bs.ops.to_dense(x, sparsity_layout_packed, sparsity_block_size)
62
+ normalisation_out = function(blksprs_pseudo_dense)
63
+ blksprs_sparse = bs.ops.to_sparse(normalisation_out, sparsity_layout_packed, sparsity_block_size)
64
+
65
+ return blksprs_sparse
66
+
67
+
68
+ def _pack_layout(sparsity_layout: Tensor) -> BlksprsTensor:
69
+ sparsity_layout_reshaped = sparsity_layout.reshape(1, sparsity_layout.size(0) * sparsity_layout.size(1),
70
+ sparsity_layout.size(2))
71
+ non_zero_rows = torch.any(sparsity_layout_reshaped, dim=-1)
72
+ sparsity_layout_filtered = sparsity_layout_reshaped[non_zero_rows].unsqueeze(0)
73
+
74
+ return sparsity_layout_filtered
@@ -36,7 +36,8 @@ def validate_dtype_int(*tensors: Tensor) -> None:
36
36
  return
37
37
 
38
38
  for tensor in tensors:
39
- if tensor.dtype != torch.int32 and tensor.dtype != torch.int64:
39
+ if (tensor.dtype !=
40
+ torch.int32 and tensor.dtype != torch.int64):
40
41
  raise ValueError("Tensor must have int32 or int64 dtype")
41
42
 
42
43
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: blksprs
3
- Version: 1.9
3
+ Version: 1.9.2
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -8,14 +8,15 @@ Project-URL: Bugtracker, https://github.com/FelixSchoen/blksprs/issues
8
8
  Requires-Python: >=3.11
9
9
  Description-Content-Type: text/markdown
10
10
  Requires-Dist: torch
11
- Provides-Extra: build
12
- Requires-Dist: build; extra == "build"
11
+ Requires-Dist: numpy
13
12
  Provides-Extra: test
14
13
  Requires-Dist: pytest; extra == "test"
15
14
  Requires-Dist: pytest-xdist; extra == "test"
16
15
  Requires-Dist: pytest-cov; extra == "test"
17
16
  Requires-Dist: coverage; extra == "test"
18
17
  Requires-Dist: matplotlib; extra == "test"
18
+ Provides-Extra: build
19
+ Requires-Dist: build; extra == "build"
19
20
 
20
21
  # blksprs
21
22
 
@@ -64,8 +65,12 @@ Further helpful operations (included in the ``bs.ops.misc`` module) that do **no
64
65
  - Row-wise sum, max, addition, and subtraction
65
66
  - Broadcast addition and subtraction between slices
66
67
 
67
- Furthermore, the library provides a set of utility functions for the creation of sparsity layouts based on existing
68
- dense tensors and for the scatter operation (module ``bs.layouting``), as well as utility functions to apply linear layers,
68
+ Furthermore, the library provides a set of utility functions
69
+
70
+ - for the creation of sparsity layouts based on existing
71
+ dense tensors and for the scatter operation (module ``bs.layouting``),
72
+ - for the application of ``nn.Linear``, ``nn.Dropout``, and ``nn.LayerNorm`` layers to block-sparse tensors,
73
+ - as well as utility functions to apply linear layers,
69
74
  ensure correct input dimensionality, and validate input (module ``bs.utils``).
70
75
 
71
76
  ## Installation
@@ -79,7 +84,8 @@ We recommend installing blksprs from [PyPI](https://pypi.org/project/blksprs/) u
79
84
 
80
85
  ### Dependencies
81
86
 
82
- - [PyTorch](https://pytorch.org/) (built with v2.5.0)
87
+ - [PyTorch](https://pytorch.org/) (built with v2.5.1)
88
+ - _[NumPy](https://numpy.org/) (to get rid of warnings, built with v2.2.0)_
83
89
  - _[Triton](https://github.com/triton-lang/triton) (included with PyTorch)_
84
90
 
85
91
  ## Changelog
@@ -1,11 +1,12 @@
1
- blksprs/__init__.py,sha256=YMrERuEf1hTv5vVdOvPEzh9rESn4uqOB7WHB12Qs5lU,1836
1
+ blksprs/__init__.py,sha256=L2wP3sFBjfcIOuI2WhQW1eUEYuKoZLKxSV9z0aQmknM,2001
2
2
  blksprs/layouting/distribution_layout.py,sha256=9f_Bx2YQF4LTH95C0S7OuB9eeOuh73NcE0Z7Wrtug38,5034
3
3
  blksprs/layouting/sparsity_layout.py,sha256=-sScIn4hhG35j9BXytrojEzp8jnFkMargJjtivPV1fc,9755
4
- blksprs/ops/conversion.py,sha256=ol-iV45wDzp9G1dJEkY53EdrvnmHzcl7QQmPJ-xqQTs,22410
5
- blksprs/ops/distribution.py,sha256=OWTH_dfO43uIMY6S44wpvRoIBuKzaTy1f57BOEf7EYA,19925
4
+ blksprs/ops/conversion.py,sha256=2lQZfPd1iFheXIcoH0LbN2m7vqFRQ8XUzhGFlDckBsM,22052
5
+ blksprs/ops/distribution.py,sha256=JGa-eLY-1OgicU3vPAwuhqsoUIeyadzmTk2t25aYyak,19956
6
+ blksprs/ops/flow.py,sha256=RBXNOA6O0Ay2sotH8uNoltZywkdxJocJCn3bfB1fGjM,6185
6
7
  blksprs/ops/matmul.py,sha256=yh2ZnO0ZltT1AgadiFP0vX28YJ4n74xO-I_5vFUmOmA,11452
7
- blksprs/ops/partitioning.py,sha256=K0ExR2a3W62d_9xxCJzsdJDLgtbxTI6P8loOOBdhPzE,7674
8
- blksprs/ops/repeat.py,sha256=IvSIRbuyFn0b57LObymLgup0LqlWQ3ndIw-QuiYQcaU,14564
8
+ blksprs/ops/partitioning.py,sha256=z7kx4FrC-ugxZP-IsOHCfdbsF__ld0P-vDota5CbU4s,7672
9
+ blksprs/ops/repeat.py,sha256=RCa-dITomA5v12K5Oxa5_ReA361zS7WHPNNHxSp9PGw,8578
9
10
  blksprs/ops/softmax.py,sha256=V-1vqRefjjwSp6JPwKxVxh5pTng9gOdtgGlXHDPbpYM,12190
10
11
  blksprs/ops/transpose.py,sha256=jxzFFffrj4S_9tiCrwwUMdz6EA98o1dziWXjlqb64a4,6859
11
12
  blksprs/ops/experimental/distribution_mdi.py,sha256=F_0tl4Gn-9JZs_TZfDtZqO_RPFl7sejqQNF8UNIoCbs,20533
@@ -14,10 +15,11 @@ blksprs/ops/misc/exp.py,sha256=FnSFosBfJHuiEbD0MD-i4axLghRn4a0f8KvHXrKBB6M,3802
14
15
  blksprs/ops/misc/row_wise.py,sha256=U4Kk0-P4oOuMNjMHXxP2gP9njMIeMfz8RZrzItNIF94,17229
15
16
  blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
16
17
  blksprs/utils/blksprs_tensor.py,sha256=VjplBgDhnf9sxf-1R5feA0xp5FDCDdaeZmCeoIRdCnc,151
17
- blksprs/utils/processing.py,sha256=hYsFxEbQKcbqU4WtZWusPnWMHg8ZAZF1SKZJYjez9aU,2060
18
+ blksprs/utils/layout_utils.py,sha256=49ZdPS_gMn_IrWty3FARbi2rda5a8g5DmAEL8LOrC30,670
19
+ blksprs/utils/processing.py,sha256=WLuMJQ8v-YovXwcDjhlDn3N31WMZXrtyeeyKSgq_zn4,3642
18
20
  blksprs/utils/tools.py,sha256=r7Y4C37vfSWUyQTGwa8NyRqgovmsq9hMufkenqYHOxo,539
19
- blksprs/utils/validation.py,sha256=IZxH2HZpePmv7lRqLsSwV_6FwsdnTXv9q4j98vCMSsQ,4195
20
- blksprs-1.9.dist-info/METADATA,sha256=9mMjmvJ2_Rz0uyiY9S8SKTRcs6YW5Jk1w6PRobh6Q3c,8456
21
- blksprs-1.9.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
22
- blksprs-1.9.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
23
- blksprs-1.9.dist-info/RECORD,,
21
+ blksprs/utils/validation.py,sha256=CbxBbeQWJo8wox5eMoVzaTlP9FVBwt3-gxUOmi3EUgw,4213
22
+ blksprs-1.9.2.dist-info/METADATA,sha256=JIHA58YnLfFrUyAOsPmHMWbDz_XmkDiXypLhg1ijO0E,8670
23
+ blksprs-1.9.2.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
24
+ blksprs-1.9.2.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
25
+ blksprs-1.9.2.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.3.0)
2
+ Generator: setuptools (75.6.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5