spotgp 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spotgp-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,50 @@
1
+ Metadata-Version: 2.4
2
+ Name: spotgp
3
+ Version: 0.1.0
4
+ Summary: Gaussian Process kernels for stellar variability from starspot models
5
+ Author: Jessica Birky
6
+ License: MIT
7
+ Requires-Python: >=3.8
8
+ Description-Content-Type: text/markdown
9
+ Requires-Dist: numpy
10
+ Requires-Dist: scipy
11
+ Requires-Dist: matplotlib
12
+ Requires-Dist: astropy
13
+ Requires-Dist: scikit-learn
14
+ Requires-Dist: tqdm
15
+ Provides-Extra: jax
16
+ Requires-Dist: jax; extra == "jax"
17
+ Requires-Dist: jaxopt; extra == "jax"
18
+ Provides-Extra: docs
19
+ Requires-Dist: sphinx>=7.0; extra == "docs"
20
+ Requires-Dist: sphinx-book-theme>=1.0; extra == "docs"
21
+ Requires-Dist: sphinx-copybutton>=0.5; extra == "docs"
22
+ Requires-Dist: myst-nb>=1.0; extra == "docs"
23
+ Requires-Dist: sphinxcontrib-mermaid>=0.9; extra == "docs"
24
+ Requires-Dist: pygments-styles>=0.3; extra == "docs"
25
+
26
+ # `spotgp`
27
+
28
+ [![Tests](https://github.com/jbirky/spotgp/actions/workflows/tests.yml/badge.svg)](https://github.com/jbirky/spotgp/actions/workflows/tests.yml)
29
+ [![codecov](https://codecov.io/gh/jbirky/spotgp/branch/main/graph/badge.svg)](https://codecov.io/gh/jbirky/spotgp)
30
+ [![Documentation Status](https://readthedocs.org/projects/spotgp/badge/?version=latest)](https://spotgp.readthedocs.io/en/latest/?badge=latest)
31
+
32
+ **`spotgp`**: Gaussian Process kernels for stellar starspot variability implemented in `JAX`.
33
+
34
+ <br>
35
+
36
+ ![Lightcurve animation](docs/tutorials/lightcurve_animation.gif)
37
+
38
+ ## Installation
39
+
40
+ ```bash
41
+ git clone https://github.com/jbirky/spotgp.git
42
+ cd spotgp
43
+ pip install -e .
44
+ ```
45
+
46
+ For JAX acceleration:
47
+
48
+ ```bash
49
+ pip install -e ".[jax]"
50
+ ```
spotgp-0.1.0/README.md ADDED
@@ -0,0 +1,25 @@
1
+ # `spotgp`
2
+
3
+ [![Tests](https://github.com/jbirky/spotgp/actions/workflows/tests.yml/badge.svg)](https://github.com/jbirky/spotgp/actions/workflows/tests.yml)
4
+ [![codecov](https://codecov.io/gh/jbirky/spotgp/branch/main/graph/badge.svg)](https://codecov.io/gh/jbirky/spotgp)
5
+ [![Documentation Status](https://readthedocs.org/projects/spotgp/badge/?version=latest)](https://spotgp.readthedocs.io/en/latest/?badge=latest)
6
+
7
+ **`spotgp`**: Gaussian Process kernels for stellar starspot variability implemented in `JAX`.
8
+
9
+ <br>
10
+
11
+ ![Lightcurve animation](docs/tutorials/lightcurve_animation.gif)
12
+
13
+ ## Installation
14
+
15
+ ```bash
16
+ git clone https://github.com/jbirky/spotgp.git
17
+ cd spotgp
18
+ pip install -e .
19
+ ```
20
+
21
+ For JAX acceleration:
22
+
23
+ ```bash
24
+ pip install -e ".[jax]"
25
+ ```
@@ -0,0 +1,36 @@
1
+ [build-system]
2
+ requires = ["setuptools>=64", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "spotgp"
7
+ version = "0.1.0"
8
+ description = "Gaussian Process kernels for stellar variability from starspot models"
9
+ readme = "README.md"
10
+ requires-python = ">=3.8"
11
+ license = {text = "MIT"}
12
+ authors = [
13
+ {name = "Jessica Birky"},
14
+ ]
15
+ dependencies = [
16
+ "numpy",
17
+ "scipy",
18
+ "matplotlib",
19
+ "astropy",
20
+ "scikit-learn",
21
+ "tqdm",
22
+ ]
23
+
24
+ [project.optional-dependencies]
25
+ jax = ["jax", "jaxopt"]
26
+ docs = [
27
+ "sphinx>=7.0",
28
+ "sphinx-book-theme>=1.0",
29
+ "sphinx-copybutton>=0.5",
30
+ "myst-nb>=1.0",
31
+ "sphinxcontrib-mermaid>=0.9",
32
+ "pygments-styles>=0.3",
33
+ ]
34
+
35
+ [tool.setuptools.packages.find]
36
+ include = ["src*"]
spotgp-0.1.0/setup.cfg ADDED
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,50 @@
1
+ Metadata-Version: 2.4
2
+ Name: spotgp
3
+ Version: 0.1.0
4
+ Summary: Gaussian Process kernels for stellar variability from starspot models
5
+ Author: Jessica Birky
6
+ License: MIT
7
+ Requires-Python: >=3.8
8
+ Description-Content-Type: text/markdown
9
+ Requires-Dist: numpy
10
+ Requires-Dist: scipy
11
+ Requires-Dist: matplotlib
12
+ Requires-Dist: astropy
13
+ Requires-Dist: scikit-learn
14
+ Requires-Dist: tqdm
15
+ Provides-Extra: jax
16
+ Requires-Dist: jax; extra == "jax"
17
+ Requires-Dist: jaxopt; extra == "jax"
18
+ Provides-Extra: docs
19
+ Requires-Dist: sphinx>=7.0; extra == "docs"
20
+ Requires-Dist: sphinx-book-theme>=1.0; extra == "docs"
21
+ Requires-Dist: sphinx-copybutton>=0.5; extra == "docs"
22
+ Requires-Dist: myst-nb>=1.0; extra == "docs"
23
+ Requires-Dist: sphinxcontrib-mermaid>=0.9; extra == "docs"
24
+ Requires-Dist: pygments-styles>=0.3; extra == "docs"
25
+
26
+ # `spotgp`
27
+
28
+ [![Tests](https://github.com/jbirky/spotgp/actions/workflows/tests.yml/badge.svg)](https://github.com/jbirky/spotgp/actions/workflows/tests.yml)
29
+ [![codecov](https://codecov.io/gh/jbirky/spotgp/branch/main/graph/badge.svg)](https://codecov.io/gh/jbirky/spotgp)
30
+ [![Documentation Status](https://readthedocs.org/projects/spotgp/badge/?version=latest)](https://spotgp.readthedocs.io/en/latest/?badge=latest)
31
+
32
+ **`spotgp`**: Gaussian Process kernels for stellar starspot variability implemented in `JAX`.
33
+
34
+ <br>
35
+
36
+ ![Lightcurve animation](docs/tutorials/lightcurve_animation.gif)
37
+
38
+ ## Installation
39
+
40
+ ```bash
41
+ git clone https://github.com/jbirky/spotgp.git
42
+ cd spotgp
43
+ pip install -e .
44
+ ```
45
+
46
+ For JAX acceleration:
47
+
48
+ ```bash
49
+ pip install -e ".[jax]"
50
+ ```
@@ -0,0 +1,31 @@
1
+ README.md
2
+ pyproject.toml
3
+ spotgp.egg-info/PKG-INFO
4
+ spotgp.egg-info/SOURCES.txt
5
+ spotgp.egg-info/dependency_links.txt
6
+ spotgp.egg-info/requires.txt
7
+ spotgp.egg-info/top_level.txt
8
+ src/__init__.py
9
+ src/analytic_kernel.py
10
+ src/banded_cholesky.py
11
+ src/envelope.py
12
+ src/gp_solver.py
13
+ src/latitude.py
14
+ src/lightcurve.py
15
+ src/mcmc.py
16
+ src/numerical_kernel.py
17
+ src/params.py
18
+ src/plotting.py
19
+ src/psd.py
20
+ src/spot_model.py
21
+ src/visibility.py
22
+ tests/test_analytic_kernel.py
23
+ tests/test_banded_cholesky.py
24
+ tests/test_envelope.py
25
+ tests/test_gp_solver.py
26
+ tests/test_lightcurve.py
27
+ tests/test_mcmc.py
28
+ tests/test_numerical_kernel.py
29
+ tests/test_params.py
30
+ tests/test_psd.py
31
+ tests/test_spot_model.py
@@ -0,0 +1,18 @@
1
+ numpy
2
+ scipy
3
+ matplotlib
4
+ astropy
5
+ scikit-learn
6
+ tqdm
7
+
8
+ [docs]
9
+ sphinx>=7.0
10
+ sphinx-book-theme>=1.0
11
+ sphinx-copybutton>=0.5
12
+ myst-nb>=1.0
13
+ sphinxcontrib-mermaid>=0.9
14
+ pygments-styles>=0.3
15
+
16
+ [jax]
17
+ jax
18
+ jaxopt
@@ -0,0 +1 @@
1
+ src
@@ -0,0 +1,11 @@
1
+ from .envelope import *
2
+ from .latitude import *
3
+ from .visibility import *
4
+ from .spot_model import *
5
+ from .lightcurve import *
6
+ from .analytic_kernel import *
7
+ from .numerical_kernel import *
8
+ from .psd import *
9
+ from .gp_solver import *
10
+ from .mcmc import *
11
+ from .plotting import *
@@ -0,0 +1,404 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import numpy as np
4
+ from functools import partial
5
+
6
+ try:
7
+ from .params import resolve_hparam
8
+ from .envelope import (
9
+ EnvelopeFunction,
10
+ TrapezoidAsymmetricEnvelope,
11
+ SkewedGaussianEnvelope,
12
+ ExponentialEnvelope,
13
+ compute_R_Gamma_numerical,
14
+ )
15
+ from .spot_model import (
16
+ VisibilityFunction, EdgeOnVisibilityFunction, SpotEvolutionModel,
17
+ _cn_squared_coefficients_jax, _gauss_legendre_grid,
18
+ )
19
+ except ImportError:
20
+ from params import resolve_hparam
21
+ from envelope import (
22
+ EnvelopeFunction,
23
+ TrapezoidAsymmetricEnvelope,
24
+ SkewedGaussianEnvelope,
25
+ ExponentialEnvelope,
26
+ compute_R_Gamma_numerical,
27
+ )
28
+ from spot_model import (
29
+ VisibilityFunction, EdgeOnVisibilityFunction, SpotEvolutionModel,
30
+ _cn_squared_coefficients_jax, _gauss_legendre_grid,
31
+ )
32
+
33
+ __all__ = ["AnalyticKernel", "compute_R_Gamma_numerical"]
34
+
35
+
36
+ class AnalyticKernel:
37
+ """
38
+ JAX-accelerated analytic GP kernel for stellar rotation variability.
39
+
40
+ Parameters
41
+ ----------
42
+ model_or_hparam : SpotEvolutionModel or dict
43
+ Either a SpotEvolutionModel instance (new API) or a raw hparam dict
44
+ (backward-compatible old API).
45
+ n_harmonics : int
46
+ Number of Fourier harmonics for the visibility function (default 3).
47
+ n_lat : int
48
+ Number of latitude quadrature points (default 64).
49
+ lat_range : tuple
50
+ (min, max) latitude in radians (default (-pi/2, pi/2)).
51
+ quadrature : str
52
+ Latitude integration method: "trapezoid" or "gauss-legendre".
53
+ """
54
+
55
+ def __init__(self, model_or_hparam, n_harmonics=3, n_lat=64,
56
+ lat_range=None, quadrature="trapezoid"):
57
+
58
+ # ── Accept SpotEvolutionModel or legacy hparam dict ────────────────
59
+ if isinstance(model_or_hparam, SpotEvolutionModel):
60
+ self.spot_model = model_or_hparam
61
+ self.hparam = model_or_hparam.to_hparam()
62
+ else:
63
+ # Backward compat: dict input
64
+ self.hparam = resolve_hparam(model_or_hparam)
65
+ self.spot_model = SpotEvolutionModel.from_hparam(self.hparam)
66
+
67
+ # ── Unpack commonly-used params ────────────────────────────────────
68
+ self.envelope = self.spot_model.envelope
69
+ self.visibility = self.spot_model.visibility
70
+
71
+ self.peq = self.spot_model.peq
72
+ self.kappa = self.spot_model.kappa
73
+ self.inc = self.spot_model.inc
74
+ self.lspot = self.spot_model.lspot
75
+ self.sigma_k = self.spot_model.sigma_k
76
+ self.tau_spot = self.spot_model.tau_spot
77
+
78
+ # ── Envelope-type attributes (backward compat) ────────────────────
79
+ if isinstance(self.envelope, SkewedGaussianEnvelope):
80
+ self.envelope_type = "skew_normal"
81
+ self.sigma_sn = self.envelope.sigma_sn
82
+ self.n_sn = self.envelope.n_sn
83
+ self.tau_em = self.tau_spot
84
+ self.tau_dec = self.tau_spot
85
+ self.asymmetric = False
86
+ # Re-use grids from the envelope object
87
+ self._R_Gamma_lag_grid = self.envelope._R_lag_grid
88
+ self._R_Gamma_vals = self.envelope._R_vals
89
+ self._Gh_sq_omega_grid = self.envelope._Gh_omega_grid
90
+ self._Gh_sq_vals = self.envelope._Gh_sq_vals
91
+
92
+ elif isinstance(self.envelope, TrapezoidAsymmetricEnvelope):
93
+ self.envelope_type = "trapezoid_asymmetric"
94
+ self.asymmetric = True
95
+ self.tau_em = self.envelope.tau_em
96
+ self.tau_dec = self.envelope.tau_dec
97
+ self._te = min(self.tau_em, self.tau_dec)
98
+ self._td = max(self.tau_em, self.tau_dec)
99
+
100
+ elif isinstance(self.envelope, ExponentialEnvelope):
101
+ self.envelope_type = "exponential"
102
+ self.asymmetric = False
103
+ self.tau_em = self.tau_spot
104
+ self.tau_dec = self.tau_spot
105
+
106
+ else:
107
+ # Default: symmetric trapezoid (or any other future type)
108
+ self.envelope_type = "trapezoid_symmetric"
109
+ self.asymmetric = False
110
+ self.tau_em = self.tau_spot
111
+ self.tau_dec = self.tau_spot
112
+
113
+ # ── Kernel config ──────────────────────────────────────────────────
114
+ self.n_harmonics = n_harmonics
115
+ self.n_lat = n_lat
116
+ self.lat_range = (lat_range if lat_range is not None
117
+ else self.spot_model.latitude_distribution.lat_range)
118
+ self.quadrature = quadrature
119
+
120
+ if quadrature == "gauss-legendre":
121
+ self._quad_nodes, self._quad_weights = _gauss_legendre_grid(
122
+ n_lat, lat_range[0], lat_range[1])
123
+ elif quadrature == "trapezoid":
124
+ self._quad_nodes = None
125
+ self._quad_weights = None
126
+ else:
127
+ raise ValueError(
128
+ f"Unknown quadrature method: {quadrature!r}. "
129
+ "Use 'trapezoid' or 'gauss-legendre'.")
130
+
131
+ # ── Core kernel helpers ─────────────────────────────────────────────────
132
+
133
+ def omega0(self, phi):
134
+ """Latitude-dependent rotation angular frequency [rad/day]."""
135
+ return self.visibility.omega0(phi)
136
+
137
+ def R_Gamma(self, lag):
138
+ """Autocorrelation of the squared envelope (delegates to envelope)."""
139
+ return self.envelope.R_Gamma(jnp.asarray(lag))
140
+
141
+ def cn_squared(self, phi):
142
+ """Squared Fourier visibility coefficients at latitude phi."""
143
+ return self.visibility.cn_squared(phi, self.n_harmonics)
144
+
145
+ # ── Single-latitude kernel ──────────────────────────────────────────────
146
+
147
+ def kernel_single_latitude(self, lag, phi):
148
+ """Single-spot kernel at a fixed latitude."""
149
+ lag = jnp.asarray(lag, dtype=float).ravel()
150
+ R = self.R_Gamma(lag)
151
+ cn_sq = self.cn_squared(phi)
152
+ w0 = self.omega0(phi)
153
+
154
+ ns = jnp.arange(1, len(cn_sq))
155
+ cosine_terms = jnp.sum(
156
+ cn_sq[1:] * jnp.cos(ns * w0 * lag[:, None]), axis=1)
157
+ return R * (cn_sq[0] + 2 * cosine_terms)
158
+
159
+ # ── Full kernel (latitude-averaged) ────────────────────────────────────
160
+
161
+ def kernel(self, lag, lat_dist=None):
162
+ """
163
+ Full GP kernel averaged over latitude.
164
+
165
+ Uses jax.lax.scan for memory-efficient accumulation: only one
166
+ lag-sized buffer is live at a time — O(M) instead of O(n_lat·M).
167
+
168
+ When the visibility function is an EdgeOnVisibilityFunction, the
169
+ latitude-averaged \|c_n\|^2 are known constants and the latitude
170
+ loop is bypassed entirely.
171
+
172
+ Parameters
173
+ ----------
174
+ lag : array_like
175
+ Time lags [days]. Can be 1D or 2D.
176
+ lat_dist : callable or None
177
+ Latitude probability density. If None, uniform.
178
+
179
+ Returns
180
+ -------
181
+ K : ndarray, same shape as lag input.
182
+ """
183
+ lag = jnp.asarray(lag, dtype=float)
184
+ orig_shape = lag.shape
185
+ lag_flat = lag.ravel()
186
+
187
+ # Fast path: EdgeOnVisibilityFunction has closed-form latitude-
188
+ # averaged |c_n|^2, so no quadrature loop is needed.
189
+ if isinstance(self.visibility, EdgeOnVisibilityFunction):
190
+ R = self.R_Gamma(lag_flat)
191
+ cn_sq = self.visibility.cn_squared(0.0, self.n_harmonics)
192
+ w0 = self.visibility.omega0(0.0)
193
+ ns = jnp.arange(1, self.n_harmonics + 1)
194
+ cosine_terms = jnp.sum(
195
+ cn_sq[1:] * jnp.cos(ns * w0 * lag_flat[:, None]), axis=1)
196
+ K = self.sigma_k ** 2 * R * (cn_sq[0] + 2 * cosine_terms)
197
+ return np.asarray(K.reshape(orig_shape))
198
+
199
+ if lat_dist is None:
200
+ lat_dist = self.spot_model.latitude_distribution
201
+
202
+ R = self.R_Gamma(lag_flat)
203
+ n_harmonics = self.n_harmonics
204
+
205
+ def _lat_contribution(phi):
206
+ cn_sq = self.cn_squared(phi)
207
+ w0 = self.omega0(phi)
208
+ ns = jnp.arange(1, n_harmonics + 1)
209
+ cosine_terms = jnp.sum(
210
+ cn_sq[1:] * jnp.cos(ns * w0 * lag_flat[:, None]), axis=1)
211
+ return cn_sq[0] + 2 * cosine_terms
212
+
213
+ if self.quadrature == "gauss-legendre":
214
+ phi_grid = self._quad_nodes
215
+ quad_weights = self._quad_weights
216
+ user_weights = jnp.array([lat_dist(float(p)) for p in phi_grid])
217
+ weights = user_weights * quad_weights
218
+ norm = jnp.sum(weights)
219
+ else:
220
+ phi_min, phi_max = self.lat_range
221
+ phi_grid = jnp.linspace(phi_min, phi_max, self.n_lat)
222
+ dphi = phi_grid[1] - phi_grid[0]
223
+ user_weights = jnp.array([lat_dist(float(p)) for p in phi_grid])
224
+ weights = user_weights * dphi
225
+ norm = jnp.trapezoid(user_weights, phi_grid)
226
+
227
+ def _scan_body(K_acc, idx):
228
+ phi = phi_grid[idx]
229
+ w = weights[idx]
230
+ return K_acc + w * _lat_contribution(phi), None
231
+
232
+ K, _ = jax.lax.scan(
233
+ _scan_body, jnp.zeros_like(lag_flat), jnp.arange(len(phi_grid)))
234
+ K = K / norm
235
+ K = R * K * self.sigma_k ** 2
236
+
237
+ return np.asarray(K.reshape(orig_shape))
238
+
239
+ def kernel_solid_body(self, lag, lat_dist=None):
240
+ """Kernel for solid-body rotation (kappa=0)."""
241
+ lag = jnp.asarray(lag, dtype=float)
242
+
243
+ if lat_dist is None:
244
+ lat_dist = self.spot_model.latitude_distribution
245
+
246
+ if self.quadrature == "gauss-legendre":
247
+ phi_grid = self._quad_nodes
248
+ quad_weights = self._quad_weights
249
+ all_cn_sq = jax.vmap(
250
+ lambda phi: _cn_squared_coefficients_jax(
251
+ self.inc, phi, self.n_harmonics))(phi_grid)
252
+ user_weights = jnp.array([lat_dist(float(p)) for p in phi_grid])
253
+ norm = jnp.sum(user_weights * quad_weights)
254
+ cn_sq_avg = jnp.sum(
255
+ user_weights[:, None] * quad_weights[:, None] * all_cn_sq,
256
+ axis=0) / norm
257
+ else:
258
+ phi_min, phi_max = self.lat_range
259
+ phi_grid = jnp.linspace(phi_min, phi_max, self.n_lat)
260
+ user_weights = jnp.array([lat_dist(float(p)) for p in phi_grid])
261
+ norm = jnp.trapezoid(user_weights, phi_grid)
262
+ all_cn_sq = jax.vmap(
263
+ lambda phi: _cn_squared_coefficients_jax(
264
+ self.inc, phi, self.n_harmonics))(phi_grid)
265
+ cn_sq_avg = jnp.sum(
266
+ user_weights[:, None] * all_cn_sq, axis=0
267
+ ) * (phi_grid[1] - phi_grid[0]) / norm
268
+
269
+ w0 = 2 * jnp.pi / self.peq
270
+ R = self.R_Gamma(lag)
271
+ ns = jnp.arange(1, len(cn_sq_avg))
272
+ cosine_terms = jnp.sum(
273
+ cn_sq_avg[1:] * jnp.cos(ns * w0 * lag[:, None]), axis=1)
274
+ return np.asarray(R * (cn_sq_avg[0] + 2 * cosine_terms) * self.sigma_k ** 2)
275
+
276
+ # ── Power spectral density ──────────────────────────────────────────────
277
+
278
+ def compute_psd(self, omega, lat_dist=None):
279
+ """
280
+ Analytic power spectral density.
281
+
282
+ Parameters
283
+ ----------
284
+ omega : array_like
285
+ Angular frequencies [rad/day].
286
+ lat_dist : callable or None
287
+ Latitude probability density.
288
+
289
+ Returns
290
+ -------
291
+ freq : ndarray [cycles/day]
292
+ power : ndarray
293
+ """
294
+ omega = jnp.asarray(omega, dtype=float)
295
+
296
+ if lat_dist is None:
297
+ lat_dist = self.spot_model.latitude_distribution
298
+
299
+ # Build the per-latitude PSD contribution based on envelope type
300
+ if isinstance(self.envelope, (SkewedGaussianEnvelope, ExponentialEnvelope)):
301
+ # Use envelope's Gamma_hat_sq directly
302
+ def _psd_at_lat(phi):
303
+ cn_sq = self.cn_squared(phi)
304
+ w0 = self.omega0(phi)
305
+
306
+ contrib = cn_sq[0] * self.envelope.Gamma_hat_sq(omega)
307
+
308
+ def _harmonic(n):
309
+ return cn_sq[n] * (
310
+ self.envelope.Gamma_hat_sq(omega - n * w0)
311
+ + self.envelope.Gamma_hat_sq(omega + n * w0))
312
+
313
+ ns = jnp.arange(1, len(cn_sq))
314
+ harmonic_contribs = jax.vmap(lambda n: _harmonic(n))(ns)
315
+ return contrib + jnp.sum(harmonic_contribs, axis=0)
316
+
317
+ else:
318
+ # Trapezoid types use the closed-form _Gamma_hat
319
+ def _psd_at_lat(phi):
320
+ cn_sq = self.cn_squared(phi)
321
+ w0 = self.omega0(phi)
322
+
323
+ Gh_0 = self.envelope.Gamma_hat(omega)
324
+ contrib = cn_sq[0] * Gh_0 ** 2
325
+
326
+ def _harmonic(n):
327
+ Gh_p = self.envelope.Gamma_hat(omega - n * w0)
328
+ Gh_m = self.envelope.Gamma_hat(omega + n * w0)
329
+ return cn_sq[n] * (Gh_p ** 2 + Gh_m ** 2)
330
+
331
+ ns = jnp.arange(1, len(cn_sq))
332
+ harmonic_contribs = jax.vmap(lambda n: _harmonic(n))(ns)
333
+ return contrib + jnp.sum(harmonic_contribs, axis=0)
334
+
335
+ if self.quadrature == "gauss-legendre":
336
+ phi_grid = self._quad_nodes
337
+ quad_weights = self._quad_weights
338
+ all_contribs = jax.vmap(_psd_at_lat)(phi_grid)
339
+ user_weights = jnp.array([lat_dist(float(p)) for p in phi_grid])
340
+ norm = jnp.sum(user_weights * quad_weights)
341
+ psd = jnp.sum(
342
+ user_weights[:, None] * quad_weights[:, None]
343
+ * all_contribs, axis=0) / norm
344
+ else:
345
+ phi_min, phi_max = self.lat_range
346
+ phi_grid = jnp.linspace(phi_min, phi_max, self.n_lat)
347
+ dphi = phi_grid[1] - phi_grid[0]
348
+ user_weights = jnp.array([lat_dist(float(p)) for p in phi_grid])
349
+ norm = jnp.trapezoid(user_weights, phi_grid)
350
+ all_contribs = jax.vmap(_psd_at_lat)(phi_grid)
351
+ psd = jnp.sum(user_weights[:, None] * all_contribs, axis=0) * dphi / norm
352
+
353
+ psd = psd * self.sigma_k ** 2
354
+
355
+ self.psd_omega = np.asarray(omega)
356
+ self.psd_freq = np.asarray(omega / (2 * jnp.pi))
357
+ self.psd_power = np.asarray(psd)
358
+
359
+ return self.psd_freq, self.psd_power
360
+
361
+ def build_jax(self, n_lag=256):
362
+ """
363
+ Pre-compile and warm up JAX JIT computation for this kernel.
364
+
365
+ ``jax.lax.scan`` (used inside ``kernel()``) triggers XLA compilation
366
+ on its first call for a given array shape. That compilation can take
367
+ several seconds and is easy to mistake for slow runtime. Call
368
+ ``build_jax()`` once after constructing the kernel to pay that cost
369
+ upfront — subsequent calls to ``kernel()`` and ``compute_psd()`` with
370
+ the same shape will be fast.
371
+
372
+ Parameters
373
+ ----------
374
+ n_lag : int
375
+ Length of the dummy lag array used to drive compilation (default
376
+ 256). The actual value does not matter as long as it is
377
+ representative of the sizes you will use at runtime.
378
+
379
+ Returns
380
+ -------
381
+ self : AnalyticKernel
382
+ Returns ``self`` so the call can be chained:
383
+ ``ak = AnalyticKernel(model).build_jax()``.
384
+ """
385
+ import time
386
+
387
+ dummy_lag = jnp.linspace(0.0, float(self.peq) * 3.0, n_lag)
388
+ dummy_omega = jnp.linspace(0.0, 4.0 * float(np.pi / self.peq), n_lag)
389
+
390
+ t0 = time.time()
391
+ jax.block_until_ready(self.kernel(dummy_lag))
392
+ jax.block_until_ready(self.compute_psd(dummy_omega))
393
+ print(f"JAX kernel compiled in {np.round(time.time() - t0, 2)}s")
394
+
395
+ t0 = time.time()
396
+ jax.block_until_ready(self.kernel(dummy_lag))
397
+ jax.block_until_ready(self.compute_psd(dummy_omega))
398
+ print(f"JAX kernel recompute in {np.round(time.time() - t0, 2)}s")
399
+
400
+ return self
401
+
402
+ def __call__(self, lag, **kwargs):
403
+ """Evaluate the kernel at the given lags."""
404
+ return self.kernel(lag, **kwargs)