torch-bessel 0.0.3__cp312-cp312-macosx_11_0_arm64.whl → 0.0.5__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
benchmarks/__init__.py ADDED
File without changes
@@ -0,0 +1,99 @@
1
+ import torch
2
+ from asv_runner.benchmarks.mark import skip_benchmark_if
3
+
4
+ import torch_bessel
5
+
6
+
7
+ def _setup(
8
+ n, is_real, singularity, dtype=torch.float, requires_grad=False, device="cpu"
9
+ ):
10
+ kwargs = {"dtype": dtype, "requires_grad": requires_grad, "device": device}
11
+ real = torch.randn(n, **kwargs).abs()
12
+ if is_real:
13
+ args = (real, singularity)
14
+ else:
15
+ imag = torch.randn(n, **kwargs)
16
+ args = (torch.complex(real, imag), singularity)
17
+ return args
18
+
19
+
20
+ class ModifiedBesselK0ForwardCPU:
21
+ params = (
22
+ [10_000, 100_000, 1_000_000],
23
+ [False, True],
24
+ [None, 0.0],
25
+ [torch.float32, torch.float64],
26
+ [False, True],
27
+ )
28
+ param_names = ["n", "is_real", "singularity", "dtype", "requires_grad"]
29
+
30
+ def setup(self, n, is_real, singularity, dtype, requires_grad):
31
+ self.args = _setup(n, is_real, singularity, dtype, requires_grad)
32
+
33
+ def time_modified_bessel_k0_forward_cpu(
34
+ self, n, is_real, singularity, dtype, requires_grad
35
+ ):
36
+ torch_bessel.ops.modified_bessel_k0(*self.args)
37
+
38
+
39
+ class ModifiedBesselK0ForwardCUDA:
40
+ params = (
41
+ [10_000, 100_000, 1_000_000],
42
+ [False, True],
43
+ [None, 0.0],
44
+ [torch.float32, torch.float64],
45
+ [False, True],
46
+ )
47
+ param_names = ["n", "is_real", "singularity", "dtype", "requires_grad"]
48
+
49
+ def setup(self, n, is_real, singularity, dtype, requires_grad):
50
+ self.args = _setup(n, is_real, singularity, dtype, requires_grad, device="cuda")
51
+
52
+ @skip_benchmark_if(not torch.cuda.is_available())
53
+ def time_modified_bessel_k0_forward_cuda(
54
+ self, n, is_real, singularity, dtype, requires_grad
55
+ ):
56
+ torch.cuda.synchronize()
57
+ torch_bessel.ops.modified_bessel_k0(*self.args)
58
+ torch.cuda.synchronize()
59
+
60
+
61
+ class ModifiedBesselK0BackwardCPU:
62
+ warmup_time = 0.0 # for some reason backward is called multiple times if not 0
63
+ number = 1 # Avoids calling backward multiple times
64
+ params = (
65
+ [10_000, 100_000, 1_000_000],
66
+ [False, True],
67
+ [None, 0.0],
68
+ [torch.float32, torch.float64],
69
+ )
70
+ param_names = ["n", "is_real", "singularity", "dtype"]
71
+
72
+ def setup(self, n, is_real, singularity, dtype):
73
+ args = _setup(n, is_real, singularity, dtype, requires_grad=True)
74
+ self.out = torch_bessel.ops.modified_bessel_k0(*args).norm()
75
+
76
+ def time_modified_bessel_k0_backward_cpu(self, n, is_real, singularity, dtype):
77
+ self.out.backward()
78
+
79
+
80
+ class ModifiedBesselK0BackwardCUDA:
81
+ warmup_time = 0.0 # for some reason backward is called multiple times if not 0
82
+ number = 1 # Avoids calling backward multiple times
83
+ params = (
84
+ [10_000, 100_000, 1_000_000],
85
+ [False, True],
86
+ [None, 0.0],
87
+ [torch.float32, torch.float64],
88
+ )
89
+ param_names = ["n", "is_real", "singularity", "dtype"]
90
+
91
+ def setup(self, n, is_real, singularity, dtype):
92
+ args = _setup(n, is_real, singularity, dtype, requires_grad=True, device="cuda")
93
+ self.out = torch_bessel.ops.modified_bessel_k0(*args).norm()
94
+
95
+ @skip_benchmark_if(not torch.cuda.is_available())
96
+ def time_modified_bessel_k0_backward_cuda(self, n, is_real, singularity, dtype):
97
+ torch.cuda.synchronize()
98
+ self.out.backward()
99
+ torch.cuda.synchronize()
Binary file
torch_bessel/ops.py CHANGED
@@ -1,4 +1,6 @@
1
1
  from pathlib import Path
2
+ from numbers import Number
3
+ from typing import Union
2
4
 
3
5
  import torch
4
6
  from torch import Tensor
@@ -17,32 +19,77 @@ torch.ops.load_library(so_files[0])
17
19
 
18
20
  class ModifiedBesselK0(torch.autograd.Function):
19
21
  @staticmethod
20
- def forward(x):
21
- return torch.special.modified_bessel_k0(x)
22
+ def forward(z, singularity):
23
+ if not z.is_complex():
24
+ out = (torch.special.modified_bessel_k0(z), None)
25
+ elif not z.requires_grad:
26
+ out = (
27
+ torch.ops.torch_bessel.modified_bessel_k0_complex_forward.default(z),
28
+ None,
29
+ )
30
+ else:
31
+ out = torch.ops.torch_bessel.modified_bessel_k0_complex_forward_backward.default(
32
+ z
33
+ )
34
+
35
+ if singularity is None:
36
+ return (*out, None)
37
+
38
+ mask = z != 0
39
+ return (out[0].where(mask, singularity), out[1], mask)
22
40
 
23
41
  @staticmethod
24
- def setup_context(ctx, inputs, _):
42
+ def setup_context(ctx, inputs, outputs):
43
+ if ctx.needs_input_grad[1]:
44
+ raise NotImplementedError("Gradient w.r.t. singularity is not implemented")
45
+
25
46
  if ctx.needs_input_grad[0]:
26
- ctx.save_for_backward(*inputs)
47
+ if outputs[1] is None:
48
+ ctx.save_for_backward(inputs[0], None, outputs[2])
49
+ else:
50
+ ctx.save_for_backward(None, outputs[1], outputs[2])
51
+
27
52
  ctx.set_materialize_grads(False)
28
53
 
29
54
  @staticmethod
30
- def backward(ctx, grad):
55
+ def backward(ctx, grad, _, __):
31
56
  if grad is None or not ctx.needs_input_grad[0]:
32
- return None
57
+ return (None, None)
58
+
59
+ x, deriv, mask = ctx.saved_tensors
60
+ if deriv is None:
61
+ out = -torch.special.modified_bessel_k1(x).mul_(grad)
62
+ else:
63
+ out = grad * deriv
64
+ if mask is not None:
65
+ out = out.where(mask, 0)
66
+ return (out, None)
67
+
68
+ @staticmethod
69
+ def vmap(info, in_dims, z, singularity):
70
+ if singularity is None and in_dims[1] is not None:
71
+ raise ValueError("in_dims[1] must be None if singularity is not provided.")
72
+
73
+ if in_dims[0] is not None:
74
+ z = z.movedim(in_dims[0], 0)
75
+
76
+ if in_dims[1] is not None:
77
+ singularity = singularity.movedim(in_dims[1], 0)
78
+
79
+ out = ModifiedBesselK0.apply(z, singularity)
80
+ out_dims = [0] * 3 if any(d is not None for d in in_dims) else [None] * 3
81
+ if out[1] is None:
82
+ out_dims[1] = None
83
+ if out[2] is None:
84
+ out_dims[2] = None
33
85
 
34
- (x,) = ctx.saved_tensors
35
- return -torch.special.modified_bessel_k1(x).mul_(grad)
86
+ return (out, tuple(out_dims))
36
87
 
37
88
 
38
- def modified_bessel_k0(z: Tensor) -> Tensor:
39
- if not z.is_complex():
40
- return ModifiedBesselK0.apply(z)
41
- if not z.requires_grad:
42
- return torch.ops.torch_bessel.modified_bessel_k0_complex_forward.default(z)
43
- return torch.ops.torch_bessel.modified_bessel_k0_complex_forward_backward.default(
44
- z
45
- )[0]
89
+ def modified_bessel_k0(
90
+ z: Tensor, singularity: Union[Number, Tensor, None] = None
91
+ ) -> Tensor:
92
+ return ModifiedBesselK0.apply(z, singularity)[0]
46
93
 
47
94
 
48
95
  @torch.library.register_fake("torch_bessel::modified_bessel_k0_complex_forward")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: torch_bessel
3
- Version: 0.0.3
3
+ Version: 0.0.5
4
4
  Summary: PyTorch extension package for Bessel functions with arbitrary real order and complex inputs
5
5
  Home-page: https://github.com/hchau630/torch-bessel
6
6
  Author: Ho Yin Chau
@@ -30,12 +30,15 @@ pip install torch-bessel
30
30
  import torch_bessel
31
31
 
32
32
  real, imag = torch.randn(2, 5, device="cuda")
33
- z = torch.complex(real.abs(), imag) # inputs on the left-half complex plane are set to NaNs.
33
+ z = torch.complex(real.abs(), imag) # correctness for inputs in the left-half complex plane is not gauranteed.
34
34
  torch_bessel.ops.modified_bessel_k0(z)
35
35
  ```
36
36
 
37
37
  # Implemented functions
38
- - `modified_bessel_k0`: Same as `torch.special.modified_bessel_k0`, but also handles backpropagation and complex inputs with $\mathrm{Re}(z) \geq 0$ on cpu and cuda.
38
+ - `modified_bessel_k0`: Same as `torch.special.modified_bessel_k0`, but also handles backpropagation and complex inputs on cpu and cuda. Correctness is guaranteed on the right-half complex plane. On the left-half complex plane, function output appears mostly correct, but with small numerical errors for certain inputs. On the negative real line, output is NaN.
39
39
 
40
40
  # WIP
41
41
  - `modified_bessel_kv`: Analogue of `scipy.special.kv`.
42
+
43
+ # Benchmarks
44
+ Benchmarking performed with the `asv` package. Results can be viewed at https://hchau630.github.io/torch-bessel.
@@ -0,0 +1,10 @@
1
+ benchmarks/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ benchmarks/benchmarks.py,sha256=LZDGAlKEWwo-VCenNTGn98bJzj9FSlQbFK5pPXRvDyI,3319
3
+ torch_bessel/_C.cpython-312-darwin.so,sha256=ARpLpmh0MCoSGplMTVRxOWtT2zCGI8DbJDY7satv5NM,232280
4
+ torch_bessel/__init__.py,sha256=oohbWz8vxekl7kqDNSWqDH3ORabf9-Tug1KJryKw51A,230
5
+ torch_bessel/ops.py,sha256=P4qo1GN6XExuKfI061dsZMpAPuimn6gYt-PwRmR4_Fk,3584
6
+ torch_bessel-0.0.5.dist-info/LICENSE,sha256=do0DI6wu4mF3VXnEXXPYZqVEatoRSSamgz9t80wU7_o,1068
7
+ torch_bessel-0.0.5.dist-info/METADATA,sha256=LX6X8j53vJCYKaXLfjlk3X_RwRYHlkoQVYC7CSJJ1vc,1561
8
+ torch_bessel-0.0.5.dist-info/WHEEL,sha256=QEo1-fvjBiv4iPfIbG-kA4GNtCNK6HOxclrG5aWVgHI,109
9
+ torch_bessel-0.0.5.dist-info/top_level.txt,sha256=xmyVjWSQ91kX-v8KCzl6wDwfAmbdZNWsP2EH9b9BccQ,24
10
+ torch_bessel-0.0.5.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.0)
2
+ Generator: setuptools (76.0.0)
3
3
  Root-Is-Purelib: false
4
4
  Tag: cp312-cp312-macosx_11_0_arm64
5
5
 
@@ -1 +1,2 @@
1
+ benchmarks
1
2
  torch_bessel
@@ -1,8 +0,0 @@
1
- torch_bessel/_C.cpython-312-darwin.so,sha256=d2qpLB8simaybQxmgXOrrHwvsG3yw32yoehTIX5PsO0,232280
2
- torch_bessel/__init__.py,sha256=oohbWz8vxekl7kqDNSWqDH3ORabf9-Tug1KJryKw51A,230
3
- torch_bessel/ops.py,sha256=Q9BrLxi15MS53xSt_S9dyE3g8_8_GCFYhfAztIor8Fw,2043
4
- torch_bessel-0.0.3.dist-info/LICENSE,sha256=do0DI6wu4mF3VXnEXXPYZqVEatoRSSamgz9t80wU7_o,1068
5
- torch_bessel-0.0.3.dist-info/METADATA,sha256=7hH6thd8bu4oaKYfNq5v6V4JkUe6g6_C9Irl618y3-8,1220
6
- torch_bessel-0.0.3.dist-info/WHEEL,sha256=VujM3ypTCyUW6hcTDdK2ej0ARVMxlU1Djlh_zWnDgqk,109
7
- torch_bessel-0.0.3.dist-info/top_level.txt,sha256=cbDIjTj71LuAlVyyYyDt8fOAeLaVeX3Vums5F2FBa-4,13
8
- torch_bessel-0.0.3.dist-info/RECORD,,