torch-bessel 0.0.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_bessel-0.0.9/LICENSE +21 -0
- torch_bessel-0.0.9/PKG-INFO +50 -0
- torch_bessel-0.0.9/README.md +31 -0
- torch_bessel-0.0.9/benchmarks/__init__.py +0 -0
- torch_bessel-0.0.9/benchmarks/benchmarks.py +143 -0
- torch_bessel-0.0.9/pyproject.toml +9 -0
- torch_bessel-0.0.9/setup.cfg +4 -0
- torch_bessel-0.0.9/setup.py +80 -0
- torch_bessel-0.0.9/tests/test_extension.py +281 -0
- torch_bessel-0.0.9/torch_bessel/__init__.py +11 -0
- torch_bessel-0.0.9/torch_bessel/csrc/iterator.cpp +60 -0
- torch_bessel-0.0.9/torch_bessel/csrc/torch_bessel.cpp +94 -0
- torch_bessel-0.0.9/torch_bessel/ops.py +134 -0
- torch_bessel-0.0.9/torch_bessel.egg-info/PKG-INFO +50 -0
- torch_bessel-0.0.9/torch_bessel.egg-info/SOURCES.txt +16 -0
- torch_bessel-0.0.9/torch_bessel.egg-info/dependency_links.txt +1 -0
- torch_bessel-0.0.9/torch_bessel.egg-info/requires.txt +1 -0
- torch_bessel-0.0.9/torch_bessel.egg-info/top_level.txt +2 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Ho Yin Chau
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: torch_bessel
|
|
3
|
+
Version: 0.0.9
|
|
4
|
+
Summary: PyTorch extension package for Bessel functions with arbitrary real order and complex inputs
|
|
5
|
+
Home-page: https://github.com/hchau630/torch-bessel
|
|
6
|
+
Author: Ho Yin Chau
|
|
7
|
+
Requires-Python: >= 3.10
|
|
8
|
+
Description-Content-Type: text/markdown
|
|
9
|
+
License-File: LICENSE
|
|
10
|
+
Requires-Dist: torch
|
|
11
|
+
Dynamic: author
|
|
12
|
+
Dynamic: description
|
|
13
|
+
Dynamic: description-content-type
|
|
14
|
+
Dynamic: home-page
|
|
15
|
+
Dynamic: license-file
|
|
16
|
+
Dynamic: requires-dist
|
|
17
|
+
Dynamic: requires-python
|
|
18
|
+
Dynamic: summary
|
|
19
|
+
|
|
20
|
+
# About
|
|
21
|
+
PyTorch extension package for modified Bessel functions of the second kind with complex inputs
|
|
22
|
+
|
|
23
|
+
# Install
|
|
24
|
+
Currently only supports Linux (with CUDA 12.4) or MacOS (Apple silicon, cpu only) with python >= 3.9, <= 3.12.
|
|
25
|
+
```
|
|
26
|
+
pip install torch-bessel
|
|
27
|
+
```
|
|
28
|
+
|
|
29
|
+
UPDATE (May 26, 2026): It seems the package might no longer be compatible with the latest pytorch version (>= 2.12.0), so try installing an earlier version of pytorch if you encounter issues.
|
|
30
|
+
|
|
31
|
+
UPDATE (May 28, 2026): On Apple silicon, importing this package with pytorch version >= 2.9.0 causes a crash (bus error: 10) when exiting a python program, though the program runs normally prior to that. So consider installing torch < 2.9.0 on Apple silicon.
|
|
32
|
+
|
|
33
|
+
# Example
|
|
34
|
+
```
|
|
35
|
+
import torch_bessel
|
|
36
|
+
|
|
37
|
+
real, imag = torch.randn(2, 5, device="cuda")
|
|
38
|
+
z = torch.complex(real.abs(), imag) # correctness for inputs in the left-half complex plane is not gauranteed.
|
|
39
|
+
torch_bessel.ops.modified_bessel_k0(z)
|
|
40
|
+
```
|
|
41
|
+
|
|
42
|
+
# Implemented functions
|
|
43
|
+
- `modified_bessel_k0`: Same as `torch.special.modified_bessel_k0`, but also handles backpropagation and complex inputs on cpu and cuda. Correctness is guaranteed on the right-half complex plane. On cuda, `torch.chalf` inputs are also supported, though the underlying cuda kernel just upcasts `chalf` to `cfloat` (note that this uses no extra GPU memory, as opposed to manually casting torch.chalf to torch.cfloat before calling `modified_bessel_k0` which doubles the GPU memory used). On the left-half complex plane, function output appears mostly correct, but with small numerical errors for certain inputs. On the negative real line, output is NaN.
|
|
44
|
+
- `modified_bessel_k1`: Same as `torch.special.modified_bessel_k1`, but also handles complex inputs on cpu and cuda. Backpropagation not implemented, but this can be easily manually implemented yourself by writing a torch.autograd.Function using the recurrence properties of bessel functions. Same caveats as `modified_bessel_k0` apply.
|
|
45
|
+
|
|
46
|
+
# WIP
|
|
47
|
+
- `modified_bessel_kv`: Analogue of `scipy.special.kv`.
|
|
48
|
+
|
|
49
|
+
# Benchmarks
|
|
50
|
+
Benchmarking performed with the `asv` package. Results can be viewed at https://hchau630.github.io/torch-bessel.
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
# About
|
|
2
|
+
PyTorch extension package for modified Bessel functions of the second kind with complex inputs
|
|
3
|
+
|
|
4
|
+
# Install
|
|
5
|
+
Currently only supports Linux (with CUDA 12.4) or MacOS (Apple silicon, cpu only) with python >= 3.9, <= 3.12.
|
|
6
|
+
```
|
|
7
|
+
pip install torch-bessel
|
|
8
|
+
```
|
|
9
|
+
|
|
10
|
+
UPDATE (May 26, 2026): It seems the package might no longer be compatible with the latest pytorch version (>= 2.12.0), so try installing an earlier version of pytorch if you encounter issues.
|
|
11
|
+
|
|
12
|
+
UPDATE (May 28, 2026): On Apple silicon, importing this package with pytorch version >= 2.9.0 causes a crash (bus error: 10) when exiting a python program, though the program runs normally prior to that. So consider installing torch < 2.9.0 on Apple silicon.
|
|
13
|
+
|
|
14
|
+
# Example
|
|
15
|
+
```
|
|
16
|
+
import torch_bessel
|
|
17
|
+
|
|
18
|
+
real, imag = torch.randn(2, 5, device="cuda")
|
|
19
|
+
z = torch.complex(real.abs(), imag) # correctness for inputs in the left-half complex plane is not gauranteed.
|
|
20
|
+
torch_bessel.ops.modified_bessel_k0(z)
|
|
21
|
+
```
|
|
22
|
+
|
|
23
|
+
# Implemented functions
|
|
24
|
+
- `modified_bessel_k0`: Same as `torch.special.modified_bessel_k0`, but also handles backpropagation and complex inputs on cpu and cuda. Correctness is guaranteed on the right-half complex plane. On cuda, `torch.chalf` inputs are also supported, though the underlying cuda kernel just upcasts `chalf` to `cfloat` (note that this uses no extra GPU memory, as opposed to manually casting torch.chalf to torch.cfloat before calling `modified_bessel_k0` which doubles the GPU memory used). On the left-half complex plane, function output appears mostly correct, but with small numerical errors for certain inputs. On the negative real line, output is NaN.
|
|
25
|
+
- `modified_bessel_k1`: Same as `torch.special.modified_bessel_k1`, but also handles complex inputs on cpu and cuda. Backpropagation not implemented, but this can be easily manually implemented yourself by writing a torch.autograd.Function using the recurrence properties of bessel functions. Same caveats as `modified_bessel_k0` apply.
|
|
26
|
+
|
|
27
|
+
# WIP
|
|
28
|
+
- `modified_bessel_kv`: Analogue of `scipy.special.kv`.
|
|
29
|
+
|
|
30
|
+
# Benchmarks
|
|
31
|
+
Benchmarking performed with the `asv` package. Results can be viewed at https://hchau630.github.io/torch-bessel.
|
|
File without changes
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from asv_runner.benchmarks.mark import skip_benchmark_if
|
|
3
|
+
|
|
4
|
+
import torch_bessel
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def _setup(
|
|
8
|
+
n, is_real, singularity, dtype=torch.float, requires_grad=False, device="cpu"
|
|
9
|
+
):
|
|
10
|
+
kwargs = {"dtype": dtype, "requires_grad": requires_grad, "device": device}
|
|
11
|
+
real = torch.randn(n, **kwargs).abs()
|
|
12
|
+
if is_real:
|
|
13
|
+
args = (real, singularity)
|
|
14
|
+
else:
|
|
15
|
+
imag = torch.randn(n, **kwargs)
|
|
16
|
+
args = (torch.complex(real, imag), singularity)
|
|
17
|
+
return args
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _setup_k1(n, is_real, dtype=torch.float, device="cpu"):
|
|
21
|
+
kwargs = {"dtype": dtype, "device": device}
|
|
22
|
+
real = torch.randn(n, **kwargs).abs()
|
|
23
|
+
if is_real:
|
|
24
|
+
args = (real,)
|
|
25
|
+
else:
|
|
26
|
+
imag = torch.randn(n, **kwargs)
|
|
27
|
+
args = (torch.complex(real, imag),)
|
|
28
|
+
return args
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class ModifiedBesselK0ForwardCPU:
|
|
32
|
+
params = (
|
|
33
|
+
[10_000, 100_000, 1_000_000],
|
|
34
|
+
[False, True],
|
|
35
|
+
[None, 0.0],
|
|
36
|
+
[torch.float32, torch.float64],
|
|
37
|
+
[False, True],
|
|
38
|
+
)
|
|
39
|
+
param_names = ["n", "is_real", "singularity", "dtype", "requires_grad"]
|
|
40
|
+
|
|
41
|
+
def setup(self, n, is_real, singularity, dtype, requires_grad):
|
|
42
|
+
self.args = _setup(n, is_real, singularity, dtype, requires_grad)
|
|
43
|
+
|
|
44
|
+
def time_modified_bessel_k0_forward_cpu(
|
|
45
|
+
self, n, is_real, singularity, dtype, requires_grad
|
|
46
|
+
):
|
|
47
|
+
torch_bessel.ops.modified_bessel_k0(*self.args)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class ModifiedBesselK0ForwardCUDA:
|
|
51
|
+
params = (
|
|
52
|
+
[10_000, 100_000, 1_000_000],
|
|
53
|
+
[False, True],
|
|
54
|
+
[None, 0.0],
|
|
55
|
+
[torch.float32, torch.float64],
|
|
56
|
+
[False, True],
|
|
57
|
+
)
|
|
58
|
+
param_names = ["n", "is_real", "singularity", "dtype", "requires_grad"]
|
|
59
|
+
|
|
60
|
+
def setup(self, n, is_real, singularity, dtype, requires_grad):
|
|
61
|
+
self.args = _setup(n, is_real, singularity, dtype, requires_grad, device="cuda")
|
|
62
|
+
|
|
63
|
+
@skip_benchmark_if(not torch.cuda.is_available())
|
|
64
|
+
def time_modified_bessel_k0_forward_cuda(
|
|
65
|
+
self, n, is_real, singularity, dtype, requires_grad
|
|
66
|
+
):
|
|
67
|
+
torch.cuda.synchronize()
|
|
68
|
+
torch_bessel.ops.modified_bessel_k0(*self.args)
|
|
69
|
+
torch.cuda.synchronize()
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class ModifiedBesselK0BackwardCPU:
|
|
73
|
+
warmup_time = 0.0 # for some reason backward is called multiple times if not 0
|
|
74
|
+
number = 1 # Avoids calling backward multiple times
|
|
75
|
+
params = (
|
|
76
|
+
[10_000, 100_000, 1_000_000],
|
|
77
|
+
[False, True],
|
|
78
|
+
[None, 0.0],
|
|
79
|
+
[torch.float32, torch.float64],
|
|
80
|
+
)
|
|
81
|
+
param_names = ["n", "is_real", "singularity", "dtype"]
|
|
82
|
+
|
|
83
|
+
def setup(self, n, is_real, singularity, dtype):
|
|
84
|
+
args = _setup(n, is_real, singularity, dtype, requires_grad=True)
|
|
85
|
+
self.out = torch_bessel.ops.modified_bessel_k0(*args).norm()
|
|
86
|
+
|
|
87
|
+
def time_modified_bessel_k0_backward_cpu(self, n, is_real, singularity, dtype):
|
|
88
|
+
self.out.backward()
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class ModifiedBesselK0BackwardCUDA:
|
|
92
|
+
warmup_time = 0.0 # for some reason backward is called multiple times if not 0
|
|
93
|
+
number = 1 # Avoids calling backward multiple times
|
|
94
|
+
params = (
|
|
95
|
+
[10_000, 100_000, 1_000_000],
|
|
96
|
+
[False, True],
|
|
97
|
+
[None, 0.0],
|
|
98
|
+
[torch.float32, torch.float64],
|
|
99
|
+
)
|
|
100
|
+
param_names = ["n", "is_real", "singularity", "dtype"]
|
|
101
|
+
|
|
102
|
+
def setup(self, n, is_real, singularity, dtype):
|
|
103
|
+
args = _setup(n, is_real, singularity, dtype, requires_grad=True, device="cuda")
|
|
104
|
+
self.out = torch_bessel.ops.modified_bessel_k0(*args).norm()
|
|
105
|
+
|
|
106
|
+
@skip_benchmark_if(not torch.cuda.is_available())
|
|
107
|
+
def time_modified_bessel_k0_backward_cuda(self, n, is_real, singularity, dtype):
|
|
108
|
+
torch.cuda.synchronize()
|
|
109
|
+
self.out.backward()
|
|
110
|
+
torch.cuda.synchronize()
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class ModifiedBesselK1ForwardCPU:
|
|
114
|
+
params = (
|
|
115
|
+
[10_000, 100_000, 1_000_000],
|
|
116
|
+
[False, True],
|
|
117
|
+
[torch.float32, torch.float64],
|
|
118
|
+
)
|
|
119
|
+
param_names = ["n", "is_real", "dtype"]
|
|
120
|
+
|
|
121
|
+
def setup(self, n, is_real, dtype):
|
|
122
|
+
self.args = _setup_k1(n, is_real, dtype)
|
|
123
|
+
|
|
124
|
+
def time_modified_bessel_k1_forward_cpu(self, n, is_real, dtype):
|
|
125
|
+
torch_bessel.ops.modified_bessel_k1(*self.args)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class ModifiedBesselK1ForwardCUDA:
|
|
129
|
+
params = (
|
|
130
|
+
[10_000, 100_000, 1_000_000],
|
|
131
|
+
[False, True],
|
|
132
|
+
[torch.float32, torch.float64],
|
|
133
|
+
)
|
|
134
|
+
param_names = ["n", "is_real", "dtype"]
|
|
135
|
+
|
|
136
|
+
def setup(self, n, is_real, dtype):
|
|
137
|
+
self.args = _setup_k1(n, is_real, dtype, device="cuda")
|
|
138
|
+
|
|
139
|
+
@skip_benchmark_if(not torch.cuda.is_available())
|
|
140
|
+
def time_modified_bessel_k1_forward_cuda(self, n, is_real, dtype):
|
|
141
|
+
torch.cuda.synchronize()
|
|
142
|
+
torch_bessel.ops.modified_bessel_k1(*self.args)
|
|
143
|
+
torch.cuda.synchronize()
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import glob
|
|
3
|
+
|
|
4
|
+
from setuptools import find_packages, setup
|
|
5
|
+
|
|
6
|
+
from torch.utils.cpp_extension import (
|
|
7
|
+
CppExtension,
|
|
8
|
+
CUDAExtension,
|
|
9
|
+
BuildExtension,
|
|
10
|
+
CUDA_HOME,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
library_name = "torch_bessel"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get_extensions():
|
|
17
|
+
debug_mode = os.getenv("DEBUG", "0") == "1"
|
|
18
|
+
use_cuda = os.getenv("USE_CUDA", "1") == "1"
|
|
19
|
+
if debug_mode:
|
|
20
|
+
print("Compiling in debug mode")
|
|
21
|
+
|
|
22
|
+
use_cuda = use_cuda and CUDA_HOME is not None
|
|
23
|
+
extension = CUDAExtension if use_cuda else CppExtension
|
|
24
|
+
|
|
25
|
+
extra_link_args = []
|
|
26
|
+
extra_compile_args = {
|
|
27
|
+
"cxx": [
|
|
28
|
+
"-O3" if not debug_mode else "-O0",
|
|
29
|
+
"-std=c++17",
|
|
30
|
+
"-fdiagnostics-color=always",
|
|
31
|
+
"-DPy_LIMITED_API=0x030A0000",
|
|
32
|
+
],
|
|
33
|
+
"nvcc": [
|
|
34
|
+
"-O3" if not debug_mode else "-O0",
|
|
35
|
+
"--extended-lambda",
|
|
36
|
+
],
|
|
37
|
+
}
|
|
38
|
+
if debug_mode:
|
|
39
|
+
extra_compile_args["cxx"].append("-g")
|
|
40
|
+
extra_compile_args["nvcc"].append("-g")
|
|
41
|
+
extra_link_args.extend(["-O0", "-g"])
|
|
42
|
+
|
|
43
|
+
this_dir = os.path.dirname(os.path.curdir)
|
|
44
|
+
extensions_dir = os.path.join(this_dir, library_name, "csrc")
|
|
45
|
+
sources = list(glob.glob(os.path.join(extensions_dir, "*.cpp")))
|
|
46
|
+
|
|
47
|
+
extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
|
|
48
|
+
cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "*.cu")))
|
|
49
|
+
|
|
50
|
+
if use_cuda:
|
|
51
|
+
sources += cuda_sources
|
|
52
|
+
|
|
53
|
+
ext_modules = [
|
|
54
|
+
extension(
|
|
55
|
+
f"{library_name}._C",
|
|
56
|
+
sources,
|
|
57
|
+
extra_compile_args=extra_compile_args,
|
|
58
|
+
extra_link_args=extra_link_args,
|
|
59
|
+
py_limited_api=True,
|
|
60
|
+
)
|
|
61
|
+
]
|
|
62
|
+
|
|
63
|
+
return ext_modules
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
setup(
|
|
67
|
+
name=library_name,
|
|
68
|
+
version="0.0.9",
|
|
69
|
+
author="Ho Yin Chau",
|
|
70
|
+
packages=find_packages(),
|
|
71
|
+
ext_modules=get_extensions(),
|
|
72
|
+
install_requires=["torch"],
|
|
73
|
+
python_requires=">= 3.10",
|
|
74
|
+
description="PyTorch extension package for Bessel functions with arbitrary real order and complex inputs",
|
|
75
|
+
long_description=open("README.md").read(),
|
|
76
|
+
long_description_content_type="text/markdown",
|
|
77
|
+
url="https://github.com/hchau630/torch-bessel",
|
|
78
|
+
cmdclass={"build_ext": BuildExtension},
|
|
79
|
+
options={"bdist_wheel": {"py_limited_api": "cp310"}},
|
|
80
|
+
)
|
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch.testing._internal.common_utils import TestCase
|
|
3
|
+
from torch.testing._internal.optests import opcheck
|
|
4
|
+
import pytest
|
|
5
|
+
from scipy import special
|
|
6
|
+
|
|
7
|
+
import torch_bessel
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def reference_modified_bessel_k0(z, singularity=None):
|
|
11
|
+
device = z.device
|
|
12
|
+
dtype = z.dtype
|
|
13
|
+
if dtype is torch.chalf:
|
|
14
|
+
z = z.to(torch.cfloat)
|
|
15
|
+
out = special.kv(0.0, z.detach().cpu()).to(device).to(dtype)
|
|
16
|
+
if singularity is not None:
|
|
17
|
+
out = out.where(z != 0, singularity)
|
|
18
|
+
return out
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def reference_modified_bessel_k1(z):
|
|
22
|
+
device = z.device
|
|
23
|
+
dtype = z.dtype
|
|
24
|
+
if dtype is torch.chalf:
|
|
25
|
+
z = z.to(torch.cfloat)
|
|
26
|
+
out = special.kv(1.0, z.detach().cpu()).to(device).to(dtype)
|
|
27
|
+
return out
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class TestBesselK0(TestCase):
|
|
31
|
+
def sample_inputs(self, device, *, requires_grad=False):
|
|
32
|
+
def make_z(*size, dtype=None):
|
|
33
|
+
kwargs = dict(dtype=dtype, device=device, requires_grad=requires_grad)
|
|
34
|
+
real = torch.randn(size, **kwargs).abs()
|
|
35
|
+
imag = torch.randn(size, **kwargs)
|
|
36
|
+
return torch.complex(real, imag)
|
|
37
|
+
|
|
38
|
+
out = [
|
|
39
|
+
[make_z(50, dtype=torch.double)],
|
|
40
|
+
[make_z(50, dtype=torch.double).real],
|
|
41
|
+
[make_z(50)],
|
|
42
|
+
[make_z(50).real],
|
|
43
|
+
[torch.tensor(70.1162 + 89.0190j)], # used to fail on this input
|
|
44
|
+
]
|
|
45
|
+
if device == "cuda":
|
|
46
|
+
# half precision is only supported on CUDA
|
|
47
|
+
out.append([make_z(50, dtype=torch.half)])
|
|
48
|
+
|
|
49
|
+
return out
|
|
50
|
+
|
|
51
|
+
def grid_inputs(self, device, *, requires_grad=False):
|
|
52
|
+
def make_z(*args, dtype=torch.float):
|
|
53
|
+
real = torch.tensor([0, *torch.logspace(*args, dtype=torch.double)])
|
|
54
|
+
imag = torch.logspace(*args, dtype=torch.double)
|
|
55
|
+
|
|
56
|
+
# Don't test on subnormal numbers, since the AMOS code doesn't consider them
|
|
57
|
+
real[real < torch.finfo(dtype).tiny] = 0.0
|
|
58
|
+
imag[imag < torch.finfo(dtype).tiny] = 0.0
|
|
59
|
+
|
|
60
|
+
kwargs = dict(dtype=dtype, device=device, requires_grad=requires_grad)
|
|
61
|
+
real = torch.tensor(real, **kwargs)
|
|
62
|
+
imag = torch.tensor([*(-imag.flip(0)), 0, *imag], **kwargs)
|
|
63
|
+
real, imag = real[:, None], imag[None, :]
|
|
64
|
+
return torch.complex(real, imag)
|
|
65
|
+
|
|
66
|
+
out = [
|
|
67
|
+
[make_z(-350, 350, 75, dtype=torch.double)],
|
|
68
|
+
[make_z(-50, 50, 75)],
|
|
69
|
+
[make_z(-5, 5, 75, dtype=torch.double)], # scipy is buggy on this input
|
|
70
|
+
[make_z(-5, 5, 75)], # scipy is buggy on this input
|
|
71
|
+
[make_z(-50, 50, 75), 1.0],
|
|
72
|
+
[make_z(-50, 50, 75), torch.randn((76, 151), device=device)],
|
|
73
|
+
]
|
|
74
|
+
if device == "cuda":
|
|
75
|
+
# half precision is only supported on CUDA
|
|
76
|
+
out.append([make_z(-5, 5, 75, dtype=torch.half)])
|
|
77
|
+
return out
|
|
78
|
+
|
|
79
|
+
def _test_correctness(self, device):
|
|
80
|
+
samples = (
|
|
81
|
+
self.sample_inputs(device)
|
|
82
|
+
+ self.grid_inputs(device)
|
|
83
|
+
+ self.sample_inputs(device, requires_grad=True)
|
|
84
|
+
+ self.grid_inputs(device, requires_grad=True)
|
|
85
|
+
)
|
|
86
|
+
for args in samples:
|
|
87
|
+
result = torch_bessel.ops.modified_bessel_k0(*args)
|
|
88
|
+
expected = reference_modified_bessel_k0(*args)
|
|
89
|
+
if expected.dtype in {torch.float, torch.cfloat, torch.chalf}:
|
|
90
|
+
# ierr = 4, complete loss of significance
|
|
91
|
+
expected[args[0].abs() > 4194303.98419452] = torch.nan
|
|
92
|
+
# ierr = 2, overflow
|
|
93
|
+
mask = (args[0].abs() < 1.1754944e-35) & (args[0] != 0)
|
|
94
|
+
if expected.is_complex():
|
|
95
|
+
expected[mask] = torch.nan
|
|
96
|
+
expected[mask & (args[0].imag == 0) & (args[0] != 0)] = torch.inf
|
|
97
|
+
else:
|
|
98
|
+
expected[mask & (args[0] != 0)] = torch.inf
|
|
99
|
+
# fix scipy bug. See https://github.com/scipy/xsf/issues/46
|
|
100
|
+
if args[0].is_complex():
|
|
101
|
+
mask = (
|
|
102
|
+
(args[0].real > 685)
|
|
103
|
+
& (args[0].real < 690)
|
|
104
|
+
& (args[0].imag.abs() > 685)
|
|
105
|
+
& (args[0].imag.abs() < 1.1e5)
|
|
106
|
+
)
|
|
107
|
+
expected[mask] = 0.0 # scipy returns incorrect/NaN output
|
|
108
|
+
|
|
109
|
+
torch.testing.assert_close(result, expected, equal_nan=True)
|
|
110
|
+
|
|
111
|
+
def test_correctness_cpu(self):
|
|
112
|
+
self._test_correctness("cpu")
|
|
113
|
+
|
|
114
|
+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda")
|
|
115
|
+
def test_correctness_cuda(self):
|
|
116
|
+
self._test_correctness("cuda")
|
|
117
|
+
|
|
118
|
+
def _test_gradients(self, device):
|
|
119
|
+
samples = self.sample_inputs(device, requires_grad=True)
|
|
120
|
+
for args in samples:
|
|
121
|
+
if (
|
|
122
|
+
args[0].dtype in {torch.double, torch.complex128}
|
|
123
|
+
and args[0].requires_grad
|
|
124
|
+
):
|
|
125
|
+
torch.autograd.gradcheck(torch_bessel.ops.modified_bessel_k0, args)
|
|
126
|
+
|
|
127
|
+
def test_gradients_cpu(self):
|
|
128
|
+
self._test_gradients("cpu")
|
|
129
|
+
|
|
130
|
+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda")
|
|
131
|
+
def test_gradients_cuda(self):
|
|
132
|
+
self._test_gradients("cuda")
|
|
133
|
+
|
|
134
|
+
def _opcheck(self, device):
|
|
135
|
+
# Use opcheck to check for incorrect usage of operator registration APIs
|
|
136
|
+
samples = self.sample_inputs(device, requires_grad=False)
|
|
137
|
+
samples.extend(self.sample_inputs(device, requires_grad=True))
|
|
138
|
+
for args in samples:
|
|
139
|
+
if not args[0].is_complex():
|
|
140
|
+
continue
|
|
141
|
+
opcheck(
|
|
142
|
+
torch.ops.torch_bessel.modified_bessel_k0_complex_forward_backward.default,
|
|
143
|
+
args,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
def test_opcheck_cpu(self):
|
|
147
|
+
self._opcheck("cpu")
|
|
148
|
+
|
|
149
|
+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda")
|
|
150
|
+
def test_opcheck_cuda(self):
|
|
151
|
+
self._opcheck("cuda")
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
@pytest.mark.parametrize(
|
|
155
|
+
"singularity, in_dims",
|
|
156
|
+
[
|
|
157
|
+
(None, (0, None)),
|
|
158
|
+
(0.0, (0, None)),
|
|
159
|
+
((4, 3), (0, 1)),
|
|
160
|
+
],
|
|
161
|
+
)
|
|
162
|
+
def test_vmap(singularity, in_dims):
|
|
163
|
+
z = torch.randn(3, 4).abs() + torch.randn(3, 4) * 1j
|
|
164
|
+
if isinstance(singularity, tuple):
|
|
165
|
+
singularity = torch.randn(singularity)
|
|
166
|
+
func = torch.func.vmap(torch_bessel.ops.modified_bessel_k0, in_dims=in_dims)
|
|
167
|
+
out = func(z, singularity)
|
|
168
|
+
if not isinstance(singularity, torch.Tensor):
|
|
169
|
+
expected = torch_bessel.ops.modified_bessel_k0(z, singularity)
|
|
170
|
+
else:
|
|
171
|
+
expected = torch_bessel.ops.modified_bessel_k0(z, singularity.t())
|
|
172
|
+
torch.testing.assert_close(out, expected)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
class TestBesselK1(TestCase):
|
|
176
|
+
def sample_inputs(self, device, *, requires_grad=False):
|
|
177
|
+
def make_z(*size, dtype=None):
|
|
178
|
+
kwargs = dict(dtype=dtype, device=device, requires_grad=requires_grad)
|
|
179
|
+
real = torch.randn(size, **kwargs).abs()
|
|
180
|
+
imag = torch.randn(size, **kwargs)
|
|
181
|
+
return torch.complex(real, imag)
|
|
182
|
+
|
|
183
|
+
out = [
|
|
184
|
+
[make_z(50, dtype=torch.double)],
|
|
185
|
+
[make_z(50, dtype=torch.double).real],
|
|
186
|
+
[make_z(50)],
|
|
187
|
+
[make_z(50).real],
|
|
188
|
+
[torch.tensor(70.1162 + 89.0190j)], # used to fail on this input
|
|
189
|
+
]
|
|
190
|
+
if device == "cuda":
|
|
191
|
+
# half precision is only supported on CUDA
|
|
192
|
+
out.append([make_z(50, dtype=torch.half)])
|
|
193
|
+
|
|
194
|
+
return out
|
|
195
|
+
|
|
196
|
+
def grid_inputs(self, device, *, requires_grad=False):
|
|
197
|
+
def make_z(*args, dtype=torch.float):
|
|
198
|
+
real = torch.tensor([0, *torch.logspace(*args, dtype=torch.double)])
|
|
199
|
+
imag = torch.logspace(*args, dtype=torch.double)
|
|
200
|
+
|
|
201
|
+
# Don't test on subnormal numbers, since the AMOS code doesn't consider them
|
|
202
|
+
real[real < torch.finfo(dtype).tiny] = 0.0
|
|
203
|
+
imag[imag < torch.finfo(dtype).tiny] = 0.0
|
|
204
|
+
|
|
205
|
+
kwargs = dict(dtype=dtype, device=device, requires_grad=requires_grad)
|
|
206
|
+
real = torch.tensor(real, **kwargs)
|
|
207
|
+
imag = torch.tensor([*(-imag.flip(0)), 0, *imag], **kwargs)
|
|
208
|
+
real, imag = real[:, None], imag[None, :]
|
|
209
|
+
return torch.complex(real, imag)
|
|
210
|
+
|
|
211
|
+
out = [
|
|
212
|
+
[make_z(-350, 350, 75, dtype=torch.double)],
|
|
213
|
+
[make_z(-50, 50, 75)],
|
|
214
|
+
[make_z(-5, 5, 75, dtype=torch.double)], # scipy is buggy on this input
|
|
215
|
+
[make_z(-5, 5, 75)], # scipy is buggy on this input
|
|
216
|
+
]
|
|
217
|
+
if device == "cuda":
|
|
218
|
+
# half precision is only supported on CUDA
|
|
219
|
+
out.append([make_z(-5, 5, 75, dtype=torch.half)])
|
|
220
|
+
return out
|
|
221
|
+
|
|
222
|
+
def _test_correctness(self, device):
|
|
223
|
+
samples = self.sample_inputs(device) + self.grid_inputs(device)
|
|
224
|
+
for args in samples:
|
|
225
|
+
result = torch_bessel.ops.modified_bessel_k1(*args)
|
|
226
|
+
expected = reference_modified_bessel_k1(*args)
|
|
227
|
+
if expected.dtype in {torch.float, torch.cfloat, torch.chalf}:
|
|
228
|
+
# ierr = 4, complete loss of significance
|
|
229
|
+
expected[args[0].abs() > 4194303.98419452] = torch.nan
|
|
230
|
+
# ierr = 2, overflow
|
|
231
|
+
mask = (args[0].abs() < 1.1754944e-35) & (args[0] != 0)
|
|
232
|
+
if expected.is_complex():
|
|
233
|
+
expected[mask] = torch.nan
|
|
234
|
+
expected[mask & (args[0].imag == 0) & (args[0] != 0)] = torch.inf
|
|
235
|
+
else:
|
|
236
|
+
expected[mask & (args[0] != 0)] = torch.inf
|
|
237
|
+
# fix scipy bug. See https://github.com/scipy/xsf/issues/46
|
|
238
|
+
if args[0].is_complex():
|
|
239
|
+
mask = (
|
|
240
|
+
(args[0].real > 685)
|
|
241
|
+
& (args[0].real < 690)
|
|
242
|
+
& (args[0].imag.abs() > 685)
|
|
243
|
+
& (args[0].imag.abs() < 1.1e5)
|
|
244
|
+
)
|
|
245
|
+
expected[mask] = 0.0 # scipy returns incorrect/NaN output
|
|
246
|
+
torch.testing.assert_close(result, expected, equal_nan=True)
|
|
247
|
+
|
|
248
|
+
def test_correctness_cpu(self):
|
|
249
|
+
self._test_correctness("cpu")
|
|
250
|
+
|
|
251
|
+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda")
|
|
252
|
+
def test_correctness_cuda(self):
|
|
253
|
+
self._test_correctness("cuda")
|
|
254
|
+
|
|
255
|
+
def _opcheck(self, device):
|
|
256
|
+
# Use opcheck to check for incorrect usage of operator registration APIs
|
|
257
|
+
samples = self.sample_inputs(device, requires_grad=False)
|
|
258
|
+
for args in samples:
|
|
259
|
+
if not args[0].is_complex():
|
|
260
|
+
continue
|
|
261
|
+
opcheck(
|
|
262
|
+
torch.ops.torch_bessel.modified_bessel_k1_complex_forward.default, args
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
def test_opcheck_cpu(self):
|
|
266
|
+
self._opcheck("cpu")
|
|
267
|
+
|
|
268
|
+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda")
|
|
269
|
+
def test_opcheck_cuda(self):
|
|
270
|
+
self._opcheck("cuda")
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
@pytest.mark.parametrize("in_dims", [0, 1])
|
|
274
|
+
def test_vmap_k1(in_dims):
|
|
275
|
+
z = torch.randn(3, 4).abs() + torch.randn(3, 4) * 1j
|
|
276
|
+
func = torch.func.vmap(torch_bessel.ops.modified_bessel_k1, in_dims=in_dims)
|
|
277
|
+
out = func(z)
|
|
278
|
+
expected = torch_bessel.ops.modified_bessel_k1(z)
|
|
279
|
+
if in_dims == 1:
|
|
280
|
+
expected = expected.t()
|
|
281
|
+
torch.testing.assert_close(out, expected)
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
|
|
3
|
+
# from . import ops
|
|
4
|
+
from . import _C, ops
|
|
5
|
+
|
|
6
|
+
# Silence warning emitted at torch/nested/_internal/nested_tensor.py:417
|
|
7
|
+
warnings.filterwarnings(
|
|
8
|
+
"ignore", message="Failed to initialize NumPy: No module named 'numpy'"
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
__all__ = ["ops"]
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
// #include <torch/extension.h>
|
|
2
|
+
#include <torch/library.h>
|
|
3
|
+
#include <ATen/TensorIterator.h>
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
at::TensorIterator build_iterator_11(const at::Tensor& result, const at::Tensor& z) {
|
|
7
|
+
return (
|
|
8
|
+
at::TensorIteratorConfig()
|
|
9
|
+
.set_check_mem_overlap(true)
|
|
10
|
+
.allow_cpu_scalars(true)
|
|
11
|
+
.enforce_safe_casting_to_output(true)
|
|
12
|
+
.add_output(result)
|
|
13
|
+
.add_input(z)
|
|
14
|
+
).build();
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
at::TensorIterator build_iterator_12(const at::Tensor& result, const at::Tensor& v, const at::Tensor& z) {
|
|
19
|
+
return (
|
|
20
|
+
at::TensorIteratorConfig()
|
|
21
|
+
.set_check_mem_overlap(true)
|
|
22
|
+
.allow_cpu_scalars(true)
|
|
23
|
+
.promote_inputs_to_common_dtype(true)
|
|
24
|
+
.cast_common_dtype_to_outputs(true)
|
|
25
|
+
.enforce_safe_casting_to_output(true)
|
|
26
|
+
.promote_integer_inputs_to_float(true)
|
|
27
|
+
.add_output(result)
|
|
28
|
+
.add_input(v)
|
|
29
|
+
.add_input(z)
|
|
30
|
+
).build();
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
at::TensorIterator build_iterator_21(const at::Tensor& result1, const at::Tensor& result2, const at::Tensor& z) {
|
|
34
|
+
return (
|
|
35
|
+
at::TensorIteratorConfig()
|
|
36
|
+
.set_check_mem_overlap(true)
|
|
37
|
+
.allow_cpu_scalars(true)
|
|
38
|
+
.enforce_safe_casting_to_output(true)
|
|
39
|
+
.add_output(result1)
|
|
40
|
+
.add_output(result2)
|
|
41
|
+
.add_input(z)
|
|
42
|
+
).build();
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
at::TensorIterator build_iterator_22(const at::Tensor& result1, const at::Tensor& result2, const at::Tensor& v, const at::Tensor& z) {
|
|
47
|
+
return (
|
|
48
|
+
at::TensorIteratorConfig()
|
|
49
|
+
.set_check_mem_overlap(true)
|
|
50
|
+
.allow_cpu_scalars(true)
|
|
51
|
+
.promote_inputs_to_common_dtype(true)
|
|
52
|
+
.cast_common_dtype_to_outputs(true)
|
|
53
|
+
.enforce_safe_casting_to_output(true)
|
|
54
|
+
.promote_integer_inputs_to_float(true)
|
|
55
|
+
.add_output(result1)
|
|
56
|
+
.add_output(result2)
|
|
57
|
+
.add_input(v)
|
|
58
|
+
.add_input(z)
|
|
59
|
+
).build();
|
|
60
|
+
}
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
#include <Python.h>
|
|
2
|
+
// #include <torch/extension.h>
|
|
3
|
+
#include <torch/library.h>
|
|
4
|
+
#include <ATen/ATen.h>
|
|
5
|
+
#include <ATen/TensorIterator.h>
|
|
6
|
+
#include <ATen/native/cpu/Loops.h>
|
|
7
|
+
#include <c10/util/complex.h>
|
|
8
|
+
#include <iostream>
|
|
9
|
+
|
|
10
|
+
#include "bessel_k.h"
|
|
11
|
+
#include "iterator.h"
|
|
12
|
+
|
|
13
|
+
// See https://docs.pytorch.org/tutorials/advanced/cpp_custom_ops.html
|
|
14
|
+
extern "C" {
|
|
15
|
+
/* Creates a dummy empty _C module that can be imported from Python.
|
|
16
|
+
The import from Python will load the .so consisting of this file
|
|
17
|
+
in this extension, so that the TORCH_LIBRARY static initializers
|
|
18
|
+
below are run. */
|
|
19
|
+
PyObject* PyInit__C(void)
|
|
20
|
+
{
|
|
21
|
+
static struct PyModuleDef module_def = {
|
|
22
|
+
PyModuleDef_HEAD_INIT,
|
|
23
|
+
"_C", /* name of module */
|
|
24
|
+
NULL, /* module documentation, may be NULL */
|
|
25
|
+
-1, /* size of per-interpreter state of the module,
|
|
26
|
+
or -1 if the module keeps state in global variables. */
|
|
27
|
+
NULL, /* methods */
|
|
28
|
+
};
|
|
29
|
+
return PyModule_Create(&module_def);
|
|
30
|
+
}
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
namespace torch_bessel {
|
|
34
|
+
|
|
35
|
+
at::Tensor modified_bessel_k0_complex_forward_cpu(const at::Tensor& z) {
|
|
36
|
+
TORCH_INTERNAL_ASSERT(z.device().type() == at::DeviceType::CPU);
|
|
37
|
+
at::ScalarType dtype = z.scalar_type();
|
|
38
|
+
at::Tensor result = at::empty(at::IntArrayRef(), dtype).resize_(0);
|
|
39
|
+
at::TensorIterator iter = build_iterator_11(result, z);
|
|
40
|
+
AT_DISPATCH_COMPLEX_TYPES(dtype, "modified_bessel_k0_complex_forward_cpu", [&]() {
|
|
41
|
+
at::native::cpu_kernel(iter, [](scalar_t z) -> scalar_t {
|
|
42
|
+
return modified_bessel_k0_complex_forward(z);
|
|
43
|
+
});
|
|
44
|
+
});
|
|
45
|
+
return result;
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
std::tuple<at::Tensor, at::Tensor> modified_bessel_k0_complex_forward_backward_cpu(const at::Tensor& z) {
|
|
49
|
+
TORCH_INTERNAL_ASSERT(z.device().type() == at::DeviceType::CPU);
|
|
50
|
+
at::ScalarType dtype = z.scalar_type();
|
|
51
|
+
at::Tensor result1 = at::empty(at::IntArrayRef(), dtype).resize_(0);
|
|
52
|
+
at::Tensor result2 = at::empty(at::IntArrayRef(), dtype).resize_(0);
|
|
53
|
+
at::TensorIterator iter = build_iterator_21(result1, result2, z);
|
|
54
|
+
AT_DISPATCH_COMPLEX_TYPES(dtype, "modified_bessel_k0_complex_forward_backward_cpu", [&]() {
|
|
55
|
+
at::native::cpu_kernel_multiple_outputs(iter, [](scalar_t z) -> std::tuple<scalar_t, scalar_t> {
|
|
56
|
+
scalar_t cy[2];
|
|
57
|
+
modified_bessel_k0_complex_forward_backward(z, cy);
|
|
58
|
+
return std::make_tuple(cy[0], cy[1]);
|
|
59
|
+
});
|
|
60
|
+
});
|
|
61
|
+
return std::make_tuple(result1, result2);
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
at::Tensor modified_bessel_k1_complex_forward_cpu(const at::Tensor& z) {
|
|
65
|
+
TORCH_INTERNAL_ASSERT(z.device().type() == at::DeviceType::CPU);
|
|
66
|
+
at::ScalarType dtype = z.scalar_type();
|
|
67
|
+
at::Tensor result = at::empty(at::IntArrayRef(), dtype).resize_(0);
|
|
68
|
+
at::TensorIterator iter = build_iterator_11(result, z);
|
|
69
|
+
AT_DISPATCH_COMPLEX_TYPES(dtype, "modified_bessel_k1_complex_forward_cpu", [&]() {
|
|
70
|
+
at::native::cpu_kernel(iter, [](scalar_t z) -> scalar_t {
|
|
71
|
+
return modified_bessel_k1_complex_forward(z);
|
|
72
|
+
});
|
|
73
|
+
});
|
|
74
|
+
return result;
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
// // Registers _C as a Python extension module.
|
|
78
|
+
// PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
|
|
79
|
+
|
|
80
|
+
// Defines the operators
|
|
81
|
+
TORCH_LIBRARY(torch_bessel, m) {
|
|
82
|
+
m.def("modified_bessel_k0_complex_forward(Tensor z) -> Tensor");
|
|
83
|
+
m.def("modified_bessel_k0_complex_forward_backward(Tensor z) -> (Tensor, Tensor)");
|
|
84
|
+
m.def("modified_bessel_k1_complex_forward(Tensor z) -> Tensor");
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
// Registers CPU implementations for bessel_k
|
|
88
|
+
TORCH_LIBRARY_IMPL(torch_bessel, CPU, m) {
|
|
89
|
+
m.impl("modified_bessel_k0_complex_forward", &modified_bessel_k0_complex_forward_cpu);
|
|
90
|
+
m.impl("modified_bessel_k0_complex_forward_backward", &modified_bessel_k0_complex_forward_backward_cpu);
|
|
91
|
+
m.impl("modified_bessel_k1_complex_forward", &modified_bessel_k1_complex_forward_cpu);
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
}
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
# from pathlib import Path
|
|
2
|
+
from numbers import Number
|
|
3
|
+
from typing import Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch import Tensor
|
|
7
|
+
|
|
8
|
+
__all__ = ["modified_bessel_k0", "modified_bessel_k1"]
|
|
9
|
+
|
|
10
|
+
# # load C extension before calling torch.library API, see
|
|
11
|
+
# # https://pytorch.org/tutorials/advanced/cpp_custom_ops.html
|
|
12
|
+
# so_dir = Path(__file__).parent
|
|
13
|
+
# so_files = list(so_dir.glob("_C*.so"))
|
|
14
|
+
# assert len(so_files) == 1, (
|
|
15
|
+
# f"Expected one _C*.so file at {so_dir}, found {len(so_files)}"
|
|
16
|
+
# )
|
|
17
|
+
# torch.ops.load_library(so_files[0])
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ModifiedBesselK0(torch.autograd.Function):
|
|
21
|
+
@staticmethod
|
|
22
|
+
def forward(z, singularity):
|
|
23
|
+
if not z.is_complex():
|
|
24
|
+
out = (torch.special.modified_bessel_k0(z), None)
|
|
25
|
+
elif not z.requires_grad:
|
|
26
|
+
out = (
|
|
27
|
+
torch.ops.torch_bessel.modified_bessel_k0_complex_forward.default(z),
|
|
28
|
+
None,
|
|
29
|
+
)
|
|
30
|
+
else:
|
|
31
|
+
out = torch.ops.torch_bessel.modified_bessel_k0_complex_forward_backward.default(
|
|
32
|
+
z
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
if singularity is None:
|
|
36
|
+
return (*out, None)
|
|
37
|
+
|
|
38
|
+
mask = z != 0
|
|
39
|
+
return (out[0].where(mask, singularity), out[1], mask)
|
|
40
|
+
|
|
41
|
+
@staticmethod
|
|
42
|
+
def setup_context(ctx, inputs, outputs):
|
|
43
|
+
if ctx.needs_input_grad[1]:
|
|
44
|
+
raise NotImplementedError("Gradient w.r.t. singularity is not implemented")
|
|
45
|
+
|
|
46
|
+
if ctx.needs_input_grad[0]:
|
|
47
|
+
if outputs[1] is None:
|
|
48
|
+
ctx.save_for_backward(inputs[0], None, outputs[2])
|
|
49
|
+
else:
|
|
50
|
+
ctx.save_for_backward(None, outputs[1], outputs[2])
|
|
51
|
+
|
|
52
|
+
ctx.set_materialize_grads(False)
|
|
53
|
+
|
|
54
|
+
@staticmethod
|
|
55
|
+
def backward(ctx, grad, _, __):
|
|
56
|
+
if grad is None or not ctx.needs_input_grad[0]:
|
|
57
|
+
return (None, None)
|
|
58
|
+
|
|
59
|
+
x, deriv, mask = ctx.saved_tensors
|
|
60
|
+
if deriv is None:
|
|
61
|
+
out = -torch.special.modified_bessel_k1(x).mul_(grad)
|
|
62
|
+
else:
|
|
63
|
+
out = grad * deriv
|
|
64
|
+
if mask is not None:
|
|
65
|
+
out = out.where(mask, 0)
|
|
66
|
+
return (out, None)
|
|
67
|
+
|
|
68
|
+
@staticmethod
|
|
69
|
+
def vmap(info, in_dims, z, singularity):
|
|
70
|
+
if singularity is None and in_dims[1] is not None:
|
|
71
|
+
raise ValueError("in_dims[1] must be None if singularity is not provided.")
|
|
72
|
+
|
|
73
|
+
if in_dims[0] is not None:
|
|
74
|
+
z = z.movedim(in_dims[0], 0)
|
|
75
|
+
|
|
76
|
+
if in_dims[1] is not None:
|
|
77
|
+
singularity = singularity.movedim(in_dims[1], 0)
|
|
78
|
+
|
|
79
|
+
out = ModifiedBesselK0.apply(z, singularity)
|
|
80
|
+
out_dims = [0] * 3 if any(d is not None for d in in_dims) else [None] * 3
|
|
81
|
+
if out[1] is None:
|
|
82
|
+
out_dims[1] = None
|
|
83
|
+
if out[2] is None:
|
|
84
|
+
out_dims[2] = None
|
|
85
|
+
|
|
86
|
+
return (out, tuple(out_dims))
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def modified_bessel_k0(
|
|
90
|
+
z: Tensor, singularity: Union[Number, Tensor, None] = None
|
|
91
|
+
) -> Tensor:
|
|
92
|
+
return ModifiedBesselK0.apply(z, singularity)[0]
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def modified_bessel_k1(z: Tensor) -> Tensor:
|
|
96
|
+
# non-differentiable for now
|
|
97
|
+
if not z.is_complex():
|
|
98
|
+
return torch.special.modified_bessel_k1(z)
|
|
99
|
+
return torch.ops.torch_bessel.modified_bessel_k1_complex_forward.default(z)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@torch.library.register_fake("torch_bessel::modified_bessel_k0_complex_forward")
|
|
103
|
+
def _(z):
|
|
104
|
+
return torch.empty_like(z)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@torch.library.register_fake(
|
|
108
|
+
"torch_bessel::modified_bessel_k0_complex_forward_backward"
|
|
109
|
+
)
|
|
110
|
+
def _(z):
|
|
111
|
+
return torch.empty_like(z), torch.empty_like(z)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@torch.library.register_fake("torch_bessel::modified_bessel_k1_complex_forward")
|
|
115
|
+
def _(z):
|
|
116
|
+
return torch.empty_like(z)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def modified_bessel_k0_backward(ctx, grad, _):
|
|
120
|
+
if ctx.needs_input_grad[0]:
|
|
121
|
+
return grad * ctx.saved_tensors[0]
|
|
122
|
+
return None
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def modified_bessel_k0_setup_context(ctx, inputs, output):
|
|
126
|
+
if ctx.needs_input_grad[0]:
|
|
127
|
+
ctx.save_for_backward(output[-1])
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
torch.library.register_autograd(
|
|
131
|
+
"torch_bessel::modified_bessel_k0_complex_forward_backward",
|
|
132
|
+
modified_bessel_k0_backward,
|
|
133
|
+
setup_context=modified_bessel_k0_setup_context,
|
|
134
|
+
)
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: torch_bessel
|
|
3
|
+
Version: 0.0.9
|
|
4
|
+
Summary: PyTorch extension package for Bessel functions with arbitrary real order and complex inputs
|
|
5
|
+
Home-page: https://github.com/hchau630/torch-bessel
|
|
6
|
+
Author: Ho Yin Chau
|
|
7
|
+
Requires-Python: >= 3.10
|
|
8
|
+
Description-Content-Type: text/markdown
|
|
9
|
+
License-File: LICENSE
|
|
10
|
+
Requires-Dist: torch
|
|
11
|
+
Dynamic: author
|
|
12
|
+
Dynamic: description
|
|
13
|
+
Dynamic: description-content-type
|
|
14
|
+
Dynamic: home-page
|
|
15
|
+
Dynamic: license-file
|
|
16
|
+
Dynamic: requires-dist
|
|
17
|
+
Dynamic: requires-python
|
|
18
|
+
Dynamic: summary
|
|
19
|
+
|
|
20
|
+
# About
|
|
21
|
+
PyTorch extension package for modified Bessel functions of the second kind with complex inputs
|
|
22
|
+
|
|
23
|
+
# Install
|
|
24
|
+
Currently only supports Linux (with CUDA 12.4) or MacOS (Apple silicon, cpu only) with python >= 3.9, <= 3.12.
|
|
25
|
+
```
|
|
26
|
+
pip install torch-bessel
|
|
27
|
+
```
|
|
28
|
+
|
|
29
|
+
UPDATE (May 26, 2026): It seems the package might no longer be compatible with the latest pytorch version (>= 2.12.0), so try installing an earlier version of pytorch if you encounter issues.
|
|
30
|
+
|
|
31
|
+
UPDATE (May 28, 2026): On Apple silicon, importing this package with pytorch version >= 2.9.0 causes a crash (bus error: 10) when exiting a python program, though the program runs normally prior to that. So consider installing torch < 2.9.0 on Apple silicon.
|
|
32
|
+
|
|
33
|
+
# Example
|
|
34
|
+
```
|
|
35
|
+
import torch_bessel
|
|
36
|
+
|
|
37
|
+
real, imag = torch.randn(2, 5, device="cuda")
|
|
38
|
+
z = torch.complex(real.abs(), imag) # correctness for inputs in the left-half complex plane is not gauranteed.
|
|
39
|
+
torch_bessel.ops.modified_bessel_k0(z)
|
|
40
|
+
```
|
|
41
|
+
|
|
42
|
+
# Implemented functions
|
|
43
|
+
- `modified_bessel_k0`: Same as `torch.special.modified_bessel_k0`, but also handles backpropagation and complex inputs on cpu and cuda. Correctness is guaranteed on the right-half complex plane. On cuda, `torch.chalf` inputs are also supported, though the underlying cuda kernel just upcasts `chalf` to `cfloat` (note that this uses no extra GPU memory, as opposed to manually casting torch.chalf to torch.cfloat before calling `modified_bessel_k0` which doubles the GPU memory used). On the left-half complex plane, function output appears mostly correct, but with small numerical errors for certain inputs. On the negative real line, output is NaN.
|
|
44
|
+
- `modified_bessel_k1`: Same as `torch.special.modified_bessel_k1`, but also handles complex inputs on cpu and cuda. Backpropagation not implemented, but this can be easily manually implemented yourself by writing a torch.autograd.Function using the recurrence properties of bessel functions. Same caveats as `modified_bessel_k0` apply.
|
|
45
|
+
|
|
46
|
+
# WIP
|
|
47
|
+
- `modified_bessel_kv`: Analogue of `scipy.special.kv`.
|
|
48
|
+
|
|
49
|
+
# Benchmarks
|
|
50
|
+
Benchmarking performed with the `asv` package. Results can be viewed at https://hchau630.github.io/torch-bessel.
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
pyproject.toml
|
|
4
|
+
setup.py
|
|
5
|
+
benchmarks/__init__.py
|
|
6
|
+
benchmarks/benchmarks.py
|
|
7
|
+
tests/test_extension.py
|
|
8
|
+
torch_bessel/__init__.py
|
|
9
|
+
torch_bessel/ops.py
|
|
10
|
+
torch_bessel.egg-info/PKG-INFO
|
|
11
|
+
torch_bessel.egg-info/SOURCES.txt
|
|
12
|
+
torch_bessel.egg-info/dependency_links.txt
|
|
13
|
+
torch_bessel.egg-info/requires.txt
|
|
14
|
+
torch_bessel.egg-info/top_level.txt
|
|
15
|
+
torch_bessel/csrc/iterator.cpp
|
|
16
|
+
torch_bessel/csrc/torch_bessel.cpp
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
torch
|