blksprs 1.10.2__py3-none-any.whl → 2.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/ops/conversion.py CHANGED
@@ -1,294 +1,333 @@
1
1
  import torch
2
2
  import triton
3
3
  from torch import Tensor
4
+ from torch._library.triton import wrap_triton, triton_op
4
5
  from triton import language as tl
5
6
 
6
7
  from blksprs.layouting.sparsity_layout import build_sparsity_layout_adaption
7
8
  from blksprs.utils.blksprs_tensor import BlksprsTensor
8
- from blksprs.utils.tools import get_triton_block_size, stride
9
+ from blksprs.utils.tools import stride, get_autotune_configs
9
10
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
10
- validate_sparsity, validate_sparsity_block_size, validate_triton_block_size, validate_sparsity_dense
11
+ validate_sparsity, validate_sparsity_block_size, validate_sparsity_dense
11
12
 
12
13
 
13
- def from_blksprs(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
14
- triton_block_size: int = None) -> Tensor:
15
- """Wrapper for ``to_dense``.
14
+ def to_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int) -> BlksprsTensor:
15
+ """Wrapper for ``to_sparse``.
16
16
 
17
17
  """
18
- return to_dense(x, sparsity_layout, sparsity_block_size, fill_value, triton_block_size)
18
+ return to_sparse(x, sparsity_layout, sparsity_block_size)
19
19
 
20
20
 
21
- def to_dense(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
22
- triton_block_size: int = None) -> Tensor:
23
- """Converts a block-sparse tensor in compressed form to a block-sparse tensor in regular form based on the given
24
- sparsity layout.
21
+ def to_sparse(x: Tensor, sparsity_layout: Tensor,
22
+ sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
23
+ """Converts a block-sparse tensor in regular form to a block-sparse tensor in compressed form based on the given
24
+ sparsity layout.
25
25
 
26
- Args:
27
- x (BlksprsTensor): A block-sparse tensor in compressed form.
26
+ Args:
27
+ x (Tensor): A block-sparse tensor in regular form.
28
28
  sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
29
29
  sparsity_block_size (int): The size of the sparsity blocks.
30
- fill_value (float): The value to fill the resulting dense tensor with where the block-sparse tensor is not
31
- present (default ``0``).
32
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
30
+ lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
33
31
 
34
32
  Returns:
35
- Tensor: The block-sparse tensor converted to regular form.
33
+ BlksprsTensor: The block-sparse tensor converted to compressed form.
36
34
 
37
35
  """
38
36
  x = x.contiguous()
39
37
 
40
38
  validate_dimensions(x)
41
- validate_contiguous(x, sparsity_layout)
39
+ validate_contiguous(x)
42
40
  validate_device(x)
43
- validate_sparsity(sparsity_block_size, (x, sparsity_layout))
41
+ validate_sparsity_dense(sparsity_block_size, (x, sparsity_layout))
44
42
  validate_sparsity_block_size(sparsity_block_size, x)
45
- validate_triton_block_size(triton_block_size, sparsity_block_size)
46
-
47
- sparsity_layout_flat = sparsity_layout.reshape(-1)
48
- sparsity_reverse_lut = ((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
49
- (sparsity_layout_flat == 1) -
50
- (1 * (sparsity_layout_flat == 0)))
51
43
 
52
- validate_contiguous(sparsity_reverse_lut)
44
+ lut = to_sparse_build_lut(lut, sparsity_layout)
53
45
 
54
46
  if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
55
- return x
47
+ return BlksprsTensor(x)
56
48
 
57
- return _BlocksparseToDense.apply(x,
58
- sparsity_layout, sparsity_reverse_lut,
59
- sparsity_block_size, fill_value,
60
- triton_block_size)
61
-
62
-
63
- class _BlocksparseToDense(torch.autograd.Function):
64
-
65
- @staticmethod
66
- def forward(ctx, x: Tensor,
67
- sparsity_layout: Tensor, sparsity_reverse_lut: Tensor,
68
- sparsity_block_size: int, fill_value: float,
69
- triton_block_size: int) -> Tensor:
70
- output = torch.full(size=(sparsity_layout.size(0), sparsity_layout.size(1) * sparsity_block_size,
71
- sparsity_layout.size(2) * sparsity_block_size), fill_value=fill_value,
72
- dtype=x.dtype, device=x.device)
73
-
74
- x_b, x_r, x_c = x.shape
75
- x_b_s, x_r_s, x_c_s = stride(x)
76
- s_l_b, s_l_r, s_l_c = sparsity_layout.size()
77
- s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout)
78
- o_b, o_r, o_c = output.size()
79
- o_b_s, o_r_s, o_c_s = stride(output)
80
-
81
- if triton_block_size is None:
82
- triton_block_size = get_triton_block_size(sparsity_block_size)
83
-
84
- triton_grid = lambda meta: [o_b,
85
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
86
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
87
-
88
- (_BlocksparseToDense.kernel_blocksparse_to_dense[triton_grid]
89
- (x,
90
- x_b, x_b_s, x_r_s, x_c_s,
91
- s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
92
- sparsity_reverse_lut,
93
- output,
94
- o_b, o_b_s, o_r_s, o_c_s,
95
- sparsity_block_size,
96
- triton_block_size))
97
-
98
- ctx.save_for_backward(sparsity_layout)
99
- ctx.sparsity_block_size = sparsity_block_size
100
- ctx.triton_block_size = triton_block_size
101
-
102
- return output
103
-
104
- @staticmethod
105
- def backward(ctx, grad_output):
106
- sparsity_layout = ctx.saved_tensors[0]
107
- sparsity_block_size = ctx.sparsity_block_size
108
- triton_block_size = ctx.triton_block_size
109
-
110
- return to_sparse(grad_output, sparsity_layout, sparsity_block_size,
111
- triton_block_size), None, None, None, None, None
112
-
113
- @staticmethod
114
- @triton.jit
115
- def kernel_blocksparse_to_dense(x,
116
- x_b, x_b_s, x_r_s, x_c_s,
117
- s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
118
- sparsity_reverse_lut,
119
- o,
120
- o_b, o_b_s, o_r_s, o_c_s,
121
- sparsity_block_size,
122
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
123
- # Get triton block indices
124
- pid_blk = tl.program_id(axis=0)
125
- pid_row = tl.program_id(axis=1)
126
- pid_col = tl.program_id(axis=2)
127
-
128
- # Get sparsity index of current block
129
- spa_row = (pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size
130
- spa_col = (pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size
131
-
132
- # Get reverse sparsity index for current block
133
- rev_idx_spa_idx = (pid_blk * s_l_b_s + spa_row * s_l_r_s + spa_col * s_l_c_s)
134
- rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_b * s_l_b_s)
135
- rev_idx_spa = tl.load(sparsity_reverse_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
136
-
137
- # If block is present commence operations
138
- if rev_idx_spa >= 0:
139
- blk_idx = (rev_idx_spa * x_b_s +
140
- (((pid_row % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
141
- tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
142
- (((pid_col % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
143
- tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
144
- blk_msk = (blk_idx >= 0 and blk_idx < x_b * x_b_s)
145
- blk = tl.load(x + blk_idx, mask=blk_msk)
146
-
147
- o_idx = (pid_blk * o_b_s +
148
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
149
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
150
- o_msk = (o_idx >= 0 and o_idx < o_b * o_b_s)
151
- tl.store(o + o_idx, blk, o_msk)
152
-
153
-
154
- def to_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
155
- triton_block_size: int = None) -> BlksprsTensor:
156
- """Wrapper for ``to_sparse``.
49
+ return BlksprsTensor(to_sparse_forward(x, sparsity_layout,
50
+ lut["sparsity_lut"], sparsity_block_size, lut["n_sparse_blocks"]))
51
+
52
+
53
+ @triton_op("blksprs::to_sparse", mutates_args={})
54
+ def to_sparse_forward(x: Tensor, _: Tensor,
55
+ sparsity_lut: Tensor, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
56
+ output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
57
+ dtype=x.dtype, device=x.device)
58
+
59
+ x_b, x_r, x_c = x.size()
60
+ x_b_s, x_r_s, x_c_s = stride(x)
61
+ s_lut_r, s_lut_c = sparsity_lut.size()
62
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
63
+ o_b, o_r, o_c = output.size()
64
+ o_b_s, o_r_s, o_c_s = stride(output)
65
+
66
+ triton_grid = lambda meta: [o_b,
67
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
68
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
69
+
70
+ (wrap_triton(to_sparse_kernel)[triton_grid]
71
+ (x, x_b, x_b_s, x_r_s, x_c_s,
72
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
73
+ output, o_b_s, o_r_s, o_c_s,
74
+ sparsity_block_size))
75
+
76
+ return output
77
+
78
+
79
+ def to_sparse_backward(ctx, grad_output):
80
+ sparsity_layout = ctx.saved_tensors[0]
81
+ sparsity_block_size = ctx.sparsity_block_size
82
+
83
+ return to_dense(grad_output, sparsity_layout, sparsity_block_size), None, None, None, None
84
+
85
+
86
+ @triton.autotune(
87
+ configs=get_autotune_configs(),
88
+ key=[],
89
+ )
90
+ @triton.jit
91
+ def to_sparse_kernel(x,
92
+ x_b, x_b_s, x_r_s, x_c_s,
93
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
94
+ o,
95
+ o_b_s, o_r_s, o_c_s,
96
+ sparsity_block_size,
97
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
98
+ # Get triton block indices
99
+ pid_blk = tl.program_id(axis=0)
100
+ pid_row = tl.program_id(axis=1)
101
+ pid_col = tl.program_id(axis=2)
102
+
103
+ # Get valid triton block size
104
+ val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
105
+
106
+ # Get sparsity index of current output block consisting of its batch, row, and column index
107
+ spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
108
+ spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
109
+ spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
110
+
111
+ spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
112
+ spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
113
+ spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
114
+
115
+ spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
116
+ spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
117
+ spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
118
+
119
+ # Load block from dense tensor
120
+ blk_d_idx = (spa_bat * x_b_s +
121
+ ((pid_row * val_tbs + spa_row * sparsity_block_size +
122
+ tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
123
+ ((pid_col * val_tbs + spa_col * sparsity_block_size +
124
+ tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
125
+ blk_d_msk = ((blk_d_idx >= 0 and
126
+ blk_d_idx < x_b * x_b_s) and
127
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
128
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
129
+ blk_d = tl.load(x + blk_d_idx, mask=blk_d_msk)
130
+
131
+ # Store block in sparse tensor
132
+ blk_o_idx = ((pid_blk * o_b_s) +
133
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
134
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
135
+ blk_o_msk = ((blk_o_idx >= 0 and
136
+ blk_o_idx < (pid_blk + 1) * o_b_s) and
137
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
138
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
139
+ tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
140
+
141
+
142
+ def to_sparse_build_lut(lut: dict, sparsity_layout: Tensor):
143
+ if lut is None:
144
+ lut = dict()
145
+
146
+ if "sparsity_lut" not in lut:
147
+ sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
148
+ lut["sparsity_lut"] = sparsity_lut
149
+
150
+ if "n_sparse_blocks" not in lut:
151
+ n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
152
+ lut["n_sparse_blocks"] = n_sparse_blocks
153
+
154
+ validate_contiguous(sparsity_layout, lut["sparsity_lut"])
155
+
156
+ return lut
157
+
158
+
159
+ # noinspection PyUnusedLocal
160
+ def to_sparse_setup_context(ctx, inputs, output):
161
+ (_, sparsity_layout, _, sparsity_block_size, _) = inputs
162
+
163
+ ctx.save_for_backward(sparsity_layout, )
164
+ ctx.sparsity_block_size = sparsity_block_size
165
+
166
+
167
+ to_sparse_forward.register_autograd(to_sparse_backward, setup_context=to_sparse_setup_context)
168
+
169
+
170
+ def from_blksprs(x: BlksprsTensor, sparsity_layout: Tensor,
171
+ sparsity_block_size: int, fill_value: float = 0, lut: dict = None) -> Tensor:
172
+ """Wrapper for ``to_dense``.
157
173
 
158
174
  """
159
- return to_sparse(x, sparsity_layout, sparsity_block_size, triton_block_size)
175
+ return to_dense(x, sparsity_layout, sparsity_block_size, fill_value=fill_value, lut=lut)
160
176
 
161
177
 
162
- def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
163
- triton_block_size: int = None) -> BlksprsTensor:
164
- """Converts a block-sparse tensor in regular form to a block-sparse tensor in compressed form based on the given
165
- sparsity layout.
178
+ def to_dense(x: BlksprsTensor, sparsity_layout: Tensor,
179
+ sparsity_block_size: int, fill_value: float = 0, lut: dict = None) -> Tensor:
180
+ """Converts a block-sparse tensor in compressed form to a block-sparse tensor in regular form based on the given
181
+ sparsity layout.
166
182
 
167
- Args:
168
- x (Tensor): A block-sparse tensor in regular form.
183
+ Args:
184
+ x (BlksprsTensor): A block-sparse tensor in compressed form.
169
185
  sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
170
186
  sparsity_block_size (int): The size of the sparsity blocks.
171
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
187
+ fill_value (float): The value to fill the resulting dense tensor with where the block-sparse tensor is not
188
+ present (default ``0``).
189
+ lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
172
190
 
173
191
  Returns:
174
- BlksprsTensor: The block-sparse tensor converted to compressed form.
192
+ Tensor: The block-sparse tensor converted to regular form.
175
193
 
176
194
  """
177
195
  x = x.contiguous()
178
196
 
179
197
  validate_dimensions(x)
180
- validate_contiguous(x)
198
+ validate_contiguous(x, sparsity_layout)
181
199
  validate_device(x)
182
- validate_sparsity_dense(sparsity_block_size, (x, sparsity_layout))
200
+ validate_sparsity(sparsity_block_size, (x, sparsity_layout))
183
201
  validate_sparsity_block_size(sparsity_block_size, x)
184
- validate_triton_block_size(triton_block_size, sparsity_block_size)
185
-
186
- sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
187
- n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
188
202
 
189
- validate_contiguous(sparsity_layout, sparsity_lut)
203
+ lut = to_dense_build_lut(lut, sparsity_layout)
190
204
 
191
205
  if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
192
- return BlksprsTensor(x)
206
+ return x
193
207
 
194
- return BlksprsTensor(_BlocksparseToSparse.apply(x,
195
- sparsity_layout, sparsity_lut,
196
- sparsity_block_size, n_sparse_blocks,
197
- triton_block_size))
198
-
199
-
200
- class _BlocksparseToSparse(torch.autograd.Function):
201
-
202
- @staticmethod
203
- def forward(ctx, x: Tensor,
204
- sparsity_layout: Tensor, sparsity_lut: Tensor,
205
- sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
206
- output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
207
- dtype=x.dtype, device=x.device)
208
-
209
- x_b, x_r, x_c = x.size()
210
- x_b_s, x_r_s, x_c_s = stride(x)
211
- s_lut_r, s_lut_c = sparsity_lut.size()
212
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
213
- o_b, o_r, o_c = output.size()
214
- o_b_s, o_r_s, o_c_s = stride(output)
215
-
216
- if triton_block_size is None:
217
- triton_block_size = get_triton_block_size(sparsity_block_size)
218
-
219
- triton_grid = lambda meta: [o_b,
220
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
221
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
222
-
223
- (_BlocksparseToSparse.kernel_blocksparse_to_sparse[triton_grid]
224
- (x, x_b, x_b_s, x_r_s, x_c_s,
225
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
226
- output, o_b_s, o_r_s, o_c_s,
227
- sparsity_block_size,
228
- triton_block_size))
229
-
230
- ctx.save_for_backward(sparsity_layout)
231
- ctx.sparsity_block_size = sparsity_block_size
232
- ctx.triton_block_size = triton_block_size
233
-
234
- return output
235
-
236
- @staticmethod
237
- def backward(ctx, grad_output):
238
- sparsity_layout = ctx.saved_tensors[0]
239
- sparsity_block_size = ctx.sparsity_block_size
240
- triton_block_size = ctx.triton_block_size
241
-
242
- return to_dense(grad_output, sparsity_layout, sparsity_block_size,
243
- triton_block_size=triton_block_size), None, None, None, None, None
244
-
245
- @staticmethod
246
- @triton.jit
247
- def kernel_blocksparse_to_sparse(x,
248
- x_b, x_b_s, x_r_s, x_c_s,
249
- s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
250
- o,
251
- o_b_s, o_r_s, o_c_s,
252
- sparsity_block_size,
253
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
254
- # Get triton block indices
255
- pid_blk = tl.program_id(axis=0)
256
- pid_row = tl.program_id(axis=1)
257
- pid_col = tl.program_id(axis=2)
258
-
259
- # Get sparsity index of current output block consisting of its batch, row, and column index
260
- spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
261
- spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
262
- spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
263
-
264
- spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
265
- spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
266
- spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
267
-
268
- spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
269
- spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
270
- spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
271
-
272
- # Load block from dense tensor
273
- blk_d_idx = (spa_bat * x_b_s +
274
- ((spa_row * sparsity_block_size + pid_row * TRITON_BLOCK_SIZE +
275
- tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
276
- ((spa_col * sparsity_block_size + pid_col * TRITON_BLOCK_SIZE +
277
- tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
278
- blk_d_msk = (blk_d_idx >= 0 and blk_d_idx < x_b * x_b_s)
279
- blk_d = tl.load(x + blk_d_idx, mask=blk_d_msk)
280
-
281
- # Store block in sparse tensor
282
- blk_o_idx = ((pid_blk * o_b_s) +
283
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
284
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
285
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < (pid_blk + 1) * o_b_s)
286
- tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
208
+ return to_dense_forward(x, sparsity_layout,
209
+ lut["sparsity_reverse_lut"], sparsity_block_size, fill_value)
210
+
211
+
212
+ @triton_op("blksprs::to_dense", mutates_args={})
213
+ def to_dense_forward(x: Tensor, sparsity_layout: Tensor,
214
+ sparsity_reverse_lut: Tensor,
215
+ sparsity_block_size: int, fill_value: float) -> Tensor:
216
+ output = torch.full(size=(sparsity_layout.size(0), sparsity_layout.size(1) * sparsity_block_size,
217
+ sparsity_layout.size(2) * sparsity_block_size), fill_value=fill_value,
218
+ dtype=x.dtype, device=x.device)
219
+
220
+ x_b, x_r, x_c = x.shape
221
+ x_b_s, x_r_s, x_c_s = stride(x)
222
+ s_l_b, s_l_r, s_l_c = sparsity_layout.size()
223
+ s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout)
224
+ o_b, o_r, o_c = output.size()
225
+ o_b_s, o_r_s, o_c_s = stride(output)
226
+
227
+ triton_grid = lambda meta: [o_b,
228
+ triton.cdiv(o_r, min(meta["sparsity_block_size"], meta["TRITON_BLOCK_SIZE"])),
229
+ triton.cdiv(o_c, min(meta["sparsity_block_size"], meta["TRITON_BLOCK_SIZE"]))]
230
+
231
+ (wrap_triton(to_dense_kernel)[triton_grid]
232
+ (x,
233
+ x_b, x_b_s, x_r_s, x_c_s,
234
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
235
+ sparsity_reverse_lut,
236
+ output,
237
+ o_b, o_b_s, o_r_s, o_c_s,
238
+ sparsity_block_size))
239
+
240
+ return output
241
+
242
+
243
+ def to_dense_backward(ctx, grad_output):
244
+ sparsity_layout = ctx.saved_tensors[0]
245
+ sparsity_block_size = ctx.sparsity_block_size
246
+
247
+ return to_sparse(grad_output, sparsity_layout, sparsity_block_size), None, None, None, None
248
+
249
+
250
+ @triton.autotune(
251
+ configs=get_autotune_configs(),
252
+ key=[],
253
+ )
254
+ @triton.jit
255
+ def to_dense_kernel(x,
256
+ x_b, x_b_s, x_r_s, x_c_s,
257
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
258
+ sparsity_reverse_lut,
259
+ o,
260
+ o_b, o_b_s, o_r_s, o_c_s,
261
+ sparsity_block_size,
262
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
263
+ # Get triton block indices
264
+ pid_blk = tl.program_id(axis=0)
265
+ pid_row = tl.program_id(axis=1)
266
+ pid_col = tl.program_id(axis=2)
267
+
268
+ # Get valid triton block size
269
+ val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
270
+
271
+ # Get sparsity index of current block
272
+ spa_row = (pid_row * val_tbs) // sparsity_block_size
273
+ spa_col = (pid_col * val_tbs) // sparsity_block_size
274
+
275
+ # Get reverse sparsity index for current block
276
+ rev_idx_spa_idx = (pid_blk * s_l_b_s + spa_row * s_l_r_s + spa_col * s_l_c_s)
277
+ rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_b * s_l_b_s)
278
+ rev_idx_spa = tl.load(sparsity_reverse_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
279
+
280
+ # If block is present commence operations
281
+ if rev_idx_spa >= 0:
282
+ blk_idx = (rev_idx_spa * x_b_s +
283
+ (((pid_row % (sparsity_block_size // val_tbs)) * val_tbs +
284
+ tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
285
+ (((pid_col % (sparsity_block_size // val_tbs)) * val_tbs +
286
+ tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
287
+ blk_msk = ((blk_idx >= 0 and
288
+ blk_idx < x_b * x_b_s) and
289
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
290
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
291
+ blk = tl.load(x + blk_idx, mask=blk_msk)
292
+
293
+ o_idx = (pid_blk * o_b_s +
294
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
295
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
296
+ o_msk = ((o_idx >= 0 and o_idx < o_b * o_b_s) and
297
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
298
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
299
+ tl.store(o + o_idx, blk, o_msk)
300
+
301
+
302
+ def to_dense_build_lut(lut: dict, sparsity_layout: Tensor):
303
+ if lut is None:
304
+ lut = dict()
305
+
306
+ if "sparsity_reverse_lut" not in lut:
307
+ sparsity_layout_flat = sparsity_layout.reshape(-1)
308
+ sparsity_reverse_lut = ((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
309
+ (sparsity_layout_flat == 1) -
310
+ (1 * (sparsity_layout_flat == 0)))
311
+ lut["sparsity_reverse_lut"] = sparsity_reverse_lut
312
+
313
+ validate_contiguous(lut["sparsity_reverse_lut"])
314
+
315
+ return lut
316
+
317
+
318
+ # noinspection PyUnusedLocal
319
+ def to_dense_setup_context(ctx, inputs, output):
320
+ (_, sparsity_layout, _, sparsity_block_size, _) = inputs
321
+
322
+ ctx.save_for_backward(sparsity_layout)
323
+ ctx.sparsity_block_size = sparsity_block_size
324
+
325
+
326
+ to_dense_forward.register_autograd(to_dense_backward, setup_context=to_dense_setup_context)
287
327
 
288
328
 
289
329
  def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_size_from: int,
290
- sparsity_block_size_to: int, sparsity_layout_to: Tensor = None,
291
- triton_block_size: int = None) -> (BlksprsTensor, Tensor):
330
+ sparsity_block_size_to: int, sparsity_layout_to: Tensor = None) -> (BlksprsTensor, Tensor):
292
331
  """Adapts the sparsity layout of a block-sparse tensor, resulting in a new block-sparse tensor in compressed form
293
332
  conforming to the new sparsity layout (and sparsity block size) definition.
294
333
 
@@ -298,7 +337,6 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
298
337
  sparsity_block_size_from (int): The size of the sparsity blocks of the input sparsity layout.
299
338
  sparsity_block_size_to (int): The size of the sparsity blocks of the output sparsity layout.
300
339
  sparsity_layout_to (Tensor): The sparsity layout of the output block-sparse tensor (default ``None``).
301
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
302
340
 
303
341
  Returns:
304
342
  BlksprsTensor: The block-sparse tensor in compressed form with the adapted sparsity layout and sparsity block size.
@@ -313,8 +351,6 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
313
351
  validate_sparsity(sparsity_block_size_from, (x, sparsity_layout_from))
314
352
  validate_sparsity_block_size(sparsity_block_size_from, x)
315
353
  validate_sparsity_block_size(sparsity_block_size_to)
316
- min_sparsity_block_size = min(sparsity_block_size_from, sparsity_block_size_to)
317
- validate_triton_block_size(triton_block_size, min_sparsity_block_size)
318
354
 
319
355
  sparsity_layout_from_flat = sparsity_layout_from.reshape(-1)
320
356
  sparsity_reverse_lut_from = ((torch.cumsum(sparsity_layout_from_flat, dim=-1) - 1) *
@@ -323,8 +359,7 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
323
359
 
324
360
  if sparsity_layout_to is None:
325
361
  sparsity_layout_to = build_sparsity_layout_adaption(x, sparsity_layout_from,
326
- sparsity_block_size_from, sparsity_block_size_to,
327
- triton_block_size)
362
+ sparsity_block_size_from, sparsity_block_size_to)
328
363
 
329
364
  sparsity_lut_to = torch.nonzero(sparsity_layout_to).contiguous()
330
365
 
@@ -335,134 +370,148 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
335
370
  if (sparsity_block_size_from == sparsity_block_size_to) and torch.equal(sparsity_layout_from, sparsity_layout_to):
336
371
  return BlksprsTensor(x), sparsity_layout_to
337
372
 
338
- return BlksprsTensor(_BlocksparseAdaptLayout.apply(x,
339
- sparsity_layout_from, sparsity_reverse_lut_from,
340
- sparsity_block_size_from,
341
- sparsity_layout_to, sparsity_lut_to,
342
- sparsity_block_size_to,
343
- n_sparse_blocks_to, min_sparsity_block_size,
344
- triton_block_size)), sparsity_layout_to
345
-
346
-
347
- class _BlocksparseAdaptLayout(torch.autograd.Function):
348
-
349
- @staticmethod
350
- def forward(ctx, x: Tensor,
351
- sparsity_layout_from: Tensor, sparsity_reverse_lut_from: Tensor,
352
- sparsity_block_size_from: int,
353
- sparsity_layout_to: Tensor, sparsity_lut_to: Tensor,
354
- sparsity_block_size_to: int,
355
- n_sparse_blocks_to: int, min_sparsity_block_size: int, triton_block_size: int) -> Tensor:
356
- output = torch.zeros(size=(n_sparse_blocks_to, sparsity_block_size_to, sparsity_block_size_to),
357
- dtype=x.dtype, device=x.device)
358
-
359
- x_b, x_r, x_c = x.size()
360
- x_b_s, x_r_s, x_c_s = stride(x)
361
- s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_from.size()
362
- s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_from)
363
- o_b, o_r, o_c = output.size()
364
- o_b_s, o_r_s, o_c_s = stride(output)
365
- s_lut_o_r, s_lut_o_c = sparsity_lut_to.size()
366
- s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_to)
367
-
368
- if triton_block_size is None:
369
- triton_block_size = get_triton_block_size(min_sparsity_block_size)
370
-
371
- triton_grid = lambda meta: [o_b,
372
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
373
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
374
-
375
- (_BlocksparseAdaptLayout.kernel_adapt_layout[triton_grid]
376
- (x,
377
- x_b, x_b_s, x_r_s, x_c_s,
378
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
379
- sparsity_reverse_lut_from,
380
- output,
381
- o_b, o_b_s, o_r_s, o_c_s,
382
- sparsity_lut_to, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
383
- sparsity_block_size_from,
384
- sparsity_block_size_to,
385
- triton_block_size))
386
-
387
- ctx.save_for_backward(x, sparsity_layout_from, sparsity_layout_to)
388
- ctx.sparsity_block_size_from = sparsity_block_size_from
389
- ctx.sparsity_block_size_to = sparsity_block_size_to
390
- ctx.triton_block_size = triton_block_size
391
-
392
- return output
393
-
394
- @staticmethod
395
- def backward(ctx, grad_output):
396
- x, sparsity_layout_from, sparsity_layout_to = ctx.saved_tensors
397
- sparsity_block_size_from = ctx.sparsity_block_size_from
398
- sparsity_block_size_to = ctx.sparsity_block_size_to
399
- triton_block_size = ctx.triton_block_size
400
-
401
- return adapt_layout(
402
- grad_output, sparsity_layout_to, sparsity_block_size_to, sparsity_block_size_from,
403
- sparsity_layout_to=sparsity_layout_from,
404
- triton_block_size=triton_block_size)[0], None, None, None, None, None, None, None, None, None
405
-
406
- @staticmethod
407
- @triton.jit
408
- def kernel_adapt_layout(x,
409
- x_b, x_b_s, x_r_s, x_c_s,
410
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
411
- r_lut_x,
412
- o,
413
- o_b, o_b_s, o_r_s, o_c_s,
414
- s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
415
- sparsity_block_size_from,
416
- sparsity_block_size_to,
417
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
418
- # Get triton block indices
419
- pid_blk = tl.program_id(axis=0)
420
- pid_row = tl.program_id(axis=1)
421
- pid_col = tl.program_id(axis=2)
422
-
423
- # Get position of current sparsity block consisting of its batch, row, and column index
424
- spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
425
- spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
426
- spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
427
-
428
- spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
429
- spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
430
- spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
431
-
432
- spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
433
- spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
434
- spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
435
-
436
- # Get equivalent sparsity block in from layout
437
- spa_bat_x = spa_bat_o
438
- spa_row_x = (spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size_from
439
- spa_col_x = (spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size_from
440
-
441
- # Get reverse sparsity indices for x
442
- rev_idx_spa_x_idx = (spa_bat_x * s_l_x_b_s +
443
- spa_row_x * s_l_x_r_s +
444
- spa_col_x * s_l_x_c_s)
445
- rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
446
- rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
447
-
448
- # If block is present commence operations
449
- if rev_idx_spa_x >= 0:
450
- # Calculate triton block size shifts
451
- shift_row_x = ((spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE)
452
- % sparsity_block_size_from) // TRITON_BLOCK_SIZE
453
- shift_col_x = ((spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE)
454
- % sparsity_block_size_from) // TRITON_BLOCK_SIZE
455
-
456
- # Load x values
457
- blk_x_idx = ((rev_idx_spa_x * x_b_s) +
458
- ((shift_row_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
459
- ((shift_col_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
460
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
461
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
462
-
463
- # Store output
464
- blk_o_idx = ((pid_blk * o_b_s) +
465
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
466
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
467
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
468
- tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
373
+ return BlksprsTensor(adapt_layout_forward(x,
374
+ sparsity_layout_from, sparsity_reverse_lut_from,
375
+ sparsity_block_size_from,
376
+ sparsity_layout_to, sparsity_lut_to,
377
+ sparsity_block_size_to,
378
+ n_sparse_blocks_to)), sparsity_layout_to
379
+
380
+
381
+ @triton_op("blksprs::adapt_layout", mutates_args={})
382
+ def adapt_layout_forward(x: Tensor,
383
+ sparsity_layout_from: Tensor, sparsity_reverse_lut_from: Tensor,
384
+ sparsity_block_size_from: int,
385
+ _: Tensor, sparsity_lut_to: Tensor,
386
+ sparsity_block_size_to: int,
387
+ n_sparse_blocks_to: int) -> Tensor:
388
+ output = torch.zeros(size=(n_sparse_blocks_to, sparsity_block_size_to, sparsity_block_size_to),
389
+ dtype=x.dtype, device=x.device)
390
+
391
+ x_b, x_r, x_c = x.size()
392
+ x_b_s, x_r_s, x_c_s = stride(x)
393
+ s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_from.size()
394
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_from)
395
+ o_b, o_r, o_c = output.size()
396
+ o_b_s, o_r_s, o_c_s = stride(output)
397
+ s_lut_o_r, s_lut_o_c = sparsity_lut_to.size()
398
+ s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_to)
399
+
400
+ triton_grid = lambda meta: [o_b,
401
+ triton.cdiv(o_r, min(meta["sparsity_block_size_from"], meta["sparsity_block_size_to"],
402
+ meta["TRITON_BLOCK_SIZE"])),
403
+ triton.cdiv(o_c, min(meta["sparsity_block_size_from"], meta["sparsity_block_size_to"],
404
+ meta["TRITON_BLOCK_SIZE"]))]
405
+
406
+ (wrap_triton(adapt_layout_kernel)[triton_grid]
407
+ (x,
408
+ x_b, x_b_s, x_r_s, x_c_s,
409
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
410
+ sparsity_reverse_lut_from,
411
+ output,
412
+ o_b, o_b_s, o_r_s, o_c_s,
413
+ sparsity_lut_to, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
414
+ sparsity_block_size_from,
415
+ sparsity_block_size_to))
416
+
417
+ return output
418
+
419
+
420
+ def adapt_layout_backward(ctx, grad_output):
421
+ x, sparsity_layout_from, sparsity_layout_to = ctx.saved_tensors
422
+ sparsity_block_size_from = ctx.sparsity_block_size_from
423
+ sparsity_block_size_to = ctx.sparsity_block_size_to
424
+
425
+ return adapt_layout(
426
+ grad_output, sparsity_layout_to, sparsity_block_size_to, sparsity_block_size_from,
427
+ sparsity_layout_to=sparsity_layout_from)[0], None, None, None, None, None, None, None
428
+
429
+
430
+ @triton.autotune(
431
+ configs=get_autotune_configs(),
432
+ key=[],
433
+ reset_to_zero=["o"]
434
+ )
435
+ @triton.jit
436
+ def adapt_layout_kernel(x,
437
+ x_b, x_b_s, x_r_s, x_c_s,
438
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
439
+ r_lut_x,
440
+ o,
441
+ o_b, o_b_s, o_r_s, o_c_s,
442
+ s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
443
+ sparsity_block_size_from,
444
+ sparsity_block_size_to,
445
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
446
+ # Get triton block indices
447
+ pid_blk = tl.program_id(axis=0)
448
+ pid_row = tl.program_id(axis=1)
449
+ pid_col = tl.program_id(axis=2)
450
+
451
+ # Get valid triton block size (Triton can only handle 2-valued min)
452
+ val_tbs = min(min(sparsity_block_size_from, sparsity_block_size_to), TRITON_BLOCK_SIZE)
453
+
454
+ # Get position of current sparsity block consisting of its batch, row, and column index
455
+ spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
456
+ spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
457
+ spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
458
+
459
+ spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
460
+ spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
461
+ spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
462
+
463
+ spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
464
+ spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
465
+ spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
466
+
467
+ # Get equivalent sparsity block in from layout
468
+ spa_bat_x = spa_bat_o
469
+ spa_row_x = (spa_row_o * sparsity_block_size_to + pid_row * val_tbs) // sparsity_block_size_from
470
+ spa_col_x = (spa_col_o * sparsity_block_size_to + pid_col * val_tbs) // sparsity_block_size_from
471
+
472
+ # Get reverse sparsity indices for x
473
+ rev_idx_spa_x_idx = (spa_bat_x * s_l_x_b_s +
474
+ spa_row_x * s_l_x_r_s +
475
+ spa_col_x * s_l_x_c_s)
476
+ rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
477
+ rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
478
+
479
+ # If block is present commence operations
480
+ if rev_idx_spa_x >= 0:
481
+ # Calculate triton block size shifts
482
+ shift_row_x = ((spa_row_o * sparsity_block_size_to + pid_row * val_tbs)
483
+ % sparsity_block_size_from) // val_tbs
484
+ shift_col_x = ((spa_col_o * sparsity_block_size_to + pid_col * val_tbs)
485
+ % sparsity_block_size_from) // val_tbs
486
+
487
+ # Load x values
488
+ blk_x_idx = ((rev_idx_spa_x * x_b_s) +
489
+ ((shift_row_x * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
490
+ ((shift_col_x * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
491
+ blk_x_msk = ((blk_x_idx >= 0 and
492
+ blk_x_idx < x_b * x_b_s) and
493
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
494
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
495
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
496
+
497
+ # Store output
498
+ blk_o_idx = ((pid_blk * o_b_s) +
499
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
500
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
501
+ blk_o_msk = ((blk_o_idx >= 0 and
502
+ blk_o_idx < o_b * o_b_s) and
503
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
504
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
505
+ tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
506
+
507
+
508
+ # noinspection PyUnusedLocal
509
+ def adapt_layout_setup_context(ctx, inputs, output):
510
+ (x, sparsity_layout_from, _, sparsity_block_size_from, sparsity_layout_to, _, sparsity_block_size_to, _) = inputs
511
+
512
+ ctx.save_for_backward(x, sparsity_layout_from, sparsity_layout_to)
513
+ ctx.sparsity_block_size_from = sparsity_block_size_from
514
+ ctx.sparsity_block_size_to = sparsity_block_size_to
515
+
516
+
517
+ adapt_layout_forward.register_autograd(adapt_layout_backward, setup_context=adapt_layout_setup_context)