torch-bessel 0.0.3__cp310-cp310-macosx_11_0_arm64.whl → 0.0.4__cp310-cp310-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/__init__.py +0 -0
- benchmarks/benchmarks.py +99 -0
- torch_bessel/_C.cpython-310-darwin.so +0 -0
- torch_bessel/ops.py +46 -19
- {torch_bessel-0.0.3.dist-info → torch_bessel-0.0.4.dist-info}/METADATA +4 -1
- torch_bessel-0.0.4.dist-info/RECORD +10 -0
- {torch_bessel-0.0.3.dist-info → torch_bessel-0.0.4.dist-info}/top_level.txt +1 -0
- torch_bessel-0.0.3.dist-info/RECORD +0 -8
- {torch_bessel-0.0.3.dist-info → torch_bessel-0.0.4.dist-info}/LICENSE +0 -0
- {torch_bessel-0.0.3.dist-info → torch_bessel-0.0.4.dist-info}/WHEEL +0 -0
benchmarks/__init__.py
ADDED
File without changes
|
benchmarks/benchmarks.py
ADDED
@@ -0,0 +1,99 @@
|
|
1
|
+
import torch
|
2
|
+
from asv_runner.benchmarks.mark import skip_benchmark_if
|
3
|
+
|
4
|
+
import torch_bessel
|
5
|
+
|
6
|
+
|
7
|
+
def _setup(
|
8
|
+
n, is_real, singularity, dtype=torch.float, requires_grad=False, device="cpu"
|
9
|
+
):
|
10
|
+
kwargs = {"dtype": dtype, "requires_grad": requires_grad, "device": device}
|
11
|
+
real = torch.randn(n, **kwargs).abs()
|
12
|
+
if is_real:
|
13
|
+
args = (real, singularity)
|
14
|
+
else:
|
15
|
+
imag = torch.randn(n, **kwargs)
|
16
|
+
args = (torch.complex(real, imag), singularity)
|
17
|
+
return args
|
18
|
+
|
19
|
+
|
20
|
+
class ModifiedBesselK0ForwardCPU:
|
21
|
+
params = (
|
22
|
+
[10_000, 100_000, 1_000_000],
|
23
|
+
[False, True],
|
24
|
+
[None, 0.0],
|
25
|
+
[torch.float32, torch.float64],
|
26
|
+
[False, True],
|
27
|
+
)
|
28
|
+
param_names = ["n", "is_real", "singularity", "dtype", "requires_grad"]
|
29
|
+
|
30
|
+
def setup(self, n, is_real, singularity, dtype, requires_grad):
|
31
|
+
self.args = _setup(n, is_real, singularity, dtype, requires_grad)
|
32
|
+
|
33
|
+
def time_modified_bessel_k0_forward_cpu(
|
34
|
+
self, n, is_real, singularity, dtype, requires_grad
|
35
|
+
):
|
36
|
+
torch_bessel.ops.modified_bessel_k0(*self.args)
|
37
|
+
|
38
|
+
|
39
|
+
@skip_benchmark_if(not torch.cuda.is_available())
|
40
|
+
class ModifiedBesselK0ForwardCUDA:
|
41
|
+
params = (
|
42
|
+
[10_000, 100_000, 1_000_000],
|
43
|
+
[False, True],
|
44
|
+
[None, 0.0],
|
45
|
+
[torch.float32, torch.float64],
|
46
|
+
[False, True],
|
47
|
+
)
|
48
|
+
param_names = ["n", "is_real", "singularity", "dtype", "requires_grad"]
|
49
|
+
|
50
|
+
def setup(self, n, is_real, singularity, dtype, requires_grad):
|
51
|
+
self.args = _setup(n, is_real, singularity, dtype, requires_grad, device="cuda")
|
52
|
+
|
53
|
+
def time_modified_bessel_k0_forward_cuda(
|
54
|
+
self, n, is_real, singularity, dtype, requires_grad
|
55
|
+
):
|
56
|
+
torch.cuda.synchronize()
|
57
|
+
torch_bessel.ops.modified_bessel_k0(*self.args)
|
58
|
+
torch.cuda.synchronize()
|
59
|
+
|
60
|
+
|
61
|
+
class ModifiedBesselK0BackwardCPU:
|
62
|
+
warmup_time = 0.0 # for some reason backward is called multiple times if not 0
|
63
|
+
number = 1 # Avoids calling backward multiple times
|
64
|
+
params = (
|
65
|
+
[10_000, 100_000, 1_000_000],
|
66
|
+
[False, True],
|
67
|
+
[None, 0.0],
|
68
|
+
[torch.float32, torch.float64],
|
69
|
+
)
|
70
|
+
param_names = ["n", "is_real", "singularity", "dtype"]
|
71
|
+
|
72
|
+
def setup(self, n, is_real, singularity, dtype):
|
73
|
+
args = _setup(n, is_real, singularity, dtype, requires_grad=True)
|
74
|
+
self.out = torch_bessel.ops.modified_bessel_k0(*args).norm()
|
75
|
+
|
76
|
+
def time_modified_bessel_k0_backward_cpu(self, n, is_real, singularity, dtype):
|
77
|
+
self.out.backward()
|
78
|
+
|
79
|
+
|
80
|
+
@skip_benchmark_if(not torch.cuda.is_available())
|
81
|
+
class ModifiedBesselK0BackwardCUDA:
|
82
|
+
warmup_time = 0.0 # for some reason backward is called multiple times if not 0
|
83
|
+
number = 1 # Avoids calling backward multiple times
|
84
|
+
params = (
|
85
|
+
[10_000, 100_000, 1_000_000],
|
86
|
+
[False, True],
|
87
|
+
[None, 0.0],
|
88
|
+
[torch.float32, torch.float64],
|
89
|
+
)
|
90
|
+
param_names = ["n", "is_real", "singularity", "dtype"]
|
91
|
+
|
92
|
+
def setup(self, n, is_real, singularity, dtype):
|
93
|
+
args = _setup(n, is_real, singularity, dtype, requires_grad=True, device="cuda")
|
94
|
+
self.out = torch_bessel.ops.modified_bessel_k0(*args).norm()
|
95
|
+
|
96
|
+
def time_modified_bessel_k0_backward_cuda(self, n, is_real, singularity, dtype):
|
97
|
+
torch.cuda.synchronize()
|
98
|
+
self.out.backward()
|
99
|
+
torch.cuda.synchronize()
|
Binary file
|
torch_bessel/ops.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1
1
|
from pathlib import Path
|
2
|
+
from numbers import Number
|
3
|
+
from typing import Union
|
2
4
|
|
3
5
|
import torch
|
4
6
|
from torch import Tensor
|
@@ -17,32 +19,57 @@ torch.ops.load_library(so_files[0])
|
|
17
19
|
|
18
20
|
class ModifiedBesselK0(torch.autograd.Function):
|
19
21
|
@staticmethod
|
20
|
-
def forward(
|
21
|
-
|
22
|
+
def forward(z, singularity):
|
23
|
+
if not z.is_complex():
|
24
|
+
out = (torch.special.modified_bessel_k0(z), None)
|
25
|
+
elif not z.requires_grad:
|
26
|
+
out = (
|
27
|
+
torch.ops.torch_bessel.modified_bessel_k0_complex_forward.default(z),
|
28
|
+
None,
|
29
|
+
)
|
30
|
+
else:
|
31
|
+
out = torch.ops.torch_bessel.modified_bessel_k0_complex_forward_backward.default(
|
32
|
+
z
|
33
|
+
)
|
34
|
+
|
35
|
+
if singularity is None:
|
36
|
+
return (*out, None)
|
37
|
+
|
38
|
+
mask = z != 0
|
39
|
+
return (out[0].where(mask, singularity), out[1], mask)
|
22
40
|
|
23
41
|
@staticmethod
|
24
|
-
def setup_context(ctx, inputs,
|
42
|
+
def setup_context(ctx, inputs, outputs):
|
43
|
+
if ctx.needs_input_grad[1]:
|
44
|
+
raise NotImplementedError("Gradient w.r.t. singularity is not implemented")
|
45
|
+
|
25
46
|
if ctx.needs_input_grad[0]:
|
26
|
-
|
47
|
+
if outputs[1] is None:
|
48
|
+
ctx.save_for_backward(inputs[0], None, outputs[2])
|
49
|
+
else:
|
50
|
+
ctx.save_for_backward(None, outputs[1], outputs[2])
|
51
|
+
|
27
52
|
ctx.set_materialize_grads(False)
|
28
53
|
|
29
54
|
@staticmethod
|
30
|
-
def backward(ctx, grad):
|
55
|
+
def backward(ctx, grad, _, __):
|
31
56
|
if grad is None or not ctx.needs_input_grad[0]:
|
32
|
-
return None
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
57
|
+
return (None, None)
|
58
|
+
|
59
|
+
x, deriv, mask = ctx.saved_tensors
|
60
|
+
if deriv is None:
|
61
|
+
out = -torch.special.modified_bessel_k1(x).mul_(grad)
|
62
|
+
else:
|
63
|
+
out = grad * deriv
|
64
|
+
if mask is not None:
|
65
|
+
out = out.where(mask, 0)
|
66
|
+
return (out, None)
|
67
|
+
|
68
|
+
|
69
|
+
def modified_bessel_k0(
|
70
|
+
z: Tensor, singularity: Union[Number, Tensor, None] = None
|
71
|
+
) -> Tensor:
|
72
|
+
return ModifiedBesselK0.apply(z, singularity)[0]
|
46
73
|
|
47
74
|
|
48
75
|
@torch.library.register_fake("torch_bessel::modified_bessel_k0_complex_forward")
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: torch_bessel
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.4
|
4
4
|
Summary: PyTorch extension package for Bessel functions with arbitrary real order and complex inputs
|
5
5
|
Home-page: https://github.com/hchau630/torch-bessel
|
6
6
|
Author: Ho Yin Chau
|
@@ -39,3 +39,6 @@ torch_bessel.ops.modified_bessel_k0(z)
|
|
39
39
|
|
40
40
|
# WIP
|
41
41
|
- `modified_bessel_kv`: Analogue of `scipy.special.kv`.
|
42
|
+
|
43
|
+
# Benchmarks
|
44
|
+
Benchmarking performed with the `asv` package. Results can be viewed at https://hchau630.github.io/torch-bessel.
|
@@ -0,0 +1,10 @@
|
|
1
|
+
benchmarks/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
+
benchmarks/benchmarks.py,sha256=qRwHDxsrqRQjbExAwtWcrUtggUCL-UGJQQHnqjIoUIM,3311
|
3
|
+
torch_bessel/_C.cpython-310-darwin.so,sha256=ACQdYDXJFXSFdRL8QVsC293diLaUpKSvzLTR6K5bO4c,231080
|
4
|
+
torch_bessel/__init__.py,sha256=oohbWz8vxekl7kqDNSWqDH3ORabf9-Tug1KJryKw51A,230
|
5
|
+
torch_bessel/ops.py,sha256=htS41Mnz2eTcBWNCs33PSFee__S4KRfJ5zVopM2VNd8,2908
|
6
|
+
torch_bessel-0.0.4.dist-info/LICENSE,sha256=do0DI6wu4mF3VXnEXXPYZqVEatoRSSamgz9t80wU7_o,1068
|
7
|
+
torch_bessel-0.0.4.dist-info/METADATA,sha256=t73hRVqeL4sosAMOVmAgnJjUfOCsahFtE_jL1yHBy6Q,1347
|
8
|
+
torch_bessel-0.0.4.dist-info/WHEEL,sha256=ezfKMaDztqf77C8lvQ0NCnZxkTaOaKLprqJ8q932MhU,109
|
9
|
+
torch_bessel-0.0.4.dist-info/top_level.txt,sha256=xmyVjWSQ91kX-v8KCzl6wDwfAmbdZNWsP2EH9b9BccQ,24
|
10
|
+
torch_bessel-0.0.4.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
torch_bessel/_C.cpython-310-darwin.so,sha256=4CEZxUFE2_LvYHrqV1bTeXn0P9LieU1eSAnRBnP6l10,231080
|
2
|
-
torch_bessel/__init__.py,sha256=oohbWz8vxekl7kqDNSWqDH3ORabf9-Tug1KJryKw51A,230
|
3
|
-
torch_bessel/ops.py,sha256=Q9BrLxi15MS53xSt_S9dyE3g8_8_GCFYhfAztIor8Fw,2043
|
4
|
-
torch_bessel-0.0.3.dist-info/LICENSE,sha256=do0DI6wu4mF3VXnEXXPYZqVEatoRSSamgz9t80wU7_o,1068
|
5
|
-
torch_bessel-0.0.3.dist-info/METADATA,sha256=7hH6thd8bu4oaKYfNq5v6V4JkUe6g6_C9Irl618y3-8,1220
|
6
|
-
torch_bessel-0.0.3.dist-info/WHEEL,sha256=ezfKMaDztqf77C8lvQ0NCnZxkTaOaKLprqJ8q932MhU,109
|
7
|
-
torch_bessel-0.0.3.dist-info/top_level.txt,sha256=cbDIjTj71LuAlVyyYyDt8fOAeLaVeX3Vums5F2FBa-4,13
|
8
|
-
torch_bessel-0.0.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|