blksprs 1.11__py3-none-any.whl → 2.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/ops/conversion.py CHANGED
@@ -1,321 +1,333 @@
1
1
  import torch
2
2
  import triton
3
3
  from torch import Tensor
4
+ from torch._library.triton import wrap_triton, triton_op
4
5
  from triton import language as tl
5
6
 
6
7
  from blksprs.layouting.sparsity_layout import build_sparsity_layout_adaption
7
8
  from blksprs.utils.blksprs_tensor import BlksprsTensor
8
- from blksprs.utils.tools import get_triton_block_size, stride
9
+ from blksprs.utils.tools import stride, get_autotune_configs
9
10
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
10
- validate_sparsity, validate_sparsity_block_size, validate_triton_block_size, validate_sparsity_dense
11
+ validate_sparsity, validate_sparsity_block_size, validate_sparsity_dense
11
12
 
12
13
 
13
- def from_blksprs(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
14
- triton_block_size: int = None) -> Tensor:
15
- """Wrapper for ``to_dense``.
14
+ def to_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int) -> BlksprsTensor:
15
+ """Wrapper for ``to_sparse``.
16
16
 
17
17
  """
18
- return to_dense(x, sparsity_layout, sparsity_block_size, fill_value, triton_block_size)
18
+ return to_sparse(x, sparsity_layout, sparsity_block_size)
19
19
 
20
20
 
21
- def to_dense(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
22
- triton_block_size: int = None, lut: dict = None) -> Tensor:
23
- """Converts a block-sparse tensor in compressed form to a block-sparse tensor in regular form based on the given
24
- sparsity layout.
21
+ def to_sparse(x: Tensor, sparsity_layout: Tensor,
22
+ sparsity_block_size: int, lut: dict = None) -> BlksprsTensor:
23
+ """Converts a block-sparse tensor in regular form to a block-sparse tensor in compressed form based on the given
24
+ sparsity layout.
25
25
 
26
- Args:
27
- x (BlksprsTensor): A block-sparse tensor in compressed form.
26
+ Args:
27
+ x (Tensor): A block-sparse tensor in regular form.
28
28
  sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
29
29
  sparsity_block_size (int): The size of the sparsity blocks.
30
- fill_value (float): The value to fill the resulting dense tensor with where the block-sparse tensor is not
31
- present (default ``0``).
32
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
33
30
  lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
34
31
 
35
32
  Returns:
36
- Tensor: The block-sparse tensor converted to regular form.
33
+ BlksprsTensor: The block-sparse tensor converted to compressed form.
37
34
 
38
35
  """
39
36
  x = x.contiguous()
40
37
 
41
38
  validate_dimensions(x)
42
- validate_contiguous(x, sparsity_layout)
39
+ validate_contiguous(x)
43
40
  validate_device(x)
44
- validate_sparsity(sparsity_block_size, (x, sparsity_layout))
41
+ validate_sparsity_dense(sparsity_block_size, (x, sparsity_layout))
45
42
  validate_sparsity_block_size(sparsity_block_size, x)
46
- validate_triton_block_size(triton_block_size, sparsity_block_size)
47
43
 
48
- lut = _BlocksparseToDense.build_lut(lut, sparsity_layout)
44
+ lut = to_sparse_build_lut(lut, sparsity_layout)
49
45
 
50
46
  if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
51
- return x
47
+ return BlksprsTensor(x)
52
48
 
53
- return _BlocksparseToDense.apply(x,
54
- sparsity_layout, lut["sparsity_reverse_lut"],
55
- sparsity_block_size, fill_value,
56
- triton_block_size)
57
-
58
-
59
- class _BlocksparseToDense(torch.autograd.Function):
60
-
61
- @staticmethod
62
- def build_lut(lut: dict, sparsity_layout: Tensor):
63
- if lut is None:
64
- lut = dict()
65
-
66
- if "sparsity_reverse_lut" not in lut:
67
- sparsity_layout_flat = sparsity_layout.reshape(-1)
68
- sparsity_reverse_lut = ((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
69
- (sparsity_layout_flat == 1) -
70
- (1 * (sparsity_layout_flat == 0)))
71
- lut["sparsity_reverse_lut"] = sparsity_reverse_lut
72
-
73
- validate_contiguous(lut["sparsity_reverse_lut"])
74
-
75
- return lut
76
-
77
- @staticmethod
78
- def forward(ctx, x: Tensor,
79
- sparsity_layout: Tensor, sparsity_reverse_lut: Tensor,
80
- sparsity_block_size: int, fill_value: float,
81
- triton_block_size: int) -> Tensor:
82
- output = torch.full(size=(sparsity_layout.size(0), sparsity_layout.size(1) * sparsity_block_size,
83
- sparsity_layout.size(2) * sparsity_block_size), fill_value=fill_value,
84
- dtype=x.dtype, device=x.device)
85
-
86
- x_b, x_r, x_c = x.shape
87
- x_b_s, x_r_s, x_c_s = stride(x)
88
- s_l_b, s_l_r, s_l_c = sparsity_layout.size()
89
- s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout)
90
- o_b, o_r, o_c = output.size()
91
- o_b_s, o_r_s, o_c_s = stride(output)
92
-
93
- if triton_block_size is None:
94
- triton_block_size = get_triton_block_size(sparsity_block_size)
95
-
96
- triton_grid = lambda meta: [o_b,
97
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
98
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
99
-
100
- (_BlocksparseToDense.kernel_blocksparse_to_dense[triton_grid]
101
- (x,
102
- x_b, x_b_s, x_r_s, x_c_s,
103
- s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
104
- sparsity_reverse_lut,
105
- output,
106
- o_b, o_b_s, o_r_s, o_c_s,
107
- sparsity_block_size,
108
- triton_block_size))
109
-
110
- ctx.save_for_backward(sparsity_layout)
111
- ctx.sparsity_block_size = sparsity_block_size
112
- ctx.triton_block_size = triton_block_size
113
-
114
- return output
115
-
116
- @staticmethod
117
- def backward(ctx, grad_output):
118
- sparsity_layout = ctx.saved_tensors[0]
119
- sparsity_block_size = ctx.sparsity_block_size
120
- triton_block_size = ctx.triton_block_size
121
-
122
- return to_sparse(grad_output, sparsity_layout, sparsity_block_size,
123
- triton_block_size), None, None, None, None, None
124
-
125
- @staticmethod
126
- @triton.jit
127
- def kernel_blocksparse_to_dense(x,
128
- x_b, x_b_s, x_r_s, x_c_s,
129
- s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
130
- sparsity_reverse_lut,
131
- o,
132
- o_b, o_b_s, o_r_s, o_c_s,
133
- sparsity_block_size,
134
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
135
- # Get triton block indices
136
- pid_blk = tl.program_id(axis=0)
137
- pid_row = tl.program_id(axis=1)
138
- pid_col = tl.program_id(axis=2)
139
-
140
- # Get sparsity index of current block
141
- spa_row = (pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size
142
- spa_col = (pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size
143
-
144
- # Get reverse sparsity index for current block
145
- rev_idx_spa_idx = (pid_blk * s_l_b_s + spa_row * s_l_r_s + spa_col * s_l_c_s)
146
- rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_b * s_l_b_s)
147
- rev_idx_spa = tl.load(sparsity_reverse_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
148
-
149
- # If block is present commence operations
150
- if rev_idx_spa >= 0:
151
- blk_idx = (rev_idx_spa * x_b_s +
152
- (((pid_row % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
153
- tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
154
- (((pid_col % (sparsity_block_size // TRITON_BLOCK_SIZE)) * TRITON_BLOCK_SIZE +
155
- tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
156
- blk_msk = (blk_idx >= 0 and blk_idx < x_b * x_b_s)
157
- blk = tl.load(x + blk_idx, mask=blk_msk)
158
-
159
- o_idx = (pid_blk * o_b_s +
160
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
161
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
162
- o_msk = (o_idx >= 0 and o_idx < o_b * o_b_s)
163
- tl.store(o + o_idx, blk, o_msk)
164
-
165
-
166
- def to_blksprs(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
167
- triton_block_size: int = None) -> BlksprsTensor:
168
- """Wrapper for ``to_sparse``.
49
+ return BlksprsTensor(to_sparse_forward(x, sparsity_layout,
50
+ lut["sparsity_lut"], sparsity_block_size, lut["n_sparse_blocks"]))
51
+
52
+
53
+ @triton_op("blksprs::to_sparse", mutates_args={})
54
+ def to_sparse_forward(x: Tensor, _: Tensor,
55
+ sparsity_lut: Tensor, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
56
+ output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
57
+ dtype=x.dtype, device=x.device)
58
+
59
+ x_b, x_r, x_c = x.size()
60
+ x_b_s, x_r_s, x_c_s = stride(x)
61
+ s_lut_r, s_lut_c = sparsity_lut.size()
62
+ s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
63
+ o_b, o_r, o_c = output.size()
64
+ o_b_s, o_r_s, o_c_s = stride(output)
65
+
66
+ triton_grid = lambda meta: [o_b,
67
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
68
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
69
+
70
+ (wrap_triton(to_sparse_kernel)[triton_grid]
71
+ (x, x_b, x_b_s, x_r_s, x_c_s,
72
+ sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
73
+ output, o_b_s, o_r_s, o_c_s,
74
+ sparsity_block_size))
75
+
76
+ return output
77
+
78
+
79
+ def to_sparse_backward(ctx, grad_output):
80
+ sparsity_layout = ctx.saved_tensors[0]
81
+ sparsity_block_size = ctx.sparsity_block_size
82
+
83
+ return to_dense(grad_output, sparsity_layout, sparsity_block_size), None, None, None, None
84
+
85
+
86
+ @triton.autotune(
87
+ configs=get_autotune_configs(),
88
+ key=[],
89
+ )
90
+ @triton.jit
91
+ def to_sparse_kernel(x,
92
+ x_b, x_b_s, x_r_s, x_c_s,
93
+ s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
94
+ o,
95
+ o_b_s, o_r_s, o_c_s,
96
+ sparsity_block_size,
97
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
98
+ # Get triton block indices
99
+ pid_blk = tl.program_id(axis=0)
100
+ pid_row = tl.program_id(axis=1)
101
+ pid_col = tl.program_id(axis=2)
102
+
103
+ # Get valid triton block size
104
+ val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
105
+
106
+ # Get sparsity index of current output block consisting of its batch, row, and column index
107
+ spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
108
+ spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
109
+ spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
110
+
111
+ spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
112
+ spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
113
+ spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
114
+
115
+ spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
116
+ spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
117
+ spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
118
+
119
+ # Load block from dense tensor
120
+ blk_d_idx = (spa_bat * x_b_s +
121
+ ((pid_row * val_tbs + spa_row * sparsity_block_size +
122
+ tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
123
+ ((pid_col * val_tbs + spa_col * sparsity_block_size +
124
+ tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
125
+ blk_d_msk = ((blk_d_idx >= 0 and
126
+ blk_d_idx < x_b * x_b_s) and
127
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
128
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
129
+ blk_d = tl.load(x + blk_d_idx, mask=blk_d_msk)
130
+
131
+ # Store block in sparse tensor
132
+ blk_o_idx = ((pid_blk * o_b_s) +
133
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
134
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
135
+ blk_o_msk = ((blk_o_idx >= 0 and
136
+ blk_o_idx < (pid_blk + 1) * o_b_s) and
137
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
138
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
139
+ tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
140
+
141
+
142
+ def to_sparse_build_lut(lut: dict, sparsity_layout: Tensor):
143
+ if lut is None:
144
+ lut = dict()
145
+
146
+ if "sparsity_lut" not in lut:
147
+ sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
148
+ lut["sparsity_lut"] = sparsity_lut
149
+
150
+ if "n_sparse_blocks" not in lut:
151
+ n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
152
+ lut["n_sparse_blocks"] = n_sparse_blocks
153
+
154
+ validate_contiguous(sparsity_layout, lut["sparsity_lut"])
155
+
156
+ return lut
157
+
158
+
159
+ # noinspection PyUnusedLocal
160
+ def to_sparse_setup_context(ctx, inputs, output):
161
+ (_, sparsity_layout, _, sparsity_block_size, _) = inputs
162
+
163
+ ctx.save_for_backward(sparsity_layout, )
164
+ ctx.sparsity_block_size = sparsity_block_size
165
+
166
+
167
+ to_sparse_forward.register_autograd(to_sparse_backward, setup_context=to_sparse_setup_context)
168
+
169
+
170
+ def from_blksprs(x: BlksprsTensor, sparsity_layout: Tensor,
171
+ sparsity_block_size: int, fill_value: float = 0, lut: dict = None) -> Tensor:
172
+ """Wrapper for ``to_dense``.
169
173
 
170
174
  """
171
- return to_sparse(x, sparsity_layout, sparsity_block_size, triton_block_size)
175
+ return to_dense(x, sparsity_layout, sparsity_block_size, fill_value=fill_value, lut=lut)
172
176
 
173
177
 
174
- def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int,
175
- triton_block_size: int = None, lut: dict = None) -> BlksprsTensor:
176
- """Converts a block-sparse tensor in regular form to a block-sparse tensor in compressed form based on the given
177
- sparsity layout.
178
+ def to_dense(x: BlksprsTensor, sparsity_layout: Tensor,
179
+ sparsity_block_size: int, fill_value: float = 0, lut: dict = None) -> Tensor:
180
+ """Converts a block-sparse tensor in compressed form to a block-sparse tensor in regular form based on the given
181
+ sparsity layout.
178
182
 
179
- Args:
180
- x (Tensor): A block-sparse tensor in regular form.
183
+ Args:
184
+ x (BlksprsTensor): A block-sparse tensor in compressed form.
181
185
  sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
182
186
  sparsity_block_size (int): The size of the sparsity blocks.
183
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
187
+ fill_value (float): The value to fill the resulting dense tensor with where the block-sparse tensor is not
188
+ present (default ``0``).
184
189
  lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
185
190
 
186
191
  Returns:
187
- BlksprsTensor: The block-sparse tensor converted to compressed form.
192
+ Tensor: The block-sparse tensor converted to regular form.
188
193
 
189
194
  """
190
195
  x = x.contiguous()
191
196
 
192
197
  validate_dimensions(x)
193
- validate_contiguous(x)
198
+ validate_contiguous(x, sparsity_layout)
194
199
  validate_device(x)
195
- validate_sparsity_dense(sparsity_block_size, (x, sparsity_layout))
200
+ validate_sparsity(sparsity_block_size, (x, sparsity_layout))
196
201
  validate_sparsity_block_size(sparsity_block_size, x)
197
- validate_triton_block_size(triton_block_size, sparsity_block_size)
198
202
 
199
- lut = _BlocksparseToSparse.build_lut(lut, sparsity_layout)
203
+ lut = to_dense_build_lut(lut, sparsity_layout)
200
204
 
201
205
  if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
202
- return BlksprsTensor(x)
206
+ return x
203
207
 
204
- return BlksprsTensor(_BlocksparseToSparse.apply(x,
205
- sparsity_layout, lut["sparsity_lut"],
206
- sparsity_block_size, lut["n_sparse_blocks"],
207
- triton_block_size))
208
-
209
-
210
- class _BlocksparseToSparse(torch.autograd.Function):
211
-
212
- @staticmethod
213
- def build_lut(lut: dict, sparsity_layout: Tensor):
214
- if lut is None:
215
- lut = dict()
216
-
217
- if "sparsity_lut" not in lut:
218
- sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
219
- lut["sparsity_lut"] = sparsity_lut
220
-
221
- if "n_sparse_blocks" not in lut:
222
- n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
223
- lut["n_sparse_blocks"] = n_sparse_blocks
224
-
225
- validate_contiguous(sparsity_layout, lut["sparsity_lut"])
226
-
227
- return lut
228
-
229
- @staticmethod
230
- def forward(ctx, x: Tensor,
231
- sparsity_layout: Tensor, sparsity_lut: Tensor,
232
- sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
233
- output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
234
- dtype=x.dtype, device=x.device)
235
-
236
- x_b, x_r, x_c = x.size()
237
- x_b_s, x_r_s, x_c_s = stride(x)
238
- s_lut_r, s_lut_c = sparsity_lut.size()
239
- s_lut_r_s, s_lut_c_s = stride(sparsity_lut)
240
- o_b, o_r, o_c = output.size()
241
- o_b_s, o_r_s, o_c_s = stride(output)
242
-
243
- if triton_block_size is None:
244
- triton_block_size = get_triton_block_size(sparsity_block_size)
245
-
246
- triton_grid = lambda meta: [o_b,
247
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
248
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
249
-
250
- (_BlocksparseToSparse.kernel_blocksparse_to_sparse[triton_grid]
251
- (x, x_b, x_b_s, x_r_s, x_c_s,
252
- sparsity_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
253
- output, o_b_s, o_r_s, o_c_s,
254
- sparsity_block_size,
255
- triton_block_size))
256
-
257
- ctx.save_for_backward(sparsity_layout)
258
- ctx.sparsity_block_size = sparsity_block_size
259
- ctx.triton_block_size = triton_block_size
260
-
261
- return output
262
-
263
- @staticmethod
264
- def backward(ctx, grad_output):
265
- sparsity_layout = ctx.saved_tensors[0]
266
- sparsity_block_size = ctx.sparsity_block_size
267
- triton_block_size = ctx.triton_block_size
268
-
269
- return to_dense(grad_output, sparsity_layout, sparsity_block_size,
270
- triton_block_size=triton_block_size), None, None, None, None, None
271
-
272
- @staticmethod
273
- @triton.jit
274
- def kernel_blocksparse_to_sparse(x,
275
- x_b, x_b_s, x_r_s, x_c_s,
276
- s_lut, s_lut_r, s_lut_r_s, s_lut_c_s,
277
- o,
278
- o_b_s, o_r_s, o_c_s,
279
- sparsity_block_size,
280
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
281
- # Get triton block indices
282
- pid_blk = tl.program_id(axis=0)
283
- pid_row = tl.program_id(axis=1)
284
- pid_col = tl.program_id(axis=2)
285
-
286
- # Get sparsity index of current output block consisting of its batch, row, and column index
287
- spa_bat_idx = (pid_blk * s_lut_r_s + 0 * s_lut_c_s)
288
- spa_bat_msk = (spa_bat_idx >= 0 and spa_bat_idx < s_lut_r * s_lut_r_s)
289
- spa_bat = tl.load(s_lut + spa_bat_idx, mask=spa_bat_msk)
290
-
291
- spa_row_idx = (pid_blk * s_lut_r_s + 1 * s_lut_c_s)
292
- spa_row_msk = (spa_row_idx >= 0 and spa_row_idx < s_lut_r * s_lut_r_s)
293
- spa_row = tl.load(s_lut + spa_row_idx, mask=spa_row_msk)
294
-
295
- spa_col_idx = (pid_blk * s_lut_r_s + 2 * s_lut_c_s)
296
- spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
297
- spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
298
-
299
- # Load block from dense tensor
300
- blk_d_idx = (spa_bat * x_b_s +
301
- ((spa_row * sparsity_block_size + pid_row * TRITON_BLOCK_SIZE +
302
- tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
303
- ((spa_col * sparsity_block_size + pid_col * TRITON_BLOCK_SIZE +
304
- tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
305
- blk_d_msk = (blk_d_idx >= 0 and blk_d_idx < x_b * x_b_s)
306
- blk_d = tl.load(x + blk_d_idx, mask=blk_d_msk)
307
-
308
- # Store block in sparse tensor
309
- blk_o_idx = ((pid_blk * o_b_s) +
310
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
311
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
312
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < (pid_blk + 1) * o_b_s)
313
- tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
208
+ return to_dense_forward(x, sparsity_layout,
209
+ lut["sparsity_reverse_lut"], sparsity_block_size, fill_value)
210
+
211
+
212
+ @triton_op("blksprs::to_dense", mutates_args={})
213
+ def to_dense_forward(x: Tensor, sparsity_layout: Tensor,
214
+ sparsity_reverse_lut: Tensor,
215
+ sparsity_block_size: int, fill_value: float) -> Tensor:
216
+ output = torch.full(size=(sparsity_layout.size(0), sparsity_layout.size(1) * sparsity_block_size,
217
+ sparsity_layout.size(2) * sparsity_block_size), fill_value=fill_value,
218
+ dtype=x.dtype, device=x.device)
219
+
220
+ x_b, x_r, x_c = x.shape
221
+ x_b_s, x_r_s, x_c_s = stride(x)
222
+ s_l_b, s_l_r, s_l_c = sparsity_layout.size()
223
+ s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout)
224
+ o_b, o_r, o_c = output.size()
225
+ o_b_s, o_r_s, o_c_s = stride(output)
226
+
227
+ triton_grid = lambda meta: [o_b,
228
+ triton.cdiv(o_r, min(meta["sparsity_block_size"], meta["TRITON_BLOCK_SIZE"])),
229
+ triton.cdiv(o_c, min(meta["sparsity_block_size"], meta["TRITON_BLOCK_SIZE"]))]
230
+
231
+ (wrap_triton(to_dense_kernel)[triton_grid]
232
+ (x,
233
+ x_b, x_b_s, x_r_s, x_c_s,
234
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
235
+ sparsity_reverse_lut,
236
+ output,
237
+ o_b, o_b_s, o_r_s, o_c_s,
238
+ sparsity_block_size))
239
+
240
+ return output
241
+
242
+
243
+ def to_dense_backward(ctx, grad_output):
244
+ sparsity_layout = ctx.saved_tensors[0]
245
+ sparsity_block_size = ctx.sparsity_block_size
246
+
247
+ return to_sparse(grad_output, sparsity_layout, sparsity_block_size), None, None, None, None
248
+
249
+
250
+ @triton.autotune(
251
+ configs=get_autotune_configs(),
252
+ key=[],
253
+ )
254
+ @triton.jit
255
+ def to_dense_kernel(x,
256
+ x_b, x_b_s, x_r_s, x_c_s,
257
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
258
+ sparsity_reverse_lut,
259
+ o,
260
+ o_b, o_b_s, o_r_s, o_c_s,
261
+ sparsity_block_size,
262
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
263
+ # Get triton block indices
264
+ pid_blk = tl.program_id(axis=0)
265
+ pid_row = tl.program_id(axis=1)
266
+ pid_col = tl.program_id(axis=2)
267
+
268
+ # Get valid triton block size
269
+ val_tbs = min(sparsity_block_size, TRITON_BLOCK_SIZE)
270
+
271
+ # Get sparsity index of current block
272
+ spa_row = (pid_row * val_tbs) // sparsity_block_size
273
+ spa_col = (pid_col * val_tbs) // sparsity_block_size
274
+
275
+ # Get reverse sparsity index for current block
276
+ rev_idx_spa_idx = (pid_blk * s_l_b_s + spa_row * s_l_r_s + spa_col * s_l_c_s)
277
+ rev_idx_spa_msk = (rev_idx_spa_idx >= 0 and rev_idx_spa_idx < s_l_b * s_l_b_s)
278
+ rev_idx_spa = tl.load(sparsity_reverse_lut + rev_idx_spa_idx, mask=rev_idx_spa_msk).to(tl.int32)
279
+
280
+ # If block is present commence operations
281
+ if rev_idx_spa >= 0:
282
+ blk_idx = (rev_idx_spa * x_b_s +
283
+ (((pid_row % (sparsity_block_size // val_tbs)) * val_tbs +
284
+ tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
285
+ (((pid_col % (sparsity_block_size // val_tbs)) * val_tbs +
286
+ tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
287
+ blk_msk = ((blk_idx >= 0 and
288
+ blk_idx < x_b * x_b_s) and
289
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
290
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
291
+ blk = tl.load(x + blk_idx, mask=blk_msk)
292
+
293
+ o_idx = (pid_blk * o_b_s +
294
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
295
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
296
+ o_msk = ((o_idx >= 0 and o_idx < o_b * o_b_s) and
297
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
298
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
299
+ tl.store(o + o_idx, blk, o_msk)
300
+
301
+
302
+ def to_dense_build_lut(lut: dict, sparsity_layout: Tensor):
303
+ if lut is None:
304
+ lut = dict()
305
+
306
+ if "sparsity_reverse_lut" not in lut:
307
+ sparsity_layout_flat = sparsity_layout.reshape(-1)
308
+ sparsity_reverse_lut = ((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
309
+ (sparsity_layout_flat == 1) -
310
+ (1 * (sparsity_layout_flat == 0)))
311
+ lut["sparsity_reverse_lut"] = sparsity_reverse_lut
312
+
313
+ validate_contiguous(lut["sparsity_reverse_lut"])
314
+
315
+ return lut
316
+
317
+
318
+ # noinspection PyUnusedLocal
319
+ def to_dense_setup_context(ctx, inputs, output):
320
+ (_, sparsity_layout, _, sparsity_block_size, _) = inputs
321
+
322
+ ctx.save_for_backward(sparsity_layout)
323
+ ctx.sparsity_block_size = sparsity_block_size
324
+
325
+
326
+ to_dense_forward.register_autograd(to_dense_backward, setup_context=to_dense_setup_context)
314
327
 
315
328
 
316
329
  def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_size_from: int,
317
- sparsity_block_size_to: int, sparsity_layout_to: Tensor = None,
318
- triton_block_size: int = None) -> (BlksprsTensor, Tensor):
330
+ sparsity_block_size_to: int, sparsity_layout_to: Tensor = None) -> (BlksprsTensor, Tensor):
319
331
  """Adapts the sparsity layout of a block-sparse tensor, resulting in a new block-sparse tensor in compressed form
320
332
  conforming to the new sparsity layout (and sparsity block size) definition.
321
333
 
@@ -325,7 +337,6 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
325
337
  sparsity_block_size_from (int): The size of the sparsity blocks of the input sparsity layout.
326
338
  sparsity_block_size_to (int): The size of the sparsity blocks of the output sparsity layout.
327
339
  sparsity_layout_to (Tensor): The sparsity layout of the output block-sparse tensor (default ``None``).
328
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
329
340
 
330
341
  Returns:
331
342
  BlksprsTensor: The block-sparse tensor in compressed form with the adapted sparsity layout and sparsity block size.
@@ -340,8 +351,6 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
340
351
  validate_sparsity(sparsity_block_size_from, (x, sparsity_layout_from))
341
352
  validate_sparsity_block_size(sparsity_block_size_from, x)
342
353
  validate_sparsity_block_size(sparsity_block_size_to)
343
- min_sparsity_block_size = min(sparsity_block_size_from, sparsity_block_size_to)
344
- validate_triton_block_size(triton_block_size, min_sparsity_block_size)
345
354
 
346
355
  sparsity_layout_from_flat = sparsity_layout_from.reshape(-1)
347
356
  sparsity_reverse_lut_from = ((torch.cumsum(sparsity_layout_from_flat, dim=-1) - 1) *
@@ -350,8 +359,7 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
350
359
 
351
360
  if sparsity_layout_to is None:
352
361
  sparsity_layout_to = build_sparsity_layout_adaption(x, sparsity_layout_from,
353
- sparsity_block_size_from, sparsity_block_size_to,
354
- triton_block_size)
362
+ sparsity_block_size_from, sparsity_block_size_to)
355
363
 
356
364
  sparsity_lut_to = torch.nonzero(sparsity_layout_to).contiguous()
357
365
 
@@ -362,134 +370,148 @@ def adapt_layout(x: BlksprsTensor, sparsity_layout_from: Tensor, sparsity_block_
362
370
  if (sparsity_block_size_from == sparsity_block_size_to) and torch.equal(sparsity_layout_from, sparsity_layout_to):
363
371
  return BlksprsTensor(x), sparsity_layout_to
364
372
 
365
- return BlksprsTensor(_BlocksparseAdaptLayout.apply(x,
366
- sparsity_layout_from, sparsity_reverse_lut_from,
367
- sparsity_block_size_from,
368
- sparsity_layout_to, sparsity_lut_to,
369
- sparsity_block_size_to,
370
- n_sparse_blocks_to, min_sparsity_block_size,
371
- triton_block_size)), sparsity_layout_to
372
-
373
-
374
- class _BlocksparseAdaptLayout(torch.autograd.Function):
375
-
376
- @staticmethod
377
- def forward(ctx, x: Tensor,
378
- sparsity_layout_from: Tensor, sparsity_reverse_lut_from: Tensor,
379
- sparsity_block_size_from: int,
380
- sparsity_layout_to: Tensor, sparsity_lut_to: Tensor,
381
- sparsity_block_size_to: int,
382
- n_sparse_blocks_to: int, min_sparsity_block_size: int, triton_block_size: int) -> Tensor:
383
- output = torch.zeros(size=(n_sparse_blocks_to, sparsity_block_size_to, sparsity_block_size_to),
384
- dtype=x.dtype, device=x.device)
385
-
386
- x_b, x_r, x_c = x.size()
387
- x_b_s, x_r_s, x_c_s = stride(x)
388
- s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_from.size()
389
- s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_from)
390
- o_b, o_r, o_c = output.size()
391
- o_b_s, o_r_s, o_c_s = stride(output)
392
- s_lut_o_r, s_lut_o_c = sparsity_lut_to.size()
393
- s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_to)
394
-
395
- if triton_block_size is None:
396
- triton_block_size = get_triton_block_size(min_sparsity_block_size)
397
-
398
- triton_grid = lambda meta: [o_b,
399
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
400
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
401
-
402
- (_BlocksparseAdaptLayout.kernel_adapt_layout[triton_grid]
403
- (x,
404
- x_b, x_b_s, x_r_s, x_c_s,
405
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
406
- sparsity_reverse_lut_from,
407
- output,
408
- o_b, o_b_s, o_r_s, o_c_s,
409
- sparsity_lut_to, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
410
- sparsity_block_size_from,
411
- sparsity_block_size_to,
412
- triton_block_size))
413
-
414
- ctx.save_for_backward(x, sparsity_layout_from, sparsity_layout_to)
415
- ctx.sparsity_block_size_from = sparsity_block_size_from
416
- ctx.sparsity_block_size_to = sparsity_block_size_to
417
- ctx.triton_block_size = triton_block_size
418
-
419
- return output
420
-
421
- @staticmethod
422
- def backward(ctx, grad_output):
423
- x, sparsity_layout_from, sparsity_layout_to = ctx.saved_tensors
424
- sparsity_block_size_from = ctx.sparsity_block_size_from
425
- sparsity_block_size_to = ctx.sparsity_block_size_to
426
- triton_block_size = ctx.triton_block_size
427
-
428
- return adapt_layout(
429
- grad_output, sparsity_layout_to, sparsity_block_size_to, sparsity_block_size_from,
430
- sparsity_layout_to=sparsity_layout_from,
431
- triton_block_size=triton_block_size)[0], None, None, None, None, None, None, None, None, None
432
-
433
- @staticmethod
434
- @triton.jit
435
- def kernel_adapt_layout(x,
436
- x_b, x_b_s, x_r_s, x_c_s,
437
- s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
438
- r_lut_x,
439
- o,
440
- o_b, o_b_s, o_r_s, o_c_s,
441
- s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
442
- sparsity_block_size_from,
443
- sparsity_block_size_to,
444
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
445
- # Get triton block indices
446
- pid_blk = tl.program_id(axis=0)
447
- pid_row = tl.program_id(axis=1)
448
- pid_col = tl.program_id(axis=2)
449
-
450
- # Get position of current sparsity block consisting of its batch, row, and column index
451
- spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
452
- spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
453
- spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
454
-
455
- spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
456
- spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
457
- spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
458
-
459
- spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
460
- spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
461
- spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
462
-
463
- # Get equivalent sparsity block in from layout
464
- spa_bat_x = spa_bat_o
465
- spa_row_x = (spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size_from
466
- spa_col_x = (spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size_from
467
-
468
- # Get reverse sparsity indices for x
469
- rev_idx_spa_x_idx = (spa_bat_x * s_l_x_b_s +
470
- spa_row_x * s_l_x_r_s +
471
- spa_col_x * s_l_x_c_s)
472
- rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
473
- rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
474
-
475
- # If block is present commence operations
476
- if rev_idx_spa_x >= 0:
477
- # Calculate triton block size shifts
478
- shift_row_x = ((spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE)
479
- % sparsity_block_size_from) // TRITON_BLOCK_SIZE
480
- shift_col_x = ((spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE)
481
- % sparsity_block_size_from) // TRITON_BLOCK_SIZE
482
-
483
- # Load x values
484
- blk_x_idx = ((rev_idx_spa_x * x_b_s) +
485
- ((shift_row_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
486
- ((shift_col_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
487
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
488
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
489
-
490
- # Store output
491
- blk_o_idx = ((pid_blk * o_b_s) +
492
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
493
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
494
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
495
- tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
373
+ return BlksprsTensor(adapt_layout_forward(x,
374
+ sparsity_layout_from, sparsity_reverse_lut_from,
375
+ sparsity_block_size_from,
376
+ sparsity_layout_to, sparsity_lut_to,
377
+ sparsity_block_size_to,
378
+ n_sparse_blocks_to)), sparsity_layout_to
379
+
380
+
381
+ @triton_op("blksprs::adapt_layout", mutates_args={})
382
+ def adapt_layout_forward(x: Tensor,
383
+ sparsity_layout_from: Tensor, sparsity_reverse_lut_from: Tensor,
384
+ sparsity_block_size_from: int,
385
+ _: Tensor, sparsity_lut_to: Tensor,
386
+ sparsity_block_size_to: int,
387
+ n_sparse_blocks_to: int) -> Tensor:
388
+ output = torch.zeros(size=(n_sparse_blocks_to, sparsity_block_size_to, sparsity_block_size_to),
389
+ dtype=x.dtype, device=x.device)
390
+
391
+ x_b, x_r, x_c = x.size()
392
+ x_b_s, x_r_s, x_c_s = stride(x)
393
+ s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_from.size()
394
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = stride(sparsity_layout_from)
395
+ o_b, o_r, o_c = output.size()
396
+ o_b_s, o_r_s, o_c_s = stride(output)
397
+ s_lut_o_r, s_lut_o_c = sparsity_lut_to.size()
398
+ s_lut_o_r_s, s_lut_o_c_s = stride(sparsity_lut_to)
399
+
400
+ triton_grid = lambda meta: [o_b,
401
+ triton.cdiv(o_r, min(meta["sparsity_block_size_from"], meta["sparsity_block_size_to"],
402
+ meta["TRITON_BLOCK_SIZE"])),
403
+ triton.cdiv(o_c, min(meta["sparsity_block_size_from"], meta["sparsity_block_size_to"],
404
+ meta["TRITON_BLOCK_SIZE"]))]
405
+
406
+ (wrap_triton(adapt_layout_kernel)[triton_grid]
407
+ (x,
408
+ x_b, x_b_s, x_r_s, x_c_s,
409
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
410
+ sparsity_reverse_lut_from,
411
+ output,
412
+ o_b, o_b_s, o_r_s, o_c_s,
413
+ sparsity_lut_to, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
414
+ sparsity_block_size_from,
415
+ sparsity_block_size_to))
416
+
417
+ return output
418
+
419
+
420
+ def adapt_layout_backward(ctx, grad_output):
421
+ x, sparsity_layout_from, sparsity_layout_to = ctx.saved_tensors
422
+ sparsity_block_size_from = ctx.sparsity_block_size_from
423
+ sparsity_block_size_to = ctx.sparsity_block_size_to
424
+
425
+ return adapt_layout(
426
+ grad_output, sparsity_layout_to, sparsity_block_size_to, sparsity_block_size_from,
427
+ sparsity_layout_to=sparsity_layout_from)[0], None, None, None, None, None, None, None
428
+
429
+
430
+ @triton.autotune(
431
+ configs=get_autotune_configs(),
432
+ key=[],
433
+ reset_to_zero=["o"]
434
+ )
435
+ @triton.jit
436
+ def adapt_layout_kernel(x,
437
+ x_b, x_b_s, x_r_s, x_c_s,
438
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
439
+ r_lut_x,
440
+ o,
441
+ o_b, o_b_s, o_r_s, o_c_s,
442
+ s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
443
+ sparsity_block_size_from,
444
+ sparsity_block_size_to,
445
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
446
+ # Get triton block indices
447
+ pid_blk = tl.program_id(axis=0)
448
+ pid_row = tl.program_id(axis=1)
449
+ pid_col = tl.program_id(axis=2)
450
+
451
+ # Get valid triton block size (Triton can only handle 2-valued min)
452
+ val_tbs = min(min(sparsity_block_size_from, sparsity_block_size_to), TRITON_BLOCK_SIZE)
453
+
454
+ # Get position of current sparsity block consisting of its batch, row, and column index
455
+ spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
456
+ spa_bat_o_msk = (spa_bat_o_idx >= 0 and spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
457
+ spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
458
+
459
+ spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
460
+ spa_row_o_msk = (spa_row_o_idx >= 0 and spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
461
+ spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
462
+
463
+ spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
464
+ spa_col_o_msk = (spa_col_o_idx >= 0 and spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
465
+ spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
466
+
467
+ # Get equivalent sparsity block in from layout
468
+ spa_bat_x = spa_bat_o
469
+ spa_row_x = (spa_row_o * sparsity_block_size_to + pid_row * val_tbs) // sparsity_block_size_from
470
+ spa_col_x = (spa_col_o * sparsity_block_size_to + pid_col * val_tbs) // sparsity_block_size_from
471
+
472
+ # Get reverse sparsity indices for x
473
+ rev_idx_spa_x_idx = (spa_bat_x * s_l_x_b_s +
474
+ spa_row_x * s_l_x_r_s +
475
+ spa_col_x * s_l_x_c_s)
476
+ rev_idx_spa_x_msk = (rev_idx_spa_x_idx >= 0 and rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
477
+ rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
478
+
479
+ # If block is present commence operations
480
+ if rev_idx_spa_x >= 0:
481
+ # Calculate triton block size shifts
482
+ shift_row_x = ((spa_row_o * sparsity_block_size_to + pid_row * val_tbs)
483
+ % sparsity_block_size_from) // val_tbs
484
+ shift_col_x = ((spa_col_o * sparsity_block_size_to + pid_col * val_tbs)
485
+ % sparsity_block_size_from) // val_tbs
486
+
487
+ # Load x values
488
+ blk_x_idx = ((rev_idx_spa_x * x_b_s) +
489
+ ((shift_row_x * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
490
+ ((shift_col_x * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
491
+ blk_x_msk = ((blk_x_idx >= 0 and
492
+ blk_x_idx < x_b * x_b_s) and
493
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
494
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
495
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
496
+
497
+ # Store output
498
+ blk_o_idx = ((pid_blk * o_b_s) +
499
+ ((pid_row * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
500
+ ((pid_col * val_tbs + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
501
+ blk_o_msk = ((blk_o_idx >= 0 and
502
+ blk_o_idx < o_b * o_b_s) and
503
+ (tl.arange(0, TRITON_BLOCK_SIZE)[:, None] < val_tbs and
504
+ tl.arange(0, TRITON_BLOCK_SIZE)[None, :] < val_tbs))
505
+ tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)
506
+
507
+
508
+ # noinspection PyUnusedLocal
509
+ def adapt_layout_setup_context(ctx, inputs, output):
510
+ (x, sparsity_layout_from, _, sparsity_block_size_from, sparsity_layout_to, _, sparsity_block_size_to, _) = inputs
511
+
512
+ ctx.save_for_backward(x, sparsity_layout_from, sparsity_layout_to)
513
+ ctx.sparsity_block_size_from = sparsity_block_size_from
514
+ ctx.sparsity_block_size_to = sparsity_block_size_to
515
+
516
+
517
+ adapt_layout_forward.register_autograd(adapt_layout_backward, setup_context=adapt_layout_setup_context)