bispectrum 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bispectrum/__init__.py +30 -0
- bispectrum/_bessel.py +253 -0
- bispectrum/_cg.py +734 -0
- bispectrum/cn_on_cn.py +163 -0
- bispectrum/data/__init__.py +1 -0
- bispectrum/data/cg_lmax5.json +84832 -0
- bispectrum/dn_on_dn.py +588 -0
- bispectrum/octa_on_octa.py +786 -0
- bispectrum/rotation.py +203 -0
- bispectrum/so2_on_disk.py +454 -0
- bispectrum/so2_on_s1.py +32 -0
- bispectrum/so3_on_s2.py +1119 -0
- bispectrum/torus_on_torus.py +286 -0
- bispectrum-0.3.0.dist-info/METADATA +125 -0
- bispectrum-0.3.0.dist-info/RECORD +17 -0
- bispectrum-0.3.0.dist-info/WHEEL +4 -0
- bispectrum-0.3.0.dist-info/licenses/LICENSE +21 -0
bispectrum/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""Bispectrum analysis for machine learning."""
|
|
2
|
+
|
|
3
|
+
from importlib.metadata import PackageNotFoundError
|
|
4
|
+
from importlib.metadata import version as _pkg_version
|
|
5
|
+
|
|
6
|
+
try:
|
|
7
|
+
__version__ = _pkg_version('bispectrum')
|
|
8
|
+
except PackageNotFoundError:
|
|
9
|
+
__version__ = '0.0.0+unknown'
|
|
10
|
+
|
|
11
|
+
from bispectrum.cn_on_cn import CnonCn
|
|
12
|
+
from bispectrum.dn_on_dn import DnonDn
|
|
13
|
+
from bispectrum.octa_on_octa import OctaonOcta
|
|
14
|
+
from bispectrum.rotation import random_rotation_matrix, rotate_spherical_function
|
|
15
|
+
from bispectrum.so2_on_disk import SO2onDisk
|
|
16
|
+
from bispectrum.so2_on_s1 import SO2onS1
|
|
17
|
+
from bispectrum.so3_on_s2 import SO3onS2
|
|
18
|
+
from bispectrum.torus_on_torus import TorusOnTorus
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
'CnonCn',
|
|
22
|
+
'DnonDn',
|
|
23
|
+
'OctaonOcta',
|
|
24
|
+
'SO2onDisk',
|
|
25
|
+
'SO2onS1',
|
|
26
|
+
'SO3onS2',
|
|
27
|
+
'TorusOnTorus',
|
|
28
|
+
'random_rotation_matrix',
|
|
29
|
+
'rotate_spherical_function',
|
|
30
|
+
]
|
bispectrum/_bessel.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
"""Internal Bessel function utilities.
|
|
2
|
+
|
|
3
|
+
Pure torch implementation — no scipy. Used by SO2onDisk for disk harmonic
|
|
4
|
+
computations. Not part of the public API.
|
|
5
|
+
|
|
6
|
+
Provides:
|
|
7
|
+
bessel_jn — J_n(x) for integer order n >= 0 (torch tensor)
|
|
8
|
+
bessel_jn_zeros — first k positive roots of J_n(x) = 0
|
|
9
|
+
compute_all_bessel_roots — roots for all orders 0..n_max in a single pass
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import math
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def bessel_jn(n: int, x: torch.Tensor) -> torch.Tensor:
|
|
18
|
+
"""Compute J_n(x) for integer order n >= 0 via forward recurrence.
|
|
19
|
+
|
|
20
|
+
Uses torch.special.bessel_j0 and bessel_j1 as base cases and the
|
|
21
|
+
standard recurrence J_{k+1}(x) = (2k/x)*J_k(x) - J_{k-1}(x).
|
|
22
|
+
|
|
23
|
+
Forward recurrence is stable when x >= n, which holds for our use
|
|
24
|
+
case (evaluating at Bessel root * r where r in [0, 1]).
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
n: Non-negative integer order.
|
|
28
|
+
x: Argument tensor (any shape).
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
J_n(x) with same shape and dtype as x.
|
|
32
|
+
"""
|
|
33
|
+
if n < 0:
|
|
34
|
+
raise ValueError(f'Order n must be >= 0, got {n}')
|
|
35
|
+
|
|
36
|
+
if n == 0:
|
|
37
|
+
return torch.special.bessel_j0(x)
|
|
38
|
+
if n == 1:
|
|
39
|
+
return torch.special.bessel_j1(x)
|
|
40
|
+
|
|
41
|
+
j_prev = torch.special.bessel_j0(x)
|
|
42
|
+
j_curr = torch.special.bessel_j1(x)
|
|
43
|
+
|
|
44
|
+
for k in range(1, n):
|
|
45
|
+
safe_x = torch.where(x == 0, torch.ones_like(x), x)
|
|
46
|
+
j_next = (2.0 * k / safe_x) * j_curr - j_prev
|
|
47
|
+
j_next = torch.where(x == 0, torch.zeros_like(j_next), j_next)
|
|
48
|
+
j_prev = j_curr
|
|
49
|
+
j_curr = j_next
|
|
50
|
+
|
|
51
|
+
return j_curr
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _jn_scalar(n: int, x: float) -> float:
|
|
55
|
+
"""Fast scalar evaluation of J_n(x) using raw math."""
|
|
56
|
+
if n == 0:
|
|
57
|
+
return torch.special.bessel_j0(torch.tensor(x, dtype=torch.float64)).item()
|
|
58
|
+
if n == 1:
|
|
59
|
+
return torch.special.bessel_j1(torch.tensor(x, dtype=torch.float64)).item()
|
|
60
|
+
|
|
61
|
+
xt = torch.tensor(x, dtype=torch.float64)
|
|
62
|
+
j_prev = torch.special.bessel_j0(xt).item()
|
|
63
|
+
j_curr = torch.special.bessel_j1(xt).item()
|
|
64
|
+
|
|
65
|
+
if x == 0:
|
|
66
|
+
return 0.0
|
|
67
|
+
|
|
68
|
+
for k in range(1, n):
|
|
69
|
+
j_next = (2.0 * k / x) * j_curr - j_prev
|
|
70
|
+
j_prev = j_curr
|
|
71
|
+
j_curr = j_next
|
|
72
|
+
|
|
73
|
+
return j_curr
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _djn_scalar(n: int, x: float) -> float:
|
|
77
|
+
"""Scalar J_n'(x) = (J_{n-1}(x) - J_{n+1}(x)) / 2."""
|
|
78
|
+
if n == 0:
|
|
79
|
+
return -_jn_scalar(1, x)
|
|
80
|
+
return (_jn_scalar(n - 1, x) - _jn_scalar(n + 1, x)) / 2.0
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _bisect_newton(n: int, a: float, b: float) -> float:
|
|
84
|
+
"""Find root of J_n in bracket [a, b] using Newton + bisection."""
|
|
85
|
+
fa = _jn_scalar(n, a)
|
|
86
|
+
fb = _jn_scalar(n, b)
|
|
87
|
+
|
|
88
|
+
if abs(fa) < 1e-15:
|
|
89
|
+
return a
|
|
90
|
+
if abs(fb) < 1e-15:
|
|
91
|
+
return b
|
|
92
|
+
if fa * fb > 0:
|
|
93
|
+
return (a + b) / 2.0
|
|
94
|
+
|
|
95
|
+
x = (a + b) / 2.0
|
|
96
|
+
for _ in range(80):
|
|
97
|
+
fx = _jn_scalar(n, x)
|
|
98
|
+
if abs(fx) < 1e-15:
|
|
99
|
+
return x
|
|
100
|
+
|
|
101
|
+
dfx = _djn_scalar(n, x)
|
|
102
|
+
|
|
103
|
+
if abs(dfx) > 1e-30:
|
|
104
|
+
x_new = x - fx / dfx
|
|
105
|
+
else:
|
|
106
|
+
x_new = x
|
|
107
|
+
|
|
108
|
+
if a < x_new < b:
|
|
109
|
+
x = x_new
|
|
110
|
+
else:
|
|
111
|
+
x = (a + b) / 2.0
|
|
112
|
+
|
|
113
|
+
fx = _jn_scalar(n, x)
|
|
114
|
+
if fa * fx < 0:
|
|
115
|
+
b = x
|
|
116
|
+
fb = fx
|
|
117
|
+
else:
|
|
118
|
+
a = x
|
|
119
|
+
fa = fx
|
|
120
|
+
|
|
121
|
+
if (b - a) < 1e-14 * max(abs(a), 1.0):
|
|
122
|
+
return (a + b) / 2.0
|
|
123
|
+
|
|
124
|
+
return (a + b) / 2.0
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def _mcmahon_zeros_j0(num_zeros: int) -> list[float]:
|
|
128
|
+
"""McMahon expansion for J_0 roots — highly accurate for all k."""
|
|
129
|
+
if num_zeros <= 0:
|
|
130
|
+
return []
|
|
131
|
+
s = torch.arange(1, num_zeros + 1, dtype=torch.float64)
|
|
132
|
+
beta = math.pi * (s - 0.25)
|
|
133
|
+
z = beta - 1.0 / (8.0 * beta)
|
|
134
|
+
for _ in range(10):
|
|
135
|
+
fz = bessel_jn(0, z)
|
|
136
|
+
dfz = -bessel_jn(1, z)
|
|
137
|
+
safe_dfz = torch.where(dfz.abs() < 1e-30, torch.ones_like(dfz), dfz)
|
|
138
|
+
dz = fz / safe_dfz
|
|
139
|
+
dz = torch.where(dfz.abs() < 1e-30, torch.zeros_like(dz), dz)
|
|
140
|
+
z = z - dz
|
|
141
|
+
if (dz.abs() / z.abs().clamp(min=1.0)).max() < 1e-14:
|
|
142
|
+
break
|
|
143
|
+
return z.tolist()
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _bisect_newton_batch(n: int, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
|
147
|
+
"""Vectorized root-finding for J_n in brackets [a, b].
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
n: Bessel order.
|
|
151
|
+
a: Lower bracket endpoints, shape (num_roots,).
|
|
152
|
+
b: Upper bracket endpoints, shape (num_roots,).
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
Roots tensor, shape (num_roots,).
|
|
156
|
+
"""
|
|
157
|
+
fa = bessel_jn(n, a)
|
|
158
|
+
fb = bessel_jn(n, b)
|
|
159
|
+
|
|
160
|
+
exact_a = fa.abs() < 1e-15
|
|
161
|
+
exact_b = fb.abs() < 1e-15
|
|
162
|
+
no_sign_change = fa * fb > 0
|
|
163
|
+
|
|
164
|
+
x = (a + b) / 2.0
|
|
165
|
+
|
|
166
|
+
for _ in range(80):
|
|
167
|
+
fx = bessel_jn(n, x)
|
|
168
|
+
if n == 0:
|
|
169
|
+
dfx = -bessel_jn(1, x)
|
|
170
|
+
else:
|
|
171
|
+
dfx = (bessel_jn(n - 1, x) - bessel_jn(n + 1, x)) / 2.0
|
|
172
|
+
|
|
173
|
+
newton_ok = dfx.abs() > 1e-30
|
|
174
|
+
x_newton = torch.where(newton_ok, x - fx / dfx.clamp_min(1e-30).copysign(dfx), x)
|
|
175
|
+
in_bracket = (a < x_newton) & (x_newton < b)
|
|
176
|
+
x = torch.where(in_bracket & newton_ok, x_newton, (a + b) / 2.0)
|
|
177
|
+
|
|
178
|
+
fx = bessel_jn(n, x)
|
|
179
|
+
go_left = fa * fx < 0
|
|
180
|
+
b = torch.where(go_left, x, b)
|
|
181
|
+
fb = torch.where(go_left, fx, fb)
|
|
182
|
+
a = torch.where(~go_left, x, a)
|
|
183
|
+
fa = torch.where(~go_left, fx, fa)
|
|
184
|
+
|
|
185
|
+
converged = (b - a) < 1e-14 * a.abs().clamp(min=1.0)
|
|
186
|
+
if converged.all():
|
|
187
|
+
break
|
|
188
|
+
|
|
189
|
+
result = (a + b) / 2.0
|
|
190
|
+
result = torch.where(exact_a, a, result)
|
|
191
|
+
result = torch.where(exact_b, b, result)
|
|
192
|
+
result = torch.where(no_sign_change & ~exact_a & ~exact_b, (a + b) / 2.0, result)
|
|
193
|
+
return result
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def compute_all_bessel_roots(n_max: int, k_max: int) -> dict[int, list[float]]:
|
|
197
|
+
"""Compute Bessel roots for all orders 0..n_max using interlacing.
|
|
198
|
+
|
|
199
|
+
Uses the interlacing property j_{n-1,k} < j_{n,k} < j_{n-1,k+1}
|
|
200
|
+
to bracket each root, then Newton-bisection within the bracket.
|
|
201
|
+
All orders are computed in a single pass from J_0 upward, sharing
|
|
202
|
+
intermediate results.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
n_max: Maximum Bessel order.
|
|
206
|
+
k_max: Maximum number of roots per order.
|
|
207
|
+
|
|
208
|
+
Returns:
|
|
209
|
+
Dict mapping order n -> list of first k_max roots (or fewer if
|
|
210
|
+
not enough brackets exist).
|
|
211
|
+
"""
|
|
212
|
+
total_j0 = k_max + n_max + 5
|
|
213
|
+
prev_roots_list = _mcmahon_zeros_j0(total_j0)
|
|
214
|
+
|
|
215
|
+
all_roots: dict[int, list[float]] = {0: prev_roots_list[:k_max]}
|
|
216
|
+
|
|
217
|
+
prev_roots = torch.tensor(prev_roots_list, dtype=torch.float64)
|
|
218
|
+
|
|
219
|
+
for order in range(1, n_max + 1):
|
|
220
|
+
num_needed = k_max + (n_max - order) + 3
|
|
221
|
+
num_brackets = min(num_needed, len(prev_roots) - 1)
|
|
222
|
+
if num_brackets <= 0:
|
|
223
|
+
all_roots[order] = []
|
|
224
|
+
prev_roots = torch.tensor([], dtype=torch.float64)
|
|
225
|
+
continue
|
|
226
|
+
a = prev_roots[:num_brackets]
|
|
227
|
+
b = prev_roots[1 : num_brackets + 1]
|
|
228
|
+
curr_roots = _bisect_newton_batch(order, a, b)
|
|
229
|
+
prev_roots = curr_roots
|
|
230
|
+
all_roots[order] = curr_roots[:k_max].tolist()
|
|
231
|
+
|
|
232
|
+
return all_roots
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def bessel_jn_zeros(n: int, num_zeros: int) -> torch.Tensor:
|
|
236
|
+
"""Compute the first `num_zeros` positive roots of J_n(x) = 0.
|
|
237
|
+
|
|
238
|
+
For single-order queries. For multi-order queries, use
|
|
239
|
+
compute_all_bessel_roots() which is more efficient.
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
n: Non-negative integer order.
|
|
243
|
+
num_zeros: Number of positive roots to compute.
|
|
244
|
+
|
|
245
|
+
Returns:
|
|
246
|
+
1D float64 tensor of shape (num_zeros,) with roots in ascending order.
|
|
247
|
+
"""
|
|
248
|
+
if num_zeros <= 0:
|
|
249
|
+
return torch.zeros(0, dtype=torch.float64)
|
|
250
|
+
|
|
251
|
+
all_roots = compute_all_bessel_roots(n, num_zeros)
|
|
252
|
+
roots = all_roots.get(n, [])
|
|
253
|
+
return torch.tensor(roots[:num_zeros], dtype=torch.float64)
|