blksprs 1.2.1__tar.gz → 1.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. {blksprs-1.2.1 → blksprs-1.4}/PKG-INFO +28 -25
  2. {blksprs-1.2.1 → blksprs-1.4}/README.md +27 -24
  3. blksprs-1.4/blksprs/__init__.py +18 -0
  4. blksprs-1.2.1/blksprs/misc/broadcast_addition.py → blksprs-1.4/blksprs/misc/broadcast_ops.py +6 -6
  5. blksprs-1.4/blksprs/misc/repeat_interleave.py +130 -0
  6. blksprs-1.4/blksprs/misc/row_wise.py +386 -0
  7. {blksprs-1.2.1 → blksprs-1.4}/blksprs/ops/softmax.py +11 -13
  8. {blksprs-1.2.1 → blksprs-1.4}/blksprs/ops/transpose.py +1 -1
  9. {blksprs-1.2.1 → blksprs-1.4}/blksprs/utils/tools.py +1 -1
  10. {blksprs-1.2.1 → blksprs-1.4}/blksprs.egg-info/PKG-INFO +28 -25
  11. {blksprs-1.2.1 → blksprs-1.4}/blksprs.egg-info/SOURCES.txt +4 -2
  12. {blksprs-1.2.1 → blksprs-1.4}/pyproject.toml +1 -1
  13. blksprs-1.2.1/blksprs/ops/row_wise_sum.py +0 -231
  14. {blksprs-1.2.1 → blksprs-1.4}/blksprs/layouting/distribution_layout.py +0 -0
  15. {blksprs-1.2.1 → blksprs-1.4}/blksprs/layouting/sparsity_layout.py +0 -0
  16. {blksprs-1.2.1 → blksprs-1.4}/blksprs/ops/conversion.py +0 -0
  17. {blksprs-1.2.1 → blksprs-1.4}/blksprs/ops/distribution.py +0 -0
  18. {blksprs-1.2.1 → blksprs-1.4}/blksprs/ops/exp.py +0 -0
  19. {blksprs-1.2.1 → blksprs-1.4}/blksprs/ops/matmul.py +0 -0
  20. {blksprs-1.2.1 → blksprs-1.4}/blksprs/utils/benchmarking.py +0 -0
  21. {blksprs-1.2.1 → blksprs-1.4}/blksprs/utils/validation.py +0 -0
  22. {blksprs-1.2.1 → blksprs-1.4}/blksprs.egg-info/dependency_links.txt +0 -0
  23. {blksprs-1.2.1 → blksprs-1.4}/blksprs.egg-info/requires.txt +0 -0
  24. {blksprs-1.2.1 → blksprs-1.4}/blksprs.egg-info/top_level.txt +0 -0
  25. {blksprs-1.2.1 → blksprs-1.4}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: blksprs
3
- Version: 1.2.1
3
+ Version: 1.4
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -83,14 +83,7 @@ the [test cases](https://github.com/FelixSchoen/blksprs/blob/main/test/cases/tes
83
83
 
84
84
  ```python
85
85
  import torch
86
-
87
- from blksprs.layouting.sparsity_layout import build_sparsity_layout
88
- from blksprs.ops.conversion import to_sparse, to_dense
89
- from blksprs.ops.matmul import matmul
90
- from blksprs.ops.row_wise_sum import row_wise_sum
91
- from blksprs.ops.softmax import softmax
92
- from blksprs.ops.transpose import transpose
93
- from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse
86
+ import blksprs as bs
94
87
 
95
88
 
96
89
  def test_readme():
@@ -112,47 +105,57 @@ def test_readme():
112
105
  y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()
113
106
 
114
107
  # Convert tensors to three-dimensional (dense) tensors since Triton can only handle tensors of exactly three dimensions
115
- x_dense, x_shape_original = do_shape_blocksparse(x)
116
- y_dense, y_shape_original = do_shape_blocksparse(y)
108
+ x_dense, x_shape_original = bs.util.do_shape_blocksparse(x)
109
+ y_dense, y_shape_original = bs.util.do_shape_blocksparse(y)
117
110
 
118
111
  # Create sparsity layouts from existing tensors
119
- sparsity_layout_x = build_sparsity_layout(x_dense, sparsity_block_size, triton_block_size=triton_block_size)
120
- sparsity_layout_y = build_sparsity_layout(y_dense, sparsity_block_size, triton_block_size=triton_block_size)
112
+ sparsity_layout_x = bs.layout.build_sparsity_layout(x_dense, sparsity_block_size,
113
+ triton_block_size=triton_block_size)
114
+ sparsity_layout_y = bs.layout.build_sparsity_layout(y_dense, sparsity_block_size,
115
+ triton_block_size=triton_block_size)
121
116
 
122
117
  # Create random sparsity layout for output tensor
123
118
  sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)
124
119
 
125
120
  # Convert tensors to sparse tensors for matrix multiplication
126
- x_sparse = to_sparse(x_dense, sparsity_layout_x, sparsity_block_size, triton_block_size=triton_block_size)
127
- y_sparse = to_sparse(y_dense, sparsity_layout_y, sparsity_block_size, triton_block_size=triton_block_size)
121
+ x_sparse = bs.to_sparse(x_dense, sparsity_layout_x, sparsity_block_size, triton_block_size=triton_block_size)
122
+ y_sparse = bs.to_sparse(y_dense, sparsity_layout_y, sparsity_block_size, triton_block_size=triton_block_size)
128
123
 
129
124
  # Perform matrix multiplication
130
- o_sparse = matmul(x_sparse, sparsity_layout_x, y_sparse, sparsity_layout_y, sparsity_layout_o, sparsity_block_size,
131
- triton_block_size=triton_block_size)
132
- o_dense = to_dense(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
125
+ o_sparse = bs.matmul(x_sparse, sparsity_layout_x, y_sparse, sparsity_layout_y, sparsity_layout_o,
126
+ sparsity_block_size,
127
+ triton_block_size=triton_block_size)
128
+
129
+ # Apply element-wise operation
130
+ o_sparse = torch.add(o_sparse, 1)
131
+
132
+ o_dense = bs.to_dense(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
133
133
 
134
134
  # Sanity check
135
135
  o_torch = torch.matmul(x_dense, y_dense)
136
+ o_torch = torch.add(o_torch, 1)
136
137
 
137
138
  # Perform round trip to set sparse blocks to 0
138
- o_torch_round_trip = to_dense(
139
- to_sparse(o_torch, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size),
139
+ o_torch_round_trip = bs.to_dense(
140
+ bs.to_sparse(o_torch, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size),
140
141
  sparsity_layout_o, sparsity_block_size, fill_value=0, triton_block_size=triton_block_size)
141
142
 
142
143
  # Assert that the output is correct
143
144
  assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2) # Note that small numerical differences are expected
144
145
 
145
146
  # Assert that the output has the correct sparsity layout
146
- actual_sparsity_layout_o = build_sparsity_layout(o_dense, sparsity_block_size, triton_block_size=triton_block_size)
147
+ actual_sparsity_layout_o = bs.layout.build_sparsity_layout(o_dense, sparsity_block_size,
148
+ triton_block_size=triton_block_size)
147
149
  assert torch.allclose(actual_sparsity_layout_o, sparsity_layout_o)
148
150
 
149
151
  # Convert output tensor back to original shape
150
- o = undo_shape_blocksparse(o_dense, x_shape_original)
152
+ o = bs.util.undo_shape_blocksparse(o_dense, x_shape_original)
151
153
 
152
154
  # Other available functions
153
- transpose(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
154
- softmax(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
155
- row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
155
+ bs.transpose(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
156
+ bs.softmax(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
157
+ bs.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
158
+ bs.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
156
159
 
157
160
 
158
161
  def _get_random_sparsity_layout(b, m, n, sparsity_block_size, sparsity_percentage):
@@ -62,14 +62,7 @@ the [test cases](https://github.com/FelixSchoen/blksprs/blob/main/test/cases/tes
62
62
 
63
63
  ```python
64
64
  import torch
65
-
66
- from blksprs.layouting.sparsity_layout import build_sparsity_layout
67
- from blksprs.ops.conversion import to_sparse, to_dense
68
- from blksprs.ops.matmul import matmul
69
- from blksprs.ops.row_wise_sum import row_wise_sum
70
- from blksprs.ops.softmax import softmax
71
- from blksprs.ops.transpose import transpose
72
- from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse
65
+ import blksprs as bs
73
66
 
74
67
 
75
68
  def test_readme():
@@ -91,47 +84,57 @@ def test_readme():
91
84
  y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()
92
85
 
93
86
  # Convert tensors to three-dimensional (dense) tensors since Triton can only handle tensors of exactly three dimensions
94
- x_dense, x_shape_original = do_shape_blocksparse(x)
95
- y_dense, y_shape_original = do_shape_blocksparse(y)
87
+ x_dense, x_shape_original = bs.util.do_shape_blocksparse(x)
88
+ y_dense, y_shape_original = bs.util.do_shape_blocksparse(y)
96
89
 
97
90
  # Create sparsity layouts from existing tensors
98
- sparsity_layout_x = build_sparsity_layout(x_dense, sparsity_block_size, triton_block_size=triton_block_size)
99
- sparsity_layout_y = build_sparsity_layout(y_dense, sparsity_block_size, triton_block_size=triton_block_size)
91
+ sparsity_layout_x = bs.layout.build_sparsity_layout(x_dense, sparsity_block_size,
92
+ triton_block_size=triton_block_size)
93
+ sparsity_layout_y = bs.layout.build_sparsity_layout(y_dense, sparsity_block_size,
94
+ triton_block_size=triton_block_size)
100
95
 
101
96
  # Create random sparsity layout for output tensor
102
97
  sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)
103
98
 
104
99
  # Convert tensors to sparse tensors for matrix multiplication
105
- x_sparse = to_sparse(x_dense, sparsity_layout_x, sparsity_block_size, triton_block_size=triton_block_size)
106
- y_sparse = to_sparse(y_dense, sparsity_layout_y, sparsity_block_size, triton_block_size=triton_block_size)
100
+ x_sparse = bs.to_sparse(x_dense, sparsity_layout_x, sparsity_block_size, triton_block_size=triton_block_size)
101
+ y_sparse = bs.to_sparse(y_dense, sparsity_layout_y, sparsity_block_size, triton_block_size=triton_block_size)
107
102
 
108
103
  # Perform matrix multiplication
109
- o_sparse = matmul(x_sparse, sparsity_layout_x, y_sparse, sparsity_layout_y, sparsity_layout_o, sparsity_block_size,
110
- triton_block_size=triton_block_size)
111
- o_dense = to_dense(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
104
+ o_sparse = bs.matmul(x_sparse, sparsity_layout_x, y_sparse, sparsity_layout_y, sparsity_layout_o,
105
+ sparsity_block_size,
106
+ triton_block_size=triton_block_size)
107
+
108
+ # Apply element-wise operation
109
+ o_sparse = torch.add(o_sparse, 1)
110
+
111
+ o_dense = bs.to_dense(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
112
112
 
113
113
  # Sanity check
114
114
  o_torch = torch.matmul(x_dense, y_dense)
115
+ o_torch = torch.add(o_torch, 1)
115
116
 
116
117
  # Perform round trip to set sparse blocks to 0
117
- o_torch_round_trip = to_dense(
118
- to_sparse(o_torch, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size),
118
+ o_torch_round_trip = bs.to_dense(
119
+ bs.to_sparse(o_torch, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size),
119
120
  sparsity_layout_o, sparsity_block_size, fill_value=0, triton_block_size=triton_block_size)
120
121
 
121
122
  # Assert that the output is correct
122
123
  assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2) # Note that small numerical differences are expected
123
124
 
124
125
  # Assert that the output has the correct sparsity layout
125
- actual_sparsity_layout_o = build_sparsity_layout(o_dense, sparsity_block_size, triton_block_size=triton_block_size)
126
+ actual_sparsity_layout_o = bs.layout.build_sparsity_layout(o_dense, sparsity_block_size,
127
+ triton_block_size=triton_block_size)
126
128
  assert torch.allclose(actual_sparsity_layout_o, sparsity_layout_o)
127
129
 
128
130
  # Convert output tensor back to original shape
129
- o = undo_shape_blocksparse(o_dense, x_shape_original)
131
+ o = bs.util.undo_shape_blocksparse(o_dense, x_shape_original)
130
132
 
131
133
  # Other available functions
132
- transpose(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
133
- softmax(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
134
- row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
134
+ bs.transpose(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
135
+ bs.softmax(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
136
+ bs.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
137
+ bs.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
135
138
 
136
139
 
137
140
  def _get_random_sparsity_layout(b, m, n, sparsity_block_size, sparsity_percentage):
@@ -0,0 +1,18 @@
1
+ from blksprs.ops.conversion import to_dense, to_sparse
2
+ from blksprs.ops.distribution import gather, scatter, scatter_reduce
3
+ from blksprs.ops.exp import exp
4
+ from blksprs.ops.matmul import matmul
5
+ from blksprs.ops.softmax import softmax
6
+ from blksprs.ops.transpose import transpose
7
+
8
+ class layout:
9
+ from blksprs.layouting.distribution_layout import build_distribution_layout
10
+ from blksprs.layouting.sparsity_layout import build_sparsity_layout, build_sparsity_layout_adaption
11
+
12
+ class misc:
13
+ from blksprs.misc.broadcast_ops import broadcast_add, broadcast_sub
14
+ from blksprs.misc.repeat_interleave import repeat_interleave
15
+ from blksprs.misc.row_wise import row_wise_sum, row_wise_max, row_wise_add, row_wise_sub
16
+
17
+ class util:
18
+ from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse
@@ -8,8 +8,8 @@ from blksprs.utils.validation import validate_contiguous, validate_device, \
8
8
  validate_sparsity_block_size, validate_triton_block_size
9
9
 
10
10
 
11
- def broadcast_addition(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
12
- sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
11
+ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
12
+ sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
13
13
  """Performs a broadcast and subsequent addition of two dense tensors x and y. Returns a block-sparse tensor in
14
14
  compressed form.
15
15
 
@@ -70,12 +70,12 @@ def broadcast_addition(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
70
70
  return output
71
71
 
72
72
 
73
- def broadcast_subtraction(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
74
- sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
75
- """Wrapper for ``broadcast_addition`` with negated y.
73
+ def broadcast_sub(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
74
+ sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
75
+ """Wrapper for ``broadcast_add`` with negated y.
76
76
 
77
77
  """
78
- return broadcast_addition(x, torch.neg(y), sparsity_layout_output, sparsity_block_size, triton_block_size)
78
+ return broadcast_add(x, torch.neg(y), sparsity_layout_output, sparsity_block_size, triton_block_size)
79
79
 
80
80
 
81
81
  @triton.jit
@@ -0,0 +1,130 @@
1
+ import torch
2
+ import triton
3
+ from torch import Tensor
4
+ from triton import language as tl
5
+
6
+ from blksprs.utils.tools import get_triton_block_size
7
+ from blksprs.utils.validation import validate_contiguous, validate_device, \
8
+ validate_sparsity_block_size, validate_triton_block_size, validate_dimensions
9
+
10
+
11
+ def repeat_interleave(x: Tensor, sparsity_layout: Tensor, repeats: int,
12
+ sparsity_block_size: int, triton_block_size: int = None) -> tuple[Tensor, Tensor]:
13
+ """Repeats and interleaves the block-sparse tensor in compressed form.
14
+
15
+ Repeats each matrix contained in the tensors by ``repeats`` amount and places them consecutively in the output
16
+ tensor.
17
+
18
+ Args:
19
+ x (Tensor): A block-sparse tensor in compressed form.
20
+ sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
21
+ repeats (int): The number of times to repeat the matrices.
22
+ sparsity_block_size (int): The size of the sparsity blocks.
23
+ triton_block_size (int): The block size to use for the triton kernel (default ``None``).
24
+
25
+ Returns:
26
+ Tensor: A block-sparse tensor in compressed form containing the repeated and interleaved matrices.
27
+ Tensor: The sparsity layout of the resulting output tensor.
28
+
29
+ """
30
+ validate_dimensions(x)
31
+ validate_contiguous(x)
32
+ validate_device(x)
33
+ validate_sparsity_block_size(sparsity_block_size, x)
34
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
35
+
36
+ sparsity_layout_output = torch.repeat_interleave(sparsity_layout, 3, dim=0).contiguous()
37
+
38
+ sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
39
+
40
+ sparsity_layout_output_flat = sparsity_layout_output.reshape(-1)
41
+ sparsity_output_reverse_lut = ((torch.cumsum(sparsity_layout_output_flat, dim=-1) - 1) *
42
+ (sparsity_layout_output_flat == 1) -
43
+ (1 * (sparsity_layout_output_flat == 0)))
44
+
45
+ n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
46
+
47
+ validate_contiguous(sparsity_layout, sparsity_lut, sparsity_layout_output, sparsity_output_reverse_lut)
48
+
49
+ output = torch.empty(n_sparse_blocks * repeats, sparsity_block_size, sparsity_block_size,
50
+ dtype=x.dtype, device=x.device)
51
+
52
+ x_b, x_r, x_c = x.size()
53
+ x_b_s, x_r_s, x_c_s = x.stride()
54
+ s_lut_r, s_lut_c = sparsity_lut.size()
55
+ s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
56
+ o_b, o_r, o_c = output.size()
57
+ o_b_s, o_r_s, o_c_s = output.stride()
58
+ s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
59
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_output.stride()
60
+
61
+ if triton_block_size is None:
62
+ triton_block_size = get_triton_block_size(sparsity_block_size)
63
+
64
+ triton_grid = lambda meta: [x_b,
65
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
66
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
67
+
68
+ (kernel_repeat_interleave[triton_grid]
69
+ (x,
70
+ x_b, x_b_s, x_r_s, x_c_s,
71
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
72
+ output,
73
+ o_b, o_b_s, o_r_s, o_c_s,
74
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
75
+ sparsity_output_reverse_lut,
76
+ repeats,
77
+ triton_block_size))
78
+
79
+ return output, sparsity_layout_output
80
+
81
+
82
+ @triton.jit
83
+ def kernel_repeat_interleave(x,
84
+ x_b, x_b_s, x_r_s, x_c_s,
85
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
86
+ o,
87
+ o_b, o_b_s, o_r_s, o_c_s,
88
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s, s_l_o_c_s,
89
+ r_lut_o,
90
+ repeats,
91
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
92
+ # Get triton block indices
93
+ pid_blk = tl.program_id(axis=0)
94
+ pid_row = tl.program_id(axis=1)
95
+ pid_col = tl.program_id(axis=2)
96
+
97
+ # Get sparsity index of current output block consisting of its batch, row, and column index
98
+ spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
99
+ spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
100
+ spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
101
+
102
+ spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
103
+ spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
104
+ spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
105
+
106
+ spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
107
+ spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
108
+ spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
109
+
110
+ # Load block
111
+ blk_x_idx = ((pid_blk * x_b_s) +
112
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
113
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
114
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
115
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
116
+
117
+ for repeat in range(repeats):
118
+ # Get reverse sparsity index
119
+ rev_idx_spa_idx = ((spa_bat * repeats + repeat) * s_l_o_b_s +
120
+ spa_row * s_l_o_r_s +
121
+ spa_col * s_l_o_c_s)
122
+ rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
123
+ rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
124
+
125
+ # Store block
126
+ blk_o_idx = ((rev_idx_spa * o_b_s) +
127
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
128
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
129
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
130
+ tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
@@ -0,0 +1,386 @@
1
+ import torch
2
+ import triton
3
+ from torch import Tensor
4
+ from triton import language as tl
5
+
6
+ from blksprs.utils.tools import get_triton_block_size
7
+ from blksprs.utils.validation import validate_dimensions, validate_contiguous, validate_device, validate_sparsity, \
8
+ validate_sparsity_block_size, validate_triton_block_size
9
+
10
+
11
+ def row_wise_sum(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
12
+ flag_slice_only: bool = False, triton_block_size: int = None) -> tuple[Tensor, Tensor]:
13
+ """Computes the row-wise sum of a block-sparse tensor.
14
+
15
+ Returns a block-sparse tensor in compressed form with only one block per row, where the first entry contains the sum
16
+ of the corresponding row.
17
+
18
+ Note:
19
+ If ``flag_slice_only`` is set the output will be of shape ``[x.size(0), x.size(1), 1]``.
20
+
21
+ Args:
22
+ x (Tensor): A block-sparse tensor in compressed form.
23
+ sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
24
+ sparsity_block_size (int): The size of the sparsity blocks.
25
+ flag_slice_only (bool, optional): If set the output will be of shape ``[x.size(0), x.size(1), 1]``
26
+ (default ``False``).
27
+ triton_block_size (int): The block size to use for the triton kernel (default ``None``).
28
+
29
+ Returns:
30
+ tuple[Tensor, Tensor]: A tuple containing a block-sparse tensor in compressed form containing the row-wise sum
31
+ of the input and the sparsity layout of the output tensor.
32
+
33
+ """
34
+ validate_dimensions(x)
35
+ validate_contiguous(x)
36
+ validate_device(x)
37
+ validate_sparsity(sparsity_block_size, (x, sparsity_layout))
38
+ validate_sparsity_block_size(sparsity_block_size, x)
39
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
40
+
41
+ sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
42
+
43
+ sparsity_layout_output, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)
44
+ sparsity_layout_output_flat = sparsity_layout_output.reshape(-1)
45
+ sparsity_reverse_lut_output = ((torch.cumsum(sparsity_layout_output_flat, dim=-1) - 1) *
46
+ (sparsity_layout_output_flat == 1) -
47
+ (1 * (sparsity_layout_output_flat == 0)))
48
+
49
+ n_sparse_blocks_output = torch.sum(sparsity_layout_output.to(torch.int)).item()
50
+
51
+ validate_contiguous(sparsity_layout, sparsity_lut,
52
+ sparsity_layout_output, sparsity_reverse_lut_output)
53
+
54
+ output = torch.zeros(size=(n_sparse_blocks_output,
55
+ sparsity_block_size,
56
+ 1 if flag_slice_only else sparsity_block_size),
57
+ device=x.device)
58
+
59
+ x_b, x_r, x_c = x.size()
60
+ x_b_s, x_r_s, x_c_s = x.stride()
61
+ s_lut_x_r, s_lut_x_c = sparsity_lut.size()
62
+ s_lut_x_r_s, s_lut_x_c_s = sparsity_lut.stride()
63
+ o_b, o_r, o_c = output.size()
64
+ o_b_s, o_r_s, o_c_s = output.stride()
65
+ s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
66
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_output.stride()
67
+
68
+ if triton_block_size is None:
69
+ triton_block_size = get_triton_block_size(sparsity_block_size)
70
+
71
+ triton_grid = lambda meta: [x_b,
72
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
73
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
74
+
75
+ (kernel_blocksparse_row_wise_sum[triton_grid]
76
+ (x,
77
+ x_b, x_b_s, x_r_s, x_c_s,
78
+ sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
79
+ output,
80
+ o_b, o_b_s, o_r_s,
81
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s,
82
+ sparsity_reverse_lut_output,
83
+ triton_block_size))
84
+
85
+ return (output, sparsity_layout_output)
86
+
87
+
88
+ @triton.jit
89
+ def kernel_blocksparse_row_wise_sum(x,
90
+ x_b, x_b_s, x_r_s, x_c_s,
91
+ s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
92
+ o,
93
+ o_b, o_b_s, o_r_s,
94
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s,
95
+ r_lut_o,
96
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
97
+ pid_blk = tl.program_id(axis=0)
98
+ pid_row = tl.program_id(axis=1)
99
+ pid_col = tl.program_id(axis=2)
100
+
101
+ # Get position of current sparsity block consisting of its batch and row index
102
+ spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
103
+ spa_bat_msk = (spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
104
+ spa_bat = tl.load(s_lut_x + spa_bat_idx, mask=spa_bat_msk)
105
+
106
+ spa_row_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
107
+ spa_row_msk = (spa_row_idx < s_lut_x_r * s_lut_x_r_s)
108
+ spa_row = tl.load(s_lut_x + spa_row_idx, mask=spa_row_msk)
109
+
110
+ # Load reverse sparsity index for current block
111
+ rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
112
+ spa_row * s_l_o_r_s)
113
+ rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
114
+ rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
115
+
116
+ blk_idx = ((pid_blk * x_b_s) +
117
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
118
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
119
+ blk_msk = (blk_idx < x_b * x_b_s)
120
+ blk = tl.load(x + blk_idx, mask=blk_msk)
121
+
122
+ buf = tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
123
+
124
+ o_idx = (rev_idx_spa * o_b_s +
125
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
126
+ (tl.arange(0, 1))[None, :])
127
+ o_msk = (o_idx < o_b * o_b_s)
128
+ tl.atomic_add(o + o_idx, buf, o_msk)
129
+
130
+
131
+ def row_wise_max(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
132
+ flag_slice_only: bool = False, triton_block_size: int = None) -> tuple[Tensor, Tensor]:
133
+ """Computes the row-wise max of a block-sparse tensor.
134
+
135
+ Returns a block-sparse tensor in compressed form with only one block per row, where the first entry contains the
136
+ maximum of the corresponding row.
137
+
138
+ Note:
139
+ If ``flag_slice_only`` is set the output will be of shape ``[x.size(0), x.size(1), 1]``.
140
+
141
+ Args:
142
+ x (Tensor): A block-sparse tensor in compressed form.
143
+ sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
144
+ sparsity_block_size (int): The size of the sparsity blocks.
145
+ flag_slice_only (bool, optional): If set the output will be of shape ``[x.size(0), x.size(1), 1]``
146
+ (default ``False``).
147
+ triton_block_size (int): The block size to use for the triton kernel (default ``None``).
148
+
149
+ Returns:
150
+ tuple[Tensor, Tensor]: A tuple containing a block-sparse tensor in compressed form containing the row-wise max
151
+ of the input and the sparsity layout of the output tensor.
152
+
153
+ """
154
+ validate_dimensions(x)
155
+ validate_contiguous(x)
156
+ validate_device(x)
157
+ validate_sparsity(sparsity_block_size, (x, sparsity_layout))
158
+ validate_sparsity_block_size(sparsity_block_size, x)
159
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
160
+
161
+ sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
162
+
163
+ sparsity_layout_output, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)
164
+ sparsity_layout_output_flat = sparsity_layout_output.reshape(-1)
165
+ sparsity_reverse_lut_output = ((torch.cumsum(sparsity_layout_output_flat, dim=-1) - 1) *
166
+ (sparsity_layout_output_flat == 1) -
167
+ (1 * (sparsity_layout_output_flat == 0)))
168
+
169
+ n_sparse_blocks_output = torch.sum(sparsity_layout_output.to(torch.int)).item()
170
+
171
+ validate_contiguous(sparsity_layout, sparsity_lut,
172
+ sparsity_layout_output, sparsity_reverse_lut_output)
173
+
174
+ output = torch.full(size=(n_sparse_blocks_output,
175
+ sparsity_block_size,
176
+ 1 if flag_slice_only else sparsity_block_size),
177
+ fill_value=float("-inf"),
178
+ device=x.device)
179
+
180
+ x_b, x_r, x_c = x.size()
181
+ x_b_s, x_r_s, x_c_s = x.stride()
182
+ s_lut_x_r, s_lut_x_c = sparsity_lut.size()
183
+ s_lut_x_r_s, s_lut_x_c_s = sparsity_lut.stride()
184
+ o_b, o_r, o_c = output.size()
185
+ o_b_s, o_r_s, o_c_s = output.stride()
186
+ s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
187
+ s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_output.stride()
188
+
189
+ if triton_block_size is None:
190
+ triton_block_size = get_triton_block_size(sparsity_block_size)
191
+
192
+ triton_grid = lambda meta: [x_b,
193
+ triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
194
+ triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
195
+
196
+ (kernel_blocksparse_row_wise_max[triton_grid]
197
+ (x,
198
+ x_b, x_b_s, x_r_s, x_c_s,
199
+ sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
200
+ output,
201
+ o_b, o_b_s, o_r_s,
202
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s,
203
+ sparsity_reverse_lut_output,
204
+ triton_block_size))
205
+
206
+ return output, sparsity_layout_output
207
+
208
+
209
+ @triton.jit
210
+ def kernel_blocksparse_row_wise_max(x,
211
+ x_b, x_b_s, x_r_s, x_c_s,
212
+ s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
213
+ o,
214
+ o_b, o_b_s, o_r_s,
215
+ s_l_o_b, s_l_o_b_s, s_l_o_r_s,
216
+ r_lut_o,
217
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
218
+ pid_blk = tl.program_id(axis=0)
219
+ pid_row = tl.program_id(axis=1)
220
+ pid_col = tl.program_id(axis=2)
221
+
222
+ # Get position of current sparsity block consisting of its batch and row index
223
+ spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
224
+ spa_bat_msk = (spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
225
+ spa_bat = tl.load(s_lut_x + spa_bat_idx, mask=spa_bat_msk)
226
+
227
+ spa_row_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
228
+ spa_row_msk = (spa_row_idx < s_lut_x_r * s_lut_x_r_s)
229
+ spa_row = tl.load(s_lut_x + spa_row_idx, mask=spa_row_msk)
230
+
231
+ # Load reverse sparsity index for current block
232
+ rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
233
+ spa_row * s_l_o_r_s)
234
+ rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
235
+ rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
236
+
237
+ blk_idx = ((pid_blk * x_b_s) +
238
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
239
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
240
+ blk_msk = (blk_idx < x_b * x_b_s)
241
+ blk = tl.load(x + blk_idx, mask=blk_msk)
242
+
243
+ buf = tl.reshape(tl.max(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
244
+
245
+ o_idx = (rev_idx_spa * o_b_s +
246
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
247
+ (tl.arange(0, 1))[None, :])
248
+ o_msk = (o_idx < o_b * o_b_s)
249
+ tl.atomic_max(o + o_idx, buf, o_msk)
250
+
251
+
252
+ def row_wise_add(x: Tensor, sparsity_layout_x: Tensor, y: Tensor,
253
+ sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
254
+ """For each row in ``y`` adds the value to each value in the corresponding row of the block-sparse tensor ``x``.
255
+
256
+ Args:
257
+ x (Tensor): A block-sparse tensor in compressed form.
258
+ sparsity_layout_x (Tensor): The sparsity layout of the block-sparse tensor.
259
+ y (Tensor): A block-sparse tensor in compressed form with only one value per row and a single column of sparse blocks.
260
+ sparsity_block_size (int): The size of the sparsity blocks.
261
+ triton_block_size (int): The block size to use for the triton kernel (default ``None``).
262
+
263
+ Returns:
264
+ Tensor: The values of ``x`` with the first value of ``y`` in each row added to them as a block-sparse tensor in
265
+ compressed form.
266
+
267
+ """
268
+ validate_dimensions(x)
269
+ validate_contiguous(x)
270
+ validate_device(x)
271
+ validate_sparsity(sparsity_block_size, (x, sparsity_layout_x))
272
+ validate_sparsity_block_size(sparsity_block_size, x)
273
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
274
+
275
+ sparsity_lut = torch.nonzero(sparsity_layout_x).contiguous()
276
+
277
+ sparsity_layout_rwm, _ = torch.max(sparsity_layout_x, dim=-1, keepdim=True)
278
+ sparsity_layout_rwm_flat = sparsity_layout_rwm.reshape(-1)
279
+ sparsity_reverse_lut_rwm = ((torch.cumsum(sparsity_layout_rwm_flat, dim=-1) - 1) *
280
+ (sparsity_layout_rwm_flat == 1) -
281
+ (1 * (sparsity_layout_rwm_flat == 0)))
282
+
283
+ validate_contiguous(sparsity_layout_x, sparsity_lut, sparsity_reverse_lut_rwm)
284
+
285
+ output = torch.empty_like(x)
286
+
287
+ x_b, x_r, x_c = x.size()
288
+ x_b_s, x_r_s, x_c_s = x.stride()
289
+ s_lut_r, s_lut_c = sparsity_lut.size()
290
+ s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
291
+ y_b, y_r, y_c = y.size()
292
+ y_b_s, y_r_s, y_c_s = y.stride()
293
+ s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_rwm.size()
294
+ s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = sparsity_layout_rwm.stride()
295
+ o_b, o_r, o_c = output.size()
296
+ o_b_s, o_r_s, o_c_s = output.stride()
297
+
298
+ if triton_block_size is None:
299
+ triton_block_size = get_triton_block_size(sparsity_block_size)
300
+
301
+ triton_grid = lambda meta: [o_b,
302
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
303
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
304
+
305
+ (kernel_blocksparse_row_wise_add[triton_grid]
306
+ (x,
307
+ x_b, x_b_s, x_r_s, x_c_s,
308
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
309
+ y, y_b, y_b_s, y_r_s, y_c_s,
310
+ s_l_y_b, s_l_y_b_s, s_l_y_r_s,
311
+ sparsity_reverse_lut_rwm,
312
+ output,
313
+ o_b, o_b_s, o_r_s, o_c_s,
314
+ triton_block_size
315
+ ))
316
+
317
+ return output
318
+
319
+
320
+ def row_wise_sub(x: Tensor, sparsity_layout_x: Tensor, y: Tensor,
321
+ sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
322
+ """Wrapper for ``row_wise_add`` with negated y.
323
+
324
+ """
325
+ return row_wise_add(x, sparsity_layout_x, torch.neg(y), sparsity_block_size, triton_block_size)
326
+
327
+
328
+ @triton.jit
329
+ def kernel_blocksparse_row_wise_add(x,
330
+ x_b, x_b_s, x_r_s, x_c_s,
331
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
332
+ y, y_b, y_b_s, y_r_s, y_c_s,
333
+ s_l_y_b, s_l_y_b_s, s_l_y_r_s,
334
+ r_lut_y,
335
+ o,
336
+ o_b, o_b_s, o_r_s, o_c_s,
337
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
338
+ # Get triton block indices
339
+ pid_blk = tl.program_id(axis=0)
340
+ pid_row = tl.program_id(axis=1)
341
+ pid_col = tl.program_id(axis=2)
342
+
343
+ # Get position of current sparsity block consisting of its batch and row index
344
+ spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
345
+ spa_bat_msk = (spa_bat_idx < s_lut_r * s_lut_r_s)
346
+ spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
347
+
348
+ spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
349
+ spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
350
+ spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
351
+
352
+ # Get reverse sparsity indices for s
353
+ rev_idx_spa_s_idx = (spa_bat * s_l_y_b_s +
354
+ spa_row * s_l_y_r_s)
355
+ rev_idx_spa_s_msk = (rev_idx_spa_s_idx < s_l_y_b * s_l_y_b_s)
356
+ rev_idx_spa_s = tl.load(r_lut_y + rev_idx_spa_s_idx, mask=rev_idx_spa_s_msk).to(tl.int32)
357
+
358
+ if rev_idx_spa_s == -1:
359
+ assert False, "Invalid sparsity block"
360
+
361
+ # Load x block
362
+ blk_x_idx = ((pid_blk * x_b_s) +
363
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
364
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
365
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
366
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
367
+
368
+ # Load sum block
369
+ blk_s_idx = (rev_idx_spa_s * y_b_s +
370
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
371
+ (tl.arange(0, 1) * y_c_s)[None, :])
372
+ blk_s_msk = (blk_s_idx < y_b * y_b_s)
373
+ blk_s = tl.load(y + blk_s_idx, mask=blk_s_msk)
374
+
375
+ # Compute exp
376
+ buf = blk_x + tl.broadcast_to(blk_s, (TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE))
377
+
378
+ # debug
379
+ asdf = tl.full((TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), 1.0, dtype=tl.float32)
380
+
381
+ # Store block
382
+ blk_o_idx = ((pid_blk * o_b_s) +
383
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
384
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
385
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
386
+ tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
@@ -4,7 +4,7 @@ from torch import Tensor
4
4
  from triton import language as tl
5
5
 
6
6
  from blksprs.ops.exp import exp
7
- from blksprs.ops.row_wise_sum import row_wise_sum
7
+ from blksprs.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
8
8
  from blksprs.utils.tools import get_triton_block_size
9
9
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
10
10
  validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
@@ -33,12 +33,6 @@ def softmax(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton
33
33
  validate_sparsity_block_size(sparsity_block_size, x)
34
34
  validate_triton_block_size(triton_block_size, sparsity_block_size)
35
35
 
36
- if x.size(0) != 0:
37
- max_val = torch.max(x).item()
38
- else:
39
- max_val = 0
40
- x_scaled = x - max_val
41
-
42
36
  sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
43
37
 
44
38
  sparsity_layout_rws, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)
@@ -49,7 +43,7 @@ def softmax(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton
49
43
 
50
44
  validate_contiguous(sparsity_layout, sparsity_lut, sparsity_reverse_lut_rws)
51
45
 
52
- return _BlocksparseSoftmax.apply(x_scaled, sparsity_layout,
46
+ return _BlocksparseSoftmax.apply(x, sparsity_layout,
53
47
  sparsity_lut,
54
48
  sparsity_reverse_lut_rws,
55
49
  sparsity_block_size, triton_block_size)
@@ -64,13 +58,17 @@ class _BlocksparseSoftmax(torch.autograd.Function):
64
58
  sparsity_block_size: int, triton_block_size: int) -> Tensor:
65
59
  output = torch.empty_like(x)
66
60
 
67
- x_b, x_r, x_c = x.shape
61
+ x_b, x_r, x_c = x.size()
68
62
  x_b_s, x_r_s, x_c_s = x.stride()
69
- s_lut_r, s_lut_c = sparsity_lut.shape
63
+ s_lut_r, s_lut_c = sparsity_lut.size()
70
64
  s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
71
- o_b, o_r, o_c = output.shape
65
+ o_b, o_r, o_c = output.size()
72
66
 
73
- x_exp = exp(x, sparsity_block_size, triton_block_size=triton_block_size)
67
+ x_row_wise_max, sparsity_layout_rwm = row_wise_max(x, sparsity_layout, sparsity_block_size,
68
+ flag_slice_only=True,
69
+ triton_block_size=triton_block_size)
70
+ x_scaled = row_wise_sub(x, sparsity_layout, x_row_wise_max, sparsity_block_size, triton_block_size)
71
+ x_exp = exp(x_scaled, sparsity_block_size, triton_block_size=triton_block_size)
74
72
  x_exp_row_wise_sum, sparsity_layout_rws = row_wise_sum(x_exp, sparsity_layout, sparsity_block_size,
75
73
  flag_slice_only=True,
76
74
  triton_block_size=triton_block_size)
@@ -174,7 +172,7 @@ class _BlocksparseSoftmax(torch.autograd.Function):
174
172
  spa_row_msk = (spa_row_idx < s_lut_r * s_lut_r_s)
175
173
  spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
176
174
 
177
- # Get reverse sparsity indices for x
175
+ # Get reverse sparsity indices for s
178
176
  rev_idx_spa_s_idx = (spa_bat * s_l_s_b_s +
179
177
  spa_row * s_l_s_r_s)
180
178
  rev_idx_spa_s_msk = (rev_idx_spa_s_idx < s_l_s_b * s_l_s_b_s)
@@ -129,7 +129,7 @@ class _BlocksparseTranspose(torch.autograd.Function):
129
129
  spa_col_msk = (spa_col_idx < s_lut_r * s_lut_r_s)
130
130
  spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
131
131
 
132
- # Get reverse sparsity indices
132
+ # Get reverse sparsity index
133
133
  rev_idx_spa_idx = (spa_bat * s_l_b_s +
134
134
  spa_row * s_l_r_s +
135
135
  spa_col * s_l_c_s)
@@ -10,7 +10,7 @@ def do_shape_blocksparse(x: Tensor):
10
10
 
11
11
 
12
12
  def undo_shape_blocksparse(x: Tensor, shape: Size):
13
- if x.shape[-2:] == shape[-2:]:
13
+ if x.shape[:-2] == shape[:-2]:
14
14
  return x
15
15
 
16
16
  return x.reshape((*shape[:-2], *x.shape[-2:]))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: blksprs
3
- Version: 1.2.1
3
+ Version: 1.4
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -83,14 +83,7 @@ the [test cases](https://github.com/FelixSchoen/blksprs/blob/main/test/cases/tes
83
83
 
84
84
  ```python
85
85
  import torch
86
-
87
- from blksprs.layouting.sparsity_layout import build_sparsity_layout
88
- from blksprs.ops.conversion import to_sparse, to_dense
89
- from blksprs.ops.matmul import matmul
90
- from blksprs.ops.row_wise_sum import row_wise_sum
91
- from blksprs.ops.softmax import softmax
92
- from blksprs.ops.transpose import transpose
93
- from blksprs.utils.tools import do_shape_blocksparse, undo_shape_blocksparse
86
+ import blksprs as bs
94
87
 
95
88
 
96
89
  def test_readme():
@@ -112,47 +105,57 @@ def test_readme():
112
105
  y = torch.randn(size=(b, h, n, k), device="cuda").transpose(-1, -2).contiguous()
113
106
 
114
107
  # Convert tensors to three-dimensional (dense) tensors since Triton can only handle tensors of exactly three dimensions
115
- x_dense, x_shape_original = do_shape_blocksparse(x)
116
- y_dense, y_shape_original = do_shape_blocksparse(y)
108
+ x_dense, x_shape_original = bs.util.do_shape_blocksparse(x)
109
+ y_dense, y_shape_original = bs.util.do_shape_blocksparse(y)
117
110
 
118
111
  # Create sparsity layouts from existing tensors
119
- sparsity_layout_x = build_sparsity_layout(x_dense, sparsity_block_size, triton_block_size=triton_block_size)
120
- sparsity_layout_y = build_sparsity_layout(y_dense, sparsity_block_size, triton_block_size=triton_block_size)
112
+ sparsity_layout_x = bs.layout.build_sparsity_layout(x_dense, sparsity_block_size,
113
+ triton_block_size=triton_block_size)
114
+ sparsity_layout_y = bs.layout.build_sparsity_layout(y_dense, sparsity_block_size,
115
+ triton_block_size=triton_block_size)
121
116
 
122
117
  # Create random sparsity layout for output tensor
123
118
  sparsity_layout_o = _get_random_sparsity_layout(b * h, m, n, sparsity_block_size, sparsity_percentage)
124
119
 
125
120
  # Convert tensors to sparse tensors for matrix multiplication
126
- x_sparse = to_sparse(x_dense, sparsity_layout_x, sparsity_block_size, triton_block_size=triton_block_size)
127
- y_sparse = to_sparse(y_dense, sparsity_layout_y, sparsity_block_size, triton_block_size=triton_block_size)
121
+ x_sparse = bs.to_sparse(x_dense, sparsity_layout_x, sparsity_block_size, triton_block_size=triton_block_size)
122
+ y_sparse = bs.to_sparse(y_dense, sparsity_layout_y, sparsity_block_size, triton_block_size=triton_block_size)
128
123
 
129
124
  # Perform matrix multiplication
130
- o_sparse = matmul(x_sparse, sparsity_layout_x, y_sparse, sparsity_layout_y, sparsity_layout_o, sparsity_block_size,
131
- triton_block_size=triton_block_size)
132
- o_dense = to_dense(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
125
+ o_sparse = bs.matmul(x_sparse, sparsity_layout_x, y_sparse, sparsity_layout_y, sparsity_layout_o,
126
+ sparsity_block_size,
127
+ triton_block_size=triton_block_size)
128
+
129
+ # Apply element-wise operation
130
+ o_sparse = torch.add(o_sparse, 1)
131
+
132
+ o_dense = bs.to_dense(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
133
133
 
134
134
  # Sanity check
135
135
  o_torch = torch.matmul(x_dense, y_dense)
136
+ o_torch = torch.add(o_torch, 1)
136
137
 
137
138
  # Perform round trip to set sparse blocks to 0
138
- o_torch_round_trip = to_dense(
139
- to_sparse(o_torch, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size),
139
+ o_torch_round_trip = bs.to_dense(
140
+ bs.to_sparse(o_torch, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size),
140
141
  sparsity_layout_o, sparsity_block_size, fill_value=0, triton_block_size=triton_block_size)
141
142
 
142
143
  # Assert that the output is correct
143
144
  assert torch.allclose(o_dense, o_torch_round_trip, atol=2e-2) # Note that small numerical differences are expected
144
145
 
145
146
  # Assert that the output has the correct sparsity layout
146
- actual_sparsity_layout_o = build_sparsity_layout(o_dense, sparsity_block_size, triton_block_size=triton_block_size)
147
+ actual_sparsity_layout_o = bs.layout.build_sparsity_layout(o_dense, sparsity_block_size,
148
+ triton_block_size=triton_block_size)
147
149
  assert torch.allclose(actual_sparsity_layout_o, sparsity_layout_o)
148
150
 
149
151
  # Convert output tensor back to original shape
150
- o = undo_shape_blocksparse(o_dense, x_shape_original)
152
+ o = bs.util.undo_shape_blocksparse(o_dense, x_shape_original)
151
153
 
152
154
  # Other available functions
153
- transpose(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
154
- softmax(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
155
- row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
155
+ bs.transpose(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
156
+ bs.softmax(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
157
+ bs.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
158
+ bs.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size, triton_block_size=triton_block_size)
156
159
 
157
160
 
158
161
  def _get_random_sparsity_layout(b, m, n, sparsity_block_size, sparsity_percentage):
@@ -1,5 +1,6 @@
1
1
  README.md
2
2
  pyproject.toml
3
+ blksprs/__init__.py
3
4
  blksprs.egg-info/PKG-INFO
4
5
  blksprs.egg-info/SOURCES.txt
5
6
  blksprs.egg-info/dependency_links.txt
@@ -7,12 +8,13 @@ blksprs.egg-info/requires.txt
7
8
  blksprs.egg-info/top_level.txt
8
9
  blksprs/layouting/distribution_layout.py
9
10
  blksprs/layouting/sparsity_layout.py
10
- blksprs/misc/broadcast_addition.py
11
+ blksprs/misc/broadcast_ops.py
12
+ blksprs/misc/repeat_interleave.py
13
+ blksprs/misc/row_wise.py
11
14
  blksprs/ops/conversion.py
12
15
  blksprs/ops/distribution.py
13
16
  blksprs/ops/exp.py
14
17
  blksprs/ops/matmul.py
15
- blksprs/ops/row_wise_sum.py
16
18
  blksprs/ops/softmax.py
17
19
  blksprs/ops/transpose.py
18
20
  blksprs/utils/benchmarking.py
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "blksprs"
3
- version = "1.2.1"
3
+ version = "1.4"
4
4
  authors = [{ name = "Felix Schön", email = "schoen@kr.tuwien.ac.at" }]
5
5
  description = "A lightweight library for operations on blocksparse matrices in PyTorch."
6
6
  readme = "README.md"
@@ -1,231 +0,0 @@
1
- import torch
2
- import triton
3
- from torch import Tensor
4
- from triton import language as tl
5
-
6
- from blksprs.utils.tools import get_triton_block_size
7
- from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
8
- validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
9
-
10
-
11
- def row_wise_sum(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
12
- flag_slice_only: bool = False, triton_block_size: int = None) -> tuple[Tensor, Tensor]:
13
- """Computes the row-wise sum of a block-sparse tensor.
14
-
15
- Returns a block-sparse tensor in compressed form with only one block per row, where the first entry contains the sum
16
- of the corresponding row.
17
-
18
- Note:
19
- If ``flag_slice_only`` is set the output will be of shape ``[x.size(0), x.size(1), 1]``.
20
-
21
- Args:
22
- x (Tensor): A block-sparse tensor in compressed form.
23
- sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
24
- sparsity_block_size (int): The size of the sparsity blocks.
25
- flag_slice_only (bool, optional): If set the output will be of shape ``[x.size(0), x.size(1), 1]``
26
- (default ``False``).
27
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
28
-
29
- Returns:
30
- tuple[Tensor, Tensor]: A tuple containing a block-sparse tensor in compressed form containing the row-wise sum
31
- of the input and the sparsity layout of the output tensor.
32
-
33
- """
34
- validate_dimensions(x)
35
- validate_contiguous(x)
36
- validate_device(x)
37
- validate_sparsity(sparsity_block_size, (x, sparsity_layout))
38
- validate_sparsity_block_size(sparsity_block_size, x)
39
- validate_triton_block_size(triton_block_size, sparsity_block_size)
40
-
41
- sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
42
- sparsity_layout_flat = sparsity_layout.reshape(-1)
43
- sparsity_reverse_lut = ((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
44
- (sparsity_layout_flat == 1) -
45
- (1 * (sparsity_layout_flat == 0)))
46
-
47
- sparsity_layout_output, _ = torch.max(sparsity_layout, dim=-1, keepdim=True)
48
- sparsity_lut_output = torch.nonzero(sparsity_layout_output).contiguous()
49
- sparsity_layout_output_flat = sparsity_layout_output.reshape(-1)
50
- sparsity_reverse_lut_output = ((torch.cumsum(sparsity_layout_output_flat, dim=-1) - 1) *
51
- (sparsity_layout_output_flat == 1) -
52
- (1 * (sparsity_layout_output_flat == 0)))
53
-
54
- n_sparse_blocks_output = torch.sum(sparsity_layout_output.to(torch.int)).item()
55
-
56
- validate_contiguous(sparsity_layout, sparsity_lut, sparsity_reverse_lut,
57
- sparsity_layout_output, sparsity_lut_output, sparsity_reverse_lut_output)
58
-
59
- return (_BlocksparseRowWiseSum.apply(x,
60
- sparsity_layout, sparsity_lut, sparsity_reverse_lut,
61
- sparsity_layout_output, sparsity_lut_output, sparsity_reverse_lut_output,
62
- n_sparse_blocks_output,
63
- flag_slice_only,
64
- sparsity_block_size, triton_block_size),
65
- sparsity_layout_output)
66
-
67
-
68
- class _BlocksparseRowWiseSum(torch.autograd.Function):
69
- IMPLEMENTATION = "atomic_add"
70
-
71
- @staticmethod
72
- def forward(ctx, x: Tensor,
73
- sparsity_layout: Tensor, sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
74
- sparsity_layout_output: Tensor, sparsity_lut_output: Tensor, sparsity_reverse_lut_output: Tensor,
75
- n_sparse_blocks_output: int,
76
- flag_slice_only: bool,
77
- sparsity_block_size: int, triton_block_size: int) -> Tensor:
78
- output = torch.zeros(size=(n_sparse_blocks_output,
79
- sparsity_block_size,
80
- 1 if flag_slice_only else sparsity_block_size),
81
- device=x.device)
82
-
83
- x_b, x_r, x_c = x.size()
84
- x_b_s, x_r_s, x_c_s = x.stride()
85
- s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout.size()
86
- s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout.stride()
87
- s_lut_x_r, s_lut_x_c = sparsity_lut.size()
88
- s_lut_x_r_s, s_lut_x_c_s = sparsity_lut.stride()
89
- o_b, o_r, o_c = output.size()
90
- o_b_s, o_r_s, o_c_s = output.stride()
91
- s_l_o_b, s_l_o_r, s_l_o_c = sparsity_layout_output.size()
92
- s_l_o_b_s, s_l_o_r_s, s_l_o_c_s = sparsity_layout_output.stride()
93
- s_lut_o_r, s_lut_o_c = sparsity_lut_output.size()
94
- s_lut_o_r_s, s_lut_o_c_s = sparsity_lut_output.stride()
95
-
96
- if triton_block_size is None:
97
- triton_block_size = get_triton_block_size(sparsity_block_size)
98
-
99
- if _BlocksparseRowWiseSum.IMPLEMENTATION == "basic":
100
- triton_grid = lambda meta: [o_b,
101
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"])]
102
-
103
- (_BlocksparseRowWiseSum.kernel_blocksparse_row_wise_sum[triton_grid]
104
- (x,
105
- x_b, x_b_s, x_r_s, x_c_s,
106
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c, s_l_x_c_s,
107
- sparsity_reverse_lut,
108
- output,
109
- o_b, o_b_s, o_r_s,
110
- sparsity_lut_output, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
111
- sparsity_block_size,
112
- triton_block_size))
113
- elif _BlocksparseRowWiseSum.IMPLEMENTATION == "atomic_add":
114
- triton_grid = lambda meta: [x_b,
115
- triton.cdiv(x_r, meta["TRITON_BLOCK_SIZE"]),
116
- triton.cdiv(x_c, meta["TRITON_BLOCK_SIZE"])]
117
-
118
- (_BlocksparseRowWiseSum.kernel_blocksparse_row_wise_sum_atomic_add[triton_grid]
119
- (x,
120
- x_b, x_b_s, x_r_s, x_c_s,
121
- sparsity_lut, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
122
- output,
123
- o_b, o_b_s, o_r_s,
124
- s_l_o_b, s_l_o_b_s, s_l_o_r_s,
125
- sparsity_reverse_lut_output,
126
- triton_block_size))
127
-
128
- return output
129
-
130
- @staticmethod
131
- def backward(ctx, grad_output):
132
- raise NotImplementedError
133
-
134
- @staticmethod
135
- @triton.jit
136
- def kernel_blocksparse_row_wise_sum(x,
137
- x_b, x_b_s, x_r_s, x_c_s,
138
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c, s_l_x_c_s,
139
- r_lut_x,
140
- o,
141
- o_b, o_b_s, o_r_s,
142
- s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
143
- sparsity_block_size,
144
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
145
- pid_blk = tl.program_id(axis=0)
146
- pid_row = tl.program_id(axis=1)
147
-
148
- # Get position of current sparsity block consisting of its batch and row index
149
- spa_bat_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
150
- spa_bat_msk = (spa_bat_idx < s_lut_o_r * s_lut_o_r_s)
151
- spa_bat = tl.load(s_lut_o + spa_bat_idx, mask=spa_bat_msk)
152
-
153
- spa_row_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
154
- spa_row_msk = (spa_row_idx < s_lut_o_r * s_lut_o_r_s)
155
- spa_row = tl.load(s_lut_o + spa_row_idx, mask=spa_row_msk)
156
-
157
- buf = tl.zeros(shape=(TRITON_BLOCK_SIZE, 1), dtype=tl.float32)
158
-
159
- # Slide over triton block sized segments of input tensor
160
- for i_seg_tri in range(0, tl.cdiv(s_l_x_c * sparsity_block_size, TRITON_BLOCK_SIZE)):
161
- # Convert to segment index of sparsity layout
162
- i_seg_spa = (i_seg_tri * TRITON_BLOCK_SIZE) // sparsity_block_size
163
- # Calculate the triton segment index within a block
164
- i_seg_tri_mod = i_seg_tri % (sparsity_block_size // TRITON_BLOCK_SIZE)
165
-
166
- # Load reverse sparsity index for current block
167
- rev_idx_spa_idx = (spa_bat * s_l_x_b_s +
168
- spa_row * s_l_x_r_s +
169
- i_seg_spa * s_l_x_c_s)
170
- rev_idx_spa_msk = (rev_idx_spa_idx < s_l_x_b * s_l_x_b_s)
171
- rev_idx_spa = tl.load(r_lut_x + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
172
-
173
- # If block is present commence operations
174
- if rev_idx_spa >= 0:
175
- blk_idx = ((rev_idx_spa * x_b_s) +
176
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
177
- ((i_seg_tri_mod * TRITON_BLOCK_SIZE +
178
- tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
179
- blk_msk = (blk_idx < x_b * x_b_s)
180
- blk = tl.load(x + blk_idx, mask=blk_msk)
181
-
182
- buf = buf + tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
183
-
184
- o_idx = (pid_blk * o_b_s +
185
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
186
- (tl.arange(0, 1))[None, :])
187
- o_msk = (o_idx < o_b * o_b_s)
188
- tl.store(o + o_idx, buf, o_msk)
189
-
190
- @staticmethod
191
- @triton.jit
192
- def kernel_blocksparse_row_wise_sum_atomic_add(x,
193
- x_b, x_b_s, x_r_s, x_c_s,
194
- s_lut_x, s_lut_x_r, s_lut_x_r_s, s_lut_x_c_s,
195
- o,
196
- o_b, o_b_s, o_r_s,
197
- s_l_o_b, s_l_o_b_s, s_l_o_r_s,
198
- r_lut_o,
199
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
200
- pid_blk = tl.program_id(axis=0)
201
- pid_row = tl.program_id(axis=1)
202
- pid_col = tl.program_id(axis=2)
203
-
204
- # Get position of current sparsity block consisting of its batch and row index
205
- spa_bat_idx = (pid_blk * s_lut_x_r_s + 0 * s_lut_x_c_s)
206
- spa_bat_msk = (spa_bat_idx < s_lut_x_r * s_lut_x_r_s)
207
- spa_bat = tl.load(s_lut_x + spa_bat_idx, mask=spa_bat_msk)
208
-
209
- spa_row_idx = (pid_blk * s_lut_x_r_s + 1 * s_lut_x_c_s)
210
- spa_row_msk = (spa_row_idx < s_lut_x_r * s_lut_x_r_s)
211
- spa_row = tl.load(s_lut_x + spa_row_idx, mask=spa_row_msk)
212
-
213
- # Load reverse sparsity index for current block
214
- rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
215
- spa_row * s_l_o_r_s)
216
- rev_idx_spa_msk = (rev_idx_spa_idx < s_l_o_b * s_l_o_b_s)
217
- rev_idx_spa = tl.load(r_lut_o + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
218
-
219
- blk_idx = ((pid_blk * x_b_s) +
220
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
221
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
222
- blk_msk = (blk_idx < x_b * x_b_s)
223
- blk = tl.load(x + blk_idx, mask=blk_msk)
224
-
225
- buf = tl.reshape(tl.sum(blk, axis=-1), (TRITON_BLOCK_SIZE, 1))
226
-
227
- o_idx = (rev_idx_spa * o_b_s +
228
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
229
- (tl.arange(0, 1))[None, :])
230
- o_msk = (o_idx < o_b * o_b_s)
231
- tl.atomic_add(o + o_idx, buf, o_msk)
File without changes
File without changes
File without changes
File without changes