blksprs 1.11__py3-none-any.whl → 2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/ops/conversion.py CHANGED
@@ -1,25 +1,181 @@
1
1
  import torch
2
2
  import triton
3
3
  from torch import Tensor
4
+ from torch._library.triton import wrap_triton, triton_op
4
5
  from triton import language as tl
5
6
 
6
7
  from blksprs.layouting.sparsity_layout import build_sparsity_layout_adaption
7
8
  from blksprs.utils.blksprs_tensor import BlksprsTensor
8
- from blksprs.utils.tools import get_triton_block_size, stride
9
+ from blksprs.utils.tools import stride
10
+ from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs, prune_autotune_configs_conversion
9
11
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
10
- validate_sparsity, validate_sparsity_block_size, validate_triton_block_size, validate_sparsity_dense
12
+ validate_sparsity, validate_sparsity_block_size, validate_sparsity_dense
11
13
 
12
14
 
13
- def from_blksprs(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
14
- triton_block_size: int = None) -> Tensor:
15
+ def to_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int) -> BlksprsTensor:
16
+ """Wrapper for ``to_sparse``.
17
+
18
+ """
19
+ return to_sparse(x, sparsity_layout, sparsity_block_size)
20
+
21
+
22
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
23
+ def to_sparse(x: Tensor, sparsity_layout: Tensor,
24
+ sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
25
+ """Converts a block-sparse tensor in regular form to a block-sparse tensor in compressed form based on the given
26
+ sparsity layout.
27
+
28
+ Args:
29
+ x (Tensor): A block-sparse tensor in regular form.
30
+ sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
31
+ sparsity_block_size (int): The size of the sparsity blocks.
32
+ lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
33
+
34
+ Returns:
35
+ BlksprsTensor: The block-sparse tensor converted to compressed form.
36
+
37
+ """
38
+ x = x.contiguous()
39
+
40
+ validate_dimensions(x)
41
+ validate_contiguous(x)
42
+ validate_device(x)
43
+ validate_sparsity_dense(sparsity_block_size, (x, sparsity_layout))
44
+ validate_sparsity_block_size(sparsity_block_size, x)
45
+
46
+ lut = to_sparse_build_lut(lut, sparsity_layout)
47
+
48
+ if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
49
+ return BlksprsTensor(x)
50
+
51
+ return BlksprsTensor(to_sparse_forward(x, sparsity_layout,
52
+ lut["sparsity_lut"], sparsity_block_size, lut["n_sparse_blocks"]))
53
+
54
+
55
+ @triton_op("blksprs::to_sparse_forward", mutates_args={})
56
+ def to_sparse_forward(x: Tensor, _: Tensor,
57
+ sparsity_lut: Tensor, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
58
+ with torch.no_grad():
59
+ output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
60
+ dtype=x.dtype, device=x.device)
61
+
62
+ x_b, x_r, x_c = x.size()
63
+ x_b_s, x_r_s, x_c_s = stride(x)
64
+ s_lut_r, s_lut_c = sparsity_lut.size()
65
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
66
+ o_b, o_r, o_c = output.size()
67
+ o_b_s, o_r_s, o_c_s = stride(output)
68
+
69
+ triton_grid = lambda meta: [o_b,
70
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
71
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
72
+
73
+ (wrap_triton(to_sparse_kernel)[triton_grid]
74
+ (x, x_b, x_b_s, x_r_s, x_c_s,
75
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
76
+ output, o_b_s, o_r_s, o_c_s,
77
+ sparsity_block_size))
78
+
79
+ return output
80
+
81
+
82
+ def to_sparse_wrapper_backward(ctx, grad_output):
83
+ sparsity_layout = ctx.saved_tensors[0]
84
+ sparsity_block_size = ctx.sparsity_block_size
85
+
86
+ return to_dense(grad_output, sparsity_layout, sparsity_block_size), None, None, None, None
87
+
88
+
89
+ @triton.autotune(
90
+ configs=get_autotune_configs(),
91
+ key=["sparsity_block_size"],
92
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
93
+ reset_to_zero=["o"]
94
+ )
95
+ @triton.jit
96
+ def to_sparse_kernel(x,
97
+ x_b, x_b_s, x_r_s, x_c_s,
98
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
99
+ o,
100
+ o_b_s, o_r_s, o_c_s,
101
+ sparsity_block_size,
102
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
103
+ # Get triton block indices
104
+ pid_blk = tl.program_id(axis=0)
105
+ pid_row = tl.program_id(axis=1)
106
+ pid_col = tl.program_id(axis=2)
107
+
108
+ # Get sparsity index of current output block consisting of its batch, row, and column index
109
+ spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
110
+ spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
111
+ spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
112
+
113
+ spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
114
+ spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
115
+ spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
116
+
117
+ spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
118
+ spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
119
+ spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
120
+
121
+ # Load block from dense tensor
122
+ blk_d_idx = (spa_bat * x_b_s +
123
+ ((pid_row * TRITON_BLOCK_SIZE + spa_row * sparsity_block_size +
124
+ tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
125
+ ((pid_col * TRITON_BLOCK_SIZE + spa_col * sparsity_block_size +
126
+ tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
127
+ blk_d_msk = (blk_d_idx >= 0 and
128
+ blk_d_idx < x_b * x_b_s)
129
+ blk_d = tl.load(x + blk_d_idx, mask=blk_d_msk)
130
+
131
+ # Store block in sparse tensor
132
+ blk_o_idx = ((pid_blk * o_b_s) +
133
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
134
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
135
+ blk_o_msk = (blk_o_idx >= 0 and
136
+ blk_o_idx < (pid_blk + 1) * o_b_s)
137
+ tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
138
+
139
+
140
+ def to_sparse_build_lut(lut: dict, sparsity_layout: Tensor):
141
+ if lut is None:
142
+ lut = dict()
143
+
144
+ if "sparsity_lut" not in lut:
145
+ sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
146
+ lut["sparsity_lut"] = sparsity_lut
147
+
148
+ if "n_sparse_blocks" not in lut:
149
+ n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
150
+ lut["n_sparse_blocks"] = n_sparse_blocks
151
+
152
+ validate_contiguous(sparsity_layout, lut["sparsity_lut"])
153
+
154
+ return lut
155
+
156
+
157
+ # noinspection PyUnusedLocal
158
+ def to_sparse_setup_context(ctx, inputs, output):
159
+ (_, sparsity_layout, _, sparsity_block_size, _) = inputs
160
+
161
+ ctx.save_for_backward(sparsity_layout, )
162
+ ctx.sparsity_block_size = sparsity_block_size
163
+
164
+
165
+ to_sparse_forward.register_autograd(to_sparse_wrapper_backward, setup_context=to_sparse_setup_context)
166
+
167
+
168
+ def from_blksprs(x: BlksprsTensor, sparsity_layout: Tensor,
169
+ sparsity_block_size: int, fill_value: float = 0, lut: dict = None) -> Tensor:
15
170
  """Wrapper for ``to_dense``.
16
171
 
17
172
  """
18
- return to_dense(x, sparsity_layout, sparsity_block_size, fill_value, triton_block_size)
173
+ return to_dense(x, sparsity_layout, sparsity_block_size, fill_value=fill_value, lut=lut)
19
174
 
20
175
 
21
- def to_dense(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
22
- triton_block_size: int = None, lut: dict = None) -> Tensor:
176
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
177
+ def to_dense(x: BlksprsTensor, sparsity_layout: Tensor,
178
+ sparsity_block_size: int, fill_value: float = 0, lut: dict = None) -> Tensor:
23
179
  """Converts a block-sparse tensor in compressed form to a block-sparse tensor in regular form based on the given
24
180
  sparsity layout.
25
181
 
@@ -29,7 +185,6 @@ def to_dense(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int
29
185
  sparsity_block_size (int): The size of the sparsity blocks.
30
186
  fill_value (float): The value to fill the resulting dense tensor with where the block-sparse tensor is not
31
187
  present (default ``0``).
32
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
33
188
  lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
34
189
 
35
190
  Returns:
@@ -43,42 +198,21 @@ def to_dense(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int
43
198
  validate_device(x)
44
199
  validate_sparsity(sparsity_block_size, (x, sparsity_layout))
45
200
  validate_sparsity_block_size(sparsity_block_size, x)
46
- validate_triton_block_size(triton_block_size, sparsity_block_size)
47
201
 
48
- lut = _BlocksparseToDense.build_lut(lut, sparsity_layout)
202
+ lut = to_dense_build_lut(lut, sparsity_layout)
49
203
 
50
204
  if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
51
205
  return x
52
206
 
53
- return _BlocksparseToDense.apply(x,
54
- sparsity_layout, lut["sparsity_reverse_lut"],
55
- sparsity_block_size, fill_value,
56
- triton_block_size)
57
-
58
-
59
- class _BlocksparseToDense(torch.autograd.Function):
60
-
61
- @staticmethod
62
- def build_lut(lut: dict, sparsity_layout: Tensor):
63
- if lut is None:
64
- lut = dict()
207
+ return Tensor(to_dense_forward(x, sparsity_layout,
208
+ lut["sparsity_reverse_lut"], sparsity_block_size, fill_value))
65
209
 
66
- if "sparsity_reverse_lut" not in lut:
67
- sparsity_layout_flat = sparsity_layout.reshape(-1)
68
- sparsity_reverse_lut = ((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
69
- (sparsity_layout_flat == 1) -
70
- (1 * (sparsity_layout_flat == 0)))
71
- lut["sparsity_reverse_lut"] = sparsity_reverse_lut
72
210
 
73
- validate_contiguous(lut["sparsity_reverse_lut"])
74
-
75
- return lut
76
-
77
- @staticmethod
78
- def forward(ctx, x: Tensor,
79
- sparsity_layout: Tensor, sparsity_reverse_lut: Tensor,
80
- sparsity_block_size: int, fill_value: float,
81
- triton_block_size: int) -> Tensor:
211
+ @triton_op("blksprs::to_dense_forward", mutates_args={})
212
+ def to_dense_forward(x: Tensor, sparsity_layout: Tensor,
213
+ sparsity_reverse_lut: Tensor,
214
+ sparsity_block_size: int, fill_value: float) -> Tensor:
215
+ with torch.no_grad():
82
216
  output = torch.full(size=(sparsity_layout.size(0), sparsity_layout.size(1) * sparsity_block_size,
83
217
  sparsity_layout.size(2) * sparsity_block_size), fill_value=fill_value,
84
218
  dtype=x.dtype, device=x.device)
@@ -90,232 +224,106 @@ class _BlocksparseToDense(torch.autograd.Function):
90
224
  o_b, o_r, o_c = output.size()
91
225
  o_b_s, o_r_s, o_c_s = stride(output)
92
226
 
93
- if triton_block_size is None:
94
- triton_block_size = get_triton_block_size(sparsity_block_size)
95
-
96
227
  triton_grid = lambda meta: [o_b,
97
228
  triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
98
229
  triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
99
230
 
100
- (_BlocksparseToDense.kernel_blocksparse_to_dense[triton_grid]
231
+ (wrap_triton(to_dense_kernel)[triton_grid]
101
232
  (x,
102
233
  x_b, x_b_s, x_r_s, x_c_s,
103
234
  s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
104
235
  sparsity_reverse_lut,
105
236
  output,
106
237
  o_b, o_b_s, o_r_s, o_c_s,
107
- sparsity_block_size,
108
- triton_block_size))
109
-
110
- ctx.save_for_backward(sparsity_layout)
111
- ctx.sparsity_block_size = sparsity_block_size
112
- ctx.triton_block_size = triton_block_size
238
+ sparsity_block_size))
113
239
 
114
240
  return output
115
241
 
116
- @staticmethod
117
- def backward(ctx, grad_output):
118
- sparsity_layout = ctx.saved_tensors[0]
119
- sparsity_block_size = ctx.sparsity_block_size
120
- triton_block_size = ctx.triton_block_size
121
-
122
- return to_sparse(grad_output, sparsity_layout, sparsity_block_size,
123
- triton_block_size), None, None, None, None, None
124
-
125
- @staticmethod
126
- @triton.jit
127
- def kernel_blocksparse_to_dense(x,
128
- x_b, x_b_s, x_r_s, x_c_s,
129
- s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
130
- sparsity_reverse_lut,
131
- o,
132
- o_b, o_b_s, o_r_s, o_c_s,
133
- sparsity_block_size,
134
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
135
- # Get triton block indices
136
- pid_blk = tl.program_id(axis=0)
137
- pid_row = tl.program_id(axis=1)
138
- pid_col = tl.program_id(axis=2)
139
-
140
- # Get sparsity index of current block
141
- spa_row = (pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size
142
- spa_col = (pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size
143
-
144
- # Get reverse sparsity index for current block
145
- rev_idx_spa_idx = (pid_blk * s_l_b_s + spa_row * s_l_r_s + spa_col * s_l_c_s)
146
- rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_b * s_l_b_s)
147
- rev_idx_spa = tl.load(sparsity_reverse_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
148
-
149
- # If block is present commence operations
150
- if rev_idx_spa >= 0:
151
- blk_idx = (rev_idx_spa * x_b_s +
152
- (((pid_row % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
153
- tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
154
- (((pid_col % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
155
- tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
156
- blk_msk = (blk_idx >= 0 and blk_idx < x_b * x_b_s)
157
- blk = tl.load(x + blk_idx, mask=blk_msk)
158
-
159
- o_idx = (pid_blk * o_b_s +
160
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
161
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
162
- o_msk = (o_idx >= 0 and o_idx < o_b * o_b_s)
163
- tl.store(o + o_idx, blk, o_msk)
164
-
165
-
166
- def to_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
167
- triton_block_size: int = None) -> BlksprsTensor:
168
- """Wrapper for ``to_sparse``.
169
-
170
- """
171
- return to_sparse(x, sparsity_layout, sparsity_block_size, triton_block_size)
172
-
173
-
174
- def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
175
- triton_block_size: int = None, lut: dict = None) -> BlksprsTensor:
176
- """Converts a block-sparse tensor in regular form to a block-sparse tensor in compressed form based on the given
177
- sparsity layout.
178
-
179
- Args:
180
- x (Tensor): A block-sparse tensor in regular form.
181
- sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
182
- sparsity_block_size (int): The size of the sparsity blocks.
183
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
184
- lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
185
-
186
- Returns:
187
- BlksprsTensor: The block-sparse tensor converted to compressed form.
188
-
189
- """
190
- x = x.contiguous()
191
-
192
- validate_dimensions(x)
193
- validate_contiguous(x)
194
- validate_device(x)
195
- validate_sparsity_dense(sparsity_block_size, (x, sparsity_layout))
196
- validate_sparsity_block_size(sparsity_block_size, x)
197
- validate_triton_block_size(triton_block_size, sparsity_block_size)
198
242
 
199
- lut = _BlocksparseToSparse.build_lut(lut, sparsity_layout)
243
+ def to_dense_wrapper_backward(ctx, grad_output):
244
+ sparsity_layout = ctx.saved_tensors[0]
245
+ sparsity_block_size = ctx.sparsity_block_size
200
246
 
201
- if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
202
- return BlksprsTensor(x)
247
+ return to_sparse(grad_output, sparsity_layout, sparsity_block_size), None, None, None, None
203
248
 
204
- return BlksprsTensor(_BlocksparseToSparse.apply(x,
205
- sparsity_layout, lut["sparsity_lut"],
206
- sparsity_block_size, lut["n_sparse_blocks"],
207
- triton_block_size))
208
249
 
250
+ @triton.autotune(
251
+ configs=get_autotune_configs(),
252
+ key=["sparsity_block_size"],
253
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
254
+ restore_value=["o"]
255
+ )
256
+ @triton.jit
257
+ def to_dense_kernel(x,
258
+ x_b, x_b_s, x_r_s, x_c_s,
259
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
260
+ sparsity_reverse_lut,
261
+ o,
262
+ o_b, o_b_s, o_r_s, o_c_s,
263
+ sparsity_block_size,
264
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
265
+ # Get triton block indices
266
+ pid_blk = tl.program_id(axis=0)
267
+ pid_row = tl.program_id(axis=1)
268
+ pid_col = tl.program_id(axis=2)
209
269
 
210
- class _BlocksparseToSparse(torch.autograd.Function):
270
+ # Get sparsity index of current block
271
+ spa_row = (pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size
272
+ spa_col = (pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size
211
273
 
212
- @staticmethod
213
- def build_lut(lut: dict, sparsity_layout: Tensor):
214
- if lut is None:
215
- lut = dict()
274
+ # Get reverse sparsity index for current block
275
+ rev_idx_spa_idx = (pid_blk * s_l_b_s + spa_row * s_l_r_s + spa_col * s_l_c_s)
276
+ rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_b * s_l_b_s)
277
+ rev_idx_spa = tl.load(sparsity_reverse_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
216
278
 
217
- if "sparsity_lut" not in lut:
218
- sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
219
- lut["sparsity_lut"] = sparsity_lut
279
+ # If block is present commence operations
280
+ if rev_idx_spa >= 0:
281
+ blk_idx = (rev_idx_spa * x_b_s +
282
+ (((pid_row % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
283
+ tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
284
+ (((pid_col % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
285
+ tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
286
+ blk_msk = (blk_idx >= 0 and
287
+ blk_idx < x_b * x_b_s)
288
+ blk = tl.load(x + blk_idx, mask=blk_msk)
220
289
 
221
- if "n_sparse_blocks" not in lut:
222
- n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
223
- lut["n_sparse_blocks"] = n_sparse_blocks
290
+ o_idx = (pid_blk * o_b_s +
291
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
292
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
293
+ o_msk = (o_idx >= 0 and o_idx < o_b * o_b_s)
294
+ tl.store(o + o_idx, blk, o_msk)
224
295
 
225
- validate_contiguous(sparsity_layout, lut["sparsity_lut"])
226
296
 
227
- return lut
297
+ def to_dense_build_lut(lut: dict, sparsity_layout: Tensor):
298
+ if lut is None:
299
+ lut = dict()
228
300
 
229
- @staticmethod
230
- def forward(ctx, x: Tensor,
231
- sparsity_layout: Tensor, sparsity_lut: Tensor,
232
- sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
233
- output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
234
- dtype=x.dtype, device=x.device)
301
+ if "sparsity_reverse_lut" not in lut:
302
+ sparsity_layout_flat = sparsity_layout.reshape(-1)
303
+ sparsity_reverse_lut = ((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
304
+ (sparsity_layout_flat == 1) -
305
+ (1 * (sparsity_layout_flat == 0)))
306
+ lut["sparsity_reverse_lut"] = sparsity_reverse_lut
235
307
 
236
- x_b, x_r, x_c = x.size()
237
- x_b_s, x_r_s, x_c_s = stride(x)
238
- s_lut_r, s_lut_c = sparsity_lut.size()
239
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
240
- o_b, o_r, o_c = output.size()
241
- o_b_s, o_r_s, o_c_s = stride(output)
308
+ validate_contiguous(lut["sparsity_reverse_lut"])
242
309
 
243
- if triton_block_size is None:
244
- triton_block_size = get_triton_block_size(sparsity_block_size)
310
+ return lut
245
311
 
246
- triton_grid = lambda meta: [o_b,
247
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
248
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
249
312
 
250
- (_BlocksparseToSparse.kernel_blocksparse_to_sparse[triton_grid]
251
- (x, x_b, x_b_s, x_r_s, x_c_s,
252
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
253
- output, o_b_s, o_r_s, o_c_s,
254
- sparsity_block_size,
255
- triton_block_size))
313
+ # noinspection PyUnusedLocal
314
+ def to_dense_setup_context(ctx, inputs, output):
315
+ (_, sparsity_layout, _, sparsity_block_size, _) = inputs
256
316
 
257
- ctx.save_for_backward(sparsity_layout)
258
- ctx.sparsity_block_size = sparsity_block_size
259
- ctx.triton_block_size = triton_block_size
317
+ ctx.save_for_backward(sparsity_layout)
318
+ ctx.sparsity_block_size = sparsity_block_size
260
319
 
261
- return output
262
320
 
263
- @staticmethod
264
- def backward(ctx, grad_output):
265
- sparsity_layout = ctx.saved_tensors[0]
266
- sparsity_block_size = ctx.sparsity_block_size
267
- triton_block_size = ctx.triton_block_size
268
-
269
- return to_dense(grad_output, sparsity_layout, sparsity_block_size,
270
- triton_block_size=triton_block_size), None, None, None, None, None
271
-
272
- @staticmethod
273
- @triton.jit
274
- def kernel_blocksparse_to_sparse(x,
275
- x_b, x_b_s, x_r_s, x_c_s,
276
- s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
277
- o,
278
- o_b_s, o_r_s, o_c_s,
279
- sparsity_block_size,
280
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
281
- # Get triton block indices
282
- pid_blk = tl.program_id(axis=0)
283
- pid_row = tl.program_id(axis=1)
284
- pid_col = tl.program_id(axis=2)
285
-
286
- # Get sparsity index of current output block consisting of its batch, row, and column index
287
- spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
288
- spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
289
- spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
290
-
291
- spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
292
- spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
293
- spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
294
-
295
- spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
296
- spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
297
- spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
298
-
299
- # Load block from dense tensor
300
- blk_d_idx = (spa_bat * x_b_s +
301
- ((spa_row * sparsity_block_size + pid_row * TRITON_BLOCK_SIZE +
302
- tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
303
- ((spa_col * sparsity_block_size + pid_col * TRITON_BLOCK_SIZE +
304
- tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
305
- blk_d_msk = (blk_d_idx >= 0 and blk_d_idx < x_b * x_b_s)
306
- blk_d = tl.load(x + blk_d_idx, mask=blk_d_msk)
307
-
308
- # Store block in sparse tensor
309
- blk_o_idx = ((pid_blk * o_b_s) +
310
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
311
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
312
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < (pid_blk + 1) * o_b_s)
313
- tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
321
+ to_dense_forward.register_autograd(to_dense_wrapper_backward, setup_context=to_dense_setup_context)
314
322
 
315
323
 
324
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float16)
316
325
  def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_size_from: int,
317
- sparsity_block_size_to: int, sparsity_layout_to: Tensor = None,
318
- triton_block_size: int = None) -> (BlksprsTensor, Tensor):
326
+ sparsity_block_size_to: int, sparsity_layout_to: Tensor = None) -> (BlksprsTensor, Tensor):
319
327
  """Adapts the sparsity layout of a block-sparse tensor, resulting in a new block-sparse tensor in compressed form
320
328
  conforming to the new sparsity layout (and sparsity block size) definition.
321
329
 
@@ -325,7 +333,6 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
325
333
  sparsity_block_size_from (int): The size of the sparsity blocks of the input sparsity layout.
326
334
  sparsity_block_size_to (int): The size of the sparsity blocks of the output sparsity layout.
327
335
  sparsity_layout_to (Tensor): The sparsity layout of the output block-sparse tensor (default ``None``).
328
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
329
336
 
330
337
  Returns:
331
338
  BlksprsTensor: The block-sparse tensor in compressed form with the adapted sparsity layout and sparsity block size.
@@ -340,8 +347,6 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
340
347
  validate_sparsity(sparsity_block_size_from, (x, sparsity_layout_from))
341
348
  validate_sparsity_block_size(sparsity_block_size_from, x)
342
349
  validate_sparsity_block_size(sparsity_block_size_to)
343
- min_sparsity_block_size = min(sparsity_block_size_from, sparsity_block_size_to)
344
- validate_triton_block_size(triton_block_size, min_sparsity_block_size)
345
350
 
346
351
  sparsity_layout_from_flat = sparsity_layout_from.reshape(-1)
347
352
  sparsity_reverse_lut_from = ((torch.cumsum(sparsity_layout_from_flat, dim=-1) - 1) *
@@ -350,8 +355,7 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
350
355
 
351
356
  if sparsity_layout_to is None:
352
357
  sparsity_layout_to = build_sparsity_layout_adaption(x, sparsity_layout_from,
353
- sparsity_block_size_from, sparsity_block_size_to,
354
- triton_block_size)
358
+ sparsity_block_size_from, sparsity_block_size_to)
355
359
 
356
360
  sparsity_lut_to = torch.nonzero(sparsity_layout_to).contiguous()
357
361
 
@@ -362,24 +366,22 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
362
366
  if (sparsity_block_size_from == sparsity_block_size_to) and torch.equal(sparsity_layout_from, sparsity_layout_to):
363
367
  return BlksprsTensor(x), sparsity_layout_to
364
368
 
365
- return BlksprsTensor(_BlocksparseAdaptLayout.apply(x,
366
- sparsity_layout_from, sparsity_reverse_lut_from,
367
- sparsity_block_size_from,
368
- sparsity_layout_to, sparsity_lut_to,
369
- sparsity_block_size_to,
370
- n_sparse_blocks_to, min_sparsity_block_size,
371
- triton_block_size)), sparsity_layout_to
372
-
373
-
374
- class _BlocksparseAdaptLayout(torch.autograd.Function):
375
-
376
- @staticmethod
377
- def forward(ctx, x: Tensor,
378
- sparsity_layout_from: Tensor, sparsity_reverse_lut_from: Tensor,
379
- sparsity_block_size_from: int,
380
- sparsity_layout_to: Tensor, sparsity_lut_to: Tensor,
381
- sparsity_block_size_to: int,
382
- n_sparse_blocks_to: int, min_sparsity_block_size: int, triton_block_size: int) -> Tensor:
369
+ return BlksprsTensor(adapt_layout_forward(x,
370
+ sparsity_layout_from, sparsity_reverse_lut_from,
371
+ sparsity_block_size_from,
372
+ sparsity_layout_to, sparsity_lut_to,
373
+ sparsity_block_size_to,
374
+ n_sparse_blocks_to)), sparsity_layout_to
375
+
376
+
377
+ @triton_op("blksprs::adapt_layout_forward", mutates_args={})
378
+ def adapt_layout_forward(x: Tensor,
379
+ sparsity_layout_from: Tensor, sparsity_reverse_lut_from: Tensor,
380
+ sparsity_block_size_from: int,
381
+ _: Tensor, sparsity_lut_to: Tensor,
382
+ sparsity_block_size_to: int,
383
+ n_sparse_blocks_to: int) -> Tensor:
384
+ with torch.no_grad():
383
385
  output = torch.zeros(size=(n_sparse_blocks_to, sparsity_block_size_to, sparsity_block_size_to),
384
386
  dtype=x.dtype, device=x.device)
385
387
 
@@ -392,14 +394,11 @@ class _BlocksparseAdaptLayout(torch.autograd.Function):
392
394
  s_lut_o_r, s_lut_o_c = sparsity_lut_to.size()
393
395
  s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_to)
394
396
 
395
- if triton_block_size is None:
396
- triton_block_size = get_triton_block_size(min_sparsity_block_size)
397
-
398
397
  triton_grid = lambda meta: [o_b,
399
398
  triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
400
399
  triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
401
400
 
402
- (_BlocksparseAdaptLayout.kernel_adapt_layout[triton_grid]
401
+ (wrap_triton(adapt_layout_kernel)[triton_grid]
403
402
  (x,
404
403
  x_b, x_b_s, x_r_s, x_c_s,
405
404
  s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
@@ -408,88 +407,100 @@ class _BlocksparseAdaptLayout(torch.autograd.Function):
408
407
  o_b, o_b_s, o_r_s, o_c_s,
409
408
  sparsity_lut_to, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
410
409
  sparsity_block_size_from,
411
- sparsity_block_size_to,
412
- triton_block_size))
413
-
414
- ctx.save_for_backward(x, sparsity_layout_from, sparsity_layout_to)
415
- ctx.sparsity_block_size_from = sparsity_block_size_from
416
- ctx.sparsity_block_size_to = sparsity_block_size_to
417
- ctx.triton_block_size = triton_block_size
410
+ sparsity_block_size_to))
418
411
 
419
412
  return output
420
413
 
421
- @staticmethod
422
- def backward(ctx, grad_output):
423
- x, sparsity_layout_from, sparsity_layout_to = ctx.saved_tensors
424
- sparsity_block_size_from = ctx.sparsity_block_size_from
425
- sparsity_block_size_to = ctx.sparsity_block_size_to
426
- triton_block_size = ctx.triton_block_size
427
-
428
- return adapt_layout(
429
- grad_output, sparsity_layout_to, sparsity_block_size_to, sparsity_block_size_from,
430
- sparsity_layout_to=sparsity_layout_from,
431
- triton_block_size=triton_block_size)[0], None, None, None, None, None, None, None, None, None
432
-
433
- @staticmethod
434
- @triton.jit
435
- def kernel_adapt_layout(x,
436
- x_b, x_b_s, x_r_s, x_c_s,
437
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
438
- r_lut_x,
439
- o,
440
- o_b, o_b_s, o_r_s, o_c_s,
441
- s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
442
- sparsity_block_size_from,
443
- sparsity_block_size_to,
444
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
445
- # Get triton block indices
446
- pid_blk = tl.program_id(axis=0)
447
- pid_row = tl.program_id(axis=1)
448
- pid_col = tl.program_id(axis=2)
449
-
450
- # Get position of current sparsity block consisting of its batch, row, and column index
451
- spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
452
- spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
453
- spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
454
-
455
- spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
456
- spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
457
- spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
458
-
459
- spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
460
- spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
461
- spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
462
-
463
- # Get equivalent sparsity block in from layout
464
- spa_bat_x = spa_bat_o
465
- spa_row_x = (spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size_from
466
- spa_col_x = (spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size_from
467
-
468
- # Get reverse sparsity indices for x
469
- rev_idx_spa_x_idx = (spa_bat_x * s_l_x_b_s +
470
- spa_row_x * s_l_x_r_s +
471
- spa_col_x * s_l_x_c_s)
472
- rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
473
- rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
474
-
475
- # If block is present commence operations
476
- if rev_idx_spa_x >= 0:
477
- # Calculate triton block size shifts
478
- shift_row_x = ((spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE)
479
- % sparsity_block_size_from) // TRITON_BLOCK_SIZE
480
- shift_col_x = ((spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE)
481
- % sparsity_block_size_from) // TRITON_BLOCK_SIZE
482
-
483
- # Load x values
484
- blk_x_idx = ((rev_idx_spa_x * x_b_s) +
485
- ((shift_row_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
486
- ((shift_col_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
487
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
488
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
489
-
490
- # Store output
491
- blk_o_idx = ((pid_blk * o_b_s) +
492
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
493
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
494
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
495
- tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
414
+
415
+ def adapt_layout_wrapper_backward(ctx, grad_output):
416
+ x, sparsity_layout_from, sparsity_layout_to = ctx.saved_tensors
417
+ sparsity_block_size_from = ctx.sparsity_block_size_from
418
+ sparsity_block_size_to = ctx.sparsity_block_size_to
419
+
420
+ return adapt_layout(
421
+ grad_output, sparsity_layout_to, sparsity_block_size_to, sparsity_block_size_from,
422
+ sparsity_layout_to=sparsity_layout_from)[0], None, None, None, None, None, None, None
423
+
424
+
425
+ @triton.autotune(
426
+ configs=get_autotune_configs(),
427
+ key=["sparsity_block_size_from", "sparsity_block_size_to"],
428
+ prune_configs_by={"early_config_prune": prune_autotune_configs_conversion},
429
+ reset_to_zero=["o"]
430
+ )
431
+ @triton.jit
432
+ def adapt_layout_kernel(x,
433
+ x_b, x_b_s, x_r_s, x_c_s,
434
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
435
+ r_lut_x,
436
+ o,
437
+ o_b, o_b_s, o_r_s, o_c_s,
438
+ s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
439
+ sparsity_block_size_from,
440
+ sparsity_block_size_to,
441
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
442
+ # Get triton block indices
443
+ pid_blk = tl.program_id(axis=0)
444
+ pid_row = tl.program_id(axis=1)
445
+ pid_col = tl.program_id(axis=2)
446
+
447
+ # Get position of current sparsity block consisting of its batch, row, and column index
448
+ spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
449
+ spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
450
+ spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
451
+
452
+ spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
453
+ spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
454
+ spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
455
+
456
+ spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
457
+ spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
458
+ spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
459
+
460
+ # Get equivalent sparsity block in from layout
461
+ spa_bat_x = spa_bat_o
462
+ spa_row_x = (spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size_from
463
+ spa_col_x = (spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size_from
464
+
465
+ # Get reverse sparsity indices for x
466
+ rev_idx_spa_x_idx = (spa_bat_x * s_l_x_b_s +
467
+ spa_row_x * s_l_x_r_s +
468
+ spa_col_x * s_l_x_c_s)
469
+ rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
470
+ rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
471
+
472
+ # If block is present commence operations
473
+ if rev_idx_spa_x >= 0:
474
+ # Calculate triton block size shifts
475
+ shift_row_x = ((spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE)
476
+ % sparsity_block_size_from) // TRITON_BLOCK_SIZE
477
+ shift_col_x = ((spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE)
478
+ % sparsity_block_size_from) // TRITON_BLOCK_SIZE
479
+
480
+ # Load x values
481
+ blk_x_idx = ((rev_idx_spa_x * x_b_s) +
482
+ ((shift_row_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
483
+ ((shift_col_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
484
+ blk_x_msk = (blk_x_idx >= 0 and
485
+ blk_x_idx < x_b * x_b_s)
486
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
487
+
488
+ # Store output
489
+ blk_o_idx = ((pid_blk * o_b_s) +
490
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
491
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
492
+ blk_o_msk = (blk_o_idx >= 0 and
493
+ blk_o_idx < o_b * o_b_s)
494
+ tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
495
+
496
+
497
+ # noinspection PyUnusedLocal
498
+ def adapt_layout_setup_context(ctx, inputs, output):
499
+ (x, sparsity_layout_from, _, sparsity_block_size_from, sparsity_layout_to, _, sparsity_block_size_to, _) = inputs
500
+
501
+ ctx.save_for_backward(x, sparsity_layout_from, sparsity_layout_to)
502
+ ctx.sparsity_block_size_from = sparsity_block_size_from
503
+ ctx.sparsity_block_size_to = sparsity_block_size_to
504
+
505
+
506
+ adapt_layout_forward.register_autograd(adapt_layout_wrapper_backward, setup_context=adapt_layout_setup_context)