spotgp 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spotgp-0.1.0/PKG-INFO +50 -0
- spotgp-0.1.0/README.md +25 -0
- spotgp-0.1.0/pyproject.toml +36 -0
- spotgp-0.1.0/setup.cfg +4 -0
- spotgp-0.1.0/spotgp.egg-info/PKG-INFO +50 -0
- spotgp-0.1.0/spotgp.egg-info/SOURCES.txt +31 -0
- spotgp-0.1.0/spotgp.egg-info/dependency_links.txt +1 -0
- spotgp-0.1.0/spotgp.egg-info/requires.txt +18 -0
- spotgp-0.1.0/spotgp.egg-info/top_level.txt +1 -0
- spotgp-0.1.0/src/__init__.py +11 -0
- spotgp-0.1.0/src/analytic_kernel.py +404 -0
- spotgp-0.1.0/src/banded_cholesky.py +275 -0
- spotgp-0.1.0/src/envelope.py +1013 -0
- spotgp-0.1.0/src/gp_solver.py +2597 -0
- spotgp-0.1.0/src/latitude.py +138 -0
- spotgp-0.1.0/src/lightcurve.py +917 -0
- spotgp-0.1.0/src/mcmc.py +901 -0
- spotgp-0.1.0/src/numerical_kernel.py +233 -0
- spotgp-0.1.0/src/params.py +383 -0
- spotgp-0.1.0/src/plotting.py +134 -0
- spotgp-0.1.0/src/psd.py +72 -0
- spotgp-0.1.0/src/spot_model.py +469 -0
- spotgp-0.1.0/src/visibility.py +496 -0
- spotgp-0.1.0/tests/test_analytic_kernel.py +135 -0
- spotgp-0.1.0/tests/test_banded_cholesky.py +95 -0
- spotgp-0.1.0/tests/test_envelope.py +211 -0
- spotgp-0.1.0/tests/test_gp_solver.py +111 -0
- spotgp-0.1.0/tests/test_lightcurve.py +77 -0
- spotgp-0.1.0/tests/test_mcmc.py +47 -0
- spotgp-0.1.0/tests/test_numerical_kernel.py +62 -0
- spotgp-0.1.0/tests/test_params.py +70 -0
- spotgp-0.1.0/tests/test_psd.py +53 -0
- spotgp-0.1.0/tests/test_spot_model.py +211 -0
spotgp-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: spotgp
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Gaussian Process kernels for stellar variability from starspot models
|
|
5
|
+
Author: Jessica Birky
|
|
6
|
+
License: MIT
|
|
7
|
+
Requires-Python: >=3.8
|
|
8
|
+
Description-Content-Type: text/markdown
|
|
9
|
+
Requires-Dist: numpy
|
|
10
|
+
Requires-Dist: scipy
|
|
11
|
+
Requires-Dist: matplotlib
|
|
12
|
+
Requires-Dist: astropy
|
|
13
|
+
Requires-Dist: scikit-learn
|
|
14
|
+
Requires-Dist: tqdm
|
|
15
|
+
Provides-Extra: jax
|
|
16
|
+
Requires-Dist: jax; extra == "jax"
|
|
17
|
+
Requires-Dist: jaxopt; extra == "jax"
|
|
18
|
+
Provides-Extra: docs
|
|
19
|
+
Requires-Dist: sphinx>=7.0; extra == "docs"
|
|
20
|
+
Requires-Dist: sphinx-book-theme>=1.0; extra == "docs"
|
|
21
|
+
Requires-Dist: sphinx-copybutton>=0.5; extra == "docs"
|
|
22
|
+
Requires-Dist: myst-nb>=1.0; extra == "docs"
|
|
23
|
+
Requires-Dist: sphinxcontrib-mermaid>=0.9; extra == "docs"
|
|
24
|
+
Requires-Dist: pygments-styles>=0.3; extra == "docs"
|
|
25
|
+
|
|
26
|
+
# `spotgp`
|
|
27
|
+
|
|
28
|
+
[](https://github.com/jbirky/spotgp/actions/workflows/tests.yml)
|
|
29
|
+
[](https://codecov.io/gh/jbirky/spotgp)
|
|
30
|
+
[](https://spotgp.readthedocs.io/en/latest/?badge=latest)
|
|
31
|
+
|
|
32
|
+
**`spotgp`**: Gaussian Process kernels for stellar starspot variability implemented in `JAX`.
|
|
33
|
+
|
|
34
|
+
<br>
|
|
35
|
+
|
|
36
|
+

|
|
37
|
+
|
|
38
|
+
## Installation
|
|
39
|
+
|
|
40
|
+
```bash
|
|
41
|
+
git clone https://github.com/jbirky/spotgp.git
|
|
42
|
+
cd spotgp
|
|
43
|
+
pip install -e .
|
|
44
|
+
```
|
|
45
|
+
|
|
46
|
+
For JAX acceleration:
|
|
47
|
+
|
|
48
|
+
```bash
|
|
49
|
+
pip install -e ".[jax]"
|
|
50
|
+
```
|
spotgp-0.1.0/README.md
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
# `spotgp`
|
|
2
|
+
|
|
3
|
+
[](https://github.com/jbirky/spotgp/actions/workflows/tests.yml)
|
|
4
|
+
[](https://codecov.io/gh/jbirky/spotgp)
|
|
5
|
+
[](https://spotgp.readthedocs.io/en/latest/?badge=latest)
|
|
6
|
+
|
|
7
|
+
**`spotgp`**: Gaussian Process kernels for stellar starspot variability implemented in `JAX`.
|
|
8
|
+
|
|
9
|
+
<br>
|
|
10
|
+
|
|
11
|
+

|
|
12
|
+
|
|
13
|
+
## Installation
|
|
14
|
+
|
|
15
|
+
```bash
|
|
16
|
+
git clone https://github.com/jbirky/spotgp.git
|
|
17
|
+
cd spotgp
|
|
18
|
+
pip install -e .
|
|
19
|
+
```
|
|
20
|
+
|
|
21
|
+
For JAX acceleration:
|
|
22
|
+
|
|
23
|
+
```bash
|
|
24
|
+
pip install -e ".[jax]"
|
|
25
|
+
```
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=64", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "spotgp"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "Gaussian Process kernels for stellar variability from starspot models"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.8"
|
|
11
|
+
license = {text = "MIT"}
|
|
12
|
+
authors = [
|
|
13
|
+
{name = "Jessica Birky"},
|
|
14
|
+
]
|
|
15
|
+
dependencies = [
|
|
16
|
+
"numpy",
|
|
17
|
+
"scipy",
|
|
18
|
+
"matplotlib",
|
|
19
|
+
"astropy",
|
|
20
|
+
"scikit-learn",
|
|
21
|
+
"tqdm",
|
|
22
|
+
]
|
|
23
|
+
|
|
24
|
+
[project.optional-dependencies]
|
|
25
|
+
jax = ["jax", "jaxopt"]
|
|
26
|
+
docs = [
|
|
27
|
+
"sphinx>=7.0",
|
|
28
|
+
"sphinx-book-theme>=1.0",
|
|
29
|
+
"sphinx-copybutton>=0.5",
|
|
30
|
+
"myst-nb>=1.0",
|
|
31
|
+
"sphinxcontrib-mermaid>=0.9",
|
|
32
|
+
"pygments-styles>=0.3",
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
[tool.setuptools.packages.find]
|
|
36
|
+
include = ["src*"]
|
spotgp-0.1.0/setup.cfg
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: spotgp
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Gaussian Process kernels for stellar variability from starspot models
|
|
5
|
+
Author: Jessica Birky
|
|
6
|
+
License: MIT
|
|
7
|
+
Requires-Python: >=3.8
|
|
8
|
+
Description-Content-Type: text/markdown
|
|
9
|
+
Requires-Dist: numpy
|
|
10
|
+
Requires-Dist: scipy
|
|
11
|
+
Requires-Dist: matplotlib
|
|
12
|
+
Requires-Dist: astropy
|
|
13
|
+
Requires-Dist: scikit-learn
|
|
14
|
+
Requires-Dist: tqdm
|
|
15
|
+
Provides-Extra: jax
|
|
16
|
+
Requires-Dist: jax; extra == "jax"
|
|
17
|
+
Requires-Dist: jaxopt; extra == "jax"
|
|
18
|
+
Provides-Extra: docs
|
|
19
|
+
Requires-Dist: sphinx>=7.0; extra == "docs"
|
|
20
|
+
Requires-Dist: sphinx-book-theme>=1.0; extra == "docs"
|
|
21
|
+
Requires-Dist: sphinx-copybutton>=0.5; extra == "docs"
|
|
22
|
+
Requires-Dist: myst-nb>=1.0; extra == "docs"
|
|
23
|
+
Requires-Dist: sphinxcontrib-mermaid>=0.9; extra == "docs"
|
|
24
|
+
Requires-Dist: pygments-styles>=0.3; extra == "docs"
|
|
25
|
+
|
|
26
|
+
# `spotgp`
|
|
27
|
+
|
|
28
|
+
[](https://github.com/jbirky/spotgp/actions/workflows/tests.yml)
|
|
29
|
+
[](https://codecov.io/gh/jbirky/spotgp)
|
|
30
|
+
[](https://spotgp.readthedocs.io/en/latest/?badge=latest)
|
|
31
|
+
|
|
32
|
+
**`spotgp`**: Gaussian Process kernels for stellar starspot variability implemented in `JAX`.
|
|
33
|
+
|
|
34
|
+
<br>
|
|
35
|
+
|
|
36
|
+

|
|
37
|
+
|
|
38
|
+
## Installation
|
|
39
|
+
|
|
40
|
+
```bash
|
|
41
|
+
git clone https://github.com/jbirky/spotgp.git
|
|
42
|
+
cd spotgp
|
|
43
|
+
pip install -e .
|
|
44
|
+
```
|
|
45
|
+
|
|
46
|
+
For JAX acceleration:
|
|
47
|
+
|
|
48
|
+
```bash
|
|
49
|
+
pip install -e ".[jax]"
|
|
50
|
+
```
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
README.md
|
|
2
|
+
pyproject.toml
|
|
3
|
+
spotgp.egg-info/PKG-INFO
|
|
4
|
+
spotgp.egg-info/SOURCES.txt
|
|
5
|
+
spotgp.egg-info/dependency_links.txt
|
|
6
|
+
spotgp.egg-info/requires.txt
|
|
7
|
+
spotgp.egg-info/top_level.txt
|
|
8
|
+
src/__init__.py
|
|
9
|
+
src/analytic_kernel.py
|
|
10
|
+
src/banded_cholesky.py
|
|
11
|
+
src/envelope.py
|
|
12
|
+
src/gp_solver.py
|
|
13
|
+
src/latitude.py
|
|
14
|
+
src/lightcurve.py
|
|
15
|
+
src/mcmc.py
|
|
16
|
+
src/numerical_kernel.py
|
|
17
|
+
src/params.py
|
|
18
|
+
src/plotting.py
|
|
19
|
+
src/psd.py
|
|
20
|
+
src/spot_model.py
|
|
21
|
+
src/visibility.py
|
|
22
|
+
tests/test_analytic_kernel.py
|
|
23
|
+
tests/test_banded_cholesky.py
|
|
24
|
+
tests/test_envelope.py
|
|
25
|
+
tests/test_gp_solver.py
|
|
26
|
+
tests/test_lightcurve.py
|
|
27
|
+
tests/test_mcmc.py
|
|
28
|
+
tests/test_numerical_kernel.py
|
|
29
|
+
tests/test_params.py
|
|
30
|
+
tests/test_psd.py
|
|
31
|
+
tests/test_spot_model.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
src
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from .envelope import *
|
|
2
|
+
from .latitude import *
|
|
3
|
+
from .visibility import *
|
|
4
|
+
from .spot_model import *
|
|
5
|
+
from .lightcurve import *
|
|
6
|
+
from .analytic_kernel import *
|
|
7
|
+
from .numerical_kernel import *
|
|
8
|
+
from .psd import *
|
|
9
|
+
from .gp_solver import *
|
|
10
|
+
from .mcmc import *
|
|
11
|
+
from .plotting import *
|
|
@@ -0,0 +1,404 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import jax.numpy as jnp
|
|
3
|
+
import numpy as np
|
|
4
|
+
from functools import partial
|
|
5
|
+
|
|
6
|
+
try:
|
|
7
|
+
from .params import resolve_hparam
|
|
8
|
+
from .envelope import (
|
|
9
|
+
EnvelopeFunction,
|
|
10
|
+
TrapezoidAsymmetricEnvelope,
|
|
11
|
+
SkewedGaussianEnvelope,
|
|
12
|
+
ExponentialEnvelope,
|
|
13
|
+
compute_R_Gamma_numerical,
|
|
14
|
+
)
|
|
15
|
+
from .spot_model import (
|
|
16
|
+
VisibilityFunction, EdgeOnVisibilityFunction, SpotEvolutionModel,
|
|
17
|
+
_cn_squared_coefficients_jax, _gauss_legendre_grid,
|
|
18
|
+
)
|
|
19
|
+
except ImportError:
|
|
20
|
+
from params import resolve_hparam
|
|
21
|
+
from envelope import (
|
|
22
|
+
EnvelopeFunction,
|
|
23
|
+
TrapezoidAsymmetricEnvelope,
|
|
24
|
+
SkewedGaussianEnvelope,
|
|
25
|
+
ExponentialEnvelope,
|
|
26
|
+
compute_R_Gamma_numerical,
|
|
27
|
+
)
|
|
28
|
+
from spot_model import (
|
|
29
|
+
VisibilityFunction, EdgeOnVisibilityFunction, SpotEvolutionModel,
|
|
30
|
+
_cn_squared_coefficients_jax, _gauss_legendre_grid,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
__all__ = ["AnalyticKernel", "compute_R_Gamma_numerical"]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class AnalyticKernel:
|
|
37
|
+
"""
|
|
38
|
+
JAX-accelerated analytic GP kernel for stellar rotation variability.
|
|
39
|
+
|
|
40
|
+
Parameters
|
|
41
|
+
----------
|
|
42
|
+
model_or_hparam : SpotEvolutionModel or dict
|
|
43
|
+
Either a SpotEvolutionModel instance (new API) or a raw hparam dict
|
|
44
|
+
(backward-compatible old API).
|
|
45
|
+
n_harmonics : int
|
|
46
|
+
Number of Fourier harmonics for the visibility function (default 3).
|
|
47
|
+
n_lat : int
|
|
48
|
+
Number of latitude quadrature points (default 64).
|
|
49
|
+
lat_range : tuple
|
|
50
|
+
(min, max) latitude in radians (default (-pi/2, pi/2)).
|
|
51
|
+
quadrature : str
|
|
52
|
+
Latitude integration method: "trapezoid" or "gauss-legendre".
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(self, model_or_hparam, n_harmonics=3, n_lat=64,
|
|
56
|
+
lat_range=None, quadrature="trapezoid"):
|
|
57
|
+
|
|
58
|
+
# ── Accept SpotEvolutionModel or legacy hparam dict ────────────────
|
|
59
|
+
if isinstance(model_or_hparam, SpotEvolutionModel):
|
|
60
|
+
self.spot_model = model_or_hparam
|
|
61
|
+
self.hparam = model_or_hparam.to_hparam()
|
|
62
|
+
else:
|
|
63
|
+
# Backward compat: dict input
|
|
64
|
+
self.hparam = resolve_hparam(model_or_hparam)
|
|
65
|
+
self.spot_model = SpotEvolutionModel.from_hparam(self.hparam)
|
|
66
|
+
|
|
67
|
+
# ── Unpack commonly-used params ────────────────────────────────────
|
|
68
|
+
self.envelope = self.spot_model.envelope
|
|
69
|
+
self.visibility = self.spot_model.visibility
|
|
70
|
+
|
|
71
|
+
self.peq = self.spot_model.peq
|
|
72
|
+
self.kappa = self.spot_model.kappa
|
|
73
|
+
self.inc = self.spot_model.inc
|
|
74
|
+
self.lspot = self.spot_model.lspot
|
|
75
|
+
self.sigma_k = self.spot_model.sigma_k
|
|
76
|
+
self.tau_spot = self.spot_model.tau_spot
|
|
77
|
+
|
|
78
|
+
# ── Envelope-type attributes (backward compat) ────────────────────
|
|
79
|
+
if isinstance(self.envelope, SkewedGaussianEnvelope):
|
|
80
|
+
self.envelope_type = "skew_normal"
|
|
81
|
+
self.sigma_sn = self.envelope.sigma_sn
|
|
82
|
+
self.n_sn = self.envelope.n_sn
|
|
83
|
+
self.tau_em = self.tau_spot
|
|
84
|
+
self.tau_dec = self.tau_spot
|
|
85
|
+
self.asymmetric = False
|
|
86
|
+
# Re-use grids from the envelope object
|
|
87
|
+
self._R_Gamma_lag_grid = self.envelope._R_lag_grid
|
|
88
|
+
self._R_Gamma_vals = self.envelope._R_vals
|
|
89
|
+
self._Gh_sq_omega_grid = self.envelope._Gh_omega_grid
|
|
90
|
+
self._Gh_sq_vals = self.envelope._Gh_sq_vals
|
|
91
|
+
|
|
92
|
+
elif isinstance(self.envelope, TrapezoidAsymmetricEnvelope):
|
|
93
|
+
self.envelope_type = "trapezoid_asymmetric"
|
|
94
|
+
self.asymmetric = True
|
|
95
|
+
self.tau_em = self.envelope.tau_em
|
|
96
|
+
self.tau_dec = self.envelope.tau_dec
|
|
97
|
+
self._te = min(self.tau_em, self.tau_dec)
|
|
98
|
+
self._td = max(self.tau_em, self.tau_dec)
|
|
99
|
+
|
|
100
|
+
elif isinstance(self.envelope, ExponentialEnvelope):
|
|
101
|
+
self.envelope_type = "exponential"
|
|
102
|
+
self.asymmetric = False
|
|
103
|
+
self.tau_em = self.tau_spot
|
|
104
|
+
self.tau_dec = self.tau_spot
|
|
105
|
+
|
|
106
|
+
else:
|
|
107
|
+
# Default: symmetric trapezoid (or any other future type)
|
|
108
|
+
self.envelope_type = "trapezoid_symmetric"
|
|
109
|
+
self.asymmetric = False
|
|
110
|
+
self.tau_em = self.tau_spot
|
|
111
|
+
self.tau_dec = self.tau_spot
|
|
112
|
+
|
|
113
|
+
# ── Kernel config ──────────────────────────────────────────────────
|
|
114
|
+
self.n_harmonics = n_harmonics
|
|
115
|
+
self.n_lat = n_lat
|
|
116
|
+
self.lat_range = (lat_range if lat_range is not None
|
|
117
|
+
else self.spot_model.latitude_distribution.lat_range)
|
|
118
|
+
self.quadrature = quadrature
|
|
119
|
+
|
|
120
|
+
if quadrature == "gauss-legendre":
|
|
121
|
+
self._quad_nodes, self._quad_weights = _gauss_legendre_grid(
|
|
122
|
+
n_lat, lat_range[0], lat_range[1])
|
|
123
|
+
elif quadrature == "trapezoid":
|
|
124
|
+
self._quad_nodes = None
|
|
125
|
+
self._quad_weights = None
|
|
126
|
+
else:
|
|
127
|
+
raise ValueError(
|
|
128
|
+
f"Unknown quadrature method: {quadrature!r}. "
|
|
129
|
+
"Use 'trapezoid' or 'gauss-legendre'.")
|
|
130
|
+
|
|
131
|
+
# ── Core kernel helpers ─────────────────────────────────────────────────
|
|
132
|
+
|
|
133
|
+
def omega0(self, phi):
|
|
134
|
+
"""Latitude-dependent rotation angular frequency [rad/day]."""
|
|
135
|
+
return self.visibility.omega0(phi)
|
|
136
|
+
|
|
137
|
+
def R_Gamma(self, lag):
|
|
138
|
+
"""Autocorrelation of the squared envelope (delegates to envelope)."""
|
|
139
|
+
return self.envelope.R_Gamma(jnp.asarray(lag))
|
|
140
|
+
|
|
141
|
+
def cn_squared(self, phi):
|
|
142
|
+
"""Squared Fourier visibility coefficients at latitude phi."""
|
|
143
|
+
return self.visibility.cn_squared(phi, self.n_harmonics)
|
|
144
|
+
|
|
145
|
+
# ── Single-latitude kernel ──────────────────────────────────────────────
|
|
146
|
+
|
|
147
|
+
def kernel_single_latitude(self, lag, phi):
|
|
148
|
+
"""Single-spot kernel at a fixed latitude."""
|
|
149
|
+
lag = jnp.asarray(lag, dtype=float).ravel()
|
|
150
|
+
R = self.R_Gamma(lag)
|
|
151
|
+
cn_sq = self.cn_squared(phi)
|
|
152
|
+
w0 = self.omega0(phi)
|
|
153
|
+
|
|
154
|
+
ns = jnp.arange(1, len(cn_sq))
|
|
155
|
+
cosine_terms = jnp.sum(
|
|
156
|
+
cn_sq[1:] * jnp.cos(ns * w0 * lag[:, None]), axis=1)
|
|
157
|
+
return R * (cn_sq[0] + 2 * cosine_terms)
|
|
158
|
+
|
|
159
|
+
# ── Full kernel (latitude-averaged) ────────────────────────────────────
|
|
160
|
+
|
|
161
|
+
def kernel(self, lag, lat_dist=None):
|
|
162
|
+
"""
|
|
163
|
+
Full GP kernel averaged over latitude.
|
|
164
|
+
|
|
165
|
+
Uses jax.lax.scan for memory-efficient accumulation: only one
|
|
166
|
+
lag-sized buffer is live at a time — O(M) instead of O(n_lat·M).
|
|
167
|
+
|
|
168
|
+
When the visibility function is an EdgeOnVisibilityFunction, the
|
|
169
|
+
latitude-averaged \|c_n\|^2 are known constants and the latitude
|
|
170
|
+
loop is bypassed entirely.
|
|
171
|
+
|
|
172
|
+
Parameters
|
|
173
|
+
----------
|
|
174
|
+
lag : array_like
|
|
175
|
+
Time lags [days]. Can be 1D or 2D.
|
|
176
|
+
lat_dist : callable or None
|
|
177
|
+
Latitude probability density. If None, uniform.
|
|
178
|
+
|
|
179
|
+
Returns
|
|
180
|
+
-------
|
|
181
|
+
K : ndarray, same shape as lag input.
|
|
182
|
+
"""
|
|
183
|
+
lag = jnp.asarray(lag, dtype=float)
|
|
184
|
+
orig_shape = lag.shape
|
|
185
|
+
lag_flat = lag.ravel()
|
|
186
|
+
|
|
187
|
+
# Fast path: EdgeOnVisibilityFunction has closed-form latitude-
|
|
188
|
+
# averaged |c_n|^2, so no quadrature loop is needed.
|
|
189
|
+
if isinstance(self.visibility, EdgeOnVisibilityFunction):
|
|
190
|
+
R = self.R_Gamma(lag_flat)
|
|
191
|
+
cn_sq = self.visibility.cn_squared(0.0, self.n_harmonics)
|
|
192
|
+
w0 = self.visibility.omega0(0.0)
|
|
193
|
+
ns = jnp.arange(1, self.n_harmonics + 1)
|
|
194
|
+
cosine_terms = jnp.sum(
|
|
195
|
+
cn_sq[1:] * jnp.cos(ns * w0 * lag_flat[:, None]), axis=1)
|
|
196
|
+
K = self.sigma_k ** 2 * R * (cn_sq[0] + 2 * cosine_terms)
|
|
197
|
+
return np.asarray(K.reshape(orig_shape))
|
|
198
|
+
|
|
199
|
+
if lat_dist is None:
|
|
200
|
+
lat_dist = self.spot_model.latitude_distribution
|
|
201
|
+
|
|
202
|
+
R = self.R_Gamma(lag_flat)
|
|
203
|
+
n_harmonics = self.n_harmonics
|
|
204
|
+
|
|
205
|
+
def _lat_contribution(phi):
|
|
206
|
+
cn_sq = self.cn_squared(phi)
|
|
207
|
+
w0 = self.omega0(phi)
|
|
208
|
+
ns = jnp.arange(1, n_harmonics + 1)
|
|
209
|
+
cosine_terms = jnp.sum(
|
|
210
|
+
cn_sq[1:] * jnp.cos(ns * w0 * lag_flat[:, None]), axis=1)
|
|
211
|
+
return cn_sq[0] + 2 * cosine_terms
|
|
212
|
+
|
|
213
|
+
if self.quadrature == "gauss-legendre":
|
|
214
|
+
phi_grid = self._quad_nodes
|
|
215
|
+
quad_weights = self._quad_weights
|
|
216
|
+
user_weights = jnp.array([lat_dist(float(p)) for p in phi_grid])
|
|
217
|
+
weights = user_weights * quad_weights
|
|
218
|
+
norm = jnp.sum(weights)
|
|
219
|
+
else:
|
|
220
|
+
phi_min, phi_max = self.lat_range
|
|
221
|
+
phi_grid = jnp.linspace(phi_min, phi_max, self.n_lat)
|
|
222
|
+
dphi = phi_grid[1] - phi_grid[0]
|
|
223
|
+
user_weights = jnp.array([lat_dist(float(p)) for p in phi_grid])
|
|
224
|
+
weights = user_weights * dphi
|
|
225
|
+
norm = jnp.trapezoid(user_weights, phi_grid)
|
|
226
|
+
|
|
227
|
+
def _scan_body(K_acc, idx):
|
|
228
|
+
phi = phi_grid[idx]
|
|
229
|
+
w = weights[idx]
|
|
230
|
+
return K_acc + w * _lat_contribution(phi), None
|
|
231
|
+
|
|
232
|
+
K, _ = jax.lax.scan(
|
|
233
|
+
_scan_body, jnp.zeros_like(lag_flat), jnp.arange(len(phi_grid)))
|
|
234
|
+
K = K / norm
|
|
235
|
+
K = R * K * self.sigma_k ** 2
|
|
236
|
+
|
|
237
|
+
return np.asarray(K.reshape(orig_shape))
|
|
238
|
+
|
|
239
|
+
def kernel_solid_body(self, lag, lat_dist=None):
|
|
240
|
+
"""Kernel for solid-body rotation (kappa=0)."""
|
|
241
|
+
lag = jnp.asarray(lag, dtype=float)
|
|
242
|
+
|
|
243
|
+
if lat_dist is None:
|
|
244
|
+
lat_dist = self.spot_model.latitude_distribution
|
|
245
|
+
|
|
246
|
+
if self.quadrature == "gauss-legendre":
|
|
247
|
+
phi_grid = self._quad_nodes
|
|
248
|
+
quad_weights = self._quad_weights
|
|
249
|
+
all_cn_sq = jax.vmap(
|
|
250
|
+
lambda phi: _cn_squared_coefficients_jax(
|
|
251
|
+
self.inc, phi, self.n_harmonics))(phi_grid)
|
|
252
|
+
user_weights = jnp.array([lat_dist(float(p)) for p in phi_grid])
|
|
253
|
+
norm = jnp.sum(user_weights * quad_weights)
|
|
254
|
+
cn_sq_avg = jnp.sum(
|
|
255
|
+
user_weights[:, None] * quad_weights[:, None] * all_cn_sq,
|
|
256
|
+
axis=0) / norm
|
|
257
|
+
else:
|
|
258
|
+
phi_min, phi_max = self.lat_range
|
|
259
|
+
phi_grid = jnp.linspace(phi_min, phi_max, self.n_lat)
|
|
260
|
+
user_weights = jnp.array([lat_dist(float(p)) for p in phi_grid])
|
|
261
|
+
norm = jnp.trapezoid(user_weights, phi_grid)
|
|
262
|
+
all_cn_sq = jax.vmap(
|
|
263
|
+
lambda phi: _cn_squared_coefficients_jax(
|
|
264
|
+
self.inc, phi, self.n_harmonics))(phi_grid)
|
|
265
|
+
cn_sq_avg = jnp.sum(
|
|
266
|
+
user_weights[:, None] * all_cn_sq, axis=0
|
|
267
|
+
) * (phi_grid[1] - phi_grid[0]) / norm
|
|
268
|
+
|
|
269
|
+
w0 = 2 * jnp.pi / self.peq
|
|
270
|
+
R = self.R_Gamma(lag)
|
|
271
|
+
ns = jnp.arange(1, len(cn_sq_avg))
|
|
272
|
+
cosine_terms = jnp.sum(
|
|
273
|
+
cn_sq_avg[1:] * jnp.cos(ns * w0 * lag[:, None]), axis=1)
|
|
274
|
+
return np.asarray(R * (cn_sq_avg[0] + 2 * cosine_terms) * self.sigma_k ** 2)
|
|
275
|
+
|
|
276
|
+
# ── Power spectral density ──────────────────────────────────────────────
|
|
277
|
+
|
|
278
|
+
def compute_psd(self, omega, lat_dist=None):
|
|
279
|
+
"""
|
|
280
|
+
Analytic power spectral density.
|
|
281
|
+
|
|
282
|
+
Parameters
|
|
283
|
+
----------
|
|
284
|
+
omega : array_like
|
|
285
|
+
Angular frequencies [rad/day].
|
|
286
|
+
lat_dist : callable or None
|
|
287
|
+
Latitude probability density.
|
|
288
|
+
|
|
289
|
+
Returns
|
|
290
|
+
-------
|
|
291
|
+
freq : ndarray [cycles/day]
|
|
292
|
+
power : ndarray
|
|
293
|
+
"""
|
|
294
|
+
omega = jnp.asarray(omega, dtype=float)
|
|
295
|
+
|
|
296
|
+
if lat_dist is None:
|
|
297
|
+
lat_dist = self.spot_model.latitude_distribution
|
|
298
|
+
|
|
299
|
+
# Build the per-latitude PSD contribution based on envelope type
|
|
300
|
+
if isinstance(self.envelope, (SkewedGaussianEnvelope, ExponentialEnvelope)):
|
|
301
|
+
# Use envelope's Gamma_hat_sq directly
|
|
302
|
+
def _psd_at_lat(phi):
|
|
303
|
+
cn_sq = self.cn_squared(phi)
|
|
304
|
+
w0 = self.omega0(phi)
|
|
305
|
+
|
|
306
|
+
contrib = cn_sq[0] * self.envelope.Gamma_hat_sq(omega)
|
|
307
|
+
|
|
308
|
+
def _harmonic(n):
|
|
309
|
+
return cn_sq[n] * (
|
|
310
|
+
self.envelope.Gamma_hat_sq(omega - n * w0)
|
|
311
|
+
+ self.envelope.Gamma_hat_sq(omega + n * w0))
|
|
312
|
+
|
|
313
|
+
ns = jnp.arange(1, len(cn_sq))
|
|
314
|
+
harmonic_contribs = jax.vmap(lambda n: _harmonic(n))(ns)
|
|
315
|
+
return contrib + jnp.sum(harmonic_contribs, axis=0)
|
|
316
|
+
|
|
317
|
+
else:
|
|
318
|
+
# Trapezoid types use the closed-form _Gamma_hat
|
|
319
|
+
def _psd_at_lat(phi):
|
|
320
|
+
cn_sq = self.cn_squared(phi)
|
|
321
|
+
w0 = self.omega0(phi)
|
|
322
|
+
|
|
323
|
+
Gh_0 = self.envelope.Gamma_hat(omega)
|
|
324
|
+
contrib = cn_sq[0] * Gh_0 ** 2
|
|
325
|
+
|
|
326
|
+
def _harmonic(n):
|
|
327
|
+
Gh_p = self.envelope.Gamma_hat(omega - n * w0)
|
|
328
|
+
Gh_m = self.envelope.Gamma_hat(omega + n * w0)
|
|
329
|
+
return cn_sq[n] * (Gh_p ** 2 + Gh_m ** 2)
|
|
330
|
+
|
|
331
|
+
ns = jnp.arange(1, len(cn_sq))
|
|
332
|
+
harmonic_contribs = jax.vmap(lambda n: _harmonic(n))(ns)
|
|
333
|
+
return contrib + jnp.sum(harmonic_contribs, axis=0)
|
|
334
|
+
|
|
335
|
+
if self.quadrature == "gauss-legendre":
|
|
336
|
+
phi_grid = self._quad_nodes
|
|
337
|
+
quad_weights = self._quad_weights
|
|
338
|
+
all_contribs = jax.vmap(_psd_at_lat)(phi_grid)
|
|
339
|
+
user_weights = jnp.array([lat_dist(float(p)) for p in phi_grid])
|
|
340
|
+
norm = jnp.sum(user_weights * quad_weights)
|
|
341
|
+
psd = jnp.sum(
|
|
342
|
+
user_weights[:, None] * quad_weights[:, None]
|
|
343
|
+
* all_contribs, axis=0) / norm
|
|
344
|
+
else:
|
|
345
|
+
phi_min, phi_max = self.lat_range
|
|
346
|
+
phi_grid = jnp.linspace(phi_min, phi_max, self.n_lat)
|
|
347
|
+
dphi = phi_grid[1] - phi_grid[0]
|
|
348
|
+
user_weights = jnp.array([lat_dist(float(p)) for p in phi_grid])
|
|
349
|
+
norm = jnp.trapezoid(user_weights, phi_grid)
|
|
350
|
+
all_contribs = jax.vmap(_psd_at_lat)(phi_grid)
|
|
351
|
+
psd = jnp.sum(user_weights[:, None] * all_contribs, axis=0) * dphi / norm
|
|
352
|
+
|
|
353
|
+
psd = psd * self.sigma_k ** 2
|
|
354
|
+
|
|
355
|
+
self.psd_omega = np.asarray(omega)
|
|
356
|
+
self.psd_freq = np.asarray(omega / (2 * jnp.pi))
|
|
357
|
+
self.psd_power = np.asarray(psd)
|
|
358
|
+
|
|
359
|
+
return self.psd_freq, self.psd_power
|
|
360
|
+
|
|
361
|
+
def build_jax(self, n_lag=256):
|
|
362
|
+
"""
|
|
363
|
+
Pre-compile and warm up JAX JIT computation for this kernel.
|
|
364
|
+
|
|
365
|
+
``jax.lax.scan`` (used inside ``kernel()``) triggers XLA compilation
|
|
366
|
+
on its first call for a given array shape. That compilation can take
|
|
367
|
+
several seconds and is easy to mistake for slow runtime. Call
|
|
368
|
+
``build_jax()`` once after constructing the kernel to pay that cost
|
|
369
|
+
upfront — subsequent calls to ``kernel()`` and ``compute_psd()`` with
|
|
370
|
+
the same shape will be fast.
|
|
371
|
+
|
|
372
|
+
Parameters
|
|
373
|
+
----------
|
|
374
|
+
n_lag : int
|
|
375
|
+
Length of the dummy lag array used to drive compilation (default
|
|
376
|
+
256). The actual value does not matter as long as it is
|
|
377
|
+
representative of the sizes you will use at runtime.
|
|
378
|
+
|
|
379
|
+
Returns
|
|
380
|
+
-------
|
|
381
|
+
self : AnalyticKernel
|
|
382
|
+
Returns ``self`` so the call can be chained:
|
|
383
|
+
``ak = AnalyticKernel(model).build_jax()``.
|
|
384
|
+
"""
|
|
385
|
+
import time
|
|
386
|
+
|
|
387
|
+
dummy_lag = jnp.linspace(0.0, float(self.peq) * 3.0, n_lag)
|
|
388
|
+
dummy_omega = jnp.linspace(0.0, 4.0 * float(np.pi / self.peq), n_lag)
|
|
389
|
+
|
|
390
|
+
t0 = time.time()
|
|
391
|
+
jax.block_until_ready(self.kernel(dummy_lag))
|
|
392
|
+
jax.block_until_ready(self.compute_psd(dummy_omega))
|
|
393
|
+
print(f"JAX kernel compiled in {np.round(time.time() - t0, 2)}s")
|
|
394
|
+
|
|
395
|
+
t0 = time.time()
|
|
396
|
+
jax.block_until_ready(self.kernel(dummy_lag))
|
|
397
|
+
jax.block_until_ready(self.compute_psd(dummy_omega))
|
|
398
|
+
print(f"JAX kernel recompute in {np.round(time.time() - t0, 2)}s")
|
|
399
|
+
|
|
400
|
+
return self
|
|
401
|
+
|
|
402
|
+
def __call__(self, lag, **kwargs):
|
|
403
|
+
"""Evaluate the kernel at the given lags."""
|
|
404
|
+
return self.kernel(lag, **kwargs)
|