torch-bessel 0.0.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Ho Yin Chau
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,50 @@
1
+ Metadata-Version: 2.4
2
+ Name: torch_bessel
3
+ Version: 0.0.9
4
+ Summary: PyTorch extension package for Bessel functions with arbitrary real order and complex inputs
5
+ Home-page: https://github.com/hchau630/torch-bessel
6
+ Author: Ho Yin Chau
7
+ Requires-Python: >= 3.10
8
+ Description-Content-Type: text/markdown
9
+ License-File: LICENSE
10
+ Requires-Dist: torch
11
+ Dynamic: author
12
+ Dynamic: description
13
+ Dynamic: description-content-type
14
+ Dynamic: home-page
15
+ Dynamic: license-file
16
+ Dynamic: requires-dist
17
+ Dynamic: requires-python
18
+ Dynamic: summary
19
+
20
+ # About
21
+ PyTorch extension package for modified Bessel functions of the second kind with complex inputs
22
+
23
+ # Install
24
+ Currently only supports Linux (with CUDA 12.4) or MacOS (Apple silicon, cpu only) with python >= 3.9, <= 3.12.
25
+ ```
26
+ pip install torch-bessel
27
+ ```
28
+
29
+ UPDATE (May 26, 2026): It seems the package might no longer be compatible with the latest pytorch version (>= 2.12.0), so try installing an earlier version of pytorch if you encounter issues.
30
+
31
+ UPDATE (May 28, 2026): On Apple silicon, importing this package with pytorch version >= 2.9.0 causes a crash (bus error: 10) when exiting a python program, though the program runs normally prior to that. So consider installing torch < 2.9.0 on Apple silicon.
32
+
33
+ # Example
34
+ ```
35
+ import torch_bessel
36
+
37
+ real, imag = torch.randn(2, 5, device="cuda")
38
+ z = torch.complex(real.abs(), imag) # correctness for inputs in the left-half complex plane is not gauranteed.
39
+ torch_bessel.ops.modified_bessel_k0(z)
40
+ ```
41
+
42
+ # Implemented functions
43
+ - `modified_bessel_k0`: Same as `torch.special.modified_bessel_k0`, but also handles backpropagation and complex inputs on cpu and cuda. Correctness is guaranteed on the right-half complex plane. On cuda, `torch.chalf` inputs are also supported, though the underlying cuda kernel just upcasts `chalf` to `cfloat` (note that this uses no extra GPU memory, as opposed to manually casting torch.chalf to torch.cfloat before calling `modified_bessel_k0` which doubles the GPU memory used). On the left-half complex plane, function output appears mostly correct, but with small numerical errors for certain inputs. On the negative real line, output is NaN.
44
+ - `modified_bessel_k1`: Same as `torch.special.modified_bessel_k1`, but also handles complex inputs on cpu and cuda. Backpropagation not implemented, but this can be easily manually implemented yourself by writing a torch.autograd.Function using the recurrence properties of bessel functions. Same caveats as `modified_bessel_k0` apply.
45
+
46
+ # WIP
47
+ - `modified_bessel_kv`: Analogue of `scipy.special.kv`.
48
+
49
+ # Benchmarks
50
+ Benchmarking performed with the `asv` package. Results can be viewed at https://hchau630.github.io/torch-bessel.
@@ -0,0 +1,31 @@
1
+ # About
2
+ PyTorch extension package for modified Bessel functions of the second kind with complex inputs
3
+
4
+ # Install
5
+ Currently only supports Linux (with CUDA 12.4) or MacOS (Apple silicon, cpu only) with python >= 3.9, <= 3.12.
6
+ ```
7
+ pip install torch-bessel
8
+ ```
9
+
10
+ UPDATE (May 26, 2026): It seems the package might no longer be compatible with the latest pytorch version (>= 2.12.0), so try installing an earlier version of pytorch if you encounter issues.
11
+
12
+ UPDATE (May 28, 2026): On Apple silicon, importing this package with pytorch version >= 2.9.0 causes a crash (bus error: 10) when exiting a python program, though the program runs normally prior to that. So consider installing torch < 2.9.0 on Apple silicon.
13
+
14
+ # Example
15
+ ```
16
+ import torch_bessel
17
+
18
+ real, imag = torch.randn(2, 5, device="cuda")
19
+ z = torch.complex(real.abs(), imag) # correctness for inputs in the left-half complex plane is not gauranteed.
20
+ torch_bessel.ops.modified_bessel_k0(z)
21
+ ```
22
+
23
+ # Implemented functions
24
+ - `modified_bessel_k0`: Same as `torch.special.modified_bessel_k0`, but also handles backpropagation and complex inputs on cpu and cuda. Correctness is guaranteed on the right-half complex plane. On cuda, `torch.chalf` inputs are also supported, though the underlying cuda kernel just upcasts `chalf` to `cfloat` (note that this uses no extra GPU memory, as opposed to manually casting torch.chalf to torch.cfloat before calling `modified_bessel_k0` which doubles the GPU memory used). On the left-half complex plane, function output appears mostly correct, but with small numerical errors for certain inputs. On the negative real line, output is NaN.
25
+ - `modified_bessel_k1`: Same as `torch.special.modified_bessel_k1`, but also handles complex inputs on cpu and cuda. Backpropagation not implemented, but this can be easily manually implemented yourself by writing a torch.autograd.Function using the recurrence properties of bessel functions. Same caveats as `modified_bessel_k0` apply.
26
+
27
+ # WIP
28
+ - `modified_bessel_kv`: Analogue of `scipy.special.kv`.
29
+
30
+ # Benchmarks
31
+ Benchmarking performed with the `asv` package. Results can be viewed at https://hchau630.github.io/torch-bessel.
File without changes
@@ -0,0 +1,143 @@
1
+ import torch
2
+ from asv_runner.benchmarks.mark import skip_benchmark_if
3
+
4
+ import torch_bessel
5
+
6
+
7
+ def _setup(
8
+ n, is_real, singularity, dtype=torch.float, requires_grad=False, device="cpu"
9
+ ):
10
+ kwargs = {"dtype": dtype, "requires_grad": requires_grad, "device": device}
11
+ real = torch.randn(n, **kwargs).abs()
12
+ if is_real:
13
+ args = (real, singularity)
14
+ else:
15
+ imag = torch.randn(n, **kwargs)
16
+ args = (torch.complex(real, imag), singularity)
17
+ return args
18
+
19
+
20
+ def _setup_k1(n, is_real, dtype=torch.float, device="cpu"):
21
+ kwargs = {"dtype": dtype, "device": device}
22
+ real = torch.randn(n, **kwargs).abs()
23
+ if is_real:
24
+ args = (real,)
25
+ else:
26
+ imag = torch.randn(n, **kwargs)
27
+ args = (torch.complex(real, imag),)
28
+ return args
29
+
30
+
31
+ class ModifiedBesselK0ForwardCPU:
32
+ params = (
33
+ [10_000, 100_000, 1_000_000],
34
+ [False, True],
35
+ [None, 0.0],
36
+ [torch.float32, torch.float64],
37
+ [False, True],
38
+ )
39
+ param_names = ["n", "is_real", "singularity", "dtype", "requires_grad"]
40
+
41
+ def setup(self, n, is_real, singularity, dtype, requires_grad):
42
+ self.args = _setup(n, is_real, singularity, dtype, requires_grad)
43
+
44
+ def time_modified_bessel_k0_forward_cpu(
45
+ self, n, is_real, singularity, dtype, requires_grad
46
+ ):
47
+ torch_bessel.ops.modified_bessel_k0(*self.args)
48
+
49
+
50
+ class ModifiedBesselK0ForwardCUDA:
51
+ params = (
52
+ [10_000, 100_000, 1_000_000],
53
+ [False, True],
54
+ [None, 0.0],
55
+ [torch.float32, torch.float64],
56
+ [False, True],
57
+ )
58
+ param_names = ["n", "is_real", "singularity", "dtype", "requires_grad"]
59
+
60
+ def setup(self, n, is_real, singularity, dtype, requires_grad):
61
+ self.args = _setup(n, is_real, singularity, dtype, requires_grad, device="cuda")
62
+
63
+ @skip_benchmark_if(not torch.cuda.is_available())
64
+ def time_modified_bessel_k0_forward_cuda(
65
+ self, n, is_real, singularity, dtype, requires_grad
66
+ ):
67
+ torch.cuda.synchronize()
68
+ torch_bessel.ops.modified_bessel_k0(*self.args)
69
+ torch.cuda.synchronize()
70
+
71
+
72
+ class ModifiedBesselK0BackwardCPU:
73
+ warmup_time = 0.0 # for some reason backward is called multiple times if not 0
74
+ number = 1 # Avoids calling backward multiple times
75
+ params = (
76
+ [10_000, 100_000, 1_000_000],
77
+ [False, True],
78
+ [None, 0.0],
79
+ [torch.float32, torch.float64],
80
+ )
81
+ param_names = ["n", "is_real", "singularity", "dtype"]
82
+
83
+ def setup(self, n, is_real, singularity, dtype):
84
+ args = _setup(n, is_real, singularity, dtype, requires_grad=True)
85
+ self.out = torch_bessel.ops.modified_bessel_k0(*args).norm()
86
+
87
+ def time_modified_bessel_k0_backward_cpu(self, n, is_real, singularity, dtype):
88
+ self.out.backward()
89
+
90
+
91
+ class ModifiedBesselK0BackwardCUDA:
92
+ warmup_time = 0.0 # for some reason backward is called multiple times if not 0
93
+ number = 1 # Avoids calling backward multiple times
94
+ params = (
95
+ [10_000, 100_000, 1_000_000],
96
+ [False, True],
97
+ [None, 0.0],
98
+ [torch.float32, torch.float64],
99
+ )
100
+ param_names = ["n", "is_real", "singularity", "dtype"]
101
+
102
+ def setup(self, n, is_real, singularity, dtype):
103
+ args = _setup(n, is_real, singularity, dtype, requires_grad=True, device="cuda")
104
+ self.out = torch_bessel.ops.modified_bessel_k0(*args).norm()
105
+
106
+ @skip_benchmark_if(not torch.cuda.is_available())
107
+ def time_modified_bessel_k0_backward_cuda(self, n, is_real, singularity, dtype):
108
+ torch.cuda.synchronize()
109
+ self.out.backward()
110
+ torch.cuda.synchronize()
111
+
112
+
113
+ class ModifiedBesselK1ForwardCPU:
114
+ params = (
115
+ [10_000, 100_000, 1_000_000],
116
+ [False, True],
117
+ [torch.float32, torch.float64],
118
+ )
119
+ param_names = ["n", "is_real", "dtype"]
120
+
121
+ def setup(self, n, is_real, dtype):
122
+ self.args = _setup_k1(n, is_real, dtype)
123
+
124
+ def time_modified_bessel_k1_forward_cpu(self, n, is_real, dtype):
125
+ torch_bessel.ops.modified_bessel_k1(*self.args)
126
+
127
+
128
+ class ModifiedBesselK1ForwardCUDA:
129
+ params = (
130
+ [10_000, 100_000, 1_000_000],
131
+ [False, True],
132
+ [torch.float32, torch.float64],
133
+ )
134
+ param_names = ["n", "is_real", "dtype"]
135
+
136
+ def setup(self, n, is_real, dtype):
137
+ self.args = _setup_k1(n, is_real, dtype, device="cuda")
138
+
139
+ @skip_benchmark_if(not torch.cuda.is_available())
140
+ def time_modified_bessel_k1_forward_cuda(self, n, is_real, dtype):
141
+ torch.cuda.synchronize()
142
+ torch_bessel.ops.modified_bessel_k1(*self.args)
143
+ torch.cuda.synchronize()
@@ -0,0 +1,9 @@
1
+ [build-system]
2
+ requires = [
3
+ "setuptools",
4
+ "torch>=2.4.0",
5
+ # "torch>=2.7.0; platform_system == 'Darwin'",
6
+ # "torch>=2.4.0; platform_system != 'Darwin'",
7
+ "ninja",
8
+ ]
9
+ build-backend = "setuptools.build_meta"
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,80 @@
1
+ import os
2
+ import glob
3
+
4
+ from setuptools import find_packages, setup
5
+
6
+ from torch.utils.cpp_extension import (
7
+ CppExtension,
8
+ CUDAExtension,
9
+ BuildExtension,
10
+ CUDA_HOME,
11
+ )
12
+
13
+ library_name = "torch_bessel"
14
+
15
+
16
+ def get_extensions():
17
+ debug_mode = os.getenv("DEBUG", "0") == "1"
18
+ use_cuda = os.getenv("USE_CUDA", "1") == "1"
19
+ if debug_mode:
20
+ print("Compiling in debug mode")
21
+
22
+ use_cuda = use_cuda and CUDA_HOME is not None
23
+ extension = CUDAExtension if use_cuda else CppExtension
24
+
25
+ extra_link_args = []
26
+ extra_compile_args = {
27
+ "cxx": [
28
+ "-O3" if not debug_mode else "-O0",
29
+ "-std=c++17",
30
+ "-fdiagnostics-color=always",
31
+ "-DPy_LIMITED_API=0x030A0000",
32
+ ],
33
+ "nvcc": [
34
+ "-O3" if not debug_mode else "-O0",
35
+ "--extended-lambda",
36
+ ],
37
+ }
38
+ if debug_mode:
39
+ extra_compile_args["cxx"].append("-g")
40
+ extra_compile_args["nvcc"].append("-g")
41
+ extra_link_args.extend(["-O0", "-g"])
42
+
43
+ this_dir = os.path.dirname(os.path.curdir)
44
+ extensions_dir = os.path.join(this_dir, library_name, "csrc")
45
+ sources = list(glob.glob(os.path.join(extensions_dir, "*.cpp")))
46
+
47
+ extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
48
+ cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "*.cu")))
49
+
50
+ if use_cuda:
51
+ sources += cuda_sources
52
+
53
+ ext_modules = [
54
+ extension(
55
+ f"{library_name}._C",
56
+ sources,
57
+ extra_compile_args=extra_compile_args,
58
+ extra_link_args=extra_link_args,
59
+ py_limited_api=True,
60
+ )
61
+ ]
62
+
63
+ return ext_modules
64
+
65
+
66
+ setup(
67
+ name=library_name,
68
+ version="0.0.9",
69
+ author="Ho Yin Chau",
70
+ packages=find_packages(),
71
+ ext_modules=get_extensions(),
72
+ install_requires=["torch"],
73
+ python_requires=">= 3.10",
74
+ description="PyTorch extension package for Bessel functions with arbitrary real order and complex inputs",
75
+ long_description=open("README.md").read(),
76
+ long_description_content_type="text/markdown",
77
+ url="https://github.com/hchau630/torch-bessel",
78
+ cmdclass={"build_ext": BuildExtension},
79
+ options={"bdist_wheel": {"py_limited_api": "cp310"}},
80
+ )
@@ -0,0 +1,281 @@
1
+ import torch
2
+ from torch.testing._internal.common_utils import TestCase
3
+ from torch.testing._internal.optests import opcheck
4
+ import pytest
5
+ from scipy import special
6
+
7
+ import torch_bessel
8
+
9
+
10
+ def reference_modified_bessel_k0(z, singularity=None):
11
+ device = z.device
12
+ dtype = z.dtype
13
+ if dtype is torch.chalf:
14
+ z = z.to(torch.cfloat)
15
+ out = special.kv(0.0, z.detach().cpu()).to(device).to(dtype)
16
+ if singularity is not None:
17
+ out = out.where(z != 0, singularity)
18
+ return out
19
+
20
+
21
+ def reference_modified_bessel_k1(z):
22
+ device = z.device
23
+ dtype = z.dtype
24
+ if dtype is torch.chalf:
25
+ z = z.to(torch.cfloat)
26
+ out = special.kv(1.0, z.detach().cpu()).to(device).to(dtype)
27
+ return out
28
+
29
+
30
+ class TestBesselK0(TestCase):
31
+ def sample_inputs(self, device, *, requires_grad=False):
32
+ def make_z(*size, dtype=None):
33
+ kwargs = dict(dtype=dtype, device=device, requires_grad=requires_grad)
34
+ real = torch.randn(size, **kwargs).abs()
35
+ imag = torch.randn(size, **kwargs)
36
+ return torch.complex(real, imag)
37
+
38
+ out = [
39
+ [make_z(50, dtype=torch.double)],
40
+ [make_z(50, dtype=torch.double).real],
41
+ [make_z(50)],
42
+ [make_z(50).real],
43
+ [torch.tensor(70.1162 + 89.0190j)], # used to fail on this input
44
+ ]
45
+ if device == "cuda":
46
+ # half precision is only supported on CUDA
47
+ out.append([make_z(50, dtype=torch.half)])
48
+
49
+ return out
50
+
51
+ def grid_inputs(self, device, *, requires_grad=False):
52
+ def make_z(*args, dtype=torch.float):
53
+ real = torch.tensor([0, *torch.logspace(*args, dtype=torch.double)])
54
+ imag = torch.logspace(*args, dtype=torch.double)
55
+
56
+ # Don't test on subnormal numbers, since the AMOS code doesn't consider them
57
+ real[real < torch.finfo(dtype).tiny] = 0.0
58
+ imag[imag < torch.finfo(dtype).tiny] = 0.0
59
+
60
+ kwargs = dict(dtype=dtype, device=device, requires_grad=requires_grad)
61
+ real = torch.tensor(real, **kwargs)
62
+ imag = torch.tensor([*(-imag.flip(0)), 0, *imag], **kwargs)
63
+ real, imag = real[:, None], imag[None, :]
64
+ return torch.complex(real, imag)
65
+
66
+ out = [
67
+ [make_z(-350, 350, 75, dtype=torch.double)],
68
+ [make_z(-50, 50, 75)],
69
+ [make_z(-5, 5, 75, dtype=torch.double)], # scipy is buggy on this input
70
+ [make_z(-5, 5, 75)], # scipy is buggy on this input
71
+ [make_z(-50, 50, 75), 1.0],
72
+ [make_z(-50, 50, 75), torch.randn((76, 151), device=device)],
73
+ ]
74
+ if device == "cuda":
75
+ # half precision is only supported on CUDA
76
+ out.append([make_z(-5, 5, 75, dtype=torch.half)])
77
+ return out
78
+
79
+ def _test_correctness(self, device):
80
+ samples = (
81
+ self.sample_inputs(device)
82
+ + self.grid_inputs(device)
83
+ + self.sample_inputs(device, requires_grad=True)
84
+ + self.grid_inputs(device, requires_grad=True)
85
+ )
86
+ for args in samples:
87
+ result = torch_bessel.ops.modified_bessel_k0(*args)
88
+ expected = reference_modified_bessel_k0(*args)
89
+ if expected.dtype in {torch.float, torch.cfloat, torch.chalf}:
90
+ # ierr = 4, complete loss of significance
91
+ expected[args[0].abs() > 4194303.98419452] = torch.nan
92
+ # ierr = 2, overflow
93
+ mask = (args[0].abs() < 1.1754944e-35) & (args[0] != 0)
94
+ if expected.is_complex():
95
+ expected[mask] = torch.nan
96
+ expected[mask & (args[0].imag == 0) & (args[0] != 0)] = torch.inf
97
+ else:
98
+ expected[mask & (args[0] != 0)] = torch.inf
99
+ # fix scipy bug. See https://github.com/scipy/xsf/issues/46
100
+ if args[0].is_complex():
101
+ mask = (
102
+ (args[0].real > 685)
103
+ & (args[0].real < 690)
104
+ & (args[0].imag.abs() > 685)
105
+ & (args[0].imag.abs() < 1.1e5)
106
+ )
107
+ expected[mask] = 0.0 # scipy returns incorrect/NaN output
108
+
109
+ torch.testing.assert_close(result, expected, equal_nan=True)
110
+
111
+ def test_correctness_cpu(self):
112
+ self._test_correctness("cpu")
113
+
114
+ @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda")
115
+ def test_correctness_cuda(self):
116
+ self._test_correctness("cuda")
117
+
118
+ def _test_gradients(self, device):
119
+ samples = self.sample_inputs(device, requires_grad=True)
120
+ for args in samples:
121
+ if (
122
+ args[0].dtype in {torch.double, torch.complex128}
123
+ and args[0].requires_grad
124
+ ):
125
+ torch.autograd.gradcheck(torch_bessel.ops.modified_bessel_k0, args)
126
+
127
+ def test_gradients_cpu(self):
128
+ self._test_gradients("cpu")
129
+
130
+ @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda")
131
+ def test_gradients_cuda(self):
132
+ self._test_gradients("cuda")
133
+
134
+ def _opcheck(self, device):
135
+ # Use opcheck to check for incorrect usage of operator registration APIs
136
+ samples = self.sample_inputs(device, requires_grad=False)
137
+ samples.extend(self.sample_inputs(device, requires_grad=True))
138
+ for args in samples:
139
+ if not args[0].is_complex():
140
+ continue
141
+ opcheck(
142
+ torch.ops.torch_bessel.modified_bessel_k0_complex_forward_backward.default,
143
+ args,
144
+ )
145
+
146
+ def test_opcheck_cpu(self):
147
+ self._opcheck("cpu")
148
+
149
+ @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda")
150
+ def test_opcheck_cuda(self):
151
+ self._opcheck("cuda")
152
+
153
+
154
+ @pytest.mark.parametrize(
155
+ "singularity, in_dims",
156
+ [
157
+ (None, (0, None)),
158
+ (0.0, (0, None)),
159
+ ((4, 3), (0, 1)),
160
+ ],
161
+ )
162
+ def test_vmap(singularity, in_dims):
163
+ z = torch.randn(3, 4).abs() + torch.randn(3, 4) * 1j
164
+ if isinstance(singularity, tuple):
165
+ singularity = torch.randn(singularity)
166
+ func = torch.func.vmap(torch_bessel.ops.modified_bessel_k0, in_dims=in_dims)
167
+ out = func(z, singularity)
168
+ if not isinstance(singularity, torch.Tensor):
169
+ expected = torch_bessel.ops.modified_bessel_k0(z, singularity)
170
+ else:
171
+ expected = torch_bessel.ops.modified_bessel_k0(z, singularity.t())
172
+ torch.testing.assert_close(out, expected)
173
+
174
+
175
+ class TestBesselK1(TestCase):
176
+ def sample_inputs(self, device, *, requires_grad=False):
177
+ def make_z(*size, dtype=None):
178
+ kwargs = dict(dtype=dtype, device=device, requires_grad=requires_grad)
179
+ real = torch.randn(size, **kwargs).abs()
180
+ imag = torch.randn(size, **kwargs)
181
+ return torch.complex(real, imag)
182
+
183
+ out = [
184
+ [make_z(50, dtype=torch.double)],
185
+ [make_z(50, dtype=torch.double).real],
186
+ [make_z(50)],
187
+ [make_z(50).real],
188
+ [torch.tensor(70.1162 + 89.0190j)], # used to fail on this input
189
+ ]
190
+ if device == "cuda":
191
+ # half precision is only supported on CUDA
192
+ out.append([make_z(50, dtype=torch.half)])
193
+
194
+ return out
195
+
196
+ def grid_inputs(self, device, *, requires_grad=False):
197
+ def make_z(*args, dtype=torch.float):
198
+ real = torch.tensor([0, *torch.logspace(*args, dtype=torch.double)])
199
+ imag = torch.logspace(*args, dtype=torch.double)
200
+
201
+ # Don't test on subnormal numbers, since the AMOS code doesn't consider them
202
+ real[real < torch.finfo(dtype).tiny] = 0.0
203
+ imag[imag < torch.finfo(dtype).tiny] = 0.0
204
+
205
+ kwargs = dict(dtype=dtype, device=device, requires_grad=requires_grad)
206
+ real = torch.tensor(real, **kwargs)
207
+ imag = torch.tensor([*(-imag.flip(0)), 0, *imag], **kwargs)
208
+ real, imag = real[:, None], imag[None, :]
209
+ return torch.complex(real, imag)
210
+
211
+ out = [
212
+ [make_z(-350, 350, 75, dtype=torch.double)],
213
+ [make_z(-50, 50, 75)],
214
+ [make_z(-5, 5, 75, dtype=torch.double)], # scipy is buggy on this input
215
+ [make_z(-5, 5, 75)], # scipy is buggy on this input
216
+ ]
217
+ if device == "cuda":
218
+ # half precision is only supported on CUDA
219
+ out.append([make_z(-5, 5, 75, dtype=torch.half)])
220
+ return out
221
+
222
+ def _test_correctness(self, device):
223
+ samples = self.sample_inputs(device) + self.grid_inputs(device)
224
+ for args in samples:
225
+ result = torch_bessel.ops.modified_bessel_k1(*args)
226
+ expected = reference_modified_bessel_k1(*args)
227
+ if expected.dtype in {torch.float, torch.cfloat, torch.chalf}:
228
+ # ierr = 4, complete loss of significance
229
+ expected[args[0].abs() > 4194303.98419452] = torch.nan
230
+ # ierr = 2, overflow
231
+ mask = (args[0].abs() < 1.1754944e-35) & (args[0] != 0)
232
+ if expected.is_complex():
233
+ expected[mask] = torch.nan
234
+ expected[mask & (args[0].imag == 0) & (args[0] != 0)] = torch.inf
235
+ else:
236
+ expected[mask & (args[0] != 0)] = torch.inf
237
+ # fix scipy bug. See https://github.com/scipy/xsf/issues/46
238
+ if args[0].is_complex():
239
+ mask = (
240
+ (args[0].real > 685)
241
+ & (args[0].real < 690)
242
+ & (args[0].imag.abs() > 685)
243
+ & (args[0].imag.abs() < 1.1e5)
244
+ )
245
+ expected[mask] = 0.0 # scipy returns incorrect/NaN output
246
+ torch.testing.assert_close(result, expected, equal_nan=True)
247
+
248
+ def test_correctness_cpu(self):
249
+ self._test_correctness("cpu")
250
+
251
+ @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda")
252
+ def test_correctness_cuda(self):
253
+ self._test_correctness("cuda")
254
+
255
+ def _opcheck(self, device):
256
+ # Use opcheck to check for incorrect usage of operator registration APIs
257
+ samples = self.sample_inputs(device, requires_grad=False)
258
+ for args in samples:
259
+ if not args[0].is_complex():
260
+ continue
261
+ opcheck(
262
+ torch.ops.torch_bessel.modified_bessel_k1_complex_forward.default, args
263
+ )
264
+
265
+ def test_opcheck_cpu(self):
266
+ self._opcheck("cpu")
267
+
268
+ @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda")
269
+ def test_opcheck_cuda(self):
270
+ self._opcheck("cuda")
271
+
272
+
273
+ @pytest.mark.parametrize("in_dims", [0, 1])
274
+ def test_vmap_k1(in_dims):
275
+ z = torch.randn(3, 4).abs() + torch.randn(3, 4) * 1j
276
+ func = torch.func.vmap(torch_bessel.ops.modified_bessel_k1, in_dims=in_dims)
277
+ out = func(z)
278
+ expected = torch_bessel.ops.modified_bessel_k1(z)
279
+ if in_dims == 1:
280
+ expected = expected.t()
281
+ torch.testing.assert_close(out, expected)
@@ -0,0 +1,11 @@
1
+ import warnings
2
+
3
+ # from . import ops
4
+ from . import _C, ops
5
+
6
+ # Silence warning emitted at torch/nested/_internal/nested_tensor.py:417
7
+ warnings.filterwarnings(
8
+ "ignore", message="Failed to initialize NumPy: No module named 'numpy'"
9
+ )
10
+
11
+ __all__ = ["ops"]
@@ -0,0 +1,60 @@
1
+ // #include <torch/extension.h>
2
+ #include <torch/library.h>
3
+ #include <ATen/TensorIterator.h>
4
+
5
+
6
+ at::TensorIterator build_iterator_11(const at::Tensor& result, const at::Tensor& z) {
7
+ return (
8
+ at::TensorIteratorConfig()
9
+ .set_check_mem_overlap(true)
10
+ .allow_cpu_scalars(true)
11
+ .enforce_safe_casting_to_output(true)
12
+ .add_output(result)
13
+ .add_input(z)
14
+ ).build();
15
+ }
16
+
17
+
18
+ at::TensorIterator build_iterator_12(const at::Tensor& result, const at::Tensor& v, const at::Tensor& z) {
19
+ return (
20
+ at::TensorIteratorConfig()
21
+ .set_check_mem_overlap(true)
22
+ .allow_cpu_scalars(true)
23
+ .promote_inputs_to_common_dtype(true)
24
+ .cast_common_dtype_to_outputs(true)
25
+ .enforce_safe_casting_to_output(true)
26
+ .promote_integer_inputs_to_float(true)
27
+ .add_output(result)
28
+ .add_input(v)
29
+ .add_input(z)
30
+ ).build();
31
+ }
32
+
33
+ at::TensorIterator build_iterator_21(const at::Tensor& result1, const at::Tensor& result2, const at::Tensor& z) {
34
+ return (
35
+ at::TensorIteratorConfig()
36
+ .set_check_mem_overlap(true)
37
+ .allow_cpu_scalars(true)
38
+ .enforce_safe_casting_to_output(true)
39
+ .add_output(result1)
40
+ .add_output(result2)
41
+ .add_input(z)
42
+ ).build();
43
+ }
44
+
45
+
46
+ at::TensorIterator build_iterator_22(const at::Tensor& result1, const at::Tensor& result2, const at::Tensor& v, const at::Tensor& z) {
47
+ return (
48
+ at::TensorIteratorConfig()
49
+ .set_check_mem_overlap(true)
50
+ .allow_cpu_scalars(true)
51
+ .promote_inputs_to_common_dtype(true)
52
+ .cast_common_dtype_to_outputs(true)
53
+ .enforce_safe_casting_to_output(true)
54
+ .promote_integer_inputs_to_float(true)
55
+ .add_output(result1)
56
+ .add_output(result2)
57
+ .add_input(v)
58
+ .add_input(z)
59
+ ).build();
60
+ }
@@ -0,0 +1,94 @@
1
+ #include <Python.h>
2
+ // #include <torch/extension.h>
3
+ #include <torch/library.h>
4
+ #include <ATen/ATen.h>
5
+ #include <ATen/TensorIterator.h>
6
+ #include <ATen/native/cpu/Loops.h>
7
+ #include <c10/util/complex.h>
8
+ #include <iostream>
9
+
10
+ #include "bessel_k.h"
11
+ #include "iterator.h"
12
+
13
+ // See https://docs.pytorch.org/tutorials/advanced/cpp_custom_ops.html
14
+ extern "C" {
15
+ /* Creates a dummy empty _C module that can be imported from Python.
16
+ The import from Python will load the .so consisting of this file
17
+ in this extension, so that the TORCH_LIBRARY static initializers
18
+ below are run. */
19
+ PyObject* PyInit__C(void)
20
+ {
21
+ static struct PyModuleDef module_def = {
22
+ PyModuleDef_HEAD_INIT,
23
+ "_C", /* name of module */
24
+ NULL, /* module documentation, may be NULL */
25
+ -1, /* size of per-interpreter state of the module,
26
+ or -1 if the module keeps state in global variables. */
27
+ NULL, /* methods */
28
+ };
29
+ return PyModule_Create(&module_def);
30
+ }
31
+ }
32
+
33
+ namespace torch_bessel {
34
+
35
+ at::Tensor modified_bessel_k0_complex_forward_cpu(const at::Tensor& z) {
36
+ TORCH_INTERNAL_ASSERT(z.device().type() == at::DeviceType::CPU);
37
+ at::ScalarType dtype = z.scalar_type();
38
+ at::Tensor result = at::empty(at::IntArrayRef(), dtype).resize_(0);
39
+ at::TensorIterator iter = build_iterator_11(result, z);
40
+ AT_DISPATCH_COMPLEX_TYPES(dtype, "modified_bessel_k0_complex_forward_cpu", [&]() {
41
+ at::native::cpu_kernel(iter, [](scalar_t z) -> scalar_t {
42
+ return modified_bessel_k0_complex_forward(z);
43
+ });
44
+ });
45
+ return result;
46
+ }
47
+
48
+ std::tuple<at::Tensor, at::Tensor> modified_bessel_k0_complex_forward_backward_cpu(const at::Tensor& z) {
49
+ TORCH_INTERNAL_ASSERT(z.device().type() == at::DeviceType::CPU);
50
+ at::ScalarType dtype = z.scalar_type();
51
+ at::Tensor result1 = at::empty(at::IntArrayRef(), dtype).resize_(0);
52
+ at::Tensor result2 = at::empty(at::IntArrayRef(), dtype).resize_(0);
53
+ at::TensorIterator iter = build_iterator_21(result1, result2, z);
54
+ AT_DISPATCH_COMPLEX_TYPES(dtype, "modified_bessel_k0_complex_forward_backward_cpu", [&]() {
55
+ at::native::cpu_kernel_multiple_outputs(iter, [](scalar_t z) -> std::tuple<scalar_t, scalar_t> {
56
+ scalar_t cy[2];
57
+ modified_bessel_k0_complex_forward_backward(z, cy);
58
+ return std::make_tuple(cy[0], cy[1]);
59
+ });
60
+ });
61
+ return std::make_tuple(result1, result2);
62
+ }
63
+
64
+ at::Tensor modified_bessel_k1_complex_forward_cpu(const at::Tensor& z) {
65
+ TORCH_INTERNAL_ASSERT(z.device().type() == at::DeviceType::CPU);
66
+ at::ScalarType dtype = z.scalar_type();
67
+ at::Tensor result = at::empty(at::IntArrayRef(), dtype).resize_(0);
68
+ at::TensorIterator iter = build_iterator_11(result, z);
69
+ AT_DISPATCH_COMPLEX_TYPES(dtype, "modified_bessel_k1_complex_forward_cpu", [&]() {
70
+ at::native::cpu_kernel(iter, [](scalar_t z) -> scalar_t {
71
+ return modified_bessel_k1_complex_forward(z);
72
+ });
73
+ });
74
+ return result;
75
+ }
76
+
77
+ // // Registers _C as a Python extension module.
78
+ // PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
79
+
80
+ // Defines the operators
81
+ TORCH_LIBRARY(torch_bessel, m) {
82
+ m.def("modified_bessel_k0_complex_forward(Tensor z) -> Tensor");
83
+ m.def("modified_bessel_k0_complex_forward_backward(Tensor z) -> (Tensor, Tensor)");
84
+ m.def("modified_bessel_k1_complex_forward(Tensor z) -> Tensor");
85
+ }
86
+
87
+ // Registers CPU implementations for bessel_k
88
+ TORCH_LIBRARY_IMPL(torch_bessel, CPU, m) {
89
+ m.impl("modified_bessel_k0_complex_forward", &modified_bessel_k0_complex_forward_cpu);
90
+ m.impl("modified_bessel_k0_complex_forward_backward", &modified_bessel_k0_complex_forward_backward_cpu);
91
+ m.impl("modified_bessel_k1_complex_forward", &modified_bessel_k1_complex_forward_cpu);
92
+ }
93
+
94
+ }
@@ -0,0 +1,134 @@
1
+ # from pathlib import Path
2
+ from numbers import Number
3
+ from typing import Union
4
+
5
+ import torch
6
+ from torch import Tensor
7
+
8
+ __all__ = ["modified_bessel_k0", "modified_bessel_k1"]
9
+
10
+ # # load C extension before calling torch.library API, see
11
+ # # https://pytorch.org/tutorials/advanced/cpp_custom_ops.html
12
+ # so_dir = Path(__file__).parent
13
+ # so_files = list(so_dir.glob("_C*.so"))
14
+ # assert len(so_files) == 1, (
15
+ # f"Expected one _C*.so file at {so_dir}, found {len(so_files)}"
16
+ # )
17
+ # torch.ops.load_library(so_files[0])
18
+
19
+
20
+ class ModifiedBesselK0(torch.autograd.Function):
21
+ @staticmethod
22
+ def forward(z, singularity):
23
+ if not z.is_complex():
24
+ out = (torch.special.modified_bessel_k0(z), None)
25
+ elif not z.requires_grad:
26
+ out = (
27
+ torch.ops.torch_bessel.modified_bessel_k0_complex_forward.default(z),
28
+ None,
29
+ )
30
+ else:
31
+ out = torch.ops.torch_bessel.modified_bessel_k0_complex_forward_backward.default(
32
+ z
33
+ )
34
+
35
+ if singularity is None:
36
+ return (*out, None)
37
+
38
+ mask = z != 0
39
+ return (out[0].where(mask, singularity), out[1], mask)
40
+
41
+ @staticmethod
42
+ def setup_context(ctx, inputs, outputs):
43
+ if ctx.needs_input_grad[1]:
44
+ raise NotImplementedError("Gradient w.r.t. singularity is not implemented")
45
+
46
+ if ctx.needs_input_grad[0]:
47
+ if outputs[1] is None:
48
+ ctx.save_for_backward(inputs[0], None, outputs[2])
49
+ else:
50
+ ctx.save_for_backward(None, outputs[1], outputs[2])
51
+
52
+ ctx.set_materialize_grads(False)
53
+
54
+ @staticmethod
55
+ def backward(ctx, grad, _, __):
56
+ if grad is None or not ctx.needs_input_grad[0]:
57
+ return (None, None)
58
+
59
+ x, deriv, mask = ctx.saved_tensors
60
+ if deriv is None:
61
+ out = -torch.special.modified_bessel_k1(x).mul_(grad)
62
+ else:
63
+ out = grad * deriv
64
+ if mask is not None:
65
+ out = out.where(mask, 0)
66
+ return (out, None)
67
+
68
+ @staticmethod
69
+ def vmap(info, in_dims, z, singularity):
70
+ if singularity is None and in_dims[1] is not None:
71
+ raise ValueError("in_dims[1] must be None if singularity is not provided.")
72
+
73
+ if in_dims[0] is not None:
74
+ z = z.movedim(in_dims[0], 0)
75
+
76
+ if in_dims[1] is not None:
77
+ singularity = singularity.movedim(in_dims[1], 0)
78
+
79
+ out = ModifiedBesselK0.apply(z, singularity)
80
+ out_dims = [0] * 3 if any(d is not None for d in in_dims) else [None] * 3
81
+ if out[1] is None:
82
+ out_dims[1] = None
83
+ if out[2] is None:
84
+ out_dims[2] = None
85
+
86
+ return (out, tuple(out_dims))
87
+
88
+
89
+ def modified_bessel_k0(
90
+ z: Tensor, singularity: Union[Number, Tensor, None] = None
91
+ ) -> Tensor:
92
+ return ModifiedBesselK0.apply(z, singularity)[0]
93
+
94
+
95
+ def modified_bessel_k1(z: Tensor) -> Tensor:
96
+ # non-differentiable for now
97
+ if not z.is_complex():
98
+ return torch.special.modified_bessel_k1(z)
99
+ return torch.ops.torch_bessel.modified_bessel_k1_complex_forward.default(z)
100
+
101
+
102
+ @torch.library.register_fake("torch_bessel::modified_bessel_k0_complex_forward")
103
+ def _(z):
104
+ return torch.empty_like(z)
105
+
106
+
107
+ @torch.library.register_fake(
108
+ "torch_bessel::modified_bessel_k0_complex_forward_backward"
109
+ )
110
+ def _(z):
111
+ return torch.empty_like(z), torch.empty_like(z)
112
+
113
+
114
+ @torch.library.register_fake("torch_bessel::modified_bessel_k1_complex_forward")
115
+ def _(z):
116
+ return torch.empty_like(z)
117
+
118
+
119
+ def modified_bessel_k0_backward(ctx, grad, _):
120
+ if ctx.needs_input_grad[0]:
121
+ return grad * ctx.saved_tensors[0]
122
+ return None
123
+
124
+
125
+ def modified_bessel_k0_setup_context(ctx, inputs, output):
126
+ if ctx.needs_input_grad[0]:
127
+ ctx.save_for_backward(output[-1])
128
+
129
+
130
+ torch.library.register_autograd(
131
+ "torch_bessel::modified_bessel_k0_complex_forward_backward",
132
+ modified_bessel_k0_backward,
133
+ setup_context=modified_bessel_k0_setup_context,
134
+ )
@@ -0,0 +1,50 @@
1
+ Metadata-Version: 2.4
2
+ Name: torch_bessel
3
+ Version: 0.0.9
4
+ Summary: PyTorch extension package for Bessel functions with arbitrary real order and complex inputs
5
+ Home-page: https://github.com/hchau630/torch-bessel
6
+ Author: Ho Yin Chau
7
+ Requires-Python: >= 3.10
8
+ Description-Content-Type: text/markdown
9
+ License-File: LICENSE
10
+ Requires-Dist: torch
11
+ Dynamic: author
12
+ Dynamic: description
13
+ Dynamic: description-content-type
14
+ Dynamic: home-page
15
+ Dynamic: license-file
16
+ Dynamic: requires-dist
17
+ Dynamic: requires-python
18
+ Dynamic: summary
19
+
20
+ # About
21
+ PyTorch extension package for modified Bessel functions of the second kind with complex inputs
22
+
23
+ # Install
24
+ Currently only supports Linux (with CUDA 12.4) or MacOS (Apple silicon, cpu only) with python >= 3.9, <= 3.12.
25
+ ```
26
+ pip install torch-bessel
27
+ ```
28
+
29
+ UPDATE (May 26, 2026): It seems the package might no longer be compatible with the latest pytorch version (>= 2.12.0), so try installing an earlier version of pytorch if you encounter issues.
30
+
31
+ UPDATE (May 28, 2026): On Apple silicon, importing this package with pytorch version >= 2.9.0 causes a crash (bus error: 10) when exiting a python program, though the program runs normally prior to that. So consider installing torch < 2.9.0 on Apple silicon.
32
+
33
+ # Example
34
+ ```
35
+ import torch_bessel
36
+
37
+ real, imag = torch.randn(2, 5, device="cuda")
38
+ z = torch.complex(real.abs(), imag) # correctness for inputs in the left-half complex plane is not gauranteed.
39
+ torch_bessel.ops.modified_bessel_k0(z)
40
+ ```
41
+
42
+ # Implemented functions
43
+ - `modified_bessel_k0`: Same as `torch.special.modified_bessel_k0`, but also handles backpropagation and complex inputs on cpu and cuda. Correctness is guaranteed on the right-half complex plane. On cuda, `torch.chalf` inputs are also supported, though the underlying cuda kernel just upcasts `chalf` to `cfloat` (note that this uses no extra GPU memory, as opposed to manually casting torch.chalf to torch.cfloat before calling `modified_bessel_k0` which doubles the GPU memory used). On the left-half complex plane, function output appears mostly correct, but with small numerical errors for certain inputs. On the negative real line, output is NaN.
44
+ - `modified_bessel_k1`: Same as `torch.special.modified_bessel_k1`, but also handles complex inputs on cpu and cuda. Backpropagation not implemented, but this can be easily manually implemented yourself by writing a torch.autograd.Function using the recurrence properties of bessel functions. Same caveats as `modified_bessel_k0` apply.
45
+
46
+ # WIP
47
+ - `modified_bessel_kv`: Analogue of `scipy.special.kv`.
48
+
49
+ # Benchmarks
50
+ Benchmarking performed with the `asv` package. Results can be viewed at https://hchau630.github.io/torch-bessel.
@@ -0,0 +1,16 @@
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ setup.py
5
+ benchmarks/__init__.py
6
+ benchmarks/benchmarks.py
7
+ tests/test_extension.py
8
+ torch_bessel/__init__.py
9
+ torch_bessel/ops.py
10
+ torch_bessel.egg-info/PKG-INFO
11
+ torch_bessel.egg-info/SOURCES.txt
12
+ torch_bessel.egg-info/dependency_links.txt
13
+ torch_bessel.egg-info/requires.txt
14
+ torch_bessel.egg-info/top_level.txt
15
+ torch_bessel/csrc/iterator.cpp
16
+ torch_bessel/csrc/torch_bessel.cpp
@@ -0,0 +1,2 @@
1
+ benchmarks
2
+ torch_bessel