blksprs 1.10.2__py3-none-any.whl → 2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blksprs/__init__.py +4 -6
- blksprs/layouting/distribution_layout.py +64 -48
- blksprs/layouting/sparsity_layout.py +96 -72
- blksprs/ops/conversion.py +350 -312
- blksprs/ops/distribution.py +320 -266
- blksprs/ops/flow.py +135 -89
- blksprs/ops/matmul.py +184 -151
- blksprs/ops/misc/broadcast_ops.py +68 -53
- blksprs/ops/misc/row_wise.py +223 -176
- blksprs/ops/partitioning.py +140 -89
- blksprs/ops/repeat.py +118 -108
- blksprs/ops/softmax.py +201 -167
- blksprs/ops/transpose.py +71 -131
- blksprs/utils/autotuning.py +78 -0
- blksprs/utils/benchmarking.py +3 -3
- blksprs/utils/processing.py +2 -1
- blksprs/utils/tools.py +5 -6
- blksprs/utils/validation.py +22 -16
- {blksprs-1.10.2.dist-info → blksprs-2.0.dist-info}/METADATA +55 -36
- blksprs-2.0.dist-info/RECORD +23 -0
- {blksprs-1.10.2.dist-info → blksprs-2.0.dist-info}/WHEEL +1 -1
- blksprs/ops/misc/exp.py +0 -104
- blksprs/utils/layout_utils.py +0 -17
- blksprs-1.10.2.dist-info/RECORD +0 -24
- {blksprs-1.10.2.dist-info → blksprs-2.0.dist-info}/top_level.txt +0 -0
blksprs/ops/misc/exp.py
DELETED
|
@@ -1,104 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
import triton
|
|
3
|
-
from torch import Tensor
|
|
4
|
-
from triton import language as tl
|
|
5
|
-
|
|
6
|
-
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
7
|
-
from blksprs.utils.tools import get_triton_block_size, stride
|
|
8
|
-
from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
|
|
9
|
-
validate_sparsity_block_size, validate_triton_block_size
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
def exp(x: BlksprsTensor, sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
|
|
13
|
-
"""Applies the element-wise exponential function to a block-sparse tensor.
|
|
14
|
-
|
|
15
|
-
Note:
|
|
16
|
-
This operation does not consider sparse blocks, i.e., these will not be set to ``e^0``.
|
|
17
|
-
Consider this when converting back to tensors in regular form.
|
|
18
|
-
|
|
19
|
-
Args:
|
|
20
|
-
x (BlksprsTensor): A block-sparse tensor in compressed form.
|
|
21
|
-
sparsity_block_size (int): The size of the sparsity blocks.
|
|
22
|
-
triton_block_size (int): The block size to use for the triton kernel (default ``None``).
|
|
23
|
-
|
|
24
|
-
Returns:
|
|
25
|
-
BlksprsTensor: The exponential function applied to all elements of the input tensor as a block-sparse tensor in
|
|
26
|
-
compressed form.
|
|
27
|
-
|
|
28
|
-
"""
|
|
29
|
-
x = x.contiguous()
|
|
30
|
-
|
|
31
|
-
validate_dimensions(x)
|
|
32
|
-
validate_contiguous(x)
|
|
33
|
-
validate_device(x)
|
|
34
|
-
validate_sparsity_block_size(sparsity_block_size, x)
|
|
35
|
-
validate_triton_block_size(triton_block_size, sparsity_block_size)
|
|
36
|
-
|
|
37
|
-
return BlksprsTensor(_BlocksparseExp.apply(x, sparsity_block_size, triton_block_size))
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
class _BlocksparseExp(torch.autograd.Function):
|
|
41
|
-
|
|
42
|
-
@staticmethod
|
|
43
|
-
def forward(ctx, x: Tensor, sparsity_block_size: int, triton_block_size: int) -> Tensor:
|
|
44
|
-
output = torch.empty_like(x)
|
|
45
|
-
|
|
46
|
-
x_b, x_r, x_c = x.shape
|
|
47
|
-
x_b_s, x_r_s, x_c_s = stride(x)
|
|
48
|
-
o_b, o_r, o_c = output.shape
|
|
49
|
-
o_b_s, o_r_s, o_c_s = stride(output)
|
|
50
|
-
|
|
51
|
-
if triton_block_size is None:
|
|
52
|
-
triton_block_size = get_triton_block_size(sparsity_block_size)
|
|
53
|
-
|
|
54
|
-
triton_grid = lambda meta: [o_b,
|
|
55
|
-
triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
|
|
56
|
-
triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
|
|
57
|
-
|
|
58
|
-
(_BlocksparseExp.kernel_blocksparse_exp[triton_grid]
|
|
59
|
-
(x,
|
|
60
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
61
|
-
output,
|
|
62
|
-
o_b, o_b_s, o_r_s, o_c_s,
|
|
63
|
-
triton_block_size))
|
|
64
|
-
|
|
65
|
-
ctx.save_for_backward(output)
|
|
66
|
-
|
|
67
|
-
return output
|
|
68
|
-
|
|
69
|
-
@staticmethod
|
|
70
|
-
def backward(ctx, grad_output):
|
|
71
|
-
o = ctx.saved_tensors[0]
|
|
72
|
-
|
|
73
|
-
grad_x = torch.mul(grad_output, o)
|
|
74
|
-
|
|
75
|
-
return grad_x, None, None
|
|
76
|
-
|
|
77
|
-
@staticmethod
|
|
78
|
-
@triton.jit
|
|
79
|
-
def kernel_blocksparse_exp(x,
|
|
80
|
-
x_b, x_b_s, x_r_s, x_c_s,
|
|
81
|
-
o,
|
|
82
|
-
o_b, o_b_s, o_r_s, o_c_s,
|
|
83
|
-
TRITON_BLOCK_SIZE: tl.constexpr) -> None:
|
|
84
|
-
# Get triton block indices
|
|
85
|
-
pid_blk = tl.program_id(axis=0)
|
|
86
|
-
pid_row = tl.program_id(axis=1)
|
|
87
|
-
pid_col = tl.program_id(axis=2)
|
|
88
|
-
|
|
89
|
-
# Load block
|
|
90
|
-
blk_x_idx = ((pid_blk * x_b_s) +
|
|
91
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
|
|
92
|
-
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
|
|
93
|
-
blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
|
|
94
|
-
blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
|
|
95
|
-
|
|
96
|
-
# Compute exp
|
|
97
|
-
buf = tl.exp(blk_x)
|
|
98
|
-
|
|
99
|
-
# Store block
|
|
100
|
-
blk_o_idx = ((pid_blk * o_b_s) +
|
|
101
|
-
((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
|
|
102
|
-
((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
|
|
103
|
-
blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
|
|
104
|
-
tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
|
blksprs/utils/layout_utils.py
DELETED
|
@@ -1,17 +0,0 @@
|
|
|
1
|
-
import math
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
import triton
|
|
5
|
-
from torch import Tensor
|
|
6
|
-
from torch.xpu import device
|
|
7
|
-
from triton import language as tl
|
|
8
|
-
|
|
9
|
-
from blksprs.utils.blksprs_tensor import BlksprsTensor
|
|
10
|
-
from blksprs.utils.tools import get_triton_block_size, stride
|
|
11
|
-
from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
|
|
12
|
-
validate_contiguous, validate_sparsity, validate_sparsity_block_size
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
def build_full_sparsity_layout(x: Tensor, sparsity_block_size: int) -> Tensor:
|
|
16
|
-
return torch.ones(size=(x.size(0), x.size(1) // sparsity_block_size, x.size(2) // sparsity_block_size),
|
|
17
|
-
dtype=torch.bool, device=x.device)
|
blksprs-1.10.2.dist-info/RECORD
DELETED
|
@@ -1,24 +0,0 @@
|
|
|
1
|
-
blksprs/__init__.py,sha256=wnpk-20jXq7xV0xa-WpHfPQuauI2gEZz9sH-0blKxP0,1766
|
|
2
|
-
blksprs/layouting/distribution_layout.py,sha256=xDGY5-J7uSD8oenlf8bEJ2amMiQG3NBf2klTTydbTJE,5140
|
|
3
|
-
blksprs/layouting/sparsity_layout.py,sha256=IVtHc_nN3ZM2y4GFcys70PqDWmWc7tkHlVGlToErANk,9894
|
|
4
|
-
blksprs/ops/conversion.py,sha256=NK5uXMepPJ9yYh0vnxKwx5_Ffj_bAvhqPVogf_7PY0g,22248
|
|
5
|
-
blksprs/ops/distribution.py,sha256=qK5t5XgQSJxXPced8RohprqCtUMMTaEP2pFm3KU1c8o,20267
|
|
6
|
-
blksprs/ops/flow.py,sha256=Wv15oAhX4iqUzehj0XcNUWKjUcLaVB-5uSLEIsEREzA,6399
|
|
7
|
-
blksprs/ops/matmul.py,sha256=LAQyPNwWVmBMRnAex3msLSPD_aG5SblLCMiutJWqvus,11632
|
|
8
|
-
blksprs/ops/partitioning.py,sha256=ugKnpvH36ND7qeJQp56M74qqfACkzcTVuXebzw__28Y,8286
|
|
9
|
-
blksprs/ops/repeat.py,sha256=RCa-dITomA5v12K5Oxa5_ReA361zS7WHPNNHxSp9PGw,8578
|
|
10
|
-
blksprs/ops/softmax.py,sha256=i8NJhvPRYya94AzpN6qiki6_G9KfDrtPifhWd7wbYzk,12496
|
|
11
|
-
blksprs/ops/transpose.py,sha256=oAtUu7QzQnNAH3lvRs_MIvIKpBu9h74f9Sk07AxKnDM,6991
|
|
12
|
-
blksprs/ops/misc/broadcast_ops.py,sha256=pv0nssSDOdDbQFttpqUIs2ZXShqfm2RYCfJH-C5x3H0,5544
|
|
13
|
-
blksprs/ops/misc/exp.py,sha256=ygfw7oD6ALdPwNQX_HelKgO8I3-LCgIXH_x0gWzkUN8,3840
|
|
14
|
-
blksprs/ops/misc/row_wise.py,sha256=DnV5-xEJUbqZlK2fETwHiPQDUMwT-lkc0VUhBlnJ5Y0,17458
|
|
15
|
-
blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
|
|
16
|
-
blksprs/utils/blksprs_tensor.py,sha256=pfoz59aJixj_fIoFx76ySiygwRQUemmgjMKepZ2c4j0,244
|
|
17
|
-
blksprs/utils/layout_utils.py,sha256=49ZdPS_gMn_IrWty3FARbi2rda5a8g5DmAEL8LOrC30,670
|
|
18
|
-
blksprs/utils/processing.py,sha256=WLuMJQ8v-YovXwcDjhlDn3N31WMZXrtyeeyKSgq_zn4,3642
|
|
19
|
-
blksprs/utils/tools.py,sha256=k2OfEplbQiAwVjP84zZf7SNB8FzvMtOFBL9sC98OCbI,683
|
|
20
|
-
blksprs/utils/validation.py,sha256=CbxBbeQWJo8wox5eMoVzaTlP9FVBwt3-gxUOmi3EUgw,4213
|
|
21
|
-
blksprs-1.10.2.dist-info/METADATA,sha256=sm32ieVfYJ_bM5KtKbqF8DjHJ-4L4LMweEPwJWZvZG0,9107
|
|
22
|
-
blksprs-1.10.2.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
|
23
|
-
blksprs-1.10.2.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
|
|
24
|
-
blksprs-1.10.2.dist-info/RECORD,,
|
|
File without changes
|