blksprs 2.1.2__py3-none-any.whl → 2.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/ops/conversion.py CHANGED
@@ -56,7 +56,7 @@ def to_sparse(x: Tensor, sparsity_layout: Tensor,
56
56
  def to_sparse_forward(x: Tensor, _: Tensor,
57
57
  sparsity_lut: Tensor, sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
58
58
  with torch.no_grad():
59
- output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
59
+ output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
60
60
  dtype=x.dtype, device=x.device)
61
61
 
62
62
  x_b, x_r, x_c = x.size()
blksprs/ops/flow.py CHANGED
@@ -14,7 +14,7 @@ def flow_pull_forward(x: Tensor, sparsity_layout_o: Tensor,
14
14
  sparsity_lut: Tensor, sparsity_reverse_lut: Tensor,
15
15
  sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
16
16
  with torch.no_grad():
17
- output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
17
+ output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
18
18
  dtype=x.dtype, device=x.device)
19
19
 
20
20
  x_b, x_r, x_c = x.size()
blksprs/ops/matmul.py CHANGED
@@ -62,7 +62,7 @@ def matmul_forward(x: Tensor, y: Tensor,
62
62
  _: Tensor, sparsity_lut_o: Tensor,
63
63
  sparsity_block_size: int, n_sparse_blocks: int) -> Tensor:
64
64
  with torch.no_grad():
65
- output = torch.zeros(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
65
+ output = torch.empty(size=(n_sparse_blocks, sparsity_block_size, sparsity_block_size),
66
66
  dtype=x.dtype, device=x.device)
67
67
 
68
68
  x_b, x_r, x_c = x.size()
blksprs/ops/softmax.py CHANGED
@@ -66,7 +66,7 @@ def softmax_forward(x: Tensor, sparsity_layout: Tensor,
66
66
  sparsity_lut: Tensor,
67
67
  sparsity_reverse_lut_rws: Tensor,
68
68
  sparsity_block_size: int) -> Tensor:
69
- output = torch.zeros_like(x)
69
+ output = torch.empty_like(x)
70
70
 
71
71
  x_row_wise_max, sparsity_layout_rwm = row_wise_max(x, sparsity_layout, sparsity_block_size,
72
72
  flag_slice_only=True)
@@ -114,7 +114,7 @@ def softmax_backward_wrapper(ctx, grad_output):
114
114
  def softmax_backward(grad_output: Tensor, o: Tensor, sparsity_lut: Tensor, sparsity_layout: Tensor,
115
115
  sparsity_block_size: int) -> Tensor:
116
116
  with torch.no_grad():
117
- grad_x = torch.zeros_like(o, dtype=torch.float)
117
+ grad_x = torch.empty_like(o, dtype=torch.float)
118
118
 
119
119
  s, sparsity_layout_s = row_wise_sum(grad_output * o, sparsity_layout, sparsity_block_size, flag_slice_only=True)
120
120
 
@@ -359,7 +359,7 @@ def softmax_fused_forward(x: Tensor, sparsity_layout: Tensor,
359
359
  sparsity_reverse_lut_sorted: Tensor,
360
360
  max_blocks_line: int,
361
361
  sparsity_block_size: int) -> Tensor:
362
- output = torch.zeros_like(x)
362
+ output = torch.empty_like(x)
363
363
 
364
364
  x_b, x_r, x_c = x.size()
365
365
  x_b_s, x_r_s, x_c_s = stride(x)
@@ -374,7 +374,7 @@ def softmax_fused_forward(x: Tensor, sparsity_layout: Tensor,
374
374
  (x,
375
375
  x_b, x_b_s, x_r_s, x_c_s,
376
376
  output,
377
- s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
377
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c, s_l_c_s,
378
378
  sparsity_reverse_lut_sorted,
379
379
  max_blocks_line,
380
380
  sparsity_block_size))
@@ -399,7 +399,7 @@ def softmax_fused_backward(grad_output: Tensor,
399
399
  max_blocks_line: int,
400
400
  sparsity_block_size: int) -> Tensor:
401
401
  with torch.no_grad():
402
- grad_x = torch.zeros_like(o)
402
+ grad_x = torch.empty_like(o)
403
403
 
404
404
  g_b, g_r, g_c = grad_output.size()
405
405
  g_b_s, g_r_s, g_c_s = stride(grad_output)
@@ -417,7 +417,7 @@ def softmax_fused_backward(grad_output: Tensor,
417
417
  g_b, g_b_s, g_r_s, g_c_s,
418
418
  o,
419
419
  o_b, o_b_s, o_r_s, o_c_s,
420
- s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
420
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c, s_l_c_s,
421
421
  sparsity_reverse_lut_sorted,
422
422
  grad_x,
423
423
  max_blocks_line,
@@ -437,7 +437,7 @@ def softmax_fused_backward(grad_output: Tensor,
437
437
  def softmax_fused_kernel(x,
438
438
  x_b, x_b_s, x_r_s, x_c_s,
439
439
  o,
440
- s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
440
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c, s_l_c_s,
441
441
  r_lut_s,
442
442
  mbs: tl.constexpr,
443
443
  sparsity_block_size: tl.constexpr,
@@ -451,8 +451,9 @@ def softmax_fused_kernel(x,
451
451
  blk_rev_idx = (pid_bat * s_l_b_s +
452
452
  pid_row * s_l_r_s +
453
453
  (tl.arange(0, mbs) * s_l_c_s))
454
- blk_rev_msk = (blk_rev_idx >= 0 and blk_rev_idx < s_l_b * s_l_b_s)
455
- blk_rev = tl.load(r_lut_s + blk_rev_idx, mask=blk_rev_msk).to(tl.int32)
454
+ blk_rev_msk = ((blk_rev_idx >= 0 and blk_rev_idx < s_l_b * s_l_b_s) and
455
+ (tl.arange(0, mbs) < s_l_c))
456
+ blk_rev = tl.load(r_lut_s + blk_rev_idx, mask=blk_rev_msk, other=-1).to(tl.int32)
456
457
 
457
458
  if (not (tl.min(blk_rev) == -1 and
458
459
  tl.max(blk_rev) == -1)):
@@ -488,7 +489,7 @@ def softmax_fused_kernel_grad(g,
488
489
  g_b, g_b_s, g_r_s, g_c_s,
489
490
  x,
490
491
  x_b, x_b_s, x_r_s, x_c_s,
491
- s_l_b, s_l_b_s, s_l_r_s, s_l_c_s,
492
+ s_l_b, s_l_b_s, s_l_r_s, s_l_c, s_l_c_s,
492
493
  r_lut_s,
493
494
  o,
494
495
  mbs: tl.constexpr,
@@ -503,8 +504,9 @@ def softmax_fused_kernel_grad(g,
503
504
  blk_rev_idx = (pid_bat * s_l_b_s +
504
505
  pid_row * s_l_r_s +
505
506
  (tl.arange(0, mbs) * s_l_c_s))
506
- blk_rev_msk = (blk_rev_idx >= 0 and blk_rev_idx < s_l_b * s_l_b_s)
507
- blk_rev = tl.load(r_lut_s + blk_rev_idx, mask=blk_rev_msk).to(tl.int32)
507
+ blk_rev_msk = ((blk_rev_idx >= 0 and blk_rev_idx < s_l_b * s_l_b_s) and
508
+ (tl.arange(0, mbs) < s_l_c))
509
+ blk_rev = tl.load(r_lut_s + blk_rev_idx, mask=blk_rev_msk, other=-1).to(tl.int32)
508
510
 
509
511
  if (not (tl.min(blk_rev) == -1 and
510
512
  tl.max(blk_rev) == -1)):
@@ -557,7 +559,7 @@ def softmax_fused_build_lut(lut: dict, sparsity_layout: Tensor):
557
559
  .sum(dim=-1)
558
560
  .max()
559
561
  .item())
560
- lut["max_blocks_line"] = min(ceil_pow2(max(max_blocks_line, 2)), sparsity_layout.size(-1))
562
+ lut["max_blocks_line"] = ceil_pow2(max(max_blocks_line, 2))
561
563
 
562
564
  validate_contiguous(sparsity_layout, lut["sparsity_reverse_lut_sorted"])
563
565
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: blksprs
3
- Version: 2.1.2
3
+ Version: 2.1.3
4
4
  Summary: A lightweight library for operations on block-sparse matrices in PyTorch.
5
5
  Author-email: Felix Schön <schoen@kr.tuwien.ac.at>
6
6
  Project-URL: Homepage, https://github.com/FelixSchoen/blksprs
@@ -196,8 +196,8 @@ def test_readme():
196
196
 
197
197
  # Other available functions
198
198
  bs.ops.transpose(o_sparse, sparsity_layout_o, sparsity_block_size)
199
- bs.ops.softmax(o_sparse, sparsity_layout_o, sparsity_block_size)
200
- bs.ops.softmax_fused(o_sparse, sparsity_layout_o, sparsity_block_size) # Significantly faster version that requires that rows of matrix fit into memory
199
+ bs.ops.softmax(o_sparse, sparsity_layout_o, sparsity_block_size, flag_fused=False)
200
+ bs.ops.softmax_fused(o_sparse, sparsity_layout_o, sparsity_block_size) # Significantly faster version that requires that rows of matrix fit into memory (default if flag is not set)
201
201
  bs.ops.misc.row_wise_sum(o_sparse, sparsity_layout_o, sparsity_block_size)
202
202
  bs.ops.misc.row_wise_max(o_sparse, sparsity_layout_o, sparsity_block_size)
203
203
 
@@ -1,13 +1,13 @@
1
1
  blksprs/__init__.py,sha256=NRxydw4i9jg7WeDuojfEePdtdbughV9AZsEcT9yywK4,1615
2
2
  blksprs/layouting/distribution_layout.py,sha256=ur1ty_2U-Hfj78hMWsLZvu7ZuGhzW3qGLKMc72DfTZM,5861
3
3
  blksprs/layouting/sparsity_layout.py,sha256=eXHmu2h7K5Q-YUpfOxocJoeP_5ZoQFZf_eHLxRZQbYU,11207
4
- blksprs/ops/conversion.py,sha256=RgVSyiULLwv8KWQqSyXpKwTr4Qp-lpDK9i-zKlN841I,21914
4
+ blksprs/ops/conversion.py,sha256=nv5gXiyZkUtk1kCIlPr0Vpaj4G8G6dJdW7StlbV3nDw,21914
5
5
  blksprs/ops/distribution.py,sha256=0tPldv0ARzmCV1CU2jvfqpHBgOuHPrDFiCtqsLs7CZc,20789
6
- blksprs/ops/flow.py,sha256=qdWBCLDSkKaa8CAfkO1NgH-J5N7yMsILyR7qEpyrIUU,8246
7
- blksprs/ops/matmul.py,sha256=5tVBKU_lglUjaLDi6J_dscdqlmzRz38OGxqAxZxZXDs,11879
6
+ blksprs/ops/flow.py,sha256=oUn_xDT74220-EmnBnB8bRNtbS1mjbxWpm76PFsK22o,8246
7
+ blksprs/ops/matmul.py,sha256=ES9bpiCIRBxaynNIL5ftDP0c9LSArbj8YJqkPEzBaIU,11879
8
8
  blksprs/ops/partitioning.py,sha256=cfQmY9BZqGTvvJorIhtb-EyuGRJGPraWR-wTKdb47aI,9954
9
9
  blksprs/ops/repeat.py,sha256=TLYNxwPuT9y5K9xyM41WK5gnggAJF3lI61Q2K7zWjns,9035
10
- blksprs/ops/softmax.py,sha256=1UIovPrdE_zgAIPqjmOTFn8CMbd_2Z8tPP-vMBxU07I,23526
10
+ blksprs/ops/softmax.py,sha256=tfC_jaAKrA956rxGeb57klMuYRKTiyMCd5Zg5DIH3fc,23649
11
11
  blksprs/ops/transpose.py,sha256=U-VAyLRT6_NDv9qYSFzBqfVlDeIpTqAMEXkqto0VF6w,4072
12
12
  blksprs/ops/misc/broadcast_ops.py,sha256=-PrHiSJikZh8nXUmXxSCtFEP27TTxFr4wcrNxBjnimk,5987
13
13
  blksprs/ops/misc/row_wise.py,sha256=n5FJjAuOd8BHBJQx4bsQwr-HmXkR9PYVAqfk77wjOFU,19653
@@ -17,7 +17,7 @@ blksprs/utils/blksprs_tensor.py,sha256=pfoz59aJixj_fIoFx76ySiygwRQUemmgjMKepZ2c4
17
17
  blksprs/utils/processing.py,sha256=RNkEDc0g-sNHRuMPkRzNWU13d3_lIkXMJdoqES4yQTM,3738
18
18
  blksprs/utils/tools.py,sha256=TKygEKge4wJtJnXXDg8BTL8vzBpqIJsQ_A3_5FmLpcE,859
19
19
  blksprs/utils/validation.py,sha256=G8eQlvJVMKfEX3k2AwBD0A6Ck-gFoRLpLNY6HXsB3fA,4348
20
- blksprs-2.1.2.dist-info/METADATA,sha256=U20ZL7XLhrgiMd_0QGFik0Ci43SDoCT8q876-1yCeNo,9665
21
- blksprs-2.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
22
- blksprs-2.1.2.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
23
- blksprs-2.1.2.dist-info/RECORD,,
20
+ blksprs-2.1.3.dist-info/METADATA,sha256=6ZrxPPpkLwXgmq1d-4VQBNPNjlRm76dEMI-LJyiqlfI,9712
21
+ blksprs-2.1.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
22
+ blksprs-2.1.3.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
23
+ blksprs-2.1.3.dist-info/RECORD,,