torch-bessel 0.0.6__cp312-cp312-macosx_11_0_arm64.whl → 0.0.7__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmarks.py +44 -0
- torch_bessel/_C.cpython-312-darwin.so +0 -0
- torch_bessel/ops.py +13 -1
- {torch_bessel-0.0.6.dist-info → torch_bessel-0.0.7.dist-info}/METADATA +3 -2
- torch_bessel-0.0.7.dist-info/RECORD +10 -0
- torch_bessel-0.0.6.dist-info/RECORD +0 -10
- {torch_bessel-0.0.6.dist-info → torch_bessel-0.0.7.dist-info}/WHEEL +0 -0
- {torch_bessel-0.0.6.dist-info → torch_bessel-0.0.7.dist-info}/licenses/LICENSE +0 -0
- {torch_bessel-0.0.6.dist-info → torch_bessel-0.0.7.dist-info}/top_level.txt +0 -0
benchmarks/benchmarks.py
CHANGED
@@ -17,6 +17,17 @@ def _setup(
|
|
17
17
|
return args
|
18
18
|
|
19
19
|
|
20
|
+
def _setup_k1(n, is_real, dtype=torch.float, device="cpu"):
|
21
|
+
kwargs = {"dtype": dtype, "device": device}
|
22
|
+
real = torch.randn(n, **kwargs).abs()
|
23
|
+
if is_real:
|
24
|
+
args = (real,)
|
25
|
+
else:
|
26
|
+
imag = torch.randn(n, **kwargs)
|
27
|
+
args = (torch.complex(real, imag),)
|
28
|
+
return args
|
29
|
+
|
30
|
+
|
20
31
|
class ModifiedBesselK0ForwardCPU:
|
21
32
|
params = (
|
22
33
|
[10_000, 100_000, 1_000_000],
|
@@ -97,3 +108,36 @@ class ModifiedBesselK0BackwardCUDA:
|
|
97
108
|
torch.cuda.synchronize()
|
98
109
|
self.out.backward()
|
99
110
|
torch.cuda.synchronize()
|
111
|
+
|
112
|
+
|
113
|
+
class ModifiedBesselK1ForwardCPU:
|
114
|
+
params = (
|
115
|
+
[10_000, 100_000, 1_000_000],
|
116
|
+
[False, True],
|
117
|
+
[torch.float32, torch.float64],
|
118
|
+
)
|
119
|
+
param_names = ["n", "is_real", "dtype"]
|
120
|
+
|
121
|
+
def setup(self, n, is_real, dtype):
|
122
|
+
self.args = _setup_k1(n, is_real, dtype)
|
123
|
+
|
124
|
+
def time_modified_bessel_k1_forward_cpu(self, n, is_real, dtype):
|
125
|
+
torch_bessel.ops.modified_bessel_k1(*self.args)
|
126
|
+
|
127
|
+
|
128
|
+
class ModifiedBesselK1ForwardCUDA:
|
129
|
+
params = (
|
130
|
+
[10_000, 100_000, 1_000_000],
|
131
|
+
[False, True],
|
132
|
+
[torch.float32, torch.float64],
|
133
|
+
)
|
134
|
+
param_names = ["n", "is_real", "dtype"]
|
135
|
+
|
136
|
+
def setup(self, n, is_real, dtype):
|
137
|
+
self.args = _setup_k1(n, is_real, dtype, device="cuda")
|
138
|
+
|
139
|
+
@skip_benchmark_if(not torch.cuda.is_available())
|
140
|
+
def time_modified_bessel_k1_forward_cuda(self, n, is_real, dtype):
|
141
|
+
torch.cuda.synchronize()
|
142
|
+
torch_bessel.ops.modified_bessel_k1(*self.args)
|
143
|
+
torch.cuda.synchronize()
|
Binary file
|
torch_bessel/ops.py
CHANGED
@@ -5,7 +5,7 @@ from typing import Union
|
|
5
5
|
import torch
|
6
6
|
from torch import Tensor
|
7
7
|
|
8
|
-
__all__ = ["modified_bessel_k0"]
|
8
|
+
__all__ = ["modified_bessel_k0", "modified_bessel_k1"]
|
9
9
|
|
10
10
|
# load C extension before calling torch.library API, see
|
11
11
|
# https://pytorch.org/tutorials/advanced/cpp_custom_ops.html
|
@@ -92,6 +92,13 @@ def modified_bessel_k0(
|
|
92
92
|
return ModifiedBesselK0.apply(z, singularity)[0]
|
93
93
|
|
94
94
|
|
95
|
+
def modified_bessel_k1(z: Tensor) -> Tensor:
|
96
|
+
# non-differentiable for now
|
97
|
+
if not z.is_complex():
|
98
|
+
return torch.special.modified_bessel_k1(z)
|
99
|
+
return torch.ops.torch_bessel.modified_bessel_k1_complex_forward.default(z)
|
100
|
+
|
101
|
+
|
95
102
|
@torch.library.register_fake("torch_bessel::modified_bessel_k0_complex_forward")
|
96
103
|
def _(z):
|
97
104
|
return torch.empty_like(z)
|
@@ -104,6 +111,11 @@ def _(z):
|
|
104
111
|
return torch.empty_like(z), torch.empty_like(z)
|
105
112
|
|
106
113
|
|
114
|
+
@torch.library.register_fake("torch_bessel::modified_bessel_k1_complex_forward")
|
115
|
+
def _(z):
|
116
|
+
return torch.empty_like(z)
|
117
|
+
|
118
|
+
|
107
119
|
def modified_bessel_k0_backward(ctx, grad, _):
|
108
120
|
if ctx.needs_input_grad[0]:
|
109
121
|
return grad * ctx.saved_tensors[0]
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: torch_bessel
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.7
|
4
4
|
Summary: PyTorch extension package for Bessel functions with arbitrary real order and complex inputs
|
5
5
|
Home-page: https://github.com/hchau630/torch-bessel
|
6
6
|
Author: Ho Yin Chau
|
@@ -18,7 +18,7 @@ Dynamic: requires-python
|
|
18
18
|
Dynamic: summary
|
19
19
|
|
20
20
|
# About
|
21
|
-
PyTorch extension package for Bessel functions
|
21
|
+
PyTorch extension package for modified Bessel functions of the second kind with complex inputs
|
22
22
|
|
23
23
|
# Install
|
24
24
|
Currently only supports Linux (with CUDA 12.4) or MacOS (Apple silicon, cpu only) with python >= 3.9, <= 3.12.
|
@@ -37,6 +37,7 @@ torch_bessel.ops.modified_bessel_k0(z)
|
|
37
37
|
|
38
38
|
# Implemented functions
|
39
39
|
- `modified_bessel_k0`: Same as `torch.special.modified_bessel_k0`, but also handles backpropagation and complex inputs on cpu and cuda. Correctness is guaranteed on the right-half complex plane for double types, and almost guaranteed for float types, though it appears there are a very small handful of inputs which result in NaNs which needs to be fixed. On cuda, `torch.chalf` inputs are also supported, though the underlying cuda kernel just upcasts `chalf` to `cfloat` (note that this uses no extra GPU memory, as opposed to manually casting torch.chalf to torch.cfloat before calling `modified_bessel_k0` which doubles the GPU memory used). On the left-half complex plane, function output appears mostly correct, but with small numerical errors for certain inputs. On the negative real line, output is NaN.
|
40
|
+
- `modified_bessel_k1`: Same as `torch.special.modified_bessel_k1`, but also handles complex inputs on cpu and cuda. Backpropagation not implemented, but this can be easily manually implemented yourself by writing a torch.autograd.Function using the recurrence properties of bessel functions. Same caveats as `modified_bessel_k0` apply.
|
40
41
|
|
41
42
|
# WIP
|
42
43
|
- `modified_bessel_kv`: Analogue of `scipy.special.kv`.
|
@@ -0,0 +1,10 @@
|
|
1
|
+
benchmarks/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
+
benchmarks/benchmarks.py,sha256=_Bi3QFknID_n99lrqMTW36YFhCEIxxQzu5hmDd1inkY,4595
|
3
|
+
torch_bessel/_C.cpython-312-darwin.so,sha256=0SvorSYRDqUrLH5GZMgWQ50F0-o6IggJQxjjcAtPQy8,233368
|
4
|
+
torch_bessel/__init__.py,sha256=oohbWz8vxekl7kqDNSWqDH3ORabf9-Tug1KJryKw51A,230
|
5
|
+
torch_bessel/ops.py,sha256=U-o-3brRTM4UX8cJqZInjsf2zLSaGzGX2yUZPYpq19s,3968
|
6
|
+
torch_bessel-0.0.7.dist-info/licenses/LICENSE,sha256=do0DI6wu4mF3VXnEXXPYZqVEatoRSSamgz9t80wU7_o,1068
|
7
|
+
torch_bessel-0.0.7.dist-info/METADATA,sha256=KG8_9O1eMuktubsS24MLnFBAk9xGKmsErvH24BuxAYQ,2374
|
8
|
+
torch_bessel-0.0.7.dist-info/WHEEL,sha256=CltXN3lQvXbHxKDtiDwW0RNzF8s2WyBuPbOAX_ZeQlA,109
|
9
|
+
torch_bessel-0.0.7.dist-info/top_level.txt,sha256=xmyVjWSQ91kX-v8KCzl6wDwfAmbdZNWsP2EH9b9BccQ,24
|
10
|
+
torch_bessel-0.0.7.dist-info/RECORD,,
|
@@ -1,10 +0,0 @@
|
|
1
|
-
benchmarks/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
-
benchmarks/benchmarks.py,sha256=LZDGAlKEWwo-VCenNTGn98bJzj9FSlQbFK5pPXRvDyI,3319
|
3
|
-
torch_bessel/_C.cpython-312-darwin.so,sha256=YZJPORMvIiAa-JColLS3KoyD0HI60ow1aRyuzjlsif8,232280
|
4
|
-
torch_bessel/__init__.py,sha256=oohbWz8vxekl7kqDNSWqDH3ORabf9-Tug1KJryKw51A,230
|
5
|
-
torch_bessel/ops.py,sha256=P4qo1GN6XExuKfI061dsZMpAPuimn6gYt-PwRmR4_Fk,3584
|
6
|
-
torch_bessel-0.0.6.dist-info/licenses/LICENSE,sha256=do0DI6wu4mF3VXnEXXPYZqVEatoRSSamgz9t80wU7_o,1068
|
7
|
-
torch_bessel-0.0.6.dist-info/METADATA,sha256=uViUBBZq-V6HuSKg2N4NMiWJYtJwUxmLxkB0dtTVzS4,2034
|
8
|
-
torch_bessel-0.0.6.dist-info/WHEEL,sha256=CltXN3lQvXbHxKDtiDwW0RNzF8s2WyBuPbOAX_ZeQlA,109
|
9
|
-
torch_bessel-0.0.6.dist-info/top_level.txt,sha256=xmyVjWSQ91kX-v8KCzl6wDwfAmbdZNWsP2EH9b9BccQ,24
|
10
|
-
torch_bessel-0.0.6.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|