blksprs 1.11__py3-none-any.whl → 2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/ops/matmul.py CHANGED
@@ -1,19 +1,22 @@
1
1
  import torch
2
2
  import triton
3
3
  from torch import Tensor
4
+ from torch.library import triton_op, wrap_triton
4
5
  from triton import language as tl
5
6
 
6
7
  from blksprs.ops.transpose import transpose
7
8
  from blksprs.utils.blksprs_tensor import BlksprsTensor
8
- from blksprs.utils.tools import get_triton_block_size, stride
9
+ from blksprs.utils.tools import stride
10
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
9
11
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
10
- validate_sparsity, validate_sparsity_block_size, validate_triton_block_size, validate_dtype_float
12
+ validate_sparsity, validate_sparsity_block_size, validate_dtype_float
11
13
 
12
14
 
15
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
13
16
  def matmul(x: BlksprsTensor, sparsity_layout_x: Tensor,
14
17
  y: BlksprsTensor, sparsity_layout_y: Tensor,
15
18
  sparsity_layout_output: Tensor,
16
- sparsity_block_size: int, triton_block_size: int = None, lut: dict = None) -> BlksprsTensor:
19
+ sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
17
20
  """Performs matrix multiplication between two block-sparse tensors.
18
21
 
19
22
  The sparsity layout of the output tensor is used to only calculate blocks that will be present in the output.
@@ -25,7 +28,6 @@ def matmul(x: BlksprsTensor, sparsity_layout_x: Tensor,
25
28
  sparsity_layout_y (Tensor): The sparsity layout of the second block-sparse tensor.
26
29
  sparsity_layout_output (Tensor): The sparsity layout of the output tensor.
27
30
  sparsity_block_size (int): The size of the sparsity blocks.
28
- triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
29
31
  lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
30
32
 
31
33
  Returns:
@@ -43,61 +45,24 @@ def matmul(x: BlksprsTensor, sparsity_layout_x: Tensor,
43
45
  if sparsity_layout_x.size(-1) != sparsity_layout_y.size(-2):
44
46
  raise ValueError("Inner dimensions of tensors must match")
45
47
  validate_sparsity_block_size(sparsity_block_size, x, y)
46
- validate_triton_block_size(triton_block_size, sparsity_block_size)
47
-
48
- lut = _BlocksparseMatmulSSS.build_lut(lut, sparsity_layout_x, sparsity_layout_y, sparsity_layout_output)
49
-
50
- return BlksprsTensor(_BlocksparseMatmulSSS.apply(x, y,
51
- sparsity_layout_x, lut["sparsity_reverse_lut_x"],
52
- sparsity_layout_y, lut["sparsity_reverse_lut_y"],
53
- sparsity_layout_output, lut["sparsity_lut_o"],
54
- sparsity_block_size,
55
- lut["n_sparse_blocks"],
56
- triton_block_size))
57
-
58
-
59
- class _BlocksparseMatmulSSS(torch.autograd.Function):
60
-
61
- @staticmethod
62
- def build_lut(lut: dict, sparsity_layout_x: Tensor, sparsity_layout_y: Tensor, sparsity_layout_output: Tensor):
63
- if lut is None:
64
- lut = dict()
65
-
66
- if "sparsity_reverse_lut_x" not in lut:
67
- sparsity_layout_x_flat = sparsity_layout_x.reshape(-1)
68
- sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
69
- (sparsity_layout_x_flat == 1) -
70
- (1 * (sparsity_layout_x_flat == 0)))
71
- lut["sparsity_reverse_lut_x"] = sparsity_reverse_lut_x
72
-
73
- if "sparsity_reverse_lut_y" not in lut:
74
- sparsity_layout_y_flat = sparsity_layout_y.reshape(-1)
75
- sparsity_reverse_lut_y = ((torch.cumsum(sparsity_layout_y_flat, dim=-1) - 1) *
76
- (sparsity_layout_y_flat == 1) -
77
- (1 * (sparsity_layout_y_flat == 0)))
78
- lut["sparsity_reverse_lut_y"] = sparsity_reverse_lut_y
79
-
80
- if "sparsity_lut_o" not in lut:
81
- sparsity_lut_o = torch.nonzero(sparsity_layout_output).contiguous()
82
- lut["sparsity_lut_o"] = sparsity_lut_o
83
-
84
- if "n_sparse_blocks" not in lut:
85
- n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
86
- lut["n_sparse_blocks"] = n_sparse_blocks
87
-
88
- validate_contiguous(sparsity_layout_x, lut["sparsity_reverse_lut_x"],
89
- sparsity_layout_y, lut["sparsity_reverse_lut_y"],
90
- sparsity_layout_output, lut["sparsity_lut_o"])
91
-
92
- return lut
93
-
94
- @staticmethod
95
- def forward(ctx, x: Tensor, y: Tensor,
96
- sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
97
- sparsity_layout_y: Tensor, sparsity_reverse_lut_y: Tensor,
98
- sparsity_layout_o: Tensor, sparsity_lut_o: Tensor,
99
- sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
100
- output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
48
+
49
+ lut = matmul_build_lut(lut, sparsity_layout_x, sparsity_layout_y, sparsity_layout_output)
50
+
51
+ return BlksprsTensor(matmul_forward(x, y,
52
+ sparsity_layout_x, lut["sparsity_reverse_lut_x"],
53
+ sparsity_layout_y, lut["sparsity_reverse_lut_y"],
54
+ sparsity_layout_output, lut["sparsity_lut_o"],
55
+ sparsity_block_size, lut["n_sparse_blocks"]))
56
+
57
+
58
+ @triton_op("blksprs::matmul_forward", mutates_args={})
59
+ def matmul_forward(x: Tensor, y: Tensor,
60
+ sparsity_layout_x: Tensor, sparsity_reverse_lut_x: Tensor,
61
+ sparsity_layout_y: Tensor, sparsity_reverse_lut_y: Tensor,
62
+ _: Tensor, sparsity_lut_o: Tensor,
63
+ sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
64
+ with torch.no_grad():
65
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
101
66
  dtype=x.dtype, device=x.device)
102
67
 
103
68
  x_b, x_r, x_c = x.size()
@@ -113,133 +78,183 @@ class _BlocksparseMatmulSSS(torch.autograd.Function):
113
78
  s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
114
79
  s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_o)
115
80
 
116
- if triton_block_size is None:
117
- triton_block_size = get_triton_block_size(sparsity_block_size)
118
-
119
81
  triton_grid = lambda meta: [o_b,
120
82
  triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
121
83
  triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
122
84
 
123
- (_BlocksparseMatmulSSS.kernel_blocksparse_matmul_sss[triton_grid]
85
+ (wrap_triton(matmul_kernel)[triton_grid]
124
86
  (x,
125
87
  x_b, x_b_s, x_r_s, x_c_s,
126
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c, s_l_x_c_s,
88
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s,
89
+ s_l_x_c, s_l_x_c_s,
127
90
  sparsity_reverse_lut_x,
128
91
  y,
129
92
  y_b, y_b_s, y_r_s, y_c_s,
130
- s_l_y_b, s_l_y_b_s, s_l_y_r_s, s_l_y_c_s,
93
+ s_l_y_b, s_l_y_b_s, s_l_y_r_s,
94
+ s_l_y_c_s,
131
95
  sparsity_reverse_lut_y,
132
96
  output,
133
97
  o_b, o_b_s, o_r_s, o_c_s,
134
98
  sparsity_lut_o,
135
99
  s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
136
- sparsity_block_size,
137
- triton_block_size))
138
-
139
- ctx.save_for_backward(x, sparsity_layout_x, y, sparsity_layout_y, sparsity_layout_o)
140
- ctx.sparsity_block_size = sparsity_block_size
141
- ctx.triton_block_size = triton_block_size
100
+ sparsity_block_size))
142
101
 
143
102
  return output
144
103
 
145
- @staticmethod
146
- def backward(ctx, grad_output):
147
- x, sparsity_layout_x, y, sparsity_layout_y, sparsity_layout_o = ctx.saved_tensors
148
- sparsity_block_size = ctx.sparsity_block_size
149
- triton_block_size = ctx.triton_block_size
150
-
151
- x_t, sparsity_layout_x_t = transpose(x, sparsity_layout_x, sparsity_block_size, triton_block_size)
152
- y_t, sparsity_layout_y_t = transpose(y, sparsity_layout_y, sparsity_block_size, triton_block_size)
153
-
154
- grad_x = matmul(grad_output, sparsity_layout_o, y_t, sparsity_layout_y_t, sparsity_layout_x,
155
- sparsity_block_size, triton_block_size)
156
- grad_y = matmul(x_t, sparsity_layout_x_t, grad_output, sparsity_layout_o, sparsity_layout_y,
157
- sparsity_block_size, triton_block_size)
158
-
159
- return grad_x, grad_y, None, None, None, None, None, None, None, None, None
160
-
161
- @staticmethod
162
- @triton.jit
163
- def kernel_blocksparse_matmul_sss(x,
164
- x_b, x_b_s, x_r_s, x_c_s,
165
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c, s_l_x_c_s,
166
- r_lut_x,
167
- y,
168
- y_b, y_b_s, y_r_s, y_c_s,
169
- s_l_y_b, s_l_y_b_s, s_l_y_r_s, s_l_y_c_s,
170
- r_lut_y,
171
- o,
172
- o_b, o_b_s, o_r_s, o_c_s,
173
- s_lut_o,
174
- s_lut_o_r, s_lut_o_r_s,
175
- s_lut_o_c_s,
176
- sparsity_block_size,
177
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
178
- # Get triton block indices
179
- pid_blk = tl.program_id(axis=0)
180
- pid_row = tl.program_id(axis=1)
181
- pid_col = tl.program_id(axis=2)
182
-
183
- # Get position of current sparsity block consisting of its batch, row, and column index
184
- spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
185
- spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
186
- spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
187
-
188
- spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
189
- spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
190
- spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
191
-
192
- spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
193
- spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
194
- spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
195
-
196
- # Setup buffer
197
- buf = tl.zeros(shape=(TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), dtype=tl.float32)
198
-
199
- # Slide over triton block sized segments of input tensors
200
- for i_seg_tri in range(0, tl.cdiv(s_l_x_c * sparsity_block_size, TRITON_BLOCK_SIZE)):
201
- # Convert to segment index of sparsity layout
202
- i_seg_spa = (i_seg_tri * TRITON_BLOCK_SIZE) // sparsity_block_size
203
- # Calculate the triton segment index within a block
204
- i_seg_tri_mod = i_seg_tri % (sparsity_block_size // TRITON_BLOCK_SIZE)
205
-
206
- # Get reverse sparsity indices for input tensors x and y
207
- # These are either -1 if the block is empty or equal to the index of the block in the sparse tensor
208
-
209
- # Get reverse sparsity indices for x
210
- rev_idx_spa_x_idx = (spa_bat_o * s_l_x_b_s +
211
- spa_row_o * s_l_x_r_s +
212
- i_seg_spa * s_l_x_c_s)
213
- rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
214
- rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
215
-
216
- # Get reverse sparsity indices for y
217
- rev_idx_spa_y_idx = (spa_bat_o * s_l_y_b_s + i_seg_spa * s_l_y_r_s + spa_col_o * s_l_y_c_s)
218
- rev_idx_spa_y_msk = (rev_idx_spa_y_idx >= 0 and rev_idx_spa_y_idx < s_l_y_b * s_l_y_b_s)
219
- rev_idx_spa_y = tl.load(r_lut_y + rev_idx_spa_y_idx, mask=rev_idx_spa_y_msk).to(tl.int32)
220
-
221
- # If both blocks are present commence calculation
222
- if rev_idx_spa_x >= 0 and rev_idx_spa_y >= 0:
223
- blk_x_idx = ((rev_idx_spa_x * x_b_s) +
224
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
225
- ((i_seg_tri_mod * TRITON_BLOCK_SIZE +
226
- tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
227
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
228
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
229
-
230
- blk_y_idx = ((rev_idx_spa_y * y_b_s) +
231
- ((i_seg_tri_mod * TRITON_BLOCK_SIZE +
232
- tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
233
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
234
- blk_y_msk = (blk_y_idx >= 0 and blk_y_idx < y_b * y_b_s)
235
- blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
236
-
237
- # Perform matrix multiplication
238
- buf += tl.dot(blk_x, blk_y, input_precision="tf32")
239
-
240
- # Store output
241
- blk_o_idx = ((pid_blk * o_b_s) +
242
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
243
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
244
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
245
- tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
104
+
105
+ def matmul_wrapper_backward(ctx, grad_output):
106
+ x, sparsity_layout_x, y, sparsity_layout_y, sparsity_layout_o = ctx.saved_tensors
107
+ sparsity_block_size = ctx.sparsity_block_size
108
+
109
+ x_t, sparsity_layout_x_t = transpose(x, sparsity_layout_x, sparsity_block_size)
110
+ y_t, sparsity_layout_y_t = transpose(y, sparsity_layout_y, sparsity_block_size)
111
+
112
+ grad_x = matmul(grad_output, sparsity_layout_o, y_t, sparsity_layout_y_t, sparsity_layout_x,
113
+ sparsity_block_size)
114
+ grad_y = matmul(x_t, sparsity_layout_x_t, grad_output, sparsity_layout_o, sparsity_layout_y,
115
+ sparsity_block_size)
116
+
117
+ return grad_x, grad_y, None, None, None, None, None, None, None, None
118
+
119
+
120
+ @triton.autotune(
121
+ configs=get_autotune_configs(),
122
+ key=["sparsity_block_size"],
123
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
124
+ reset_to_zero=["o"]
125
+ )
126
+ @triton.jit
127
+ def matmul_kernel(x,
128
+ x_b, x_b_s, x_r_s, x_c_s,
129
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c, s_l_x_c_s,
130
+ r_lut_x,
131
+ y,
132
+ y_b, y_b_s, y_r_s, y_c_s,
133
+ s_l_y_b, s_l_y_b_s, s_l_y_r_s, s_l_y_c_s,
134
+ r_lut_y,
135
+ o,
136
+ o_b, o_b_s, o_r_s, o_c_s,
137
+ s_lut_o,
138
+ s_lut_o_r, s_lut_o_r_s,
139
+ s_lut_o_c_s,
140
+ sparsity_block_size,
141
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
142
+ # Get triton block indices
143
+ pid_blk = tl.program_id(axis=0)
144
+ pid_row = tl.program_id(axis=1)
145
+ pid_col = tl.program_id(axis=2)
146
+
147
+ # Get position of current sparsity block consisting of its batch, row, and column index
148
+ spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
149
+ spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
150
+ spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
151
+
152
+ spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
153
+ spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
154
+ spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
155
+
156
+ spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
157
+ spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
158
+ spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
159
+
160
+ # Setup buffer
161
+ buf = tl.zeros(shape=(TRITON_BLOCK_SIZE, TRITON_BLOCK_SIZE), dtype=tl.float32)
162
+
163
+ # Slide over triton block sized segments of input tensors
164
+ for i_seg_tri in range(0, tl.cdiv(s_l_x_c * sparsity_block_size, TRITON_BLOCK_SIZE)):
165
+ # Convert to segment index of sparsity layout
166
+ i_seg_spa = (i_seg_tri * TRITON_BLOCK_SIZE) // sparsity_block_size
167
+ # Calculate the triton segment index within a block
168
+ i_seg_tri_mod = i_seg_tri % (sparsity_block_size // TRITON_BLOCK_SIZE)
169
+
170
+ # Get reverse sparsity indices for input tensors x and y
171
+ # These are either -1 if the block is empty or equal to the index of the block in the sparse tensor
172
+
173
+ # Get reverse sparsity indices for x
174
+ rev_idx_spa_x_idx = (spa_bat_o * s_l_x_b_s +
175
+ spa_row_o * s_l_x_r_s +
176
+ i_seg_spa * s_l_x_c_s)
177
+ rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
178
+ rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
179
+
180
+ # Get reverse sparsity indices for y
181
+ rev_idx_spa_y_idx = (spa_bat_o * s_l_y_b_s + i_seg_spa * s_l_y_r_s + spa_col_o * s_l_y_c_s)
182
+ rev_idx_spa_y_msk = (rev_idx_spa_y_idx >= 0 and rev_idx_spa_y_idx < s_l_y_b * s_l_y_b_s)
183
+ rev_idx_spa_y = tl.load(r_lut_y + rev_idx_spa_y_idx, mask=rev_idx_spa_y_msk).to(tl.int32)
184
+
185
+ # If both blocks are present commence calculation
186
+ if rev_idx_spa_x >= 0 and rev_idx_spa_y >= 0:
187
+ blk_x_idx = ((rev_idx_spa_x * x_b_s) +
188
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
189
+ ((i_seg_tri_mod * TRITON_BLOCK_SIZE +
190
+ tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
191
+ blk_x_msk = (blk_x_idx >= 0 and
192
+ blk_x_idx < x_b * x_b_s)
193
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
194
+
195
+ blk_y_idx = ((rev_idx_spa_y * y_b_s) +
196
+ ((i_seg_tri_mod * TRITON_BLOCK_SIZE +
197
+ tl.arange(0, TRITON_BLOCK_SIZE)) * y_r_s)[:, None] +
198
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
199
+ blk_y_msk = (blk_y_idx >= 0 and
200
+ blk_y_idx < y_b * y_b_s)
201
+ blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
202
+
203
+ # Perform matrix multiplication
204
+ buf += tl.dot(blk_x, blk_y)
205
+
206
+ # Cast buffer
207
+ buf = buf.to(o.dtype.element_ty)
208
+
209
+ # Store output
210
+ blk_o_idx = ((pid_blk * o_b_s) +
211
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
212
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
213
+ blk_o_msk = (blk_o_idx >= 0 and
214
+ blk_o_idx < o_b * o_b_s)
215
+ tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
216
+
217
+
218
+ def matmul_build_lut(lut: dict, sparsity_layout_x: Tensor, sparsity_layout_y: Tensor, sparsity_layout_output: Tensor):
219
+ if lut is None:
220
+ lut = dict()
221
+
222
+ if "sparsity_reverse_lut_x" not in lut:
223
+ sparsity_layout_x_flat = sparsity_layout_x.reshape(-1)
224
+ sparsity_reverse_lut_x = ((torch.cumsum(sparsity_layout_x_flat, dim=-1) - 1) *
225
+ (sparsity_layout_x_flat == 1) -
226
+ (1 * (sparsity_layout_x_flat == 0)))
227
+ lut["sparsity_reverse_lut_x"] = sparsity_reverse_lut_x
228
+
229
+ if "sparsity_reverse_lut_y" not in lut:
230
+ sparsity_layout_y_flat = sparsity_layout_y.reshape(-1)
231
+ sparsity_reverse_lut_y = ((torch.cumsum(sparsity_layout_y_flat, dim=-1) - 1) *
232
+ (sparsity_layout_y_flat == 1) -
233
+ (1 * (sparsity_layout_y_flat == 0)))
234
+ lut["sparsity_reverse_lut_y"] = sparsity_reverse_lut_y
235
+
236
+ if "sparsity_lut_o" not in lut:
237
+ sparsity_lut_o = torch.nonzero(sparsity_layout_output).contiguous()
238
+ lut["sparsity_lut_o"] = sparsity_lut_o
239
+
240
+ if "n_sparse_blocks" not in lut:
241
+ n_sparse_blocks = torch.sum(sparsity_layout_output.to(torch.int)).item()
242
+ lut["n_sparse_blocks"] = n_sparse_blocks
243
+
244
+ validate_contiguous(sparsity_layout_x, lut["sparsity_reverse_lut_x"],
245
+ sparsity_layout_y, lut["sparsity_reverse_lut_y"],
246
+ sparsity_layout_output, lut["sparsity_lut_o"])
247
+
248
+ return lut
249
+
250
+
251
+ # noinspection PyUnusedLocal
252
+ def matmul_setup_context(ctx, inputs, output):
253
+ (x, y, sparsity_layout_x, _, sparsity_layout_y, _,
254
+ sparsity_layout_o, _, sparsity_block_size, _) = inputs
255
+
256
+ ctx.save_for_backward(x, sparsity_layout_x, y, sparsity_layout_y, sparsity_layout_o)
257
+ ctx.sparsity_block_size = sparsity_block_size
258
+
259
+
260
+ matmul_forward.register_autograd(matmul_wrapper_backward, setup_context=matmul_setup_context)
@@ -1,16 +1,20 @@
1
1
  import torch
2
2
  import triton
3
3
  from torch import Tensor
4
+ from torch._library import triton_op
5
+ from torch._library.triton import wrap_triton
4
6
  from triton import language as tl
5
7
 
6
8
  from blksprs.utils.blksprs_tensor import BlksprsTensor
7
- from blksprs.utils.tools import get_triton_block_size, stride
9
+ from blksprs.utils.tools import stride
10
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
8
11
  from blksprs.utils.validation import validate_contiguous, validate_device, \
9
- validate_sparsity_block_size, validate_triton_block_size
12
+ validate_sparsity_block_size
10
13
 
11
14
 
15
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
12
16
  def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
13
- sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
17
+ sparsity_block_size: int) -> BlksprsTensor:
14
18
  """Performs a broadcast and subsequent addition of two dense tensors x and y. Returns a block-sparse tensor in
15
19
  compressed form.
16
20
 
@@ -19,7 +23,6 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
19
23
  y (Tensor): A dense input tensor.
20
24
  sparsity_layout_output (Tensor): The sparsity layout of the output tensor.
21
25
  sparsity_block_size (int): The size of the sparsity blocks.
22
- triton_block_size (int, optional): The block size to use for the triton kernel (default ``None``).
23
26
 
24
27
  Returns:
25
28
  BlksprsTensor: The result of the operation as a block-sparse tensor in compressed form. Each element o(i, j) of the
@@ -34,7 +37,6 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
34
37
  if x.size(-1) != y.size(-1):
35
38
  raise ValueError("Dimensions of tensors must match")
36
39
  validate_sparsity_block_size(sparsity_block_size)
37
- validate_triton_block_size(triton_block_size, sparsity_block_size)
38
40
 
39
41
  sparsity_lut_o = torch.nonzero(sparsity_layout_output).contiguous()
40
42
 
@@ -42,56 +44,66 @@ def broadcast_add(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
42
44
 
43
45
  validate_contiguous(sparsity_layout_output, sparsity_lut_o)
44
46
 
45
- output = torch.zeros(n_sparse_blocks, sparsity_block_size, sparsity_block_size, dtype=x.dtype, device=x.device)
46
-
47
- x_b, x_c = x.size()
48
- x_b_s, x_c_s = stride(x)
49
- y_b, y_c = y.size()
50
- y_b_s, y_c_s = stride(y)
51
- o_b, o_r, o_c = output.size()
52
- o_b_s, o_r_s, o_c_s = stride(output)
53
- s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
54
- s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_o)
55
-
56
- if triton_block_size is None:
57
- triton_block_size = get_triton_block_size(sparsity_block_size)
58
-
59
- triton_grid = lambda meta: [o_b,
60
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
61
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
62
-
63
- (kernel_broadcast_addition[triton_grid]
64
- (x,
65
- x_b, x_b_s, x_c_s,
66
- y,
67
- y_b, y_b_s, y_c_s,
68
- output,
69
- o_b, o_b_s, o_r_s, o_c_s,
70
- sparsity_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
71
- sparsity_block_size,
72
- triton_block_size))
73
-
74
- return BlksprsTensor(output)
47
+ return BlksprsTensor(broadcast_add_forward(x, y, sparsity_lut_o, sparsity_block_size, n_sparse_blocks))
75
48
 
76
49
 
77
50
  def broadcast_sub(x: Tensor, y: Tensor, sparsity_layout_output: Tensor,
78
- sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
51
+ sparsity_block_size: int) -> BlksprsTensor:
79
52
  """Wrapper for ``broadcast_add`` with negated y.
80
53
 
81
54
  """
82
- return broadcast_add(x, torch.neg(y), sparsity_layout_output, sparsity_block_size, triton_block_size)
83
-
84
-
55
+ return broadcast_add(x, torch.neg(y), sparsity_layout_output, sparsity_block_size)
56
+
57
+
58
+ @triton_op("blksprs::broadcast_add_forward", mutates_args={})
59
+ def broadcast_add_forward(x: Tensor, y: Tensor,
60
+ sparsity_lut_o: Tensor,
61
+ sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
62
+ with torch.no_grad():
63
+ output = torch.zeros(n_sparse_blocks, sparsity_block_size, sparsity_block_size, dtype=x.dtype, device=x.device)
64
+
65
+ x_b, x_c = x.size()
66
+ x_b_s, x_c_s = stride(x)
67
+ y_b, y_c = y.size()
68
+ y_b_s, y_c_s = stride(y)
69
+ o_b, o_r, o_c = output.size()
70
+ o_b_s, o_r_s, o_c_s = stride(output)
71
+ s_lut_o_r, s_lut_o_c = sparsity_lut_o.size()
72
+ s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_o)
73
+
74
+ triton_grid = lambda meta: [o_b,
75
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
76
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
77
+
78
+ (wrap_triton(broadcast_add_kernel)[triton_grid]
79
+ (x,
80
+ x_b, x_b_s, x_c_s,
81
+ y,
82
+ y_b, y_b_s, y_c_s,
83
+ output,
84
+ o_b, o_b_s, o_r_s, o_c_s,
85
+ sparsity_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
86
+ sparsity_block_size))
87
+
88
+ return output
89
+
90
+
91
+ @triton.autotune(
92
+ configs=get_autotune_configs(),
93
+ key=["sparsity_block_size"],
94
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
95
+ reset_to_zero=["o"]
96
+ )
85
97
  @triton.jit
86
- def kernel_broadcast_addition(x,
87
- x_b, x_b_s, x_c_s,
88
- y,
89
- y_b, y_b_s, y_c_s,
90
- o,
91
- o_b, o_b_s, o_r_s, o_c_s,
92
- s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
93
- sparsity_block_size,
94
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
98
+ def broadcast_add_kernel(x,
99
+ x_b, x_b_s, x_c_s,
100
+ y,
101
+ y_b, y_b_s, y_c_s,
102
+ o,
103
+ o_b, o_b_s, o_r_s, o_c_s,
104
+ s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
105
+ sparsity_block_size,
106
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
95
107
  # Get triton block indices
96
108
  pid_blk = tl.program_id(axis=0)
97
109
  pid_row = tl.program_id(axis=1)
@@ -112,16 +124,18 @@ def kernel_broadcast_addition(x,
112
124
 
113
125
  # Load x block
114
126
  blk_x_idx = (spa_bat_o * x_b_s +
115
- ((spa_row_o * sparsity_block_size + pid_row * TRITON_BLOCK_SIZE +
127
+ ((pid_row * TRITON_BLOCK_SIZE + spa_row_o * sparsity_block_size +
116
128
  tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
117
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
129
+ blk_x_msk = (blk_x_idx >= 0 and
130
+ blk_x_idx < x_b * x_b_s)
118
131
  blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
119
132
 
120
133
  # Load y block
121
134
  blk_y_idx = (spa_bat_o * y_b_s +
122
- ((spa_col_o * sparsity_block_size + pid_col * TRITON_BLOCK_SIZE +
135
+ ((pid_col * TRITON_BLOCK_SIZE + spa_col_o * sparsity_block_size +
123
136
  tl.arange(0, TRITON_BLOCK_SIZE)) * y_c_s)[None, :])
124
- blk_y_msk = (blk_y_idx >= 0 and blk_y_idx < y_b * y_b_s)
137
+ blk_y_msk = (blk_y_idx >= 0 and
138
+ blk_y_idx < y_b * y_b_s)
125
139
  blk_y = tl.load(y + blk_y_idx, mask=blk_y_msk)
126
140
 
127
141
  # Compute sum
@@ -132,5 +146,6 @@ def kernel_broadcast_addition(x,
132
146
  blk_o_idx = ((pid_blk * o_b_s) +
133
147
  ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
134
148
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
135
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
149
+ blk_o_msk = (blk_o_idx >= 0 and
150
+ blk_o_idx < o_b * o_b_s)
136
151
  tl.store(o + blk_o_idx, buf, mask=blk_o_msk)