blksprs 1.10.2__py3-none-any.whl → 2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blksprs/ops/misc/exp.py DELETED
@@ -1,104 +0,0 @@
1
- import torch
2
- import triton
3
- from torch import Tensor
4
- from triton import language as tl
5
-
6
- from blksprs.utils.blksprs_tensor import BlksprsTensor
7
- from blksprs.utils.tools import get_triton_block_size, stride
8
- from blksprs.utils.validation import validate_contiguous, validate_dimensions, validate_device, \
9
- validate_sparsity_block_size, validate_triton_block_size
10
-
11
-
12
- def exp(x: BlksprsTensor, sparsity_block_size: int, triton_block_size: int = None) -> BlksprsTensor:
13
- """Applies the element-wise exponential function to a block-sparse tensor.
14
-
15
- Note:
16
- This operation does not consider sparse blocks, i.e., these will not be set to ``e^0``.
17
- Consider this when converting back to tensors in regular form.
18
-
19
- Args:
20
- x (BlksprsTensor): A block-sparse tensor in compressed form.
21
- sparsity_block_size (int): The size of the sparsity blocks.
22
- triton_block_size (int): The block size to use for the triton kernel (default ``None``).
23
-
24
- Returns:
25
- BlksprsTensor: The exponential function applied to all elements of the input tensor as a block-sparse tensor in
26
- compressed form.
27
-
28
- """
29
- x = x.contiguous()
30
-
31
- validate_dimensions(x)
32
- validate_contiguous(x)
33
- validate_device(x)
34
- validate_sparsity_block_size(sparsity_block_size, x)
35
- validate_triton_block_size(triton_block_size, sparsity_block_size)
36
-
37
- return BlksprsTensor(_BlocksparseExp.apply(x, sparsity_block_size, triton_block_size))
38
-
39
-
40
- class _BlocksparseExp(torch.autograd.Function):
41
-
42
- @staticmethod
43
- def forward(ctx, x: Tensor, sparsity_block_size: int, triton_block_size: int) -> Tensor:
44
- output = torch.empty_like(x)
45
-
46
- x_b, x_r, x_c = x.shape
47
- x_b_s, x_r_s, x_c_s = stride(x)
48
- o_b, o_r, o_c = output.shape
49
- o_b_s, o_r_s, o_c_s = stride(output)
50
-
51
- if triton_block_size is None:
52
- triton_block_size = get_triton_block_size(sparsity_block_size)
53
-
54
- triton_grid = lambda meta: [o_b,
55
- triton.cdiv(o_r, meta["TRITON_BLOCK_SIZE"]),
56
- triton.cdiv(o_c, meta["TRITON_BLOCK_SIZE"])]
57
-
58
- (_BlocksparseExp.kernel_blocksparse_exp[triton_grid]
59
- (x,
60
- x_b, x_b_s, x_r_s, x_c_s,
61
- output,
62
- o_b, o_b_s, o_r_s, o_c_s,
63
- triton_block_size))
64
-
65
- ctx.save_for_backward(output)
66
-
67
- return output
68
-
69
- @staticmethod
70
- def backward(ctx, grad_output):
71
- o = ctx.saved_tensors[0]
72
-
73
- grad_x = torch.mul(grad_output, o)
74
-
75
- return grad_x, None, None
76
-
77
- @staticmethod
78
- @triton.jit
79
- def kernel_blocksparse_exp(x,
80
- x_b, x_b_s, x_r_s, x_c_s,
81
- o,
82
- o_b, o_b_s, o_r_s, o_c_s,
83
- TRITON_BLOCK_SIZE: tl.constexpr) -> None:
84
- # Get triton block indices
85
- pid_blk = tl.program_id(axis=0)
86
- pid_row = tl.program_id(axis=1)
87
- pid_col = tl.program_id(axis=2)
88
-
89
- # Load block
90
- blk_x_idx = ((pid_blk * x_b_s) +
91
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_r_s)[:, None] +
92
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * x_c_s)[None, :])
93
- blk_x_msk = (blk_x_idx >= 0 and blk_x_idx < x_b * x_b_s)
94
- blk_x = tl.load(x + blk_x_idx, mask=blk_x_msk)
95
-
96
- # Compute exp
97
- buf = tl.exp(blk_x)
98
-
99
- # Store block
100
- blk_o_idx = ((pid_blk * o_b_s) +
101
- ((pid_row * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_r_s)[:, None] +
102
- ((pid_col * TRITON_BLOCK_SIZE + tl.arange(0, TRITON_BLOCK_SIZE)) * o_c_s)[None, :])
103
- blk_o_msk = (blk_o_idx >= 0 and blk_o_idx < o_b * o_b_s)
104
- tl.store(o + blk_o_idx, buf, mask=blk_o_msk)
@@ -1,17 +0,0 @@
1
- import math
2
-
3
- import torch
4
- import triton
5
- from torch import Tensor
6
- from torch.xpu import device
7
- from triton import language as tl
8
-
9
- from blksprs.utils.blksprs_tensor import BlksprsTensor
10
- from blksprs.utils.tools import get_triton_block_size, stride
11
- from blksprs.utils.validation import validate_triton_block_size, validate_dimensions, validate_device, \
12
- validate_contiguous, validate_sparsity, validate_sparsity_block_size
13
-
14
-
15
- def build_full_sparsity_layout(x: Tensor, sparsity_block_size: int) -> Tensor:
16
- return torch.ones(size=(x.size(0), x.size(1) // sparsity_block_size, x.size(2) // sparsity_block_size),
17
- dtype=torch.bool, device=x.device)
@@ -1,24 +0,0 @@
1
- blksprs/__init__.py,sha256=wnpk-20jXq7xV0xa-WpHfPQuauI2gEZz9sH-0blKxP0,1766
2
- blksprs/layouting/distribution_layout.py,sha256=xDGY5-J7uSD8oenlf8bEJ2amMiQG3NBf2klTTydbTJE,5140
3
- blksprs/layouting/sparsity_layout.py,sha256=IVtHc_nN3ZM2y4GFcys70PqDWmWc7tkHlVGlToErANk,9894
4
- blksprs/ops/conversion.py,sha256=NK5uXMepPJ9yYh0vnxKwx5_Ffj_bAvhqPVogf_7PY0g,22248
5
- blksprs/ops/distribution.py,sha256=qK5t5XgQSJxXPced8RohprqCtUMMTaEP2pFm3KU1c8o,20267
6
- blksprs/ops/flow.py,sha256=Wv15oAhX4iqUzehj0XcNUWKjUcLaVB-5uSLEIsEREzA,6399
7
- blksprs/ops/matmul.py,sha256=LAQyPNwWVmBMRnAex3msLSPD_aG5SblLCMiutJWqvus,11632
8
- blksprs/ops/partitioning.py,sha256=ugKnpvH36ND7qeJQp56M74qqfACkzcTVuXebzw__28Y,8286
9
- blksprs/ops/repeat.py,sha256=RCa-dITomA5v12K5Oxa5_ReA361zS7WHPNNHxSp9PGw,8578
10
- blksprs/ops/softmax.py,sha256=i8NJhvPRYya94AzpN6qiki6_G9KfDrtPifhWd7wbYzk,12496
11
- blksprs/ops/transpose.py,sha256=oAtUu7QzQnNAH3lvRs_MIvIKpBu9h74f9Sk07AxKnDM,6991
12
- blksprs/ops/misc/broadcast_ops.py,sha256=pv0nssSDOdDbQFttpqUIs2ZXShqfm2RYCfJH-C5x3H0,5544
13
- blksprs/ops/misc/exp.py,sha256=ygfw7oD6ALdPwNQX_HelKgO8I3-LCgIXH_x0gWzkUN8,3840
14
- blksprs/ops/misc/row_wise.py,sha256=DnV5-xEJUbqZlK2fETwHiPQDUMwT-lkc0VUhBlnJ5Y0,17458
15
- blksprs/utils/benchmarking.py,sha256=4pLVlnPW_2EM-NT3n4SClaRznVYEljztLbJcccz8kZE,1360
16
- blksprs/utils/blksprs_tensor.py,sha256=pfoz59aJixj_fIoFx76ySiygwRQUemmgjMKepZ2c4j0,244
17
- blksprs/utils/layout_utils.py,sha256=49ZdPS_gMn_IrWty3FARbi2rda5a8g5DmAEL8LOrC30,670
18
- blksprs/utils/processing.py,sha256=WLuMJQ8v-YovXwcDjhlDn3N31WMZXrtyeeyKSgq_zn4,3642
19
- blksprs/utils/tools.py,sha256=k2OfEplbQiAwVjP84zZf7SNB8FzvMtOFBL9sC98OCbI,683
20
- blksprs/utils/validation.py,sha256=CbxBbeQWJo8wox5eMoVzaTlP9FVBwt3-gxUOmi3EUgw,4213
21
- blksprs-1.10.2.dist-info/METADATA,sha256=sm32ieVfYJ_bM5KtKbqF8DjHJ-4L4LMweEPwJWZvZG0,9107
22
- blksprs-1.10.2.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
23
- blksprs-1.10.2.dist-info/top_level.txt,sha256=qyp0IHeY3H2GQA97i4hk_To5rRBS2YcE1HRPSLy04fk,8
24
- blksprs-1.10.2.dist-info/RECORD,,