bispectrum 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bispectrum/__init__.py ADDED
@@ -0,0 +1,30 @@
1
+ """Bispectrum analysis for machine learning."""
2
+
3
+ from importlib.metadata import PackageNotFoundError
4
+ from importlib.metadata import version as _pkg_version
5
+
6
+ try:
7
+ __version__ = _pkg_version('bispectrum')
8
+ except PackageNotFoundError:
9
+ __version__ = '0.0.0+unknown'
10
+
11
+ from bispectrum.cn_on_cn import CnonCn
12
+ from bispectrum.dn_on_dn import DnonDn
13
+ from bispectrum.octa_on_octa import OctaonOcta
14
+ from bispectrum.rotation import random_rotation_matrix, rotate_spherical_function
15
+ from bispectrum.so2_on_disk import SO2onDisk
16
+ from bispectrum.so2_on_s1 import SO2onS1
17
+ from bispectrum.so3_on_s2 import SO3onS2
18
+ from bispectrum.torus_on_torus import TorusOnTorus
19
+
20
+ __all__ = [
21
+ 'CnonCn',
22
+ 'DnonDn',
23
+ 'OctaonOcta',
24
+ 'SO2onDisk',
25
+ 'SO2onS1',
26
+ 'SO3onS2',
27
+ 'TorusOnTorus',
28
+ 'random_rotation_matrix',
29
+ 'rotate_spherical_function',
30
+ ]
bispectrum/_bessel.py ADDED
@@ -0,0 +1,253 @@
1
+ """Internal Bessel function utilities.
2
+
3
+ Pure torch implementation — no scipy. Used by SO2onDisk for disk harmonic
4
+ computations. Not part of the public API.
5
+
6
+ Provides:
7
+ bessel_jn — J_n(x) for integer order n >= 0 (torch tensor)
8
+ bessel_jn_zeros — first k positive roots of J_n(x) = 0
9
+ compute_all_bessel_roots — roots for all orders 0..n_max in a single pass
10
+ """
11
+
12
+ import math
13
+
14
+ import torch
15
+
16
+
17
+ def bessel_jn(n: int, x: torch.Tensor) -> torch.Tensor:
18
+ """Compute J_n(x) for integer order n >= 0 via forward recurrence.
19
+
20
+ Uses torch.special.bessel_j0 and bessel_j1 as base cases and the
21
+ standard recurrence J_{k+1}(x) = (2k/x)*J_k(x) - J_{k-1}(x).
22
+
23
+ Forward recurrence is stable when x >= n, which holds for our use
24
+ case (evaluating at Bessel root * r where r in [0, 1]).
25
+
26
+ Args:
27
+ n: Non-negative integer order.
28
+ x: Argument tensor (any shape).
29
+
30
+ Returns:
31
+ J_n(x) with same shape and dtype as x.
32
+ """
33
+ if n < 0:
34
+ raise ValueError(f'Order n must be >= 0, got {n}')
35
+
36
+ if n == 0:
37
+ return torch.special.bessel_j0(x)
38
+ if n == 1:
39
+ return torch.special.bessel_j1(x)
40
+
41
+ j_prev = torch.special.bessel_j0(x)
42
+ j_curr = torch.special.bessel_j1(x)
43
+
44
+ for k in range(1, n):
45
+ safe_x = torch.where(x == 0, torch.ones_like(x), x)
46
+ j_next = (2.0 * k / safe_x) * j_curr - j_prev
47
+ j_next = torch.where(x == 0, torch.zeros_like(j_next), j_next)
48
+ j_prev = j_curr
49
+ j_curr = j_next
50
+
51
+ return j_curr
52
+
53
+
54
+ def _jn_scalar(n: int, x: float) -> float:
55
+ """Fast scalar evaluation of J_n(x) using raw math."""
56
+ if n == 0:
57
+ return torch.special.bessel_j0(torch.tensor(x, dtype=torch.float64)).item()
58
+ if n == 1:
59
+ return torch.special.bessel_j1(torch.tensor(x, dtype=torch.float64)).item()
60
+
61
+ xt = torch.tensor(x, dtype=torch.float64)
62
+ j_prev = torch.special.bessel_j0(xt).item()
63
+ j_curr = torch.special.bessel_j1(xt).item()
64
+
65
+ if x == 0:
66
+ return 0.0
67
+
68
+ for k in range(1, n):
69
+ j_next = (2.0 * k / x) * j_curr - j_prev
70
+ j_prev = j_curr
71
+ j_curr = j_next
72
+
73
+ return j_curr
74
+
75
+
76
+ def _djn_scalar(n: int, x: float) -> float:
77
+ """Scalar J_n'(x) = (J_{n-1}(x) - J_{n+1}(x)) / 2."""
78
+ if n == 0:
79
+ return -_jn_scalar(1, x)
80
+ return (_jn_scalar(n - 1, x) - _jn_scalar(n + 1, x)) / 2.0
81
+
82
+
83
+ def _bisect_newton(n: int, a: float, b: float) -> float:
84
+ """Find root of J_n in bracket [a, b] using Newton + bisection."""
85
+ fa = _jn_scalar(n, a)
86
+ fb = _jn_scalar(n, b)
87
+
88
+ if abs(fa) < 1e-15:
89
+ return a
90
+ if abs(fb) < 1e-15:
91
+ return b
92
+ if fa * fb > 0:
93
+ return (a + b) / 2.0
94
+
95
+ x = (a + b) / 2.0
96
+ for _ in range(80):
97
+ fx = _jn_scalar(n, x)
98
+ if abs(fx) < 1e-15:
99
+ return x
100
+
101
+ dfx = _djn_scalar(n, x)
102
+
103
+ if abs(dfx) > 1e-30:
104
+ x_new = x - fx / dfx
105
+ else:
106
+ x_new = x
107
+
108
+ if a < x_new < b:
109
+ x = x_new
110
+ else:
111
+ x = (a + b) / 2.0
112
+
113
+ fx = _jn_scalar(n, x)
114
+ if fa * fx < 0:
115
+ b = x
116
+ fb = fx
117
+ else:
118
+ a = x
119
+ fa = fx
120
+
121
+ if (b - a) < 1e-14 * max(abs(a), 1.0):
122
+ return (a + b) / 2.0
123
+
124
+ return (a + b) / 2.0
125
+
126
+
127
+ def _mcmahon_zeros_j0(num_zeros: int) -> list[float]:
128
+ """McMahon expansion for J_0 roots — highly accurate for all k."""
129
+ if num_zeros <= 0:
130
+ return []
131
+ s = torch.arange(1, num_zeros + 1, dtype=torch.float64)
132
+ beta = math.pi * (s - 0.25)
133
+ z = beta - 1.0 / (8.0 * beta)
134
+ for _ in range(10):
135
+ fz = bessel_jn(0, z)
136
+ dfz = -bessel_jn(1, z)
137
+ safe_dfz = torch.where(dfz.abs() < 1e-30, torch.ones_like(dfz), dfz)
138
+ dz = fz / safe_dfz
139
+ dz = torch.where(dfz.abs() < 1e-30, torch.zeros_like(dz), dz)
140
+ z = z - dz
141
+ if (dz.abs() / z.abs().clamp(min=1.0)).max() < 1e-14:
142
+ break
143
+ return z.tolist()
144
+
145
+
146
+ def _bisect_newton_batch(n: int, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
147
+ """Vectorized root-finding for J_n in brackets [a, b].
148
+
149
+ Args:
150
+ n: Bessel order.
151
+ a: Lower bracket endpoints, shape (num_roots,).
152
+ b: Upper bracket endpoints, shape (num_roots,).
153
+
154
+ Returns:
155
+ Roots tensor, shape (num_roots,).
156
+ """
157
+ fa = bessel_jn(n, a)
158
+ fb = bessel_jn(n, b)
159
+
160
+ exact_a = fa.abs() < 1e-15
161
+ exact_b = fb.abs() < 1e-15
162
+ no_sign_change = fa * fb > 0
163
+
164
+ x = (a + b) / 2.0
165
+
166
+ for _ in range(80):
167
+ fx = bessel_jn(n, x)
168
+ if n == 0:
169
+ dfx = -bessel_jn(1, x)
170
+ else:
171
+ dfx = (bessel_jn(n - 1, x) - bessel_jn(n + 1, x)) / 2.0
172
+
173
+ newton_ok = dfx.abs() > 1e-30
174
+ x_newton = torch.where(newton_ok, x - fx / dfx.clamp_min(1e-30).copysign(dfx), x)
175
+ in_bracket = (a < x_newton) & (x_newton < b)
176
+ x = torch.where(in_bracket & newton_ok, x_newton, (a + b) / 2.0)
177
+
178
+ fx = bessel_jn(n, x)
179
+ go_left = fa * fx < 0
180
+ b = torch.where(go_left, x, b)
181
+ fb = torch.where(go_left, fx, fb)
182
+ a = torch.where(~go_left, x, a)
183
+ fa = torch.where(~go_left, fx, fa)
184
+
185
+ converged = (b - a) < 1e-14 * a.abs().clamp(min=1.0)
186
+ if converged.all():
187
+ break
188
+
189
+ result = (a + b) / 2.0
190
+ result = torch.where(exact_a, a, result)
191
+ result = torch.where(exact_b, b, result)
192
+ result = torch.where(no_sign_change & ~exact_a & ~exact_b, (a + b) / 2.0, result)
193
+ return result
194
+
195
+
196
+ def compute_all_bessel_roots(n_max: int, k_max: int) -> dict[int, list[float]]:
197
+ """Compute Bessel roots for all orders 0..n_max using interlacing.
198
+
199
+ Uses the interlacing property j_{n-1,k} < j_{n,k} < j_{n-1,k+1}
200
+ to bracket each root, then Newton-bisection within the bracket.
201
+ All orders are computed in a single pass from J_0 upward, sharing
202
+ intermediate results.
203
+
204
+ Args:
205
+ n_max: Maximum Bessel order.
206
+ k_max: Maximum number of roots per order.
207
+
208
+ Returns:
209
+ Dict mapping order n -> list of first k_max roots (or fewer if
210
+ not enough brackets exist).
211
+ """
212
+ total_j0 = k_max + n_max + 5
213
+ prev_roots_list = _mcmahon_zeros_j0(total_j0)
214
+
215
+ all_roots: dict[int, list[float]] = {0: prev_roots_list[:k_max]}
216
+
217
+ prev_roots = torch.tensor(prev_roots_list, dtype=torch.float64)
218
+
219
+ for order in range(1, n_max + 1):
220
+ num_needed = k_max + (n_max - order) + 3
221
+ num_brackets = min(num_needed, len(prev_roots) - 1)
222
+ if num_brackets <= 0:
223
+ all_roots[order] = []
224
+ prev_roots = torch.tensor([], dtype=torch.float64)
225
+ continue
226
+ a = prev_roots[:num_brackets]
227
+ b = prev_roots[1 : num_brackets + 1]
228
+ curr_roots = _bisect_newton_batch(order, a, b)
229
+ prev_roots = curr_roots
230
+ all_roots[order] = curr_roots[:k_max].tolist()
231
+
232
+ return all_roots
233
+
234
+
235
+ def bessel_jn_zeros(n: int, num_zeros: int) -> torch.Tensor:
236
+ """Compute the first `num_zeros` positive roots of J_n(x) = 0.
237
+
238
+ For single-order queries. For multi-order queries, use
239
+ compute_all_bessel_roots() which is more efficient.
240
+
241
+ Args:
242
+ n: Non-negative integer order.
243
+ num_zeros: Number of positive roots to compute.
244
+
245
+ Returns:
246
+ 1D float64 tensor of shape (num_zeros,) with roots in ascending order.
247
+ """
248
+ if num_zeros <= 0:
249
+ return torch.zeros(0, dtype=torch.float64)
250
+
251
+ all_roots = compute_all_bessel_roots(n, num_zeros)
252
+ roots = all_roots.get(n, [])
253
+ return torch.tensor(roots[:num_zeros], dtype=torch.float64)