blksprs 1.10.2__py3-none-any.whl → 2.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/ops/matmul.py CHANGED
@@ -1,19 +1,20 @@
1
1
  import torch
2
2
  import triton
3
3
  from torch import Tensor
4
+ from torch.library import triton_op, wrap_triton
4
5
  from triton import language as tl
5
6
 
6
7
  from blksprs.ops.transpose import transpose
7
8
  from blksprs.utils.blksprs_tensor import BlksprsTensor
8
- from blksprs.utils.tools import get_triton_block_size, stride
9
+ from blksprs.utils.tools import stride, get_autotune_configs
9
10
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
10
- validate_sparsity, validate_sparsity_block_size, validate_triton_block_size, validate_dtype_float
11
+ validate_sparsity, validate_sparsity_block_size, validate_dtype_float
11
12
 
12
13
 
13
14
  def matmul(x: BlksprsTensor, sparsity_layout_x: Tensor,
14
15
  y: BlksprsTensor, sparsity_layout_y: Tensor,
15
16
  sparsity_layout_output: Tensor,
16
- sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
17
+ sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
17
18
  """Performs matrix multiplication between two block-sparse tensors.
18
19
 
19
20
  The sparsity layout of the output tensor is used to only calculate blocks that will be present in the output.
@@ -25,7 +26,7 @@ def matmul(x: BlksprsTensor, sparsity_layout_x: Tensor,
25
26
  sparsity_layout_y (Tensor): The sparsity layout of the second block-sparse tensor.
26
27
  sparsity_layout_output (Tensor): The sparsity layout of the output tensor.
27
28
  sparsity_block_size (int): The size of the sparsity blocks.
28
- triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
29
+ lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
29
30
 
30
31
  Returns:
31
32
  BlksprsTensor: The result of the matrix multiplication as a block-sparse tensor in compressed form.
@@ -42,186 +43,219 @@ def matmul(x: BlksprsTensor, sparsity_layout_x: Tensor,
42
43
  if sparsity_layout_x.size(-1) != sparsity_layout_y.size(-2):
43
44
  raise ValueError("Inner dimensions of tensors must match")
44
45
  validate_sparsity_block_size(sparsity_block_size, x, y)
45
- validate_triton_block_size(triton_block_size, sparsity_block_size)
46
-
47
- sparsity_layout_x_flat = sparsity_layout_x.reshape(-1)
48
- sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
49
- (sparsity_layout_x_flat == 1) -
50
- (1 * (sparsity_layout_x_flat == 0)))
51
-
52
- sparsity_layout_y_flat = sparsity_layout_y.reshape(-1)
53
- sparsity_reverse_lut_y = ((torch.cumsum(sparsity_layout_y_flat, dim=-1) - 1) *
54
- (sparsity_layout_y_flat == 1) -
55
- (1 * (sparsity_layout_y_flat == 0)))
56
-
57
- sparsity_lut_o = torch.nonzero(sparsity_layout_output).contiguous()
58
-
59
- n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
60
-
61
- validate_contiguous(sparsity_layout_x, sparsity_reverse_lut_x,
62
- sparsity_layout_y, sparsity_reverse_lut_y,
63
- sparsity_layout_output, sparsity_lut_o)
64
-
65
- return BlksprsTensor(_BlocksparseMatmulSSS.apply(x, y,
66
- sparsity_layout_x, sparsity_reverse_lut_x,
67
- sparsity_layout_y, sparsity_reverse_lut_y,
68
- sparsity_layout_output, sparsity_lut_o,
69
- sparsity_block_size,
70
- n_sparse_blocks,
71
- triton_block_size))
72
-
73
-
74
- class _BlocksparseMatmulSSS(torch.autograd.Function):
75
-
76
- @staticmethod
77
- def forward(ctx, x: Tensor, y: Tensor,
78
- sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
79
- sparsity_layout_y: Tensor, sparsity_reverse_lut_y: Tensor,
80
- sparsity_layout_o: Tensor, sparsity_lut_o: Tensor,
81
- sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
82
- output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
83
- dtype=x.dtype, device=x.device)
84
-
85
- x_b, x_r, x_c = x.size()
86
- x_b_s, x_r_s, x_c_s = stride(x)
87
- s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
88
- s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
89
- y_b, y_r, y_c = y.size()
90
- y_b_s, y_r_s, y_c_s = stride(y)
91
- s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_y.size()
92
- s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = stride(sparsity_layout_y)
93
- o_b, o_r, o_c = output.size()
94
- o_b_s, o_r_s, o_c_s = stride(output)
95
- s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
96
- s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_o)
97
-
98
- if triton_block_size is None:
99
- triton_block_size = get_triton_block_size(sparsity_block_size)
100
-
101
- triton_grid = lambda meta: [o_b,
102
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
103
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
104
-
105
- (_BlocksparseMatmulSSS.kernel_blocksparse_matmul_sss[triton_grid]
106
- (x,
107
- x_b, x_b_s, x_r_s, x_c_s,
108
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c, s_l_x_c_s,
109
- sparsity_reverse_lut_x,
110
- y,
111
- y_b, y_b_s, y_r_s, y_c_s,
112
- s_l_y_b, s_l_y_b_s, s_l_y_r_s, s_l_y_c_s,
113
- sparsity_reverse_lut_y,
114
- output,
115
- o_b, o_b_s, o_r_s, o_c_s,
116
- sparsity_lut_o,
117
- s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
118
- sparsity_block_size,
119
- triton_block_size))
120
-
121
- ctx.save_for_backward(x, sparsity_layout_x, y, sparsity_layout_y, sparsity_layout_o)
122
- ctx.sparsity_block_size = sparsity_block_size
123
- ctx.triton_block_size = triton_block_size
124
-
125
- return output
126
-
127
- @staticmethod
128
- def backward(ctx, grad_output):
129
- x, sparsity_layout_x, y, sparsity_layout_y, sparsity_layout_o = ctx.saved_tensors
130
- sparsity_block_size = ctx.sparsity_block_size
131
- triton_block_size = ctx.triton_block_size
132
-
133
- x_t, sparsity_layout_x_t = transpose(x, sparsity_layout_x, sparsity_block_size, triton_block_size)
134
- y_t, sparsity_layout_y_t = transpose(y, sparsity_layout_y, sparsity_block_size, triton_block_size)
135
-
136
- grad_x = matmul(grad_output, sparsity_layout_o, y_t, sparsity_layout_y_t, sparsity_layout_x,
137
- sparsity_block_size, triton_block_size)
138
- grad_y = matmul(x_t, sparsity_layout_x_t, grad_output, sparsity_layout_o, sparsity_layout_y,
139
- sparsity_block_size, triton_block_size)
140
-
141
- return grad_x, grad_y, None, None, None, None, None, None, None, None, None
142
-
143
- @staticmethod
144
- @triton.jit
145
- def kernel_blocksparse_matmul_sss(x,
146
- x_b, x_b_s, x_r_s, x_c_s,
147
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c, s_l_x_c_s,
148
- r_lut_x,
149
- y,
150
- y_b, y_b_s, y_r_s, y_c_s,
151
- s_l_y_b, s_l_y_b_s, s_l_y_r_s, s_l_y_c_s,
152
- r_lut_y,
153
- o,
154
- o_b, o_b_s, o_r_s, o_c_s,
155
- s_lut_o,
156
- s_lut_o_r, s_lut_o_r_s,
157
- s_lut_o_c_s,
158
- sparsity_block_size,
159
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
160
- # Get triton block indices
161
- pid_blk = tl.program_id(axis=0)
162
- pid_row = tl.program_id(axis=1)
163
- pid_col = tl.program_id(axis=2)
164
-
165
- # Get position of current sparsity block consisting of its batch, row, and column index
166
- spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
167
- spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
168
- spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
169
-
170
- spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
171
- spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
172
- spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
173
-
174
- spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
175
- spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
176
- spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
177
-
178
- # Setup buffer
179
- buf = tl.zeros(shape=(TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), dtype=tl.float32)
180
-
181
- # Slide over triton block sized segments of input tensors
182
- for i_seg_tri in range(0, tl.cdiv(s_l_x_c * sparsity_block_size, TRITON_BLOCK_SIZE)):
183
- # Convert to segment index of sparsity layout
184
- i_seg_spa = (i_seg_tri * TRITON_BLOCK_SIZE) // sparsity_block_size
185
- # Calculate the triton segment index within a block
186
- i_seg_tri_mod = i_seg_tri % (sparsity_block_size // TRITON_BLOCK_SIZE)
187
-
188
- # Get reverse sparsity indices for input tensors x and y
189
- # These are either -1 if the block is empty or equal to the index of the block in the sparse tensor
190
-
191
- # Get reverse sparsity indices for x
192
- rev_idx_spa_x_idx = (spa_bat_o * s_l_x_b_s +
193
- spa_row_o * s_l_x_r_s +
194
- i_seg_spa * s_l_x_c_s)
195
- rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
196
- rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
197
-
198
- # Get reverse sparsity indices for y
199
- rev_idx_spa_y_idx = (spa_bat_o * s_l_y_b_s + i_seg_spa * s_l_y_r_s + spa_col_o * s_l_y_c_s)
200
- rev_idx_spa_y_msk = (rev_idx_spa_y_idx >= 0 and rev_idx_spa_y_idx < s_l_y_b * s_l_y_b_s)
201
- rev_idx_spa_y = tl.load(r_lut_y + rev_idx_spa_y_idx, mask=rev_idx_spa_y_msk).to(tl.int32)
202
-
203
- # If both blocks are present commence calculation
204
- if rev_idx_spa_x >= 0 and rev_idx_spa_y >= 0:
205
- blk_x_idx = ((rev_idx_spa_x * x_b_s) +
206
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
207
- ((i_seg_tri_mod * TRITON_BLOCK_SIZE +
208
- tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
209
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
210
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
211
-
212
- blk_y_idx = ((rev_idx_spa_y * y_b_s) +
213
- ((i_seg_tri_mod * TRITON_BLOCK_SIZE +
214
- tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
215
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
216
- blk_y_msk = (blk_y_idx >= 0 and blk_y_idx < y_b * y_b_s)
217
- blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
218
-
219
- # Perform matrix multiplication
220
- buf += tl.dot(blk_x, blk_y, input_precision="tf32")
221
-
222
- # Store output
223
- blk_o_idx = ((pid_blk * o_b_s) +
224
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
225
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
226
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
227
- tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
46
+
47
+ lut = matmul_build_lut(lut, sparsity_layout_x, sparsity_layout_y, sparsity_layout_output)
48
+
49
+ return BlksprsTensor(matmul_forward(x, y,
50
+ sparsity_layout_x, lut["sparsity_reverse_lut_x"],
51
+ sparsity_layout_y, lut["sparsity_reverse_lut_y"],
52
+ sparsity_layout_output, lut["sparsity_lut_o"],
53
+ sparsity_block_size, lut["n_sparse_blocks"]))
54
+
55
+
56
+ @triton_op("blksprs::matmul", mutates_args={})
57
+ def matmul_forward(x: Tensor, y: Tensor,
58
+ sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
59
+ sparsity_layout_y: Tensor, sparsity_reverse_lut_y: Tensor,
60
+ _: Tensor, sparsity_lut_o: Tensor,
61
+ sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
62
+ output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
63
+ dtype=x.dtype, device=x.device)
64
+
65
+ x_b, x_r, x_c = x.size()
66
+ x_b_s, x_r_s, x_c_s = stride(x)
67
+ s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_x.size()
68
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_x)
69
+ y_b, y_r, y_c = y.size()
70
+ y_b_s, y_r_s, y_c_s = stride(y)
71
+ s_l_y_b, s_l_y_r, s_l_y_c = sparsity_layout_y.size()
72
+ s_l_y_b_s, s_l_y_r_s, s_l_y_c_s = stride(sparsity_layout_y)
73
+ o_b, o_r, o_c = output.size()
74
+ o_b_s, o_r_s, o_c_s = stride(output)
75
+ s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
76
+ s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_o)
77
+
78
+ triton_grid = lambda meta: [o_b,
79
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
80
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
81
+
82
+ (wrap_triton(matmul_kernel)[triton_grid]
83
+ (x,
84
+ x_b, x_b_s, x_r_s, x_c_s,
85
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s,
86
+ s_l_x_c, s_l_x_c_s,
87
+ sparsity_reverse_lut_x,
88
+ y,
89
+ y_b, y_b_s, y_r_s, y_c_s,
90
+ s_l_y_b, s_l_y_b_s, s_l_y_r_s,
91
+ s_l_y_c_s,
92
+ sparsity_reverse_lut_y,
93
+ output,
94
+ o_b, o_b_s, o_r_s, o_c_s,
95
+ sparsity_lut_o,
96
+ s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
97
+ sparsity_block_size))
98
+
99
+ return output
100
+
101
+
102
+ def matmul_backward(ctx, grad_output):
103
+ x, sparsity_layout_x, y, sparsity_layout_y, sparsity_layout_o = ctx.saved_tensors
104
+ sparsity_block_size = ctx.sparsity_block_size
105
+
106
+ x_t, sparsity_layout_x_t = transpose(x, sparsity_layout_x, sparsity_block_size)
107
+ y_t, sparsity_layout_y_t = transpose(y, sparsity_layout_y, sparsity_block_size)
108
+
109
+ grad_x = matmul(grad_output, sparsity_layout_o, y_t, sparsity_layout_y_t, sparsity_layout_x,
110
+ sparsity_block_size)
111
+ grad_y = matmul(x_t, sparsity_layout_x_t, grad_output, sparsity_layout_o, sparsity_layout_y,
112
+ sparsity_block_size)
113
+
114
+ return grad_x, grad_y, None, None, None, None, None, None, None, None
115
+
116
+
117
+ @triton.autotune(
118
+ configs=get_autotune_configs(),
119
+ key=[],
120
+ )
121
+ @triton.jit
122
+ def matmul_kernel(x,
123
+ x_b, x_b_s, x_r_s, x_c_s,
124
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c, s_l_x_c_s,
125
+ r_lut_x,
126
+ y,
127
+ y_b, y_b_s, y_r_s, y_c_s,
128
+ s_l_y_b, s_l_y_b_s, s_l_y_r_s, s_l_y_c_s,
129
+ r_lut_y,
130
+ o,
131
+ o_b, o_b_s, o_r_s, o_c_s,
132
+ s_lut_o,
133
+ s_lut_o_r, s_lut_o_r_s,
134
+ s_lut_o_c_s,
135
+ sparsity_block_size,
136
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
137
+ # Get triton block indices
138
+ pid_blk = tl.program_id(axis=0)
139
+ pid_row = tl.program_id(axis=1)
140
+ pid_col = tl.program_id(axis=2)
141
+
142
+ # Get valid triton block size
143
+ val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
144
+
145
+ # Get position of current sparsity block consisting of its batch, row, and column index
146
+ spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
147
+ spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
148
+ spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
149
+
150
+ spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
151
+ spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
152
+ spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
153
+
154
+ spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
155
+ spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
156
+ spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
157
+
158
+ # Setup buffer
159
+ buf = tl.zeros(shape=(TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), dtype=tl.float32)
160
+
161
+ # Slide over triton block sized segments of input tensors
162
+ for i_seg_tri in range(0, tl.cdiv(s_l_x_c * sparsity_block_size, val_tbs)):
163
+ # Convert to segment index of sparsity layout
164
+ i_seg_spa = (i_seg_tri * val_tbs) // sparsity_block_size
165
+ # Calculate the triton segment index within a block
166
+ i_seg_tri_mod = i_seg_tri % (sparsity_block_size // val_tbs)
167
+
168
+ # Get reverse sparsity indices for input tensors x and y
169
+ # These are either -1 if the block is empty or equal to the index of the block in the sparse tensor
170
+
171
+ # Get reverse sparsity indices for x
172
+ rev_idx_spa_x_idx = (spa_bat_o * s_l_x_b_s +
173
+ spa_row_o * s_l_x_r_s +
174
+ i_seg_spa * s_l_x_c_s)
175
+ rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
176
+ rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
177
+
178
+ # Get reverse sparsity indices for y
179
+ rev_idx_spa_y_idx = (spa_bat_o * s_l_y_b_s + i_seg_spa * s_l_y_r_s + spa_col_o * s_l_y_c_s)
180
+ rev_idx_spa_y_msk = (rev_idx_spa_y_idx >= 0 and rev_idx_spa_y_idx < s_l_y_b * s_l_y_b_s)
181
+ rev_idx_spa_y = tl.load(r_lut_y + rev_idx_spa_y_idx, mask=rev_idx_spa_y_msk).to(tl.int32)
182
+
183
+ # If both blocks are present commence calculation
184
+ if rev_idx_spa_x >= 0 and rev_idx_spa_y >= 0:
185
+ blk_x_idx = ((rev_idx_spa_x * x_b_s) +
186
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
187
+ ((i_seg_tri_mod * val_tbs +
188
+ tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
189
+ blk_x_msk = ((blk_x_idx >= 0 and
190
+ blk_x_idx < x_b * x_b_s) and
191
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
192
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
193
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
194
+
195
+ blk_y_idx = ((rev_idx_spa_y * y_b_s) +
196
+ ((i_seg_tri_mod * val_tbs +
197
+ tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
198
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
199
+ blk_y_msk = ((blk_y_idx >= 0 and
200
+ blk_y_idx < y_b * y_b_s) and
201
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
202
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
203
+ blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
204
+
205
+ # Perform matrix multiplication
206
+ buf += tl.dot(blk_x, blk_y)
207
+
208
+ # Store output
209
+ blk_o_idx = ((pid_blk * o_b_s) +
210
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
211
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
212
+ blk_o_msk = ((blk_o_idx >= 0 and
213
+ blk_o_idx < o_b * o_b_s) and
214
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
215
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
216
+ tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
217
+
218
+
219
+ def matmul_build_lut(lut: dict, sparsity_layout_x: Tensor, sparsity_layout_y: Tensor, sparsity_layout_output: Tensor):
220
+ if lut is None:
221
+ lut = dict()
222
+
223
+ if "sparsity_reverse_lut_x" not in lut:
224
+ sparsity_layout_x_flat = sparsity_layout_x.reshape(-1)
225
+ sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
226
+ (sparsity_layout_x_flat == 1) -
227
+ (1 * (sparsity_layout_x_flat == 0)))
228
+ lut["sparsity_reverse_lut_x"] = sparsity_reverse_lut_x
229
+
230
+ if "sparsity_reverse_lut_y" not in lut:
231
+ sparsity_layout_y_flat = sparsity_layout_y.reshape(-1)
232
+ sparsity_reverse_lut_y = ((torch.cumsum(sparsity_layout_y_flat, dim=-1) - 1) *
233
+ (sparsity_layout_y_flat == 1) -
234
+ (1 * (sparsity_layout_y_flat == 0)))
235
+ lut["sparsity_reverse_lut_y"] = sparsity_reverse_lut_y
236
+
237
+ if "sparsity_lut_o" not in lut:
238
+ sparsity_lut_o = torch.nonzero(sparsity_layout_output).contiguous()
239
+ lut["sparsity_lut_o"] = sparsity_lut_o
240
+
241
+ if "n_sparse_blocks" not in lut:
242
+ n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
243
+ lut["n_sparse_blocks"] = n_sparse_blocks
244
+
245
+ validate_contiguous(sparsity_layout_x, lut["sparsity_reverse_lut_x"],
246
+ sparsity_layout_y, lut["sparsity_reverse_lut_y"],
247
+ sparsity_layout_output, lut["sparsity_lut_o"])
248
+
249
+ return lut
250
+
251
+
252
+ # noinspection PyUnusedLocal
253
+ def matmul_setup_context(ctx, inputs, output):
254
+ (x, y, sparsity_layout_x, _, sparsity_layout_y, _,
255
+ sparsity_layout_o, _, sparsity_block_size, _) = inputs
256
+
257
+ ctx.save_for_backward(x, sparsity_layout_x, y, sparsity_layout_y, sparsity_layout_o)
258
+ ctx.sparsity_block_size = sparsity_block_size
259
+
260
+
261
+ matmul_forward.register_autograd(matmul_backward, setup_context=matmul_setup_context)
@@ -1,16 +1,18 @@
1
1
  import torch
2
2
  import triton
3
3
  from torch import Tensor
4
+ from torch._library import triton_op
5
+ from torch._library.triton import wrap_triton
4
6
  from triton import language as tl
5
7
 
6
8
  from blksprs.utils.blksprs_tensor import BlksprsTensor
7
- from blksprs.utils.tools import get_triton_block_size, stride
9
+ from blksprs.utils.tools import stride, get_autotune_configs
8
10
  from blksprs.utils.validation import validate_contiguous, validate_device, \
9
- validate_sparsity_block_size, validate_triton_block_size
11
+ validate_sparsity_block_size
10
12
 
11
13
 
12
14
  def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
13
- sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
15
+ sparsity_block_size: int) -> BlksprsTensor:
14
16
  """Performs a broadcast and subsequent addition of two dense tensors x and y. Returns a block-sparse tensor in
15
17
  compressed form.
16
18
 
@@ -19,7 +21,6 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
19
21
  y (Tensor): A dense input tensor.
20
22
  sparsity_layout_output (Tensor): The sparsity layout of the output tensor.
21
23
  sparsity_block_size (int): The size of the sparsity blocks.
22
- triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
23
24
 
24
25
  Returns:
25
26
  BlksprsTensor: The result of the operation as a block-sparse tensor in compressed form. Each element o(i, j) of the
@@ -34,7 +35,6 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
34
35
  if x.size(-1) != y.size(-1):
35
36
  raise ValueError("Dimensions of tensors must match")
36
37
  validate_sparsity_block_size(sparsity_block_size)
37
- validate_triton_block_size(triton_block_size, sparsity_block_size)
38
38
 
39
39
  sparsity_lut_o = torch.nonzero(sparsity_layout_output).contiguous()
40
40
 
@@ -42,6 +42,21 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
42
42
 
43
43
  validate_contiguous(sparsity_layout_output, sparsity_lut_o)
44
44
 
45
+ return BlksprsTensor(broadcast_add_forward(x, y, sparsity_lut_o, sparsity_block_size, n_sparse_blocks))
46
+
47
+
48
+ def broadcast_sub(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
49
+ sparsity_block_size: int) -> BlksprsTensor:
50
+ """Wrapper for ``broadcast_add`` with negated y.
51
+
52
+ """
53
+ return broadcast_add(x, torch.neg(y), sparsity_layout_output, sparsity_block_size)
54
+
55
+
56
+ @triton_op("blksprs::broadcast_add", mutates_args={})
57
+ def broadcast_add_forward(x: Tensor, y: Tensor,
58
+ sparsity_lut_o: Tensor,
59
+ sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
45
60
  output = torch.zeros(n_sparse_blocks, sparsity_block_size, sparsity_block_size, dtype=x.dtype, device=x.device)
46
61
 
47
62
  x_b, x_c = x.size()
@@ -53,14 +68,11 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
53
68
  s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
54
69
  s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_o)
55
70
 
56
- if triton_block_size is None:
57
- triton_block_size = get_triton_block_size(sparsity_block_size)
58
-
59
71
  triton_grid = lambda meta: [o_b,
60
72
  triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
61
73
  triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
62
74
 
63
- (kernel_broadcast_addition[triton_grid]
75
+ (wrap_triton(broadcast_add_kernel)[triton_grid]
64
76
  (x,
65
77
  x_b, x_b_s, x_c_s,
66
78
  y,
@@ -68,35 +80,34 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
68
80
  output,
69
81
  o_b, o_b_s, o_r_s, o_c_s,
70
82
  sparsity_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
71
- sparsity_block_size,
72
- triton_block_size))
83
+ sparsity_block_size))
73
84
 
74
85
  return BlksprsTensor(output)
75
86
 
76
87
 
77
- def broadcast_sub(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
78
- sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
79
- """Wrapper for ``broadcast_add`` with negated y.
80
-
81
- """
82
- return broadcast_add(x, torch.neg(y), sparsity_layout_output, sparsity_block_size, triton_block_size)
83
-
84
-
88
+ @triton.autotune(
89
+ configs=get_autotune_configs(),
90
+ key=[],
91
+ reset_to_zero=["o"]
92
+ )
85
93
  @triton.jit
86
- def kernel_broadcast_addition(x,
87
- x_b, x_b_s, x_c_s,
88
- y,
89
- y_b, y_b_s, y_c_s,
90
- o,
91
- o_b, o_b_s, o_r_s, o_c_s,
92
- s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
93
- sparsity_block_size,
94
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
94
+ def broadcast_add_kernel(x,
95
+ x_b, x_b_s, x_c_s,
96
+ y,
97
+ y_b, y_b_s, y_c_s,
98
+ o,
99
+ o_b, o_b_s, o_r_s, o_c_s,
100
+ s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
101
+ sparsity_block_size,
102
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
95
103
  # Get triton block indices
96
104
  pid_blk = tl.program_id(axis=0)
97
105
  pid_row = tl.program_id(axis=1)
98
106
  pid_col = tl.program_id(axis=2)
99
107
 
108
+ # Get valid triton block size
109
+ val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
110
+
100
111
  # Get position of current sparsity block consisting of its batch, row, and column index
101
112
  spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
102
113
  spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
@@ -112,16 +123,20 @@ def kernel_broadcast_addition(x,
112
123
 
113
124
  # Load x block
114
125
  blk_x_idx = (spa_bat_o * x_b_s +
115
- ((spa_row_o * sparsity_block_size + pid_row * TRITON_BLOCK_SIZE +
126
+ ((pid_row * val_tbs + spa_row_o * sparsity_block_size +
116
127
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
117
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
128
+ blk_x_msk = ((blk_x_idx >= 0 and
129
+ blk_x_idx < x_b * x_b_s) and
130
+ (tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
118
131
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
119
132
 
120
133
  # Load y block
121
134
  blk_y_idx = (spa_bat_o * y_b_s +
122
- ((spa_col_o * sparsity_block_size + pid_col * TRITON_BLOCK_SIZE +
135
+ ((pid_col * val_tbs + spa_col_o * sparsity_block_size +
123
136
  tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
124
- blk_y_msk = (blk_y_idx >= 0 and blk_y_idx < y_b * y_b_s)
137
+ blk_y_msk = ((blk_y_idx >= 0 and
138
+ blk_y_idx < y_b * y_b_s) and
139
+ (tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
125
140
  blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
126
141
 
127
142
  # Compute sum
@@ -130,7 +145,10 @@ def kernel_broadcast_addition(x,
130
145
 
131
146
  # Store result
132
147
  blk_o_idx = ((pid_blk * o_b_s) +
133
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
134
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
135
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
148
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
149
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
150
+ blk_o_msk = ((blk_o_idx >= 0 and
151
+ blk_o_idx < o_b * o_b_s) and
152
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
153
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
136
154
  tl.store(o + blk_o_idx, buf, mask=blk_o_msk)