LogPSplinePSD 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- log_psplines/__init__.py +0 -0
- log_psplines/_version.py +21 -0
- log_psplines/arviz_utils.py +142 -0
- log_psplines/datatypes.py +83 -0
- log_psplines/example_datasets/__init__.py +1 -0
- log_psplines/example_datasets/ar_data.py +270 -0
- log_psplines/example_datasets/lvk_data.py +177 -0
- log_psplines/initialisation.py +212 -0
- log_psplines/line_locator.py +123 -0
- log_psplines/mcmc.py +143 -0
- log_psplines/parameteric_approximation.py +120 -0
- log_psplines/plotting/__init__.py +5 -0
- log_psplines/plotting/diagnostics.py +73 -0
- log_psplines/plotting/pdgrm.py +81 -0
- log_psplines/plotting/plot_basis.py +50 -0
- log_psplines/plotting/utils.py +118 -0
- log_psplines/psd_diagnostics.py +382 -0
- log_psplines/psplines.py +113 -0
- log_psplines/samplers/__init__.py +5 -0
- log_psplines/samplers/base_sampler.py +165 -0
- log_psplines/samplers/metropolis_hastings.py +602 -0
- log_psplines/samplers/nuts.py +123 -0
- logpsplinepsd-0.0.3.dist-info/METADATA +148 -0
- logpsplinepsd-0.0.3.dist-info/RECORD +27 -0
- logpsplinepsd-0.0.3.dist-info/WHEEL +5 -0
- logpsplinepsd-0.0.3.dist-info/licenses/LICENSE +28 -0
- logpsplinepsd-0.0.3.dist-info/top_level.txt +1 -0
log_psplines/__init__.py
ADDED
|
File without changes
|
log_psplines/_version.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# file generated by setuptools-scm
|
|
2
|
+
# don't change, don't track in version control
|
|
3
|
+
|
|
4
|
+
__all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
|
|
5
|
+
|
|
6
|
+
TYPE_CHECKING = False
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from typing import Tuple
|
|
9
|
+
from typing import Union
|
|
10
|
+
|
|
11
|
+
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
|
12
|
+
else:
|
|
13
|
+
VERSION_TUPLE = object
|
|
14
|
+
|
|
15
|
+
version: str
|
|
16
|
+
__version__: str
|
|
17
|
+
__version_tuple__: VERSION_TUPLE
|
|
18
|
+
version_tuple: VERSION_TUPLE
|
|
19
|
+
|
|
20
|
+
__version__ = version = '0.0.3'
|
|
21
|
+
__version_tuple__ = version_tuple = (0, 0, 3)
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
import arviz as az
|
|
5
|
+
import numpy as np
|
|
6
|
+
from xarray import DataArray, Dataset
|
|
7
|
+
|
|
8
|
+
from .psplines import LogPSplines
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def get_weights(
|
|
12
|
+
idata: az.InferenceData,
|
|
13
|
+
thin: int = 10,
|
|
14
|
+
) -> np.ndarray:
|
|
15
|
+
"""
|
|
16
|
+
Extract weight samples from arviz InferenceData.
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
idata : az.InferenceData
|
|
21
|
+
Inference data containing weight samples
|
|
22
|
+
thin : int
|
|
23
|
+
Thinning factor
|
|
24
|
+
|
|
25
|
+
Returns
|
|
26
|
+
-------
|
|
27
|
+
jnp.ndarray
|
|
28
|
+
Weight samples, shape (n_samples_thinned, n_weights)
|
|
29
|
+
"""
|
|
30
|
+
# Get weight samples and flatten chains
|
|
31
|
+
weight_samples = (
|
|
32
|
+
idata.posterior.weights.values
|
|
33
|
+
) # (chains, draws, n_weights)
|
|
34
|
+
weight_samples = weight_samples.reshape(
|
|
35
|
+
-1, weight_samples.shape[-1]
|
|
36
|
+
) # (chains*draws, n_weights)
|
|
37
|
+
|
|
38
|
+
# Thin samples
|
|
39
|
+
return weight_samples[::thin]
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def get_psd_samples_arviz(
|
|
43
|
+
idata: az.InferenceData, spline_model: LogPSplines, thin: int = 10
|
|
44
|
+
) -> np.ndarray:
|
|
45
|
+
"""
|
|
46
|
+
Extract PSD samples from arviz InferenceData.
|
|
47
|
+
|
|
48
|
+
Parameters
|
|
49
|
+
----------
|
|
50
|
+
idata : az.InferenceData
|
|
51
|
+
Inference data containing weight samples
|
|
52
|
+
spline_model : LogPSplines
|
|
53
|
+
Spline model for reconstruction
|
|
54
|
+
thin : int
|
|
55
|
+
Thinning factor
|
|
56
|
+
|
|
57
|
+
Returns
|
|
58
|
+
-------
|
|
59
|
+
jnp.ndarray
|
|
60
|
+
PSD samples, shape (n_samples_thinned, n_frequencies)
|
|
61
|
+
"""
|
|
62
|
+
# Get weight samples and flatten chains
|
|
63
|
+
weight_samples = get_weights(idata, thin=thin)
|
|
64
|
+
|
|
65
|
+
# Compute PSD samples
|
|
66
|
+
psd_samples = []
|
|
67
|
+
for weights in weight_samples:
|
|
68
|
+
ln_spline = spline_model.basis.T @ weights
|
|
69
|
+
ln_psd = ln_spline + spline_model.log_parametric_model
|
|
70
|
+
psd_samples.append(np.exp(ln_psd))
|
|
71
|
+
|
|
72
|
+
return np.array(psd_samples)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _make_dataset_from_dict(data_dict, coords=None):
|
|
76
|
+
dataset_vars = {}
|
|
77
|
+
for k, v in data_dict.items():
|
|
78
|
+
if (
|
|
79
|
+
isinstance(v, tuple)
|
|
80
|
+
and len(v) == 2
|
|
81
|
+
and isinstance(v[0], (list, str))
|
|
82
|
+
):
|
|
83
|
+
dims, data = v
|
|
84
|
+
dataset_vars[k] = DataArray(data, dims=dims)
|
|
85
|
+
else:
|
|
86
|
+
dataset_vars[k] = DataArray(v)
|
|
87
|
+
return Dataset(dataset_vars, coords=coords)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def compare_runs(
|
|
91
|
+
run1: az.InferenceData,
|
|
92
|
+
run2: az.InferenceData,
|
|
93
|
+
labels: List[str],
|
|
94
|
+
outdir: str,
|
|
95
|
+
) -> Dataset:
|
|
96
|
+
"""
|
|
97
|
+
Compare two InferenceData runs and return a Dataset with differences.
|
|
98
|
+
|
|
99
|
+
Parameters
|
|
100
|
+
----------
|
|
101
|
+
run1 : az.InferenceData
|
|
102
|
+
First run to compare
|
|
103
|
+
run2 : az.InferenceData
|
|
104
|
+
Second run to compare
|
|
105
|
+
|
|
106
|
+
Returns
|
|
107
|
+
-------
|
|
108
|
+
Dataset
|
|
109
|
+
Dataset containing the differences between the two runs
|
|
110
|
+
"""
|
|
111
|
+
import matplotlib.pyplot as plt
|
|
112
|
+
|
|
113
|
+
os.makedirs(outdir, exist_ok=True)
|
|
114
|
+
|
|
115
|
+
# Ensure both runs have the same variables
|
|
116
|
+
common_vars = set(run1.posterior.data_vars) & set(run2.posterior.data_vars)
|
|
117
|
+
if not common_vars:
|
|
118
|
+
raise ValueError("No common variables found in the two runs.")
|
|
119
|
+
|
|
120
|
+
# Plot density
|
|
121
|
+
fig = az.plot_density(
|
|
122
|
+
[run1.posterior, run2.posterior],
|
|
123
|
+
data_labels=labels,
|
|
124
|
+
shade=0.2,
|
|
125
|
+
hdi_prob=0.94,
|
|
126
|
+
)
|
|
127
|
+
plt.suptitle("Density Comparison", fontsize=14)
|
|
128
|
+
plt.tight_layout()
|
|
129
|
+
plt.savefig(f"{outdir}/density_comparison.png")
|
|
130
|
+
plt.close()
|
|
131
|
+
|
|
132
|
+
# Get summaries
|
|
133
|
+
summary1 = az.summary(run1)
|
|
134
|
+
summary2 = az.summary(run2)
|
|
135
|
+
|
|
136
|
+
# Compute difference in summaries
|
|
137
|
+
common_vars = summary1.index.intersection(summary2.index)
|
|
138
|
+
diff = summary1.loc[common_vars] - summary2.loc[common_vars]
|
|
139
|
+
diff.to_csv(f"{outdir}/summary_diff.csv")
|
|
140
|
+
|
|
141
|
+
print("Summary Differences:")
|
|
142
|
+
print(diff)
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclasses.dataclass
|
|
7
|
+
class Timeseries:
|
|
8
|
+
t: jnp.ndarray
|
|
9
|
+
y: jnp.ndarray
|
|
10
|
+
std: float = 1.0
|
|
11
|
+
|
|
12
|
+
@property
|
|
13
|
+
def n(self):
|
|
14
|
+
return len(self.t)
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def fs(self) -> float:
|
|
18
|
+
"""Sampling frequency computed from the time array."""
|
|
19
|
+
return float(1 / (self.t[1] - self.t[0]))
|
|
20
|
+
|
|
21
|
+
def to_periodogram(self) -> "Periodogram":
|
|
22
|
+
"""Compute the one-sided periodogram of the timeseries."""
|
|
23
|
+
freq = jnp.fft.rfftfreq(len(self.y), d=1 / self.fs)
|
|
24
|
+
power = jnp.abs(jnp.fft.rfft(self.y)) ** 2 / len(self.y)
|
|
25
|
+
return Periodogram(freq[1:], power[1:])
|
|
26
|
+
|
|
27
|
+
def standardise(self):
|
|
28
|
+
"""Standardise the timeseries to have zero mean and unit variance."""
|
|
29
|
+
self.std = float(jnp.std(self.y))
|
|
30
|
+
y = (self.y - jnp.mean(self.y)) / self.std
|
|
31
|
+
return Timeseries(self.t, y, self.std)
|
|
32
|
+
|
|
33
|
+
def __repr__(self):
|
|
34
|
+
return f"Timeseries(n={len(self.t)}, std={self.std:.3f}, fs={self.fs:.3f})"
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclasses.dataclass
|
|
38
|
+
class Periodogram:
|
|
39
|
+
freqs: jnp.ndarray
|
|
40
|
+
power: jnp.ndarray
|
|
41
|
+
filtered: bool = False
|
|
42
|
+
|
|
43
|
+
def __post_init__(self):
|
|
44
|
+
# assert no nans
|
|
45
|
+
if jnp.isnan(self.freqs).any() or jnp.isnan(self.power).any():
|
|
46
|
+
|
|
47
|
+
raise ValueError("Frequency or power contains NaN values.")
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def n(self):
|
|
51
|
+
return len(self.freqs)
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def fs(self) -> float:
|
|
55
|
+
"""Sampling frequency computed from the frequency array."""
|
|
56
|
+
return float(2 * self.freqs[-1])
|
|
57
|
+
|
|
58
|
+
def highpass(self, min_freq: float) -> "Periodogram":
|
|
59
|
+
"""Return a new Periodogram with frequencies above a threshold."""
|
|
60
|
+
mask = self.freqs > min_freq
|
|
61
|
+
return Periodogram(self.freqs[mask], self.power[mask], filtered=True)
|
|
62
|
+
|
|
63
|
+
def to_timeseries(self) -> "Timeseries":
|
|
64
|
+
"""Compute the inverse FFT of the periodogram."""
|
|
65
|
+
y = jnp.fft.irfft(self.power, n=2 * (self.n - 1))
|
|
66
|
+
t = jnp.linspace(0, 1 / self.fs, len(y))
|
|
67
|
+
return Timeseries(t, y)
|
|
68
|
+
|
|
69
|
+
def __mul__(self, other):
|
|
70
|
+
return Periodogram(self.freqs, self.power * other)
|
|
71
|
+
|
|
72
|
+
def __truediv__(self, other):
|
|
73
|
+
return Periodogram(self.freqs, self.power / other)
|
|
74
|
+
|
|
75
|
+
def __repr__(self):
|
|
76
|
+
return f"Periodogram(n={self.n}, fs={self.fs:.3f}, filtered={self.filtered})"
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def compute_welsch_psd(
|
|
80
|
+
freqs: jnp.ndarray, power: jnp.ndarray, alpha: float = 2.0
|
|
81
|
+
) -> jnp.ndarray:
|
|
82
|
+
"""Compute the Welsch power spectral density of a periodogram."""
|
|
83
|
+
return power / (1 + (freqs / alpha) ** 2)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .ar_data import ARData
|
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
from typing import Optional, Sequence
|
|
2
|
+
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ARData:
|
|
8
|
+
"""
|
|
9
|
+
A class to simulate an AR(p) process (for p up to 4, or higher) and
|
|
10
|
+
compute its theoretical PSD as well as the raw periodogram.
|
|
11
|
+
|
|
12
|
+
Attributes
|
|
13
|
+
----------
|
|
14
|
+
ar_coefs : np.ndarray
|
|
15
|
+
1D array of AR coefficients [a1, a2, ..., ap].
|
|
16
|
+
order : int
|
|
17
|
+
Order p of the AR process.
|
|
18
|
+
sigma : float
|
|
19
|
+
Standard deviation of the white‐noise driving the AR process.
|
|
20
|
+
fs : float
|
|
21
|
+
Sampling frequency [Hz].
|
|
22
|
+
duration : float
|
|
23
|
+
Total duration of the time series [s].
|
|
24
|
+
n : int
|
|
25
|
+
Number of samples = int(duration * fs).
|
|
26
|
+
seed : Optional[int]
|
|
27
|
+
Seed for the random number generator (if given).
|
|
28
|
+
ts : np.ndarray
|
|
29
|
+
Simulated time‐domain AR(p) series of length n.
|
|
30
|
+
freqs : np.ndarray
|
|
31
|
+
One‐sided frequency axis (length n//2 + 1).
|
|
32
|
+
psd_theoretical : np.ndarray
|
|
33
|
+
Theoretical one‐sided PSD (power per Hz) sampled at freqs.
|
|
34
|
+
periodogram : np.ndarray
|
|
35
|
+
One‐sided raw periodogram (power per Hz) from the simulated ts.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
order: int,
|
|
41
|
+
duration: float,
|
|
42
|
+
fs: float,
|
|
43
|
+
sigma: float = 1.0,
|
|
44
|
+
seed: Optional[int] = None,
|
|
45
|
+
ar_coefs: Sequence[float] = None,
|
|
46
|
+
) -> None:
|
|
47
|
+
"""
|
|
48
|
+
Parameters
|
|
49
|
+
----------
|
|
50
|
+
ar_coefs : Sequence[float]
|
|
51
|
+
Coefficients [a1, a2, ..., ap] for an AR(p) model.
|
|
52
|
+
For example, for AR(2) with x[t] = a1 x[t-1] + a2 x[t-2] + noise,
|
|
53
|
+
pass ar_coefs=[a1, a2].
|
|
54
|
+
duration : float
|
|
55
|
+
Total length of the time series in seconds.
|
|
56
|
+
fs : float
|
|
57
|
+
Sampling frequency in Hz.
|
|
58
|
+
sigma : float, default=1.0
|
|
59
|
+
Standard deviation of the white‐noise innovations.
|
|
60
|
+
seed : Optional[int], default=None
|
|
61
|
+
Seed for the random number generator (if you want reproducible draws).
|
|
62
|
+
"""
|
|
63
|
+
self.order = order
|
|
64
|
+
|
|
65
|
+
if ar_coefs is None:
|
|
66
|
+
if order == 1:
|
|
67
|
+
ar_coefs = [0.9]
|
|
68
|
+
elif order == 2:
|
|
69
|
+
ar_coefs = [1.45, -0.9025]
|
|
70
|
+
elif order == 3:
|
|
71
|
+
ar_coefs = [0.9, -0.8, 0.7]
|
|
72
|
+
elif order == 4:
|
|
73
|
+
ar_coefs = [0.9, -0.8, 0.7, -0.6]
|
|
74
|
+
elif order == 5:
|
|
75
|
+
ar_coefs = [1, -2.2137, 2.9403, -2.1697, 0.9606]
|
|
76
|
+
|
|
77
|
+
else:
|
|
78
|
+
assert len(self.ar_coefs) == order
|
|
79
|
+
|
|
80
|
+
self.ar_coefs = np.asarray(ar_coefs, dtype=float)
|
|
81
|
+
self.order = len(self.ar_coefs)
|
|
82
|
+
self.sigma = float(sigma)
|
|
83
|
+
self.fs = float(fs)
|
|
84
|
+
self.duration = float(duration)
|
|
85
|
+
self.n = int(self.duration * self.fs)
|
|
86
|
+
self.seed = seed
|
|
87
|
+
|
|
88
|
+
# 1) Simulate the AR(p) time series
|
|
89
|
+
self.ts = self._generate_timeseries()
|
|
90
|
+
|
|
91
|
+
# 2) Build the one‐sided frequency axis
|
|
92
|
+
# rfftfreq returns [0, 1, 2, ..., fs/2] with n//2 + 1 points
|
|
93
|
+
self.freqs = np.fft.rfftfreq(self.n, d=1.0 / self.fs)
|
|
94
|
+
|
|
95
|
+
# 3) Compute theoretical PSD on that frequency grid
|
|
96
|
+
self.psd_theoretical = self._compute_theoretical_psd()
|
|
97
|
+
|
|
98
|
+
# 4) Compute the one‐sided raw periodogram (power per Hz)
|
|
99
|
+
self.periodogram = self._compute_periodogram()
|
|
100
|
+
|
|
101
|
+
def _generate_timeseries(self) -> np.ndarray:
|
|
102
|
+
"""
|
|
103
|
+
Generate an AR(p) time series of length n using the recursion
|
|
104
|
+
|
|
105
|
+
x[t] = a1*x[t-1] + a2*x[t-2] + ... + ap*x[t-p] + noise[t],
|
|
106
|
+
|
|
107
|
+
where noise[t] ~ Normal(0, sigma^2). For t < 0, we assume x[t] = 0.
|
|
108
|
+
|
|
109
|
+
Returns
|
|
110
|
+
-------
|
|
111
|
+
ts : np.ndarray
|
|
112
|
+
Simulated AR(p) time series of length n.
|
|
113
|
+
"""
|
|
114
|
+
rng = np.random.default_rng(self.seed)
|
|
115
|
+
x = np.zeros(self.n, dtype=float)
|
|
116
|
+
noise = rng.normal(loc=0.0, scale=self.sigma, size=self.n)
|
|
117
|
+
|
|
118
|
+
# Iterate from t = p .. n-1
|
|
119
|
+
for t in range(self.order, self.n):
|
|
120
|
+
past_terms = 0.0
|
|
121
|
+
# sum over a_k * x[t-k-1]
|
|
122
|
+
for k, a_k in enumerate(self.ar_coefs, start=1):
|
|
123
|
+
past_terms += a_k * x[t - k]
|
|
124
|
+
x[t] = past_terms + noise[t]
|
|
125
|
+
|
|
126
|
+
return x
|
|
127
|
+
|
|
128
|
+
def _compute_theoretical_psd(self) -> np.ndarray:
|
|
129
|
+
"""
|
|
130
|
+
Compute the theoretical one‐sided PSD (power per Hz) of the AR(p) process:
|
|
131
|
+
|
|
132
|
+
S_theory(f) = (sigma^2 / fs) / | 1 - a1*e^{-i*2πf/fs} - a2*e^{-i*2πf*2/fs} - ... - ap*e^{-i*2πf*p/fs} |^2
|
|
133
|
+
|
|
134
|
+
evaluated at freqs = [0, 1, 2, ..., fs/2].
|
|
135
|
+
|
|
136
|
+
Returns
|
|
137
|
+
-------
|
|
138
|
+
psd_th : np.ndarray
|
|
139
|
+
One‐sided theoretical PSD of length n//2 + 1.
|
|
140
|
+
"""
|
|
141
|
+
# digital‐frequency omega = 2π (f / fs)
|
|
142
|
+
omega = 2 * np.pi * self.freqs / self.fs
|
|
143
|
+
|
|
144
|
+
# Form the denominator polynomial: 1 - sum_{k=1}^p a_k e^{-i k omega}
|
|
145
|
+
# We compute numerator = sigma^2 / fs, denominator=|...|^2
|
|
146
|
+
denom = np.ones_like(omega, dtype=complex)
|
|
147
|
+
for k, a_k in enumerate(self.ar_coefs, start=1):
|
|
148
|
+
denom -= a_k * np.exp(-1j * k * omega)
|
|
149
|
+
denom_mag2 = np.abs(denom) ** 2
|
|
150
|
+
|
|
151
|
+
psd_th = (self.sigma**2 / self.fs) / denom_mag2
|
|
152
|
+
return psd_th.real # should already be float
|
|
153
|
+
|
|
154
|
+
def _compute_periodogram(self) -> np.ndarray:
|
|
155
|
+
"""
|
|
156
|
+
Compute the one‐sided raw periodogram of the simulated time series:
|
|
157
|
+
|
|
158
|
+
Pxx(f_k) = (1 / (n * fs)) * |H(f_k)|^2,
|
|
159
|
+
then double all bins except DC (k=0) and Nyquist (k=n/2) if n is even.
|
|
160
|
+
|
|
161
|
+
Returns
|
|
162
|
+
-------
|
|
163
|
+
pxx : np.ndarray
|
|
164
|
+
One‐sided periodogram (power per Hz) of length n//2 + 1.
|
|
165
|
+
"""
|
|
166
|
+
# 1) Full FFT
|
|
167
|
+
H_full = np.fft.fft(self.ts)
|
|
168
|
+
|
|
169
|
+
# 2) Compute |H|^2 and normalize by (n * fs) → gives power per Hz
|
|
170
|
+
Pxx_full = (1.0 / (self.n * self.fs)) * np.abs(H_full) ** 2
|
|
171
|
+
|
|
172
|
+
# 3) Keep only the first (n//2 + 1) bins for real‐input one‐sided PSD
|
|
173
|
+
Pxx_one = Pxx_full[: self.n // 2 + 1]
|
|
174
|
+
|
|
175
|
+
# 4) Double all interior bins (1 .. n//2-1) to account for negative frequencies
|
|
176
|
+
if self.n % 2 == 0:
|
|
177
|
+
# n even → Nyquist is index n/2 and should NOT be doubled
|
|
178
|
+
Pxx_one[1:-1] *= 2.0
|
|
179
|
+
else:
|
|
180
|
+
# n odd → last index is floor(n/2), which is still not doubled
|
|
181
|
+
Pxx_one[1:] *= 2.0
|
|
182
|
+
|
|
183
|
+
return Pxx_one
|
|
184
|
+
|
|
185
|
+
def plot(
|
|
186
|
+
self,
|
|
187
|
+
ax: Optional[plt.Axes] = None,
|
|
188
|
+
*,
|
|
189
|
+
show_legend: bool = True,
|
|
190
|
+
periodogram_kwargs: Optional[dict] = None,
|
|
191
|
+
theoretical_kwargs: Optional[dict] = None,
|
|
192
|
+
) -> plt.Axes:
|
|
193
|
+
"""
|
|
194
|
+
Plot the one‐sided raw periodogram and the theoretical PSD
|
|
195
|
+
on the same axes (log–log).
|
|
196
|
+
|
|
197
|
+
Parameters
|
|
198
|
+
----------
|
|
199
|
+
ax : Optional[plt.Axes]
|
|
200
|
+
If provided, plot onto this Axes object. Otherwise, create a new figure/axes.
|
|
201
|
+
show_legend : bool, default=True
|
|
202
|
+
Whether to display a legend.
|
|
203
|
+
periodogram_kwargs : Optional[dict], default=None
|
|
204
|
+
Additional kwargs to pass to plt.semilogy when plotting the periodogram.
|
|
205
|
+
theoretical_kwargs : Optional[dict], default=None
|
|
206
|
+
Additional kwargs to pass to plt.semilogy when plotting the theoretical PSD.
|
|
207
|
+
|
|
208
|
+
Returns
|
|
209
|
+
-------
|
|
210
|
+
ax : plt.Axes
|
|
211
|
+
The Axes object containing the plot.
|
|
212
|
+
"""
|
|
213
|
+
if ax is None:
|
|
214
|
+
fig, ax = plt.subplots(figsize=(8, 4))
|
|
215
|
+
|
|
216
|
+
# Default plotting styles
|
|
217
|
+
p_kwargs = {"label": "Raw Periodogram", "alpha": 0.6, "linewidth": 1.0}
|
|
218
|
+
t_kwargs = {
|
|
219
|
+
"label": "Theoretical PSD",
|
|
220
|
+
"linestyle": "--",
|
|
221
|
+
"color": "C1",
|
|
222
|
+
"linewidth": 2.0,
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
if periodogram_kwargs is not None:
|
|
226
|
+
p_kwargs.update(periodogram_kwargs)
|
|
227
|
+
if theoretical_kwargs is not None:
|
|
228
|
+
t_kwargs.update(theoretical_kwargs)
|
|
229
|
+
|
|
230
|
+
# Plot raw periodogram
|
|
231
|
+
ax.semilogy(self.freqs, self.periodogram, **p_kwargs)
|
|
232
|
+
|
|
233
|
+
# Plot theoretical PSD
|
|
234
|
+
ax.semilogy(self.freqs, self.psd_theoretical, **t_kwargs)
|
|
235
|
+
|
|
236
|
+
ax.set_xlabel("Frequency [Hz]")
|
|
237
|
+
ax.set_ylabel("PSD [power/Hz]")
|
|
238
|
+
ax.set_title(
|
|
239
|
+
f"AR({self.order}) Process: Periodogram vs Theoretical PSD"
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
if show_legend:
|
|
243
|
+
ax.legend()
|
|
244
|
+
|
|
245
|
+
ax.grid(True, which="both", ls=":", alpha=0.5)
|
|
246
|
+
return ax
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
# Example usage:
|
|
250
|
+
if __name__ == "__main__":
|
|
251
|
+
# --- Simulate AR(2) over 8 seconds at 1024 Hz ---
|
|
252
|
+
ar2 = ARData(
|
|
253
|
+
ar_coefs=[0.9, -0.5], duration=8.0, fs=1024.0, sigma=1.0, seed=42
|
|
254
|
+
)
|
|
255
|
+
fig, ax = plt.subplots(figsize=(8, 4))
|
|
256
|
+
ar2.plot(ax=ax)
|
|
257
|
+
plt.show()
|
|
258
|
+
|
|
259
|
+
# --- Simulate AR(4) over 4 seconds at 2048 Hz ---
|
|
260
|
+
# e.g. coefficients [0.5, -0.3, 0.1, -0.05]
|
|
261
|
+
ar4 = ARData(
|
|
262
|
+
ar_coefs=[0.5, -0.3, 0.1, -0.05], duration=4.0, fs=2048.0, sigma=1.0
|
|
263
|
+
)
|
|
264
|
+
fig2, ax2 = plt.subplots(figsize=(8, 4))
|
|
265
|
+
ar4.plot(
|
|
266
|
+
ax=ax2,
|
|
267
|
+
periodogram_kwargs={"color": "C2"},
|
|
268
|
+
theoretical_kwargs={"color": "k", "linestyle": "-."},
|
|
269
|
+
)
|
|
270
|
+
plt.show()
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import matplotlib.pyplot as plt
|
|
6
|
+
import numpy as np
|
|
7
|
+
import scipy.signal as signal
|
|
8
|
+
from gwpy.timeseries import TimeSeries
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class LVKData:
|
|
13
|
+
"""
|
|
14
|
+
A dataclass for downloading, loading, and computing PSDs of gravitational-wave strain data.
|
|
15
|
+
|
|
16
|
+
Upon initialization, the PSDs for all overlapping segments are computed immediately.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
strain: TimeSeries
|
|
20
|
+
duration: int
|
|
21
|
+
segment_duration: int
|
|
22
|
+
segment_overlap: float
|
|
23
|
+
min_freq: Optional[float] = None
|
|
24
|
+
max_freq: Optional[float] = None
|
|
25
|
+
|
|
26
|
+
# Fields computed in __post_init__; not passed at construction
|
|
27
|
+
fs: float = field(init=False)
|
|
28
|
+
n: int = field(init=False)
|
|
29
|
+
nperseg: int = field(init=False)
|
|
30
|
+
noverlap: int = field(init=False)
|
|
31
|
+
step: int = field(init=False)
|
|
32
|
+
n_segments: int = field(init=False)
|
|
33
|
+
freqs: np.ndarray = field(init=False)
|
|
34
|
+
psds: np.ndarray = field(init=False)
|
|
35
|
+
median_psd: np.ndarray = field(init=False)
|
|
36
|
+
|
|
37
|
+
def __post_init__(self):
|
|
38
|
+
# Sampling info
|
|
39
|
+
self.fs = float(self.strain.sample_rate.value)
|
|
40
|
+
self.n = len(self.strain)
|
|
41
|
+
|
|
42
|
+
# Number of samples per segment and overlap in samples
|
|
43
|
+
self.nperseg = int(self.fs * self.segment_duration)
|
|
44
|
+
self.noverlap = int(self.nperseg * self.segment_overlap)
|
|
45
|
+
self.step = self.nperseg - self.noverlap
|
|
46
|
+
|
|
47
|
+
# Compute number of segments
|
|
48
|
+
self.n_segments = (self.n - self.noverlap) // self.step
|
|
49
|
+
|
|
50
|
+
# Extract raw numpy array of strain values
|
|
51
|
+
data = self.strain.value
|
|
52
|
+
|
|
53
|
+
# Build strided array of shape (n_segments, nperseg)
|
|
54
|
+
shape = (self.n_segments, self.nperseg)
|
|
55
|
+
strides = (self.step * data.strides[-1], data.strides[-1])
|
|
56
|
+
segments = np.lib.stride_tricks.as_strided(
|
|
57
|
+
data, shape=shape, strides=strides
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# Compute one-sided PSD for each segment
|
|
61
|
+
freqs_full, psd_full = signal.welch(
|
|
62
|
+
segments,
|
|
63
|
+
fs=self.fs,
|
|
64
|
+
nperseg=self.nperseg,
|
|
65
|
+
noverlap=self.noverlap,
|
|
66
|
+
axis=-1,
|
|
67
|
+
return_onesided=True,
|
|
68
|
+
scaling="density",
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# Apply frequency mask if requested
|
|
72
|
+
freq_mask = np.ones_like(freqs_full, dtype=bool)
|
|
73
|
+
if self.min_freq is not None:
|
|
74
|
+
freq_mask &= freqs_full >= self.min_freq
|
|
75
|
+
if self.max_freq is not None:
|
|
76
|
+
freq_mask &= freqs_full <= self.max_freq
|
|
77
|
+
|
|
78
|
+
self.freqs = freqs_full[freq_mask]
|
|
79
|
+
self.psds = psd_full[:, freq_mask]
|
|
80
|
+
self.median_psd = np.median(self.psds, axis=0)
|
|
81
|
+
|
|
82
|
+
@classmethod
|
|
83
|
+
def download(
|
|
84
|
+
cls,
|
|
85
|
+
detector: str = "H1",
|
|
86
|
+
gps_start: int = 1126259462,
|
|
87
|
+
duration: int = 1024,
|
|
88
|
+
channel: Optional[str] = None,
|
|
89
|
+
) -> TimeSeries:
|
|
90
|
+
"""
|
|
91
|
+
Download open strain data from GWOSC for a given detector and GPS range.
|
|
92
|
+
"""
|
|
93
|
+
print(
|
|
94
|
+
f"Downloading {detector} data [{gps_start} - {gps_start + duration}]..."
|
|
95
|
+
)
|
|
96
|
+
strain = TimeSeries.fetch_open_data(
|
|
97
|
+
detector, gps_start, gps_start + duration
|
|
98
|
+
)
|
|
99
|
+
if channel:
|
|
100
|
+
strain.channel = channel
|
|
101
|
+
return strain
|
|
102
|
+
|
|
103
|
+
@classmethod
|
|
104
|
+
def load(
|
|
105
|
+
cls,
|
|
106
|
+
detector: str = "H1",
|
|
107
|
+
gps_start: int = 1126259462,
|
|
108
|
+
duration: int = 1024,
|
|
109
|
+
segment_duration: int = 4,
|
|
110
|
+
segment_overlap: float = 0.5,
|
|
111
|
+
min_freq: Optional[float] = None,
|
|
112
|
+
max_freq: Optional[float] = None,
|
|
113
|
+
cache_file: str = "strain_cache.gwf",
|
|
114
|
+
channel: str = "H1:GWOSC-STRAIN",
|
|
115
|
+
) -> "LVKData":
|
|
116
|
+
"""
|
|
117
|
+
Load strain data from cache or download if needed, then compute PSDs.
|
|
118
|
+
"""
|
|
119
|
+
if os.path.exists(cache_file):
|
|
120
|
+
try:
|
|
121
|
+
strain = TimeSeries.read(cache_file)
|
|
122
|
+
print(f"Loaded cached strain from '{cache_file}'")
|
|
123
|
+
except Exception as e:
|
|
124
|
+
print(
|
|
125
|
+
f"Failed to read cache '{cache_file}': {e}. Redownloading..."
|
|
126
|
+
)
|
|
127
|
+
os.remove(cache_file)
|
|
128
|
+
strain = cls.download(
|
|
129
|
+
detector, gps_start, duration, channel=channel
|
|
130
|
+
)
|
|
131
|
+
strain.write(cache_file)
|
|
132
|
+
print(f"Cached new strain to '{cache_file}'")
|
|
133
|
+
else:
|
|
134
|
+
strain = cls.download(
|
|
135
|
+
detector, gps_start, duration, channel=channel
|
|
136
|
+
)
|
|
137
|
+
strain.write(cache_file)
|
|
138
|
+
print(f"Cached strain to '{cache_file}'")
|
|
139
|
+
|
|
140
|
+
return cls(
|
|
141
|
+
strain=strain,
|
|
142
|
+
duration=duration,
|
|
143
|
+
segment_duration=segment_duration,
|
|
144
|
+
segment_overlap=segment_overlap,
|
|
145
|
+
min_freq=min_freq,
|
|
146
|
+
max_freq=max_freq,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
def compute_median_psd(
|
|
150
|
+
self, n_segments: Optional[int] = None
|
|
151
|
+
) -> np.ndarray:
|
|
152
|
+
"""
|
|
153
|
+
Return the median PSD computed over the first `n_segments` segments.
|
|
154
|
+
"""
|
|
155
|
+
if n_segments is None:
|
|
156
|
+
n_segments = self.n_segments
|
|
157
|
+
if n_segments > self.n_segments:
|
|
158
|
+
raise ValueError(
|
|
159
|
+
"n_segments exceeds available number of segments."
|
|
160
|
+
)
|
|
161
|
+
return np.median(self.psds[:n_segments, :], axis=0)
|
|
162
|
+
|
|
163
|
+
def plot_psd(self) -> plt.Figure:
|
|
164
|
+
"""
|
|
165
|
+
Plot all individual-segment PSDs in gray and the median PSD in red.
|
|
166
|
+
"""
|
|
167
|
+
fig, ax = plt.subplots(figsize=(8, 5))
|
|
168
|
+
ax.loglog(self.freqs, self.psds.T, color="gray", alpha=0.3)
|
|
169
|
+
ax.loglog(
|
|
170
|
+
self.freqs, self.median_psd, color="r", lw=2, label="Median PSD"
|
|
171
|
+
)
|
|
172
|
+
ax.set_xlabel("Frequency (Hz)")
|
|
173
|
+
ax.set_ylabel("PSD [strain^2/Hz]")
|
|
174
|
+
ax.set_title(f"PSD: {self.strain.channel}")
|
|
175
|
+
ax.grid(True, which="both", ls=":")
|
|
176
|
+
ax.legend()
|
|
177
|
+
return fig
|