blksprs 2.0rc8__py3-none-any.whl → 2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/__init__.py CHANGED
@@ -1,12 +1,13 @@
1
- from blksprs.utils.tools import version
2
1
  from blksprs.utils.blksprs_tensor import BlksprsTensor
3
2
 
3
+ __version__ = "2.1"
4
+
4
5
 
5
6
  class ops:
6
7
  from blksprs.ops.conversion import to_dense, to_sparse, from_blksprs, to_blksprs, adapt_layout
7
8
  from blksprs.ops.distribution import gather, scatter, scatter_reduce
8
9
  from blksprs.ops.matmul import matmul
9
- from blksprs.ops.softmax import softmax
10
+ from blksprs.ops.softmax import softmax, softmax_fused
10
11
  from blksprs.ops.transpose import transpose
11
12
  from blksprs.ops.repeat import repeat, repeat_interleave
12
13
  from blksprs.ops.partitioning import split, merge
blksprs/ops/conversion.py CHANGED
@@ -204,8 +204,8 @@ def to_dense(x: BlksprsTensor, sparsity_layout: Tensor,
204
204
  if sparsity_layout.size(1) == 1 and sparsity_layout.size(2) == 1 and torch.all(sparsity_layout):
205
205
  return x
206
206
 
207
- return to_dense_forward(x, sparsity_layout,
208
- lut["sparsity_reverse_lut"], sparsity_block_size, fill_value)
207
+ return Tensor(to_dense_forward(x, sparsity_layout,
208
+ lut["sparsity_reverse_lut"], sparsity_block_size, fill_value))
209
209
 
210
210
 
211
211
  @triton_op("blksprs::to_dense_forward", mutates_args={})
blksprs/ops/flow.py CHANGED
@@ -78,7 +78,7 @@ def flow_pull_kernel(x,
78
78
  spa_col_msk = (spa_col_idx >= 0 and spa_col_idx < s_lut_r * s_lut_r_s)
79
79
  spa_col = tl.load(s_lut + spa_col_idx, mask=spa_col_msk)
80
80
 
81
- # Get reverse sparsity index
81
+ # Load reverse sparsity index
82
82
  rev_idx_spa_idx = (spa_bat * s_l_o_b_s +
83
83
  spa_row * s_l_o_r_s +
84
84
  spa_col * s_l_o_c_s)
blksprs/ops/softmax.py CHANGED
@@ -1,3 +1,5 @@
1
+ import pdb
2
+
1
3
  import torch
2
4
  import triton
3
5
  from torch import Tensor
@@ -7,6 +9,7 @@ from triton import language as tl
7
9
 
8
10
  from blksprs.ops.misc.row_wise import row_wise_sum, row_wise_max, row_wise_sub
9
11
  from blksprs.utils.blksprs_tensor import BlksprsTensor
12
+ from blksprs.utils.debugging import dbg_tensor_full
10
13
  from blksprs.utils.tools import stride
11
14
  from blksprs.utils.autotuning import get_autotune_configs, prune_autotune_configs
12
15
  from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
@@ -100,6 +103,8 @@ def softmax_backward_wrapper(ctx, grad_output):
100
103
  def softmax_backward(grad_output: Tensor, o: Tensor, sparsity_lut: Tensor, sparsity_layout: Tensor,
101
104
  sparsity_block_size: int) -> Tensor:
102
105
  with torch.no_grad():
106
+ grad_x = torch.zeros_like(o, dtype=torch.float)
107
+
103
108
  s, sparsity_layout_s = row_wise_sum(grad_output * o, sparsity_layout, sparsity_block_size, flag_slice_only=True)
104
109
 
105
110
  sparsity_layout_s_flat = sparsity_layout_s.reshape(-1)
@@ -116,8 +121,6 @@ def softmax_backward(grad_output: Tensor, o: Tensor, sparsity_lut: Tensor, spars
116
121
  s_l_s_b, s_l_s_r, s_l_s_c = sparsity_layout_s.size()
117
122
  s_l_s_b_s, s_l_s_r_s, s_l_s_c_s = stride(sparsity_layout_s)
118
123
 
119
- grad_x = torch.zeros_like(o, dtype=torch.float)
120
-
121
124
  triton_grid = lambda meta: [o_b,
122
125
  triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
123
126
  triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
@@ -302,3 +305,238 @@ def softmax_setup_context(ctx, inputs, output):
302
305
 
303
306
 
304
307
  softmax_forward.register_autograd(softmax_backward_wrapper, setup_context=softmax_setup_context)
308
+
309
+
310
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
311
+ def softmax_fused(x: BlksprsTensor, sparsity_layout: Tensor, sparsity_block_size: int,
312
+ lut: dict = None) -> BlksprsTensor:
313
+ """Computes the softmax fused for each row of a block-sparse tensor in compressed form.
314
+
315
+ Note:
316
+ This softmax implementation is a fused version that loads the entire row of a block-sparse tensor into memory.
317
+ See :func:`softmax` for a true block-wise softmax implementation.
318
+
319
+ Args:
320
+ x (BlksprsTensor): A block-sparse tensor in compressed form.
321
+ sparsity_layout (Tensor): The sparsity layout of the block-sparse tensor.
322
+ sparsity_block_size (int): The size of the sparsity blocks.
323
+ lut (dict, optional): A dictionary containing the look-up tables for the operation (default ``None``).
324
+
325
+ Returns:
326
+ BlksprsTensor: The result of the softmax operation as a block-sparse tensor in compressed form.
327
+
328
+ """
329
+ x = x.contiguous()
330
+
331
+ validate_dimensions(x)
332
+ validate_contiguous(x)
333
+ validate_dtype_float_32(x)
334
+ validate_device(x)
335
+ validate_sparsity(sparsity_block_size, (x, sparsity_layout))
336
+ validate_sparsity_block_size(sparsity_block_size, x)
337
+
338
+ lut = softmax_fused_build_lut(lut, sparsity_layout)
339
+
340
+ return BlksprsTensor(softmax_fused_forward(x, sparsity_layout,
341
+ lut["sparsity_reverse_lut"],
342
+ sparsity_block_size))
343
+
344
+
345
+ @triton_op("blksprs::softmax_fused_forward", mutates_args={})
346
+ def softmax_fused_forward(x: Tensor, sparsity_layout: Tensor,
347
+ sparsity_reverse_lut: Tensor,
348
+ sparsity_block_size: int) -> Tensor:
349
+ output = torch.zeros_like(x)
350
+
351
+ x_b, x_r, x_c = x.size()
352
+ x_b_s, x_r_s, x_c_s = stride(x)
353
+ s_l_b, s_l_r, s_l_c = sparsity_layout.size()
354
+ s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout)
355
+
356
+ triton_grid = lambda meta: [s_l_b,
357
+ s_l_r,
358
+ sparsity_block_size]
359
+
360
+ (wrap_triton(softmax_fused_kernel)[triton_grid]
361
+ (x,
362
+ x_b, x_b_s, x_r_s, x_c_s,
363
+ output,
364
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c, s_l_c_s,
365
+ sparsity_reverse_lut,
366
+ sparsity_block_size))
367
+
368
+ return output
369
+
370
+
371
+ def softmax_fused_backward_wrapper(ctx, grad_output):
372
+ o, sparsity_layout, sparsity_reverse_lut = ctx.saved_tensors
373
+ sparsity_block_size = ctx.sparsity_block_size
374
+
375
+ return softmax_fused_backward(grad_output, o, sparsity_reverse_lut, sparsity_layout,
376
+ sparsity_block_size), None, None, None, None, None
377
+
378
+
379
+ @triton_op("blksprs::softmax_fused_backward", mutates_args={})
380
+ def softmax_fused_backward(grad_output: Tensor, o: Tensor, sparsity_reverse_lut: Tensor, sparsity_layout: Tensor,
381
+ sparsity_block_size: int) -> Tensor:
382
+ with torch.no_grad():
383
+ grad_x = torch.zeros_like(o)
384
+
385
+ g_b, g_r, g_c = grad_output.size()
386
+ g_b_s, g_r_s, g_c_s = stride(grad_output)
387
+ o_b, o_r, o_c = o.size()
388
+ o_b_s, o_r_s, o_c_s = stride(o)
389
+ s_l_b, s_l_r, s_l_c = sparsity_layout.size()
390
+ s_l_b_s, s_l_r_s, s_l_c_s = stride(sparsity_layout)
391
+
392
+ triton_grid = lambda meta: [s_l_b,
393
+ s_l_r,
394
+ sparsity_block_size]
395
+
396
+ (wrap_triton(softmax_fused_kernel_grad)[triton_grid]
397
+ (grad_output,
398
+ g_b, g_b_s, g_r_s, g_c_s,
399
+ o,
400
+ o_b, o_b_s, o_r_s, o_c_s,
401
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c, s_l_c_s,
402
+ sparsity_reverse_lut,
403
+ grad_x,
404
+ sparsity_block_size))
405
+
406
+ return grad_x
407
+
408
+
409
+ # noinspection PyUnusedLocal
410
+ @triton.autotune(
411
+ configs=get_autotune_configs(),
412
+ key=["sparsity_block_size"],
413
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
414
+ reset_to_zero=["o"]
415
+ )
416
+ @triton.jit
417
+ def softmax_fused_kernel(x,
418
+ x_b, x_b_s, x_r_s, x_c_s,
419
+ o,
420
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c: tl.constexpr, s_l_c_s,
421
+ r_lut,
422
+ sparsity_block_size: tl.constexpr,
423
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
424
+ # Get triton block indices
425
+ pid_bat = tl.program_id(axis=0)
426
+ pid_row = tl.program_id(axis=1)
427
+ pid_lin = tl.program_id(axis=2)
428
+
429
+ # Load reverse sparsity indices of row
430
+ blk_rev_idx = (pid_bat * s_l_b_s +
431
+ pid_row * s_l_r_s +
432
+ (tl.arange(0, s_l_c) * s_l_c_s))
433
+ blk_rev_msk = (blk_rev_idx >= 0 and blk_rev_idx < s_l_b * s_l_b_s)
434
+ blk_rev = tl.load(r_lut + blk_rev_idx, mask=blk_rev_msk).to(tl.int32)
435
+
436
+ if (not (tl.min(blk_rev) == -1 and
437
+ tl.max(blk_rev) == -1)):
438
+ # Extend sparsity indices to cover sparsity blocks
439
+ blk_rev_ext = tl.expand_dims(blk_rev, -1)
440
+ blk_rev_ext = tl.broadcast_to(blk_rev_ext, (s_l_c, sparsity_block_size))
441
+ blk_rev_ext = tl.reshape(blk_rev_ext, (s_l_c * sparsity_block_size))
442
+
443
+ # Load line of x
444
+ blk_x_idx = (blk_rev_ext * x_b_s +
445
+ pid_lin * x_r_s +
446
+ (tl.arange(0, s_l_c * sparsity_block_size) % sparsity_block_size) * x_c_s)
447
+ blk_x_mask = ((blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
448
+ and blk_rev_ext != -1)
449
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_mask, other=float("-inf"))
450
+
451
+ # Compute softmax
452
+ blk_x_softmax = tl.softmax(blk_x)
453
+
454
+ # Store output
455
+ tl.store(o + blk_x_idx, blk_x_softmax, mask=blk_x_mask)
456
+
457
+
458
+ # noinspection PyUnusedLocal
459
+ @triton.autotune(
460
+ configs=get_autotune_configs(),
461
+ key=["sparsity_block_size"],
462
+ prune_configs_by={"early_config_prune": prune_autotune_configs},
463
+ reset_to_zero=["o"]
464
+ )
465
+ @triton.jit
466
+ def softmax_fused_kernel_grad(g,
467
+ g_b, g_b_s, g_r_s, g_c_s,
468
+ x,
469
+ x_b, x_b_s, x_r_s, x_c_s,
470
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c: tl.constexpr, s_l_c_s,
471
+ r_lut,
472
+ o,
473
+ sparsity_block_size: tl.constexpr,
474
+ TRITON_BLOCK_SIZE: tl.constexpr) -> None:
475
+ # Get triton block indices
476
+ pid_bat = tl.program_id(axis=0)
477
+ pid_row = tl.program_id(axis=1)
478
+ pid_lin = tl.program_id(axis=2)
479
+
480
+ # Load reverse sparsity indices of row
481
+ blk_rev_idx = (pid_bat * s_l_b_s +
482
+ pid_row * s_l_r_s +
483
+ (tl.arange(0, s_l_c) * s_l_c_s))
484
+ blk_rev_msk = (blk_rev_idx >= 0 and blk_rev_idx < s_l_b * s_l_b_s)
485
+ blk_rev = tl.load(r_lut + blk_rev_idx, mask=blk_rev_msk).to(tl.int32)
486
+
487
+ if (not (tl.min(blk_rev) == -1 and
488
+ tl.max(blk_rev) == -1)):
489
+ # Extend sparsity indices to cover sparsity blocks
490
+ blk_rev_ext = tl.expand_dims(blk_rev, -1)
491
+ blk_rev_ext = tl.broadcast_to(blk_rev_ext, (s_l_c, sparsity_block_size))
492
+ blk_rev_ext = tl.reshape(blk_rev_ext, (s_l_c * sparsity_block_size))
493
+
494
+ # Load line of g
495
+ blk_g_idx = (blk_rev_ext * g_b_s +
496
+ pid_lin * g_r_s +
497
+ (tl.arange(0, s_l_c * sparsity_block_size) % sparsity_block_size) * g_c_s)
498
+ blk_g_mask = ((blk_g_idx >= 0 and blk_g_idx < g_b * g_b_s)
499
+ and blk_rev_ext != -1)
500
+ blk_g = tl.load(g + blk_g_idx, mask=blk_g_mask)
501
+
502
+ # Load line of x
503
+ blk_x_idx = (blk_rev_ext * x_b_s +
504
+ pid_lin * x_r_s +
505
+ (tl.arange(0, s_l_c * sparsity_block_size) % sparsity_block_size) * x_c_s)
506
+ blk_x_mask = ((blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
507
+ and blk_rev_ext != -1)
508
+ blk_x = tl.load(x + blk_x_idx, mask=blk_x_mask)
509
+
510
+ # Compute gradients
511
+ blk_grad = blk_x * (blk_g - tl.sum(blk_x * blk_g))
512
+
513
+ tl.store(o + blk_x_idx, blk_grad, mask=blk_x_mask)
514
+
515
+
516
+ def softmax_fused_build_lut(lut: dict, sparsity_layout: Tensor):
517
+ if lut is None:
518
+ lut = dict()
519
+
520
+ if "sparsity_reverse_lut" not in lut:
521
+ sparsity_layout_flat = sparsity_layout.reshape(-1)
522
+ sparsity_reverse_lut = (((torch.cumsum(sparsity_layout_flat, dim=-1) - 1) *
523
+ (sparsity_layout_flat == 1) -
524
+ (1 * (sparsity_layout_flat == 0)))
525
+ .reshape(sparsity_layout.size())
526
+ .reshape(-1).contiguous())
527
+ lut["sparsity_reverse_lut"] = sparsity_reverse_lut
528
+
529
+ validate_contiguous(sparsity_layout, lut["sparsity_reverse_lut"])
530
+
531
+ return lut
532
+
533
+
534
+ # noinspection PyUnusedLocal
535
+ def softmax_fused_setup_context(ctx, inputs, output):
536
+ (_, sparsity_layout, sparsity_reverse_lut, sparsity_block_size) = inputs
537
+
538
+ ctx.save_for_backward(output, sparsity_layout, sparsity_reverse_lut)
539
+ ctx.sparsity_block_size = sparsity_block_size
540
+
541
+
542
+ softmax_fused_forward.register_autograd(softmax_fused_backward_wrapper, setup_context=softmax_fused_setup_context)
blksprs/utils/tools.py CHANGED
@@ -1,6 +1,3 @@
1
- import tomllib
2
- from pathlib import Path
3
-
4
1
  import torch
5
2
  from torch import Tensor, Size
6
3
 
@@ -8,19 +5,14 @@ from torch import Tensor, Size
8
5
  torch._dynamo.config.capture_scalar_outputs = True
9
6
 
10
7
 
11
- def version():
12
- with open(Path(__file__).parent.parent.parent.joinpath("pyproject.toml"), "rb") as f:
13
- return tomllib.load(f)["project"]["version"]
14
-
15
-
16
- def do_shape_blocksparse(x: Tensor):
8
+ def do_shape_blocksparse(x: Tensor) -> tuple[Tensor, Size]:
17
9
  if x.dim() == 3:
18
10
  return x.contiguous(), x.size()
19
11
 
20
12
  return x.reshape(-1, x.size(-2), x.size(-1)).contiguous(), x.size()
21
13
 
22
14
 
23
- def undo_shape_blocksparse(x: Tensor, shape: Size):
15
+ def undo_shape_blocksparse(x: Tensor, shape: Size | tuple[int, ...]) -> Tensor:
24
16
  if x.shape[:-2] == shape[:-2]:
25
17
  return x
26
18
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: blksprs
3
- Version: 2.0rc8
3
+ Version: 2.1
4
4
  Summary: A lightweight library for operations on blocksparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -1,13 +1,13 @@
1
- blksprs/__init__.py,sha256=283rF0fbrUqsH_KXUvCgbCMqO0GOgenMkxwDVh1QdpU,1617
1
+ blksprs/__init__.py,sha256=o_Rj7fz_70vbMGLePihczVIVcM8E28vY3ah-d1q4ZO0,1613
2
2
  blksprs/layouting/distribution_layout.py,sha256=ur1ty_2U-Hfj78hMWsLZvu7ZuGhzW3qGLKMc72DfTZM,5861
3
3
  blksprs/layouting/sparsity_layout.py,sha256=eXHmu2h7K5Q-YUpfOxocJoeP_5ZoQFZf_eHLxRZQbYU,11207
4
- blksprs/ops/conversion.py,sha256=_g32aEEZdeuHHPj1pBfTNMxknRwJ9O1zk3Wv76pBIrg,21898
4
+ blksprs/ops/conversion.py,sha256=kf5HKofZ4nVeHCIqQoYKiIlgsAhq33Tnmnr1c17Fkqs,21906
5
5
  blksprs/ops/distribution.py,sha256=0tPldv0ARzmCV1CU2jvfqpHBgOuHPrDFiCtqsLs7CZc,20789
6
- blksprs/ops/flow.py,sha256=PDZAD8u4y9qW1IXERki6ItKbEKnm_ChG8SKWM3_P9Oc,8245
6
+ blksprs/ops/flow.py,sha256=qdWBCLDSkKaa8CAfkO1NgH-J5N7yMsILyR7qEpyrIUU,8246
7
7
  blksprs/ops/matmul.py,sha256=5tVBKU_lglUjaLDi6J_dscdqlmzRz38OGxqAxZxZXDs,11879
8
8
  blksprs/ops/partitioning.py,sha256=cfQmY9BZqGTvvJorIhtb-EyuGRJGPraWR-wTKdb47aI,9954
9
9
  blksprs/ops/repeat.py,sha256=TLYNxwPuT9y5K9xyM41WK5gnggAJF3lI61Q2K7zWjns,9035
10
- blksprs/ops/softmax.py,sha256=BwrRQdtRdkiSvl2mf5bpsTmyIxWiJOpa1HFg0st5yGU,12778
10
+ blksprs/ops/softmax.py,sha256=H0OxST_XX1QLa7HDTDHznzibVHAxnp5sVbMU32HLxf0,21967
11
11
  blksprs/ops/transpose.py,sha256=U-VAyLRT6_NDv9qYSFzBqfVlDeIpTqAMEXkqto0VF6w,4072
12
12
  blksprs/ops/misc/broadcast_ops.py,sha256=-PrHiSJikZh8nXUmXxSCtFEP27TTxFr4wcrNxBjnimk,5987
13
13
  blksprs/ops/misc/row_wise.py,sha256=n5FJjAuOd8BHBJQx4bsQwr-HmXkR9PYVAqfk77wjOFU,19653
@@ -15,9 +15,9 @@ blksprs/utils/autotuning.py,sha256=a-kmWRjJ3eED2XbjkQeOJSyW8bdIs27HgKMPvAKqWeU,2
15
15
  blksprs/utils/benchmarking.py,sha256=dLabDscTFn5NkmOI1g7DnKeTneUYW3RIVv9MDF-8BKc,1271
16
16
  blksprs/utils/blksprs_tensor.py,sha256=pfoz59aJixj_fIoFx76ySiygwRQUemmgjMKepZ2c4j0,244
17
17
  blksprs/utils/processing.py,sha256=RNkEDc0g-sNHRuMPkRzNWU13d3_lIkXMJdoqES4yQTM,3738
18
- blksprs/utils/tools.py,sha256=BozpH3oEXe3K9ZRJsIzlasDk-sZyJqmwSf1gl7xbbdo,865
18
+ blksprs/utils/tools.py,sha256=CPf7viQ2OTcZFrB1aSL8_us4VE9M6YEfDz2dE30jr9I,715
19
19
  blksprs/utils/validation.py,sha256=G8eQlvJVMKfEX3k2AwBD0A6Ck-gFoRLpLNY6HXsB3fA,4348
20
- blksprs-2.0rc8.dist-info/METADATA,sha256=h70L26BthR6laP7sMQLF9L3dHIRQNCF_oKwZ5g4dZSg,9509
21
- blksprs-2.0rc8.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
22
- blksprs-2.0rc8.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
23
- blksprs-2.0rc8.dist-info/RECORD,,
20
+ blksprs-2.1.dist-info/METADATA,sha256=uPVm8Y7fX5iModz6j3hNAftdtauCsJ-iYrMa-Pv3xnU,9506
21
+ blksprs-2.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
22
+ blksprs-2.1.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
23
+ blksprs-2.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.8.0)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5