torch-bessel 0.0.2__cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Sign up to get free protection for your applications and to get access to all the features.
- torch_bessel/_C.cpython-310-x86_64-linux-gnu.so +0 -0
- torch_bessel/__init__.py +9 -0
- torch_bessel/ops.py +75 -0
- torch_bessel-0.0.2.dist-info/LICENSE +21 -0
- torch_bessel-0.0.2.dist-info/METADATA +39 -0
- torch_bessel-0.0.2.dist-info/RECORD +8 -0
- torch_bessel-0.0.2.dist-info/WHEEL +6 -0
- torch_bessel-0.0.2.dist-info/top_level.txt +1 -0
Binary file
|
torch_bessel/__init__.py
ADDED
torch_bessel/ops.py
ADDED
@@ -0,0 +1,75 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from torch import Tensor
|
5
|
+
|
6
|
+
__all__ = ["modified_bessel_k0"]
|
7
|
+
|
8
|
+
# load C extension before calling torch.library API, see
|
9
|
+
# https://pytorch.org/tutorials/advanced/cpp_custom_ops.html
|
10
|
+
so_dir = Path(__file__).parent
|
11
|
+
so_files = list(so_dir.glob("_C*.so"))
|
12
|
+
assert (
|
13
|
+
len(so_files) == 1
|
14
|
+
), f"Expected one _C*.so file at {so_dir}, found {len(so_files)}"
|
15
|
+
torch.ops.load_library(so_files[0])
|
16
|
+
|
17
|
+
|
18
|
+
class ModifiedBesselK0(torch.autograd.Function):
|
19
|
+
@staticmethod
|
20
|
+
def forward(x):
|
21
|
+
return torch.special.modified_bessel_k0(x)
|
22
|
+
|
23
|
+
@staticmethod
|
24
|
+
def setup_context(ctx, inputs, _):
|
25
|
+
if ctx.needs_input_grad[0]:
|
26
|
+
ctx.save_for_backward(*inputs)
|
27
|
+
ctx.set_materialize_grads(False)
|
28
|
+
|
29
|
+
@staticmethod
|
30
|
+
def backward(ctx, grad):
|
31
|
+
if grad is None or not ctx.needs_input_grad[0]:
|
32
|
+
return None
|
33
|
+
|
34
|
+
(x,) = ctx.saved_tensors
|
35
|
+
return -torch.special.modified_bessel_k1(x).mul_(grad)
|
36
|
+
|
37
|
+
|
38
|
+
def modified_bessel_k0(z: Tensor) -> Tensor:
|
39
|
+
if not z.is_complex():
|
40
|
+
return ModifiedBesselK0.apply(z)
|
41
|
+
if not z.requires_grad:
|
42
|
+
return torch.ops.torch_bessel.modified_bessel_k0_complex_forward.default(z)
|
43
|
+
return torch.ops.torch_bessel.modified_bessel_k0_complex_forward_backward.default(
|
44
|
+
z
|
45
|
+
)[0]
|
46
|
+
|
47
|
+
|
48
|
+
@torch.library.register_fake("torch_bessel::modified_bessel_k0_complex_forward")
|
49
|
+
def _(z):
|
50
|
+
return torch.empty_like(z)
|
51
|
+
|
52
|
+
|
53
|
+
@torch.library.register_fake(
|
54
|
+
"torch_bessel::modified_bessel_k0_complex_forward_backward"
|
55
|
+
)
|
56
|
+
def _(z):
|
57
|
+
return torch.empty_like(z), torch.empty_like(z)
|
58
|
+
|
59
|
+
|
60
|
+
def modified_bessel_k0_backward(ctx, grad, _):
|
61
|
+
if ctx.needs_input_grad[0]:
|
62
|
+
return grad * ctx.saved_tensors[0]
|
63
|
+
return None
|
64
|
+
|
65
|
+
|
66
|
+
def modified_bessel_k0_setup_context(ctx, inputs, output):
|
67
|
+
if ctx.needs_input_grad[0]:
|
68
|
+
ctx.save_for_backward(output[-1])
|
69
|
+
|
70
|
+
|
71
|
+
torch.library.register_autograd(
|
72
|
+
"torch_bessel::modified_bessel_k0_complex_forward_backward",
|
73
|
+
modified_bessel_k0_backward,
|
74
|
+
setup_context=modified_bessel_k0_setup_context,
|
75
|
+
)
|
@@ -0,0 +1,21 @@
|
|
1
|
+
MIT License
|
2
|
+
|
3
|
+
Copyright (c) 2025 Ho Yin Chau
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
13
|
+
copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
+
SOFTWARE.
|
@@ -0,0 +1,39 @@
|
|
1
|
+
Metadata-Version: 2.2
|
2
|
+
Name: torch_bessel
|
3
|
+
Version: 0.0.2
|
4
|
+
Summary: PyTorch extension package for Bessel functions with arbitrary real order and complex inputs
|
5
|
+
Home-page: https://github.com/pytorch/torch-bessel
|
6
|
+
Author: Ho Yin Chau
|
7
|
+
Requires-Python: >= 3.9
|
8
|
+
Description-Content-Type: text/markdown
|
9
|
+
License-File: LICENSE
|
10
|
+
Requires-Dist: torch
|
11
|
+
Dynamic: author
|
12
|
+
Dynamic: description
|
13
|
+
Dynamic: description-content-type
|
14
|
+
Dynamic: home-page
|
15
|
+
Dynamic: requires-dist
|
16
|
+
Dynamic: requires-python
|
17
|
+
Dynamic: summary
|
18
|
+
|
19
|
+
# About
|
20
|
+
PyTorch extension package for Bessel functions with arbitrary real order and complex inputs
|
21
|
+
|
22
|
+
# Install
|
23
|
+
```
|
24
|
+
pip install torch-bessel
|
25
|
+
```
|
26
|
+
|
27
|
+
# Example usage
|
28
|
+
```
|
29
|
+
import torch_bessel
|
30
|
+
|
31
|
+
z = torch.randn(10) + 1j
|
32
|
+
torch_bessel.ops.modified_bessel_k0(z)
|
33
|
+
```
|
34
|
+
|
35
|
+
# Implemented functions
|
36
|
+
- `modified_bessel_k0`: Same as `torch.special.modified_bessel_k0`, but also handles backpropagation and complex inputs with $\mathrm{Im}(z) \geq 0$ on cpu and cuda.
|
37
|
+
|
38
|
+
# WIP
|
39
|
+
- `modified_bessel_kv`: Analogue of `scipy.special.kv`.
|
@@ -0,0 +1,8 @@
|
|
1
|
+
torch_bessel-0.0.2.dist-info/METADATA,sha256=F8Mbw3qz2XuV82NoMc7HWKQ2lKivcx1X9ht2AScvSfo,999
|
2
|
+
torch_bessel-0.0.2.dist-info/RECORD,,
|
3
|
+
torch_bessel-0.0.2.dist-info/top_level.txt,sha256=cbDIjTj71LuAlVyyYyDt8fOAeLaVeX3Vums5F2FBa-4,13
|
4
|
+
torch_bessel-0.0.2.dist-info/WHEEL,sha256=ViyZsTV2upbIniGkknQiIrLPLs1cJIoIfr1wsV7PMic,151
|
5
|
+
torch_bessel-0.0.2.dist-info/LICENSE,sha256=do0DI6wu4mF3VXnEXXPYZqVEatoRSSamgz9t80wU7_o,1068
|
6
|
+
torch_bessel/_C.cpython-310-x86_64-linux-gnu.so,sha256=sUtqFAU-sf9TTDvutaeQZLt416SMVZVWrqDCVhUpTCQ,66691960
|
7
|
+
torch_bessel/ops.py,sha256=Q9BrLxi15MS53xSt_S9dyE3g8_8_GCFYhfAztIor8Fw,2043
|
8
|
+
torch_bessel/__init__.py,sha256=oohbWz8vxekl7kqDNSWqDH3ORabf9-Tug1KJryKw51A,230
|
@@ -0,0 +1 @@
|
|
1
|
+
torch_bessel
|