blksprs 1.0__py3-none-any.whl → 1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/ops/conversion.py CHANGED
@@ -1,24 +1,39 @@
1
+ from typing import Any
2
+
1
3
  import torch
2
4
  import triton
3
5
  from torch import Tensor
4
6
  from triton import language as tl
5
7
 
8
+ from blksprs.layouting.sparsity_layout import build_sparsity_layout_adaption
6
9
  from blksprs.utils.tools import get_triton_block_size
7
- from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_dtype_float, validate_device
10
+ from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
11
+ validate_sparsity, validate_sparsity_block_size, validate_triton_block_size
8
12
 
9
13
 
10
14
  def to_dense(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_value: float = 0,
11
15
  triton_block_size: int = None) -> Tensor:
12
- """Converts a blocksparse tensor to a dense tensor based on the given sparsity layout.
16
+ """Converts a block-sparse tensor in compressed form to a block-sparse tensor in regular form based on the given
17
+ sparsity layout.
18
+
19
+ Args:
20
+ x (Tensor): A block-sparse tensor in compressed form.
21
+ sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
22
+ sparsity_block_size (int): The size of the sparsity blocks.
23
+ fill_value (float): The value to fill the resulting dense tensor with where the block-sparse tensor is not
24
+ present (default ``0``).
25
+ triton_block_size (int): The block size to use for the triton kernel (default ``None``).
13
26
 
14
- The ``fill_value`` is used to fill the resulting dense tensor with a specific value (default ``0``) where the
15
- blocksparse tensor is not present.
27
+ Returns:
28
+ Tensor: The block-sparse tensor converted to regular form.
16
29
 
17
30
  """
18
31
  validate_dimensions(x)
19
32
  validate_contiguous(x, sparsity_layout)
20
- validate_dtype_float(x)
21
33
  validate_device(x)
34
+ validate_sparsity(sparsity_block_size, (x, sparsity_layout))
35
+ validate_sparsity_block_size(sparsity_block_size, x)
36
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
22
37
 
23
38
  sparsity_layout_flat = sparsity_layout.reshape(-1)
24
39
  sparsity_reverse_lut = ((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
@@ -27,6 +42,9 @@ def to_dense(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, fill_
27
42
 
28
43
  validate_contiguous(sparsity_reverse_lut)
29
44
 
45
+ if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
46
+ return x
47
+
30
48
  return _BlocksparseToDense.apply(x,
31
49
  sparsity_layout, sparsity_reverse_lut,
32
50
  sparsity_block_size, fill_value,
@@ -68,7 +86,7 @@ class _BlocksparseToDense(torch.autograd.Function):
68
86
  sparsity_block_size,
69
87
  triton_block_size))
70
88
 
71
- ctx.sparsity_layout = sparsity_layout
89
+ ctx.save_for_backward(sparsity_layout)
72
90
  ctx.sparsity_block_size = sparsity_block_size
73
91
  ctx.triton_block_size = triton_block_size
74
92
 
@@ -76,11 +94,12 @@ class _BlocksparseToDense(torch.autograd.Function):
76
94
 
77
95
  @staticmethod
78
96
  def backward(ctx, grad_output):
79
- sparsity_layout = ctx.sparsity_layout
97
+ sparsity_layout = ctx.saved_tensors[0]
80
98
  sparsity_block_size = ctx.sparsity_block_size
81
99
  triton_block_size = ctx.triton_block_size
82
100
 
83
- return to_sparse(grad_output, sparsity_layout, sparsity_block_size, triton_block_size), None, None, None, None, None
101
+ return to_sparse(grad_output, sparsity_layout, sparsity_block_size,
102
+ triton_block_size), None, None, None, None, None
84
103
 
85
104
  @staticmethod
86
105
  @triton.jit
@@ -124,18 +143,32 @@ class _BlocksparseToDense(torch.autograd.Function):
124
143
 
125
144
 
126
145
  def to_sparse(x: Tensor, sparsity_layout: Tensor, sparsity_block_size: int, triton_block_size: int = None) -> Tensor:
127
- """Converts a dense tensor to a blocksparse tensor based on the given sparsity layout.
146
+ """Converts a block-sparse tensor in regular form to a block-sparse tensor in compressed form based on the given
147
+ sparsity layout.
148
+
149
+ Args:
150
+ x (Tensor): A block-sparse tensor in regular form.
151
+ sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
152
+ sparsity_block_size (int): The size of the sparsity blocks.
153
+ triton_block_size (int): The block size to use for the triton kernel (default ``None``).
154
+
155
+ Returns:
156
+ Tensor: The block-sparse tensor converted to compressed form.
128
157
 
129
158
  """
130
159
  validate_dimensions(x)
131
- validate_contiguous(x, sparsity_layout)
132
- validate_dtype_float(x)
160
+ validate_contiguous(x)
133
161
  validate_device(x)
162
+ validate_sparsity_block_size(sparsity_block_size, x)
163
+ validate_triton_block_size(triton_block_size, sparsity_block_size)
134
164
 
135
165
  sparsity_lut = torch.nonzero(sparsity_layout).contiguous()
136
166
  n_sparse_blocks = torch.sum(sparsity_layout.to(torch.int)).item()
137
167
 
138
- validate_contiguous(sparsity_lut)
168
+ validate_contiguous(sparsity_layout, sparsity_lut)
169
+
170
+ if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
171
+ return x
139
172
 
140
173
  return _BlocksparseToSparse.apply(x,
141
174
  sparsity_layout, sparsity_lut,
@@ -149,14 +182,15 @@ class _BlocksparseToSparse(torch.autograd.Function):
149
182
  def forward(ctx, x: Tensor,
150
183
  sparsity_layout: Tensor, sparsity_lut: Tensor,
151
184
  sparsity_block_size: int, n_sparse_blocks: int, triton_block_size: int) -> Tensor:
152
- output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size), device=x.device)
185
+ output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size), dtype=x.dtype,
186
+ device=x.device)
153
187
 
154
188
  x_b, x_r, x_c = x.size()
155
189
  x_b_s, x_r_s, x_c_s = x.stride()
156
- o_b, o_r, o_c = output.size()
157
- o_b_s, o_r_s, o_c_s = output.stride()
158
190
  s_lut_r, s_lut_c = sparsity_lut.size()
159
191
  s_lut_r_s, s_lut_c_s = sparsity_lut.stride()
192
+ o_b, o_r, o_c = output.size()
193
+ o_b_s, o_r_s, o_c_s = output.stride()
160
194
 
161
195
  if triton_block_size is None:
162
196
  triton_block_size = get_triton_block_size(sparsity_block_size)
@@ -172,7 +206,7 @@ class _BlocksparseToSparse(torch.autograd.Function):
172
206
  sparsity_block_size,
173
207
  triton_block_size))
174
208
 
175
- ctx.sparsity_layout = sparsity_layout
209
+ ctx.save_for_backward(sparsity_layout)
176
210
  ctx.sparsity_block_size = sparsity_block_size
177
211
  ctx.triton_block_size = triton_block_size
178
212
 
@@ -180,7 +214,7 @@ class _BlocksparseToSparse(torch.autograd.Function):
180
214
 
181
215
  @staticmethod
182
216
  def backward(ctx, grad_output):
183
- sparsity_layout = ctx.sparsity_layout
217
+ sparsity_layout = ctx.saved_tensors[0]
184
218
  sparsity_block_size = ctx.sparsity_block_size
185
219
  triton_block_size = ctx.triton_block_size
186
220
 
@@ -229,3 +263,189 @@ class _BlocksparseToSparse(torch.autograd.Function):
229
263
  ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE) * o_c_s))[None, :])
230
264
  blk_o_msk = (blk_o_idx < (pid_blk + 1) * o_b_s)
231
265
  tl.store(o + blk_o_idx, blk_d, mask=blk_o_msk)
266
+
267
+
268
+ def adapt_layout(x: Tensor, sparsity_layout_from: Tensor, sparsity_block_size_from: int, sparsity_block_size_to: int,
269
+ preprocess_data: dict = None, triton_block_size: int = None) -> Tensor:
270
+ """Adapts the sparsity layout of a block-sparse tensor, resulting in a new block-sparse tensor in compressed form
271
+ conforming to the new sparsity layout (and sparsity block size) definition.
272
+
273
+ Args:
274
+ x (Tensor): A block-sparse tensor in compressed form.
275
+ sparsity_layout_from (Tensor): The sparsity layout of the input block-sparse tensor.
276
+ sparsity_block_size_from (int): The size of the sparsity blocks of the input sparsity layout.
277
+ sparsity_block_size_to (int): The size of the sparsity blocks of the output sparsity layout.
278
+ preprocess_data (dict): A dictionary containing data otherwise computed by the function (default ``None``).
279
+ triton_block_size (int): The block size to use for the triton kernel (default ``None``).
280
+
281
+ Returns:
282
+ Tensor: The block-sparse tensor in compressed form with the adapted sparsity layout and sparsity block size.
283
+
284
+ """
285
+ validate_dimensions(x)
286
+ validate_contiguous(x, sparsity_layout_from)
287
+ validate_device(x)
288
+ validate_sparsity(sparsity_block_size_from, (x, sparsity_layout_from))
289
+ validate_sparsity_block_size(sparsity_block_size_from, x)
290
+ validate_sparsity_block_size(sparsity_block_size_to)
291
+ min_sparsity_block_size = min(sparsity_block_size_from, sparsity_block_size_to)
292
+ validate_triton_block_size(triton_block_size, min_sparsity_block_size)
293
+
294
+ if preprocess_data is None:
295
+ preprocess_data = {}
296
+
297
+ if "sparsity_reverse_lut_from" not in preprocess_data:
298
+ sparsity_layout_from_flat = sparsity_layout_from.reshape(-1)
299
+ sparsity_reverse_lut_from = ((torch.cumsum(sparsity_layout_from_flat, dim=-1) - 1) *
300
+ (sparsity_layout_from_flat == 1) -
301
+ (1 * (sparsity_layout_from_flat == 0)))
302
+ else:
303
+ sparsity_reverse_lut_from = preprocess_data["sparsity_reverse_lut_from"]
304
+
305
+ if "sparsity_layout_to" not in preprocess_data:
306
+ sparsity_layout_to = build_sparsity_layout_adaption(x, sparsity_layout_from,
307
+ sparsity_block_size_from, sparsity_block_size_to,
308
+ triton_block_size)
309
+ else:
310
+ sparsity_layout_to = preprocess_data["sparsity_layout_to"]
311
+
312
+ if "sparsity_lut_to" not in preprocess_data:
313
+ sparsity_lut_to = torch.nonzero(sparsity_layout_to).contiguous()
314
+ else:
315
+ sparsity_lut_to = preprocess_data["sparsity_lut_to"]
316
+
317
+ if "n_sparse_blocks_to" not in preprocess_data:
318
+ n_sparse_blocks_to = torch.sum(sparsity_layout_to.to(torch.int)).item()
319
+ else:
320
+ n_sparse_blocks_to = preprocess_data["n_sparse_blocks_to"]
321
+
322
+ validate_contiguous(sparsity_layout_to, sparsity_reverse_lut_from, sparsity_lut_to)
323
+
324
+ if (sparsity_block_size_from == sparsity_block_size_to) and torch.equal(sparsity_layout_from, sparsity_layout_to):
325
+ return x
326
+
327
+ return _BlocksparseAdaptLayout.apply(x,
328
+ sparsity_layout_from, sparsity_reverse_lut_from, sparsity_block_size_from,
329
+ sparsity_layout_to, sparsity_lut_to, sparsity_block_size_to,
330
+ n_sparse_blocks_to, min_sparsity_block_size, triton_block_size)
331
+
332
+
333
+ class _BlocksparseAdaptLayout(torch.autograd.Function):
334
+
335
+ @staticmethod
336
+ def forward(ctx, x: Tensor,
337
+ sparsity_layout_from: Tensor, sparsity_reverse_lut_from: Tensor, sparsity_block_size_from: int,
338
+ sparsity_layout_to: Tensor, sparsity_lut_to: Tensor, sparsity_block_size_to: int,
339
+ n_sparse_blocks_to: int, min_sparsity_block_size: int, triton_block_size: int) -> Tensor:
340
+ output = torch.zeros(size=(n_sparse_blocks_to, sparsity_block_size_to, sparsity_block_size_to),
341
+ dtype=x.dtype, device=x.device)
342
+
343
+ x_b, x_r, x_c = x.size()
344
+ x_b_s, x_r_s, x_c_s = x.stride()
345
+ s_l_x_b, s_l_x_r, s_l_x_c = sparsity_layout_from.size()
346
+ s_l_x_b_s, s_l_x_r_s, s_l_x_c_s = sparsity_layout_from.stride()
347
+ o_b, o_r, o_c = output.size()
348
+ o_b_s, o_r_s, o_c_s = output.stride()
349
+ s_lut_o_r, s_lut_o_c = sparsity_lut_to.size()
350
+ s_lut_o_r_s, s_lut_o_c_s = sparsity_lut_to.stride()
351
+
352
+ if triton_block_size is None:
353
+ triton_block_size = get_triton_block_size(min_sparsity_block_size)
354
+
355
+ triton_grid = lambda meta: [o_b,
356
+ triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
357
+ triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
358
+
359
+ (_BlocksparseAdaptLayout.kernel_adapt_layout[triton_grid]
360
+ (x,
361
+ x_b, x_b_s, x_r_s, x_c_s,
362
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
363
+ sparsity_reverse_lut_from,
364
+ output,
365
+ o_b, o_b_s, o_r_s, o_c_s,
366
+ sparsity_lut_to, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
367
+ sparsity_block_size_from,
368
+ sparsity_block_size_to,
369
+ triton_block_size))
370
+
371
+ ctx.save_for_backward(x, sparsity_layout_from, sparsity_layout_to)
372
+ ctx.sparsity_block_size_from = sparsity_block_size_from
373
+ ctx.sparsity_block_size_to = sparsity_block_size_to
374
+ ctx.triton_block_size = triton_block_size
375
+
376
+ return output
377
+
378
+ @staticmethod
379
+ def backward(ctx, grad_output):
380
+ x, sparsity_layout_from, sparsity_layout_to = ctx.saved_tensors
381
+ sparsity_block_size_from = ctx.sparsity_block_size_from
382
+ sparsity_block_size_to = ctx.sparsity_block_size_to
383
+ triton_block_size = ctx.triton_block_size
384
+
385
+ return adapt_layout(grad_output, sparsity_layout_to, sparsity_block_size_to, sparsity_block_size_from,
386
+ preprocess_data={"sparsity_layout_to": sparsity_layout_from},
387
+ triton_block_size=triton_block_size), None, None, None, None, None, None, None, None, None
388
+
389
+ @staticmethod
390
+ @triton.jit
391
+ def kernel_adapt_layout(x,
392
+ x_b, x_b_s, x_r_s, x_c_s,
393
+ s_l_x_b, s_l_x_b_s, s_l_x_r_s, s_l_x_c_s,
394
+ r_lut_x,
395
+ o,
396
+ o_b, o_b_s, o_r_s, o_c_s,
397
+ s_lut_o, s_lut_o_r, s_lut_o_r_s, s_lut_o_c_s,
398
+ sparsity_block_size_from,
399
+ sparsity_block_size_to,
400
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
401
+ # Get triton block indices
402
+ pid_blk = tl.program_id(axis=0)
403
+ pid_row = tl.program_id(axis=1)
404
+ pid_col = tl.program_id(axis=2)
405
+
406
+ # Get position of current sparsity block consisting of its batch, row, and column index
407
+ spa_bat_o_idx = (pid_blk * s_lut_o_r_s + 0 * s_lut_o_c_s)
408
+ spa_bat_o_msk = (spa_bat_o_idx < s_lut_o_r * s_lut_o_r_s)
409
+ spa_bat_o = tl.load(s_lut_o + spa_bat_o_idx, mask=spa_bat_o_msk)
410
+
411
+ spa_row_o_idx = (pid_blk * s_lut_o_r_s + 1 * s_lut_o_c_s)
412
+ spa_row_o_msk = (spa_row_o_idx < s_lut_o_r * s_lut_o_r_s)
413
+ spa_row_o = tl.load(s_lut_o + spa_row_o_idx, mask=spa_row_o_msk)
414
+
415
+ spa_col_o_idx = (pid_blk * s_lut_o_r_s + 2 * s_lut_o_c_s)
416
+ spa_col_o_msk = (spa_col_o_idx < s_lut_o_r * s_lut_o_r_s)
417
+ spa_col_o = tl.load(s_lut_o + spa_col_o_idx, mask=spa_col_o_msk)
418
+
419
+ # Get equivalent sparsity block in from layout
420
+ spa_bat_x = spa_bat_o
421
+ spa_row_x = (spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE) // sparsity_block_size_from
422
+ spa_col_x = (spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE) // sparsity_block_size_from
423
+
424
+ # # Get reverse sparsity indices for x
425
+ rev_idx_spa_x_idx = (spa_bat_x * s_l_x_b_s +
426
+ spa_row_x * s_l_x_r_s +
427
+ spa_col_x * s_l_x_c_s)
428
+ rev_idx_spa_x_msk = (rev_idx_spa_x_idx < s_l_x_b * s_l_x_b_s)
429
+ rev_idx_spa_x = tl.load(r_lut_x + rev_idx_spa_x_idx, mask=rev_idx_spa_x_msk).to(tl.int32)
430
+
431
+ # If block is present commence operations
432
+ if rev_idx_spa_x >= 0:
433
+ # Calculate triton block size shifts
434
+ shift_row_x = ((spa_row_o * sparsity_block_size_to + pid_row * TRITON_BLOCK_SIZE)
435
+ % sparsity_block_size_from) // TRITON_BLOCK_SIZE
436
+ shift_col_x = ((spa_col_o * sparsity_block_size_to + pid_col * TRITON_BLOCK_SIZE)
437
+ % sparsity_block_size_from) // TRITON_BLOCK_SIZE
438
+
439
+ # Load x values
440
+ blk_x_idx = ((rev_idx_spa_x * x_b_s) +
441
+ ((shift_row_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
442
+ ((shift_col_x * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
443
+ blk_x_msk = (blk_x_idx < x_b * x_b_s)
444
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
445
+
446
+ # Store output
447
+ blk_o_idx = ((pid_blk * o_b_s) +
448
+ ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
449
+ ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
450
+ blk_o_msk = (blk_o_idx < o_b * o_b_s)
451
+ tl.store(o + blk_o_idx, blk_x, mask=blk_o_msk)