blksprs 1.10.2__py3-none-any.whl → 2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/ops/conversion.py CHANGED
@@ -1,25 +1,181 @@
1
1
  import torch
2
2
  import triton
3
3
  from torch import Tensor
4
+ from torch._library.triton import wrap_triton, triton_op
4
5
  from triton import language as tl
5
6
 
6
7
  from blksprs.layouting.sparsity_layout import build_sparsity_layout_adaption
7
8
  from blksprs.utils.blksprs_tensor import BlksprsTensor
8
- from blksprs.utils.tools import get_triton_block_size, stride
9
+ from blksprs.utils.tools import stride
10
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs, prune_autotune_configs_conversion
9
11
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
10
- validate_sparsity, validate_sparsity_block_size, validate_triton_block_size, validate_sparsity_dense
12
+ validate_sparsity, validate_sparsity_block_size, validate_sparsity_dense
11
13
 
12
14
 
13
- def from_blksprs(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
14
- triton_block_size: int = None) -> Tensor:
15
+ def to_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int) -> BlksprsTensor:
16
+ """Wrapper for ``to_sparse``.
17
+
18
+ """
19
+ return to_sparse(x, sparsity_layout, sparsity_block_size)
20
+
21
+
22
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
23
+ def to_sparse(x: Tensor, sparsity_layout: Tensor,
24
+ sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
25
+ """Converts a block-sparse tensor in regular form to a block-sparse tensor in compressed form based on the given
26
+ sparsity layout.
27
+
28
+ Args:
29
+ x (Tensor): A block-sparse tensor in regular form.
30
+ sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
31
+ sparsity_block_size (int): The size of the sparsity blocks.
32
+ lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
33
+
34
+ Returns:
35
+ BlksprsTensor: The block-sparse tensor converted to compressed form.
36
+
37
+ """
38
+ x = x.contiguous()
39
+
40
+ validate_dimensions(x)
41
+ validate_contiguous(x)
42
+ validate_device(x)
43
+ validate_sparsity_dense(sparsity_block_size, (x, sparsity_layout))
44
+ validate_sparsity_block_size(sparsity_block_size, x)
45
+
46
+ lut = to_sparse_build_lut(lut, sparsity_layout)
47
+
48
+ if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
49
+ return BlksprsTensor(x)
50
+
51
+ return BlksprsTensor(to_sparse_forward(x, sparsity_layout,
52
+ lut["sparsity_lut"], sparsity_block_size, lut["n_sparse_blocks"]))
53
+
54
+
55
+ @triton_op("blksprs::to_sparse_forward", mutates_args={})
56
+ def to_sparse_forward(x: Tensor, _: Tensor,
57
+ sparsity_lut: Tensor, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
58
+ with torch.no_grad():
59
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
60
+ dtype=x.dtype, device=x.device)
61
+
62
+ x_b, x_r, x_c = x.size()
63
+ x_b_s, x_r_s, x_c_s = stride(x)
64
+ s_lut_r, s_lut_c = sparsity_lut.size()
65
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
66
+ o_b, o_r, o_c = output.size()
67
+ o_b_s, o_r_s, o_c_s = stride(output)
68
+
69
+ triton_grid = lambda meta: [o_b,
70
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
71
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
72
+
73
+ (wrap_triton(to_sparse_kernel)[triton_grid]
74
+ (x, x_b, x_b_s, x_r_s, x_c_s,
75
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
76
+ output, o_b_s, o_r_s, o_c_s,
77
+ sparsity_block_size))
78
+
79
+ return output
80
+
81
+
82
+ def to_sparse_wrapper_backward(ctx, grad_output):
83
+ sparsity_layout = ctx.saved_tensors[0]
84
+ sparsity_block_size = ctx.sparsity_block_size
85
+
86
+ return to_dense(grad_output, sparsity_layout, sparsity_block_size), None, None, None, None
87
+
88
+
89
+ @triton.autotune(
90
+ configs=get_autotune_configs(),
91
+ key=["sparsity_block_size"],
92
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
93
+ reset_to_zero=["o"]
94
+ )
95
+ @triton.jit
96
+ def to_sparse_kernel(x,
97
+ x_b, x_b_s, x_r_s, x_c_s,
98
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
99
+ o,
100
+ o_b_s, o_r_s, o_c_s,
101
+ sparsity_block_size,
102
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
103
+ # Get triton block indices
104
+ pid_blk = tl.program_id(axis=0)
105
+ pid_row = tl.program_id(axis=1)
106
+ pid_col = tl.program_id(axis=2)
107
+
108
+ # Get sparsity index of current output block consisting of its batch, row, and column index
109
+ spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
110
+ spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
111
+ spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
112
+
113
+ spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
114
+ spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
115
+ spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
116
+
117
+ spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
118
+ spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
119
+ spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
120
+
121
+ # Load block from dense tensor
122
+ blk_d_idx = (spa_bat * x_b_s +
123
+ ((pid_row * TRITON_BLOCK_SIZE + spa_row * sparsity_block_size +
124
+ tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
125
+ ((pid_col * TRITON_BLOCK_SIZE + spa_col * sparsity_block_size +
126
+ tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
127
+ blk_d_msk = (blk_d_idx >= 0 and
128
+ blk_d_idx < x_b * x_b_s)
129
+ blk_d = tl.load(x + blk_d_idx, mask=blk_d_msk)
130
+
131
+ # Store block in sparse tensor
132
+ blk_o_idx = ((pid_blk * o_b_s) +
133
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
134
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
135
+ blk_o_msk = (blk_o_idx >= 0 and
136
+ blk_o_idx < (pid_blk + 1) * o_b_s)
137
+ tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
138
+
139
+
140
+ def to_sparse_build_lut(lut: dict, sparsity_layout: Tensor):
141
+ if lut is None:
142
+ lut = dict()
143
+
144
+ if "sparsity_lut" not in lut:
145
+ sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
146
+ lut["sparsity_lut"] = sparsity_lut
147
+
148
+ if "n_sparse_blocks" not in lut:
149
+ n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
150
+ lut["n_sparse_blocks"] = n_sparse_blocks
151
+
152
+ validate_contiguous(sparsity_layout, lut["sparsity_lut"])
153
+
154
+ return lut
155
+
156
+
157
+ # noinspection PyUnusedLocal
158
+ def to_sparse_setup_context(ctx, inputs, output):
159
+ (_, sparsity_layout, _, sparsity_block_size, _) = inputs
160
+
161
+ ctx.save_for_backward(sparsity_layout, )
162
+ ctx.sparsity_block_size = sparsity_block_size
163
+
164
+
165
+ to_sparse_forward.register_autograd(to_sparse_wrapper_backward, setup_context=to_sparse_setup_context)
166
+
167
+
168
+ def from_blksprs(x: BlksprsTensor, sparsity_layout: Tensor,
169
+ sparsity_block_size: int, fill_value: float = 0, lut: dict = None) -> Tensor:
15
170
  """Wrapper for ``to_dense``.
16
171
 
17
172
  """
18
- return to_dense(x, sparsity_layout, sparsity_block_size, fill_value, triton_block_size)
173
+ return to_dense(x, sparsity_layout, sparsity_block_size, fill_value=fill_value, lut=lut)
19
174
 
20
175
 
21
- def to_dense(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
22
- triton_block_size: int = None) -> Tensor:
176
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
177
+ def to_dense(x: BlksprsTensor, sparsity_layout: Tensor,
178
+ sparsity_block_size: int, fill_value: float = 0, lut: dict = None) -> Tensor:
23
179
  """Converts a block-sparse tensor in compressed form to a block-sparse tensor in regular form based on the given
24
180
  sparsity layout.
25
181
 
@@ -29,7 +185,7 @@ def to_dense(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int
29
185
  sparsity_block_size (int): The size of the sparsity blocks.
30
186
  fill_value (float): The value to fill the resulting dense tensor with where the block-sparse tensor is not
31
187
  present (default ``0``).
32
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
188
+ lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
33
189
 
34
190
  Returns:
35
191
  Tensor: The block-sparse tensor converted to regular form.
@@ -42,31 +198,21 @@ def to_dense(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int
42
198
  validate_device(x)
43
199
  validate_sparsity(sparsity_block_size, (x, sparsity_layout))
44
200
  validate_sparsity_block_size(sparsity_block_size, x)
45
- validate_triton_block_size(triton_block_size, sparsity_block_size)
46
-
47
- sparsity_layout_flat = sparsity_layout.reshape(-1)
48
- sparsity_reverse_lut = ((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
49
- (sparsity_layout_flat == 1) -
50
- (1 * (sparsity_layout_flat == 0)))
51
201
 
52
- validate_contiguous(sparsity_reverse_lut)
202
+ lut = to_dense_build_lut(lut, sparsity_layout)
53
203
 
54
204
  if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
55
205
  return x
56
206
 
57
- return _BlocksparseToDense.apply(x,
58
- sparsity_layout, sparsity_reverse_lut,
59
- sparsity_block_size, fill_value,
60
- triton_block_size)
207
+ return Tensor(to_dense_forward(x, sparsity_layout,
208
+ lut["sparsity_reverse_lut"], sparsity_block_size, fill_value))
61
209
 
62
210
 
63
- class _BlocksparseToDense(torch.autograd.Function):
64
-
65
- @staticmethod
66
- def forward(ctx, x: Tensor,
67
- sparsity_layout: Tensor, sparsity_reverse_lut: Tensor,
68
- sparsity_block_size: int, fill_value: float,
69
- triton_block_size: int) -> Tensor:
211
+ @triton_op("blksprs::to_dense_forward", mutates_args={})
212
+ def to_dense_forward(x: Tensor, sparsity_layout: Tensor,
213
+ sparsity_reverse_lut: Tensor,
214
+ sparsity_block_size: int, fill_value: float) -> Tensor:
215
+ with torch.no_grad():
70
216
  output = torch.full(size=(sparsity_layout.size(0), sparsity_layout.size(1) * sparsity_block_size,
71
217
  sparsity_layout.size(2) * sparsity_block_size), fill_value=fill_value,
72
218
  dtype=x.dtype, device=x.device)
@@ -78,217 +224,106 @@ class _BlocksparseToDense(torch.autograd.Function):
78
224
  o_b, o_r, o_c = output.size()
79
225
  o_b_s, o_r_s, o_c_s = stride(output)
80
226
 
81
- if triton_block_size is None:
82
- triton_block_size = get_triton_block_size(sparsity_block_size)
83
-
84
227
  triton_grid = lambda meta: [o_b,
85
228
  triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
86
229
  triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
87
230
 
88
- (_BlocksparseToDense.kernel_blocksparse_to_dense[triton_grid]
231
+ (wrap_triton(to_dense_kernel)[triton_grid]
89
232
  (x,
90
233
  x_b, x_b_s, x_r_s, x_c_s,
91
234
  s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
92
235
  sparsity_reverse_lut,
93
236
  output,
94
237
  o_b, o_b_s, o_r_s, o_c_s,
95
- sparsity_block_size,
96
- triton_block_size))
97
-
98
- ctx.save_for_backward(sparsity_layout)
99
- ctx.sparsity_block_size = sparsity_block_size
100
- ctx.triton_block_size = triton_block_size
238
+ sparsity_block_size))
101
239
 
102
240
  return output
103
241
 
104
- @staticmethod
105
- def backward(ctx, grad_output):
106
- sparsity_layout = ctx.saved_tensors[0]
107
- sparsity_block_size = ctx.sparsity_block_size
108
- triton_block_size = ctx.triton_block_size
109
-
110
- return to_sparse(grad_output, sparsity_layout, sparsity_block_size,
111
- triton_block_size), None, None, None, None, None
112
-
113
- @staticmethod
114
- @triton.jit
115
- def kernel_blocksparse_to_dense(x,
116
- x_b, x_b_s, x_r_s, x_c_s,
117
- s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
118
- sparsity_reverse_lut,
119
- o,
120
- o_b, o_b_s, o_r_s, o_c_s,
121
- sparsity_block_size,
122
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
123
- # Get triton block indices
124
- pid_blk = tl.program_id(axis=0)
125
- pid_row = tl.program_id(axis=1)
126
- pid_col = tl.program_id(axis=2)
127
-
128
- # Get sparsity index of current block
129
- spa_row = (pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size
130
- spa_col = (pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size
131
-
132
- # Get reverse sparsity index for current block
133
- rev_idx_spa_idx = (pid_blk * s_l_b_s + spa_row * s_l_r_s + spa_col * s_l_c_s)
134
- rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_b * s_l_b_s)
135
- rev_idx_spa = tl.load(sparsity_reverse_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
136
-
137
- # If block is present commence operations
138
- if rev_idx_spa >= 0:
139
- blk_idx = (rev_idx_spa * x_b_s +
140
- (((pid_row % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
141
- tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
142
- (((pid_col % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
143
- tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
144
- blk_msk = (blk_idx >= 0 and blk_idx < x_b * x_b_s)
145
- blk = tl.load(x + blk_idx, mask=blk_msk)
146
-
147
- o_idx = (pid_blk * o_b_s +
148
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
149
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
150
- o_msk = (o_idx >= 0 and o_idx < o_b * o_b_s)
151
- tl.store(o + o_idx, blk, o_msk)
152
-
153
242
 
154
- def to_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
155
- triton_block_size: int = None) -> BlksprsTensor:
156
- """Wrapper for ``to_sparse``.
243
+ def to_dense_wrapper_backward(ctx, grad_output):
244
+ sparsity_layout = ctx.saved_tensors[0]
245
+ sparsity_block_size = ctx.sparsity_block_size
157
246
 
158
- """
159
- return to_sparse(x, sparsity_layout, sparsity_block_size, triton_block_size)
247
+ return to_sparse(grad_output, sparsity_layout, sparsity_block_size), None, None, None, None
160
248
 
161
249
 
162
- def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
163
- triton_block_size: int = None) -> BlksprsTensor:
164
- """Converts a block-sparse tensor in regular form to a block-sparse tensor in compressed form based on the given
165
- sparsity layout.
166
-
167
- Args:
168
- x (Tensor): A block-sparse tensor in regular form.
169
- sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
170
- sparsity_block_size (int): The size of the sparsity blocks.
171
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
172
-
173
- Returns:
174
- BlksprsTensor: The block-sparse tensor converted to compressed form.
175
-
176
- """
177
- x = x.contiguous()
178
-
179
- validate_dimensions(x)
180
- validate_contiguous(x)
181
- validate_device(x)
182
- validate_sparsity_dense(sparsity_block_size, (x, sparsity_layout))
183
- validate_sparsity_block_size(sparsity_block_size, x)
184
- validate_triton_block_size(triton_block_size, sparsity_block_size)
250
+ @triton.autotune(
251
+ configs=get_autotune_configs(),
252
+ key=["sparsity_block_size"],
253
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
254
+ restore_value=["o"]
255
+ )
256
+ @triton.jit
257
+ def to_dense_kernel(x,
258
+ x_b, x_b_s, x_r_s, x_c_s,
259
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
260
+ sparsity_reverse_lut,
261
+ o,
262
+ o_b, o_b_s, o_r_s, o_c_s,
263
+ sparsity_block_size,
264
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
265
+ # Get triton block indices
266
+ pid_blk = tl.program_id(axis=0)
267
+ pid_row = tl.program_id(axis=1)
268
+ pid_col = tl.program_id(axis=2)
185
269
 
186
- sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
187
- n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
270
+ # Get sparsity index of current block
271
+ spa_row = (pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size
272
+ spa_col = (pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size
188
273
 
189
- validate_contiguous(sparsity_layout, sparsity_lut)
274
+ # Get reverse sparsity index for current block
275
+ rev_idx_spa_idx = (pid_blk * s_l_b_s + spa_row * s_l_r_s + spa_col * s_l_c_s)
276
+ rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_b * s_l_b_s)
277
+ rev_idx_spa = tl.load(sparsity_reverse_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
190
278
 
191
- if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
192
- return BlksprsTensor(x)
279
+ # If block is present commence operations
280
+ if rev_idx_spa >= 0:
281
+ blk_idx = (rev_idx_spa * x_b_s +
282
+ (((pid_row % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
283
+ tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
284
+ (((pid_col % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
285
+ tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
286
+ blk_msk = (blk_idx >= 0 and
287
+ blk_idx < x_b * x_b_s)
288
+ blk = tl.load(x + blk_idx, mask=blk_msk)
193
289
 
194
- return BlksprsTensor(_BlocksparseToSparse.apply(x,
195
- sparsity_layout, sparsity_lut,
196
- sparsity_block_size, n_sparse_blocks,
197
- triton_block_size))
290
+ o_idx = (pid_blk * o_b_s +
291
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
292
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
293
+ o_msk = (o_idx >= 0 and o_idx < o_b * o_b_s)
294
+ tl.store(o + o_idx, blk, o_msk)
198
295
 
199
296
 
200
- class _BlocksparseToSparse(torch.autograd.Function):
297
+ def to_dense_build_lut(lut: dict, sparsity_layout: Tensor):
298
+ if lut is None:
299
+ lut = dict()
201
300
 
202
- @staticmethod
203
- def forward(ctx, x: Tensor,
204
- sparsity_layout: Tensor, sparsity_lut: Tensor,
205
- sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
206
- output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
207
- dtype=x.dtype, device=x.device)
301
+ if "sparsity_reverse_lut" not in lut:
302
+ sparsity_layout_flat = sparsity_layout.reshape(-1)
303
+ sparsity_reverse_lut = ((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
304
+ (sparsity_layout_flat == 1) -
305
+ (1 * (sparsity_layout_flat == 0)))
306
+ lut["sparsity_reverse_lut"] = sparsity_reverse_lut
208
307
 
209
- x_b, x_r, x_c = x.size()
210
- x_b_s, x_r_s, x_c_s = stride(x)
211
- s_lut_r, s_lut_c = sparsity_lut.size()
212
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
213
- o_b, o_r, o_c = output.size()
214
- o_b_s, o_r_s, o_c_s = stride(output)
308
+ validate_contiguous(lut["sparsity_reverse_lut"])
215
309
 
216
- if triton_block_size is None:
217
- triton_block_size = get_triton_block_size(sparsity_block_size)
310
+ return lut
218
311
 
219
- triton_grid = lambda meta: [o_b,
220
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
221
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
222
312
 
223
- (_BlocksparseToSparse.kernel_blocksparse_to_sparse[triton_grid]
224
- (x, x_b, x_b_s, x_r_s, x_c_s,
225
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
226
- output, o_b_s, o_r_s, o_c_s,
227
- sparsity_block_size,
228
- triton_block_size))
313
+ # noinspection PyUnusedLocal
314
+ def to_dense_setup_context(ctx, inputs, output):
315
+ (_, sparsity_layout, _, sparsity_block_size, _) = inputs
229
316
 
230
- ctx.save_for_backward(sparsity_layout)
231
- ctx.sparsity_block_size = sparsity_block_size
232
- ctx.triton_block_size = triton_block_size
317
+ ctx.save_for_backward(sparsity_layout)
318
+ ctx.sparsity_block_size = sparsity_block_size
233
319
 
234
- return output
235
320
 
236
- @staticmethod
237
- def backward(ctx, grad_output):
238
- sparsity_layout = ctx.saved_tensors[0]
239
- sparsity_block_size = ctx.sparsity_block_size
240
- triton_block_size = ctx.triton_block_size
241
-
242
- return to_dense(grad_output, sparsity_layout, sparsity_block_size,
243
- triton_block_size=triton_block_size), None, None, None, None, None
244
-
245
- @staticmethod
246
- @triton.jit
247
- def kernel_blocksparse_to_sparse(x,
248
- x_b, x_b_s, x_r_s, x_c_s,
249
- s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
250
- o,
251
- o_b_s, o_r_s, o_c_s,
252
- sparsity_block_size,
253
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
254
- # Get triton block indices
255
- pid_blk = tl.program_id(axis=0)
256
- pid_row = tl.program_id(axis=1)
257
- pid_col = tl.program_id(axis=2)
258
-
259
- # Get sparsity index of current output block consisting of its batch, row, and column index
260
- spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
261
- spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
262
- spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
263
-
264
- spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
265
- spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
266
- spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
267
-
268
- spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
269
- spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
270
- spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
271
-
272
- # Load block from dense tensor
273
- blk_d_idx = (spa_bat * x_b_s +
274
- ((spa_row * sparsity_block_size + pid_row * TRITON_BLOCK_SIZE +
275
- tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
276
- ((spa_col * sparsity_block_size + pid_col * TRITON_BLOCK_SIZE +
277
- tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
278
- blk_d_msk = (blk_d_idx >= 0 and blk_d_idx < x_b * x_b_s)
279
- blk_d = tl.load(x + blk_d_idx, mask=blk_d_msk)
280
-
281
- # Store block in sparse tensor
282
- blk_o_idx = ((pid_blk * o_b_s) +
283
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
284
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
285
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < (pid_blk + 1) * o_b_s)
286
- tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
321
+ to_dense_forward.register_autograd(to_dense_wrapper_backward, setup_context=to_dense_setup_context)
287
322
 
288
323
 
324
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
289
325
  def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_size_from: int,
290
- sparsity_block_size_to: int, sparsity_layout_to: Tensor = None,
291
- triton_block_size: int = None) -> (BlksprsTensor, Tensor):
326
+ sparsity_block_size_to: int, sparsity_layout_to: Tensor = None) -> (BlksprsTensor, Tensor):
292
327
  """Adapts the sparsity layout of a block-sparse tensor, resulting in a new block-sparse tensor in compressed form
293
328
  conforming to the new sparsity layout (and sparsity block size) definition.
294
329
 
@@ -298,7 +333,6 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
298
333
  sparsity_block_size_from (int): The size of the sparsity blocks of the input sparsity layout.
299
334
  sparsity_block_size_to (int): The size of the sparsity blocks of the output sparsity layout.
300
335
  sparsity_layout_to (Tensor): The sparsity layout of the output block-sparse tensor (default ``None``).
301
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
302
336
 
303
337
  Returns:
304
338
  BlksprsTensor: The block-sparse tensor in compressed form with the adapted sparsity layout and sparsity block size.
@@ -313,8 +347,6 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
313
347
  validate_sparsity(sparsity_block_size_from, (x, sparsity_layout_from))
314
348
  validate_sparsity_block_size(sparsity_block_size_from, x)
315
349
  validate_sparsity_block_size(sparsity_block_size_to)
316
- min_sparsity_block_size = min(sparsity_block_size_from, sparsity_block_size_to)
317
- validate_triton_block_size(triton_block_size, min_sparsity_block_size)
318
350
 
319
351
  sparsity_layout_from_flat = sparsity_layout_from.reshape(-1)
320
352
  sparsity_reverse_lut_from = ((torch.cumsum(sparsity_layout_from_flat, dim=-1) - 1) *
@@ -323,8 +355,7 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
323
355
 
324
356
  if sparsity_layout_to is None:
325
357
  sparsity_layout_to = build_sparsity_layout_adaption(x, sparsity_layout_from,
326
- sparsity_block_size_from, sparsity_block_size_to,
327
- triton_block_size)
358
+ sparsity_block_size_from, sparsity_block_size_to)
328
359
 
329
360
  sparsity_lut_to = torch.nonzero(sparsity_layout_to).contiguous()
330
361
 
@@ -335,24 +366,22 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
335
366
  if (sparsity_block_size_from == sparsity_block_size_to) and torch.equal(sparsity_layout_from, sparsity_layout_to):
336
367
  return BlksprsTensor(x), sparsity_layout_to
337
368
 
338
- return BlksprsTensor(_BlocksparseAdaptLayout.apply(x,
339
- sparsity_layout_from, sparsity_reverse_lut_from,
340
- sparsity_block_size_from,
341
- sparsity_layout_to, sparsity_lut_to,
342
- sparsity_block_size_to,
343
- n_sparse_blocks_to, min_sparsity_block_size,
344
- triton_block_size)), sparsity_layout_to
345
-
346
-
347
- class _BlocksparseAdaptLayout(torch.autograd.Function):
348
-
349
- @staticmethod
350
- def forward(ctx, x: Tensor,
351
- sparsity_layout_from: Tensor, sparsity_reverse_lut_from: Tensor,
352
- sparsity_block_size_from: int,
353
- sparsity_layout_to: Tensor, sparsity_lut_to: Tensor,
354
- sparsity_block_size_to: int,
355
- n_sparse_blocks_to: int, min_sparsity_block_size: int, triton_block_size: int) -> Tensor:
369
+ return BlksprsTensor(adapt_layout_forward(x,
370
+ sparsity_layout_from, sparsity_reverse_lut_from,
371
+ sparsity_block_size_from,
372
+ sparsity_layout_to, sparsity_lut_to,
373
+ sparsity_block_size_to,
374
+ n_sparse_blocks_to)), sparsity_layout_to
375
+
376
+
377
+ @triton_op("blksprs::adapt_layout_forward", mutates_args={})
378
+ def adapt_layout_forward(x: Tensor,
379
+ sparsity_layout_from: Tensor, sparsity_reverse_lut_from: Tensor,
380
+ sparsity_block_size_from: int,
381
+ _: Tensor, sparsity_lut_to: Tensor,
382
+ sparsity_block_size_to: int,
383
+ n_sparse_blocks_to: int) -> Tensor:
384
+ with torch.no_grad():
356
385
  output = torch.zeros(size=(n_sparse_blocks_to, sparsity_block_size_to, sparsity_block_size_to),
357
386
  dtype=x.dtype, device=x.device)
358
387
 
@@ -365,14 +394,11 @@ class _BlocksparseAdaptLayout(torch.autograd.Function):
365
394
  s_lut_o_r, s_lut_o_c = sparsity_lut_to.size()
366
395
  s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_to)
367
396
 
368
- if triton_block_size is None:
369
- triton_block_size = get_triton_block_size(min_sparsity_block_size)
370
-
371
397
  triton_grid = lambda meta: [o_b,
372
398
  triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
373
399
  triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
374
400
 
375
- (_BlocksparseAdaptLayout.kernel_adapt_layout[triton_grid]
401
+ (wrap_triton(adapt_layout_kernel)[triton_grid]
376
402
  (x,
377
403
  x_b, x_b_s, x_r_s, x_c_s,
378
404
  s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
@@ -381,88 +407,100 @@ class _BlocksparseAdaptLayout(torch.autograd.Function):
381
407
  o_b, o_b_s, o_r_s, o_c_s,
382
408
  sparsity_lut_to, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
383
409
  sparsity_block_size_from,
384
- sparsity_block_size_to,
385
- triton_block_size))
386
-
387
- ctx.save_for_backward(x, sparsity_layout_from, sparsity_layout_to)
388
- ctx.sparsity_block_size_from = sparsity_block_size_from
389
- ctx.sparsity_block_size_to = sparsity_block_size_to
390
- ctx.triton_block_size = triton_block_size
410
+ sparsity_block_size_to))
391
411
 
392
412
  return output
393
413
 
394
- @staticmethod
395
- def backward(ctx, grad_output):
396
- x, sparsity_layout_from, sparsity_layout_to = ctx.saved_tensors
397
- sparsity_block_size_from = ctx.sparsity_block_size_from
398
- sparsity_block_size_to = ctx.sparsity_block_size_to
399
- triton_block_size = ctx.triton_block_size
400
-
401
- return adapt_layout(
402
- grad_output, sparsity_layout_to, sparsity_block_size_to, sparsity_block_size_from,
403
- sparsity_layout_to=sparsity_layout_from,
404
- triton_block_size=triton_block_size)[0], None, None, None, None, None, None, None, None, None
405
-
406
- @staticmethod
407
- @triton.jit
408
- def kernel_adapt_layout(x,
409
- x_b, x_b_s, x_r_s, x_c_s,
410
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
411
- r_lut_x,
412
- o,
413
- o_b, o_b_s, o_r_s, o_c_s,
414
- s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
415
- sparsity_block_size_from,
416
- sparsity_block_size_to,
417
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
418
- # Get triton block indices
419
- pid_blk = tl.program_id(axis=0)
420
- pid_row = tl.program_id(axis=1)
421
- pid_col = tl.program_id(axis=2)
422
-
423
- # Get position of current sparsity block consisting of its batch, row, and column index
424
- spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
425
- spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
426
- spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
427
-
428
- spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
429
- spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
430
- spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
431
-
432
- spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
433
- spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
434
- spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
435
-
436
- # Get equivalent sparsity block in from layout
437
- spa_bat_x = spa_bat_o
438
- spa_row_x = (spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size_from
439
- spa_col_x = (spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size_from
440
-
441
- # Get reverse sparsity indices for x
442
- rev_idx_spa_x_idx = (spa_bat_x * s_l_x_b_s +
443
- spa_row_x * s_l_x_r_s +
444
- spa_col_x * s_l_x_c_s)
445
- rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
446
- rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
447
-
448
- # If block is present commence operations
449
- if rev_idx_spa_x >= 0:
450
- # Calculate triton block size shifts
451
- shift_row_x = ((spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE)
452
- % sparsity_block_size_from) // TRITON_BLOCK_SIZE
453
- shift_col_x = ((spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE)
454
- % sparsity_block_size_from) // TRITON_BLOCK_SIZE
455
-
456
- # Load x values
457
- blk_x_idx = ((rev_idx_spa_x * x_b_s) +
458
- ((shift_row_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
459
- ((shift_col_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
460
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
461
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
462
-
463
- # Store output
464
- blk_o_idx = ((pid_blk * o_b_s) +
465
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
466
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
467
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
468
- tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
414
+
415
+ def adapt_layout_wrapper_backward(ctx, grad_output):
416
+ x, sparsity_layout_from, sparsity_layout_to = ctx.saved_tensors
417
+ sparsity_block_size_from = ctx.sparsity_block_size_from
418
+ sparsity_block_size_to = ctx.sparsity_block_size_to
419
+
420
+ return adapt_layout(
421
+ grad_output, sparsity_layout_to, sparsity_block_size_to, sparsity_block_size_from,
422
+ sparsity_layout_to=sparsity_layout_from)[0], None, None, None, None, None, None, None
423
+
424
+
425
+ @triton.autotune(
426
+ configs=get_autotune_configs(),
427
+ key=["sparsity_block_size_from", "sparsity_block_size_to"],
428
+ prune_configs_by={"early_config_prune": prune_autotune_configs_conversion},
429
+ reset_to_zero=["o"]
430
+ )
431
+ @triton.jit
432
+ def adapt_layout_kernel(x,
433
+ x_b, x_b_s, x_r_s, x_c_s,
434
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
435
+ r_lut_x,
436
+ o,
437
+ o_b, o_b_s, o_r_s, o_c_s,
438
+ s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
439
+ sparsity_block_size_from,
440
+ sparsity_block_size_to,
441
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
442
+ # Get triton block indices
443
+ pid_blk = tl.program_id(axis=0)
444
+ pid_row = tl.program_id(axis=1)
445
+ pid_col = tl.program_id(axis=2)
446
+
447
+ # Get position of current sparsity block consisting of its batch, row, and column index
448
+ spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
449
+ spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
450
+ spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
451
+
452
+ spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
453
+ spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
454
+ spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
455
+
456
+ spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
457
+ spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
458
+ spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
459
+
460
+ # Get equivalent sparsity block in from layout
461
+ spa_bat_x = spa_bat_o
462
+ spa_row_x = (spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size_from
463
+ spa_col_x = (spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size_from
464
+
465
+ # Get reverse sparsity indices for x
466
+ rev_idx_spa_x_idx = (spa_bat_x * s_l_x_b_s +
467
+ spa_row_x * s_l_x_r_s +
468
+ spa_col_x * s_l_x_c_s)
469
+ rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
470
+ rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
471
+
472
+ # If block is present commence operations
473
+ if rev_idx_spa_x >= 0:
474
+ # Calculate triton block size shifts
475
+ shift_row_x = ((spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE)
476
+ % sparsity_block_size_from) // TRITON_BLOCK_SIZE
477
+ shift_col_x = ((spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE)
478
+ % sparsity_block_size_from) // TRITON_BLOCK_SIZE
479
+
480
+ # Load x values
481
+ blk_x_idx = ((rev_idx_spa_x * x_b_s) +
482
+ ((shift_row_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
483
+ ((shift_col_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
484
+ blk_x_msk = (blk_x_idx >= 0 and
485
+ blk_x_idx < x_b * x_b_s)
486
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
487
+
488
+ # Store output
489
+ blk_o_idx = ((pid_blk * o_b_s) +
490
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
491
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
492
+ blk_o_msk = (blk_o_idx >= 0 and
493
+ blk_o_idx < o_b * o_b_s)
494
+ tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
495
+
496
+
497
+ # noinspection PyUnusedLocal
498
+ def adapt_layout_setup_context(ctx, inputs, output):
499
+ (x, sparsity_layout_from, _, sparsity_block_size_from, sparsity_layout_to, _, sparsity_block_size_to, _) = inputs
500
+
501
+ ctx.save_for_backward(x, sparsity_layout_from, sparsity_layout_to)
502
+ ctx.sparsity_block_size_from = sparsity_block_size_from
503
+ ctx.sparsity_block_size_to = sparsity_block_size_to
504
+
505
+
506
+ adapt_layout_forward.register_autograd(adapt_layout_wrapper_backward, setup_context=adapt_layout_setup_context)