torch-bessel 0.0.4__cp39-cp39-macosx_11_0_arm64.whl → 0.0.5__cp39-cp39-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmarks.py +2 -2
- torch_bessel/_C.cpython-39-darwin.so +0 -0
- torch_bessel/ops.py +20 -0
- {torch_bessel-0.0.4.dist-info → torch_bessel-0.0.5.dist-info}/METADATA +3 -3
- torch_bessel-0.0.5.dist-info/RECORD +10 -0
- {torch_bessel-0.0.4.dist-info → torch_bessel-0.0.5.dist-info}/WHEEL +1 -1
- torch_bessel-0.0.4.dist-info/RECORD +0 -10
- {torch_bessel-0.0.4.dist-info → torch_bessel-0.0.5.dist-info}/LICENSE +0 -0
- {torch_bessel-0.0.4.dist-info → torch_bessel-0.0.5.dist-info}/top_level.txt +0 -0
benchmarks/benchmarks.py
CHANGED
@@ -36,7 +36,6 @@ class ModifiedBesselK0ForwardCPU:
|
|
36
36
|
torch_bessel.ops.modified_bessel_k0(*self.args)
|
37
37
|
|
38
38
|
|
39
|
-
@skip_benchmark_if(not torch.cuda.is_available())
|
40
39
|
class ModifiedBesselK0ForwardCUDA:
|
41
40
|
params = (
|
42
41
|
[10_000, 100_000, 1_000_000],
|
@@ -50,6 +49,7 @@ class ModifiedBesselK0ForwardCUDA:
|
|
50
49
|
def setup(self, n, is_real, singularity, dtype, requires_grad):
|
51
50
|
self.args = _setup(n, is_real, singularity, dtype, requires_grad, device="cuda")
|
52
51
|
|
52
|
+
@skip_benchmark_if(not torch.cuda.is_available())
|
53
53
|
def time_modified_bessel_k0_forward_cuda(
|
54
54
|
self, n, is_real, singularity, dtype, requires_grad
|
55
55
|
):
|
@@ -77,7 +77,6 @@ class ModifiedBesselK0BackwardCPU:
|
|
77
77
|
self.out.backward()
|
78
78
|
|
79
79
|
|
80
|
-
@skip_benchmark_if(not torch.cuda.is_available())
|
81
80
|
class ModifiedBesselK0BackwardCUDA:
|
82
81
|
warmup_time = 0.0 # for some reason backward is called multiple times if not 0
|
83
82
|
number = 1 # Avoids calling backward multiple times
|
@@ -93,6 +92,7 @@ class ModifiedBesselK0BackwardCUDA:
|
|
93
92
|
args = _setup(n, is_real, singularity, dtype, requires_grad=True, device="cuda")
|
94
93
|
self.out = torch_bessel.ops.modified_bessel_k0(*args).norm()
|
95
94
|
|
95
|
+
@skip_benchmark_if(not torch.cuda.is_available())
|
96
96
|
def time_modified_bessel_k0_backward_cuda(self, n, is_real, singularity, dtype):
|
97
97
|
torch.cuda.synchronize()
|
98
98
|
self.out.backward()
|
Binary file
|
torch_bessel/ops.py
CHANGED
@@ -65,6 +65,26 @@ class ModifiedBesselK0(torch.autograd.Function):
|
|
65
65
|
out = out.where(mask, 0)
|
66
66
|
return (out, None)
|
67
67
|
|
68
|
+
@staticmethod
|
69
|
+
def vmap(info, in_dims, z, singularity):
|
70
|
+
if singularity is None and in_dims[1] is not None:
|
71
|
+
raise ValueError("in_dims[1] must be None if singularity is not provided.")
|
72
|
+
|
73
|
+
if in_dims[0] is not None:
|
74
|
+
z = z.movedim(in_dims[0], 0)
|
75
|
+
|
76
|
+
if in_dims[1] is not None:
|
77
|
+
singularity = singularity.movedim(in_dims[1], 0)
|
78
|
+
|
79
|
+
out = ModifiedBesselK0.apply(z, singularity)
|
80
|
+
out_dims = [0] * 3 if any(d is not None for d in in_dims) else [None] * 3
|
81
|
+
if out[1] is None:
|
82
|
+
out_dims[1] = None
|
83
|
+
if out[2] is None:
|
84
|
+
out_dims[2] = None
|
85
|
+
|
86
|
+
return (out, tuple(out_dims))
|
87
|
+
|
68
88
|
|
69
89
|
def modified_bessel_k0(
|
70
90
|
z: Tensor, singularity: Union[Number, Tensor, None] = None
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: torch_bessel
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.5
|
4
4
|
Summary: PyTorch extension package for Bessel functions with arbitrary real order and complex inputs
|
5
5
|
Home-page: https://github.com/hchau630/torch-bessel
|
6
6
|
Author: Ho Yin Chau
|
@@ -30,12 +30,12 @@ pip install torch-bessel
|
|
30
30
|
import torch_bessel
|
31
31
|
|
32
32
|
real, imag = torch.randn(2, 5, device="cuda")
|
33
|
-
z = torch.complex(real.abs(), imag) # inputs
|
33
|
+
z = torch.complex(real.abs(), imag) # correctness for inputs in the left-half complex plane is not gauranteed.
|
34
34
|
torch_bessel.ops.modified_bessel_k0(z)
|
35
35
|
```
|
36
36
|
|
37
37
|
# Implemented functions
|
38
|
-
- `modified_bessel_k0`: Same as `torch.special.modified_bessel_k0`, but also handles backpropagation and complex inputs
|
38
|
+
- `modified_bessel_k0`: Same as `torch.special.modified_bessel_k0`, but also handles backpropagation and complex inputs on cpu and cuda. Correctness is guaranteed on the right-half complex plane. On the left-half complex plane, function output appears mostly correct, but with small numerical errors for certain inputs. On the negative real line, output is NaN.
|
39
39
|
|
40
40
|
# WIP
|
41
41
|
- `modified_bessel_kv`: Analogue of `scipy.special.kv`.
|
@@ -0,0 +1,10 @@
|
|
1
|
+
benchmarks/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
+
benchmarks/benchmarks.py,sha256=LZDGAlKEWwo-VCenNTGn98bJzj9FSlQbFK5pPXRvDyI,3319
|
3
|
+
torch_bessel/_C.cpython-39-darwin.so,sha256=38QT3Q1sn9xVivIA2qYlv9pvGLhq8vLoKIiz_AhenOk,231368
|
4
|
+
torch_bessel/__init__.py,sha256=oohbWz8vxekl7kqDNSWqDH3ORabf9-Tug1KJryKw51A,230
|
5
|
+
torch_bessel/ops.py,sha256=P4qo1GN6XExuKfI061dsZMpAPuimn6gYt-PwRmR4_Fk,3584
|
6
|
+
torch_bessel-0.0.5.dist-info/LICENSE,sha256=do0DI6wu4mF3VXnEXXPYZqVEatoRSSamgz9t80wU7_o,1068
|
7
|
+
torch_bessel-0.0.5.dist-info/METADATA,sha256=LX6X8j53vJCYKaXLfjlk3X_RwRYHlkoQVYC7CSJJ1vc,1561
|
8
|
+
torch_bessel-0.0.5.dist-info/WHEEL,sha256=lcS7ertGdTcwg6KPpE3fBtGreaKMXtC8sl4-ZeZxMCs,107
|
9
|
+
torch_bessel-0.0.5.dist-info/top_level.txt,sha256=xmyVjWSQ91kX-v8KCzl6wDwfAmbdZNWsP2EH9b9BccQ,24
|
10
|
+
torch_bessel-0.0.5.dist-info/RECORD,,
|
@@ -1,10 +0,0 @@
|
|
1
|
-
benchmarks/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
-
benchmarks/benchmarks.py,sha256=qRwHDxsrqRQjbExAwtWcrUtggUCL-UGJQQHnqjIoUIM,3311
|
3
|
-
torch_bessel/_C.cpython-39-darwin.so,sha256=Nv-frnFJ35e1qBavsYzKi7hZciGLas_AloaOtrBsXiU,231368
|
4
|
-
torch_bessel/__init__.py,sha256=oohbWz8vxekl7kqDNSWqDH3ORabf9-Tug1KJryKw51A,230
|
5
|
-
torch_bessel/ops.py,sha256=htS41Mnz2eTcBWNCs33PSFee__S4KRfJ5zVopM2VNd8,2908
|
6
|
-
torch_bessel-0.0.4.dist-info/LICENSE,sha256=do0DI6wu4mF3VXnEXXPYZqVEatoRSSamgz9t80wU7_o,1068
|
7
|
-
torch_bessel-0.0.4.dist-info/METADATA,sha256=t73hRVqeL4sosAMOVmAgnJjUfOCsahFtE_jL1yHBy6Q,1347
|
8
|
-
torch_bessel-0.0.4.dist-info/WHEEL,sha256=md3JO_ifs5j508p3TDNMgtQVtnQblpGEt_Wo4W56l8Y,107
|
9
|
-
torch_bessel-0.0.4.dist-info/top_level.txt,sha256=xmyVjWSQ91kX-v8KCzl6wDwfAmbdZNWsP2EH9b9BccQ,24
|
10
|
-
torch_bessel-0.0.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|