LogPSplinePSD 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
@@ -0,0 +1,21 @@
1
+ # file generated by setuptools-scm
2
+ # don't change, don't track in version control
3
+
4
+ __all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
5
+
6
+ TYPE_CHECKING = False
7
+ if TYPE_CHECKING:
8
+ from typing import Tuple
9
+ from typing import Union
10
+
11
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
12
+ else:
13
+ VERSION_TUPLE = object
14
+
15
+ version: str
16
+ __version__: str
17
+ __version_tuple__: VERSION_TUPLE
18
+ version_tuple: VERSION_TUPLE
19
+
20
+ __version__ = version = '0.0.3'
21
+ __version_tuple__ = version_tuple = (0, 0, 3)
@@ -0,0 +1,142 @@
1
+ import os
2
+ from typing import List
3
+
4
+ import arviz as az
5
+ import numpy as np
6
+ from xarray import DataArray, Dataset
7
+
8
+ from .psplines import LogPSplines
9
+
10
+
11
+ def get_weights(
12
+ idata: az.InferenceData,
13
+ thin: int = 10,
14
+ ) -> np.ndarray:
15
+ """
16
+ Extract weight samples from arviz InferenceData.
17
+
18
+ Parameters
19
+ ----------
20
+ idata : az.InferenceData
21
+ Inference data containing weight samples
22
+ thin : int
23
+ Thinning factor
24
+
25
+ Returns
26
+ -------
27
+ jnp.ndarray
28
+ Weight samples, shape (n_samples_thinned, n_weights)
29
+ """
30
+ # Get weight samples and flatten chains
31
+ weight_samples = (
32
+ idata.posterior.weights.values
33
+ ) # (chains, draws, n_weights)
34
+ weight_samples = weight_samples.reshape(
35
+ -1, weight_samples.shape[-1]
36
+ ) # (chains*draws, n_weights)
37
+
38
+ # Thin samples
39
+ return weight_samples[::thin]
40
+
41
+
42
+ def get_psd_samples_arviz(
43
+ idata: az.InferenceData, spline_model: LogPSplines, thin: int = 10
44
+ ) -> np.ndarray:
45
+ """
46
+ Extract PSD samples from arviz InferenceData.
47
+
48
+ Parameters
49
+ ----------
50
+ idata : az.InferenceData
51
+ Inference data containing weight samples
52
+ spline_model : LogPSplines
53
+ Spline model for reconstruction
54
+ thin : int
55
+ Thinning factor
56
+
57
+ Returns
58
+ -------
59
+ jnp.ndarray
60
+ PSD samples, shape (n_samples_thinned, n_frequencies)
61
+ """
62
+ # Get weight samples and flatten chains
63
+ weight_samples = get_weights(idata, thin=thin)
64
+
65
+ # Compute PSD samples
66
+ psd_samples = []
67
+ for weights in weight_samples:
68
+ ln_spline = spline_model.basis.T @ weights
69
+ ln_psd = ln_spline + spline_model.log_parametric_model
70
+ psd_samples.append(np.exp(ln_psd))
71
+
72
+ return np.array(psd_samples)
73
+
74
+
75
+ def _make_dataset_from_dict(data_dict, coords=None):
76
+ dataset_vars = {}
77
+ for k, v in data_dict.items():
78
+ if (
79
+ isinstance(v, tuple)
80
+ and len(v) == 2
81
+ and isinstance(v[0], (list, str))
82
+ ):
83
+ dims, data = v
84
+ dataset_vars[k] = DataArray(data, dims=dims)
85
+ else:
86
+ dataset_vars[k] = DataArray(v)
87
+ return Dataset(dataset_vars, coords=coords)
88
+
89
+
90
+ def compare_runs(
91
+ run1: az.InferenceData,
92
+ run2: az.InferenceData,
93
+ labels: List[str],
94
+ outdir: str,
95
+ ) -> Dataset:
96
+ """
97
+ Compare two InferenceData runs and return a Dataset with differences.
98
+
99
+ Parameters
100
+ ----------
101
+ run1 : az.InferenceData
102
+ First run to compare
103
+ run2 : az.InferenceData
104
+ Second run to compare
105
+
106
+ Returns
107
+ -------
108
+ Dataset
109
+ Dataset containing the differences between the two runs
110
+ """
111
+ import matplotlib.pyplot as plt
112
+
113
+ os.makedirs(outdir, exist_ok=True)
114
+
115
+ # Ensure both runs have the same variables
116
+ common_vars = set(run1.posterior.data_vars) & set(run2.posterior.data_vars)
117
+ if not common_vars:
118
+ raise ValueError("No common variables found in the two runs.")
119
+
120
+ # Plot density
121
+ fig = az.plot_density(
122
+ [run1.posterior, run2.posterior],
123
+ data_labels=labels,
124
+ shade=0.2,
125
+ hdi_prob=0.94,
126
+ )
127
+ plt.suptitle("Density Comparison", fontsize=14)
128
+ plt.tight_layout()
129
+ plt.savefig(f"{outdir}/density_comparison.png")
130
+ plt.close()
131
+
132
+ # Get summaries
133
+ summary1 = az.summary(run1)
134
+ summary2 = az.summary(run2)
135
+
136
+ # Compute difference in summaries
137
+ common_vars = summary1.index.intersection(summary2.index)
138
+ diff = summary1.loc[common_vars] - summary2.loc[common_vars]
139
+ diff.to_csv(f"{outdir}/summary_diff.csv")
140
+
141
+ print("Summary Differences:")
142
+ print(diff)
@@ -0,0 +1,83 @@
1
+ import dataclasses
2
+
3
+ import jax.numpy as jnp
4
+
5
+
6
+ @dataclasses.dataclass
7
+ class Timeseries:
8
+ t: jnp.ndarray
9
+ y: jnp.ndarray
10
+ std: float = 1.0
11
+
12
+ @property
13
+ def n(self):
14
+ return len(self.t)
15
+
16
+ @property
17
+ def fs(self) -> float:
18
+ """Sampling frequency computed from the time array."""
19
+ return float(1 / (self.t[1] - self.t[0]))
20
+
21
+ def to_periodogram(self) -> "Periodogram":
22
+ """Compute the one-sided periodogram of the timeseries."""
23
+ freq = jnp.fft.rfftfreq(len(self.y), d=1 / self.fs)
24
+ power = jnp.abs(jnp.fft.rfft(self.y)) ** 2 / len(self.y)
25
+ return Periodogram(freq[1:], power[1:])
26
+
27
+ def standardise(self):
28
+ """Standardise the timeseries to have zero mean and unit variance."""
29
+ self.std = float(jnp.std(self.y))
30
+ y = (self.y - jnp.mean(self.y)) / self.std
31
+ return Timeseries(self.t, y, self.std)
32
+
33
+ def __repr__(self):
34
+ return f"Timeseries(n={len(self.t)}, std={self.std:.3f}, fs={self.fs:.3f})"
35
+
36
+
37
+ @dataclasses.dataclass
38
+ class Periodogram:
39
+ freqs: jnp.ndarray
40
+ power: jnp.ndarray
41
+ filtered: bool = False
42
+
43
+ def __post_init__(self):
44
+ # assert no nans
45
+ if jnp.isnan(self.freqs).any() or jnp.isnan(self.power).any():
46
+
47
+ raise ValueError("Frequency or power contains NaN values.")
48
+
49
+ @property
50
+ def n(self):
51
+ return len(self.freqs)
52
+
53
+ @property
54
+ def fs(self) -> float:
55
+ """Sampling frequency computed from the frequency array."""
56
+ return float(2 * self.freqs[-1])
57
+
58
+ def highpass(self, min_freq: float) -> "Periodogram":
59
+ """Return a new Periodogram with frequencies above a threshold."""
60
+ mask = self.freqs > min_freq
61
+ return Periodogram(self.freqs[mask], self.power[mask], filtered=True)
62
+
63
+ def to_timeseries(self) -> "Timeseries":
64
+ """Compute the inverse FFT of the periodogram."""
65
+ y = jnp.fft.irfft(self.power, n=2 * (self.n - 1))
66
+ t = jnp.linspace(0, 1 / self.fs, len(y))
67
+ return Timeseries(t, y)
68
+
69
+ def __mul__(self, other):
70
+ return Periodogram(self.freqs, self.power * other)
71
+
72
+ def __truediv__(self, other):
73
+ return Periodogram(self.freqs, self.power / other)
74
+
75
+ def __repr__(self):
76
+ return f"Periodogram(n={self.n}, fs={self.fs:.3f}, filtered={self.filtered})"
77
+
78
+
79
+ def compute_welsch_psd(
80
+ freqs: jnp.ndarray, power: jnp.ndarray, alpha: float = 2.0
81
+ ) -> jnp.ndarray:
82
+ """Compute the Welsch power spectral density of a periodogram."""
83
+ return power / (1 + (freqs / alpha) ** 2)
@@ -0,0 +1 @@
1
+ from .ar_data import ARData
@@ -0,0 +1,270 @@
1
+ from typing import Optional, Sequence
2
+
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+
6
+
7
+ class ARData:
8
+ """
9
+ A class to simulate an AR(p) process (for p up to 4, or higher) and
10
+ compute its theoretical PSD as well as the raw periodogram.
11
+
12
+ Attributes
13
+ ----------
14
+ ar_coefs : np.ndarray
15
+ 1D array of AR coefficients [a1, a2, ..., ap].
16
+ order : int
17
+ Order p of the AR process.
18
+ sigma : float
19
+ Standard deviation of the white‐noise driving the AR process.
20
+ fs : float
21
+ Sampling frequency [Hz].
22
+ duration : float
23
+ Total duration of the time series [s].
24
+ n : int
25
+ Number of samples = int(duration * fs).
26
+ seed : Optional[int]
27
+ Seed for the random number generator (if given).
28
+ ts : np.ndarray
29
+ Simulated time‐domain AR(p) series of length n.
30
+ freqs : np.ndarray
31
+ One‐sided frequency axis (length n//2 + 1).
32
+ psd_theoretical : np.ndarray
33
+ Theoretical one‐sided PSD (power per Hz) sampled at freqs.
34
+ periodogram : np.ndarray
35
+ One‐sided raw periodogram (power per Hz) from the simulated ts.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ order: int,
41
+ duration: float,
42
+ fs: float,
43
+ sigma: float = 1.0,
44
+ seed: Optional[int] = None,
45
+ ar_coefs: Sequence[float] = None,
46
+ ) -> None:
47
+ """
48
+ Parameters
49
+ ----------
50
+ ar_coefs : Sequence[float]
51
+ Coefficients [a1, a2, ..., ap] for an AR(p) model.
52
+ For example, for AR(2) with x[t] = a1 x[t-1] + a2 x[t-2] + noise,
53
+ pass ar_coefs=[a1, a2].
54
+ duration : float
55
+ Total length of the time series in seconds.
56
+ fs : float
57
+ Sampling frequency in Hz.
58
+ sigma : float, default=1.0
59
+ Standard deviation of the white‐noise innovations.
60
+ seed : Optional[int], default=None
61
+ Seed for the random number generator (if you want reproducible draws).
62
+ """
63
+ self.order = order
64
+
65
+ if ar_coefs is None:
66
+ if order == 1:
67
+ ar_coefs = [0.9]
68
+ elif order == 2:
69
+ ar_coefs = [1.45, -0.9025]
70
+ elif order == 3:
71
+ ar_coefs = [0.9, -0.8, 0.7]
72
+ elif order == 4:
73
+ ar_coefs = [0.9, -0.8, 0.7, -0.6]
74
+ elif order == 5:
75
+ ar_coefs = [1, -2.2137, 2.9403, -2.1697, 0.9606]
76
+
77
+ else:
78
+ assert len(self.ar_coefs) == order
79
+
80
+ self.ar_coefs = np.asarray(ar_coefs, dtype=float)
81
+ self.order = len(self.ar_coefs)
82
+ self.sigma = float(sigma)
83
+ self.fs = float(fs)
84
+ self.duration = float(duration)
85
+ self.n = int(self.duration * self.fs)
86
+ self.seed = seed
87
+
88
+ # 1) Simulate the AR(p) time series
89
+ self.ts = self._generate_timeseries()
90
+
91
+ # 2) Build the one‐sided frequency axis
92
+ # rfftfreq returns [0, 1, 2, ..., fs/2] with n//2 + 1 points
93
+ self.freqs = np.fft.rfftfreq(self.n, d=1.0 / self.fs)
94
+
95
+ # 3) Compute theoretical PSD on that frequency grid
96
+ self.psd_theoretical = self._compute_theoretical_psd()
97
+
98
+ # 4) Compute the one‐sided raw periodogram (power per Hz)
99
+ self.periodogram = self._compute_periodogram()
100
+
101
+ def _generate_timeseries(self) -> np.ndarray:
102
+ """
103
+ Generate an AR(p) time series of length n using the recursion
104
+
105
+ x[t] = a1*x[t-1] + a2*x[t-2] + ... + ap*x[t-p] + noise[t],
106
+
107
+ where noise[t] ~ Normal(0, sigma^2). For t < 0, we assume x[t] = 0.
108
+
109
+ Returns
110
+ -------
111
+ ts : np.ndarray
112
+ Simulated AR(p) time series of length n.
113
+ """
114
+ rng = np.random.default_rng(self.seed)
115
+ x = np.zeros(self.n, dtype=float)
116
+ noise = rng.normal(loc=0.0, scale=self.sigma, size=self.n)
117
+
118
+ # Iterate from t = p .. n-1
119
+ for t in range(self.order, self.n):
120
+ past_terms = 0.0
121
+ # sum over a_k * x[t-k-1]
122
+ for k, a_k in enumerate(self.ar_coefs, start=1):
123
+ past_terms += a_k * x[t - k]
124
+ x[t] = past_terms + noise[t]
125
+
126
+ return x
127
+
128
+ def _compute_theoretical_psd(self) -> np.ndarray:
129
+ """
130
+ Compute the theoretical one‐sided PSD (power per Hz) of the AR(p) process:
131
+
132
+ S_theory(f) = (sigma^2 / fs) / | 1 - a1*e^{-i*2πf/fs} - a2*e^{-i*2πf*2/fs} - ... - ap*e^{-i*2πf*p/fs} |^2
133
+
134
+ evaluated at freqs = [0, 1, 2, ..., fs/2].
135
+
136
+ Returns
137
+ -------
138
+ psd_th : np.ndarray
139
+ One‐sided theoretical PSD of length n//2 + 1.
140
+ """
141
+ # digital‐frequency omega = 2π (f / fs)
142
+ omega = 2 * np.pi * self.freqs / self.fs
143
+
144
+ # Form the denominator polynomial: 1 - sum_{k=1}^p a_k e^{-i k omega}
145
+ # We compute numerator = sigma^2 / fs, denominator=|...|^2
146
+ denom = np.ones_like(omega, dtype=complex)
147
+ for k, a_k in enumerate(self.ar_coefs, start=1):
148
+ denom -= a_k * np.exp(-1j * k * omega)
149
+ denom_mag2 = np.abs(denom) ** 2
150
+
151
+ psd_th = (self.sigma**2 / self.fs) / denom_mag2
152
+ return psd_th.real # should already be float
153
+
154
+ def _compute_periodogram(self) -> np.ndarray:
155
+ """
156
+ Compute the one‐sided raw periodogram of the simulated time series:
157
+
158
+ Pxx(f_k) = (1 / (n * fs)) * |H(f_k)|^2,
159
+ then double all bins except DC (k=0) and Nyquist (k=n/2) if n is even.
160
+
161
+ Returns
162
+ -------
163
+ pxx : np.ndarray
164
+ One‐sided periodogram (power per Hz) of length n//2 + 1.
165
+ """
166
+ # 1) Full FFT
167
+ H_full = np.fft.fft(self.ts)
168
+
169
+ # 2) Compute |H|^2 and normalize by (n * fs) → gives power per Hz
170
+ Pxx_full = (1.0 / (self.n * self.fs)) * np.abs(H_full) ** 2
171
+
172
+ # 3) Keep only the first (n//2 + 1) bins for real‐input one‐sided PSD
173
+ Pxx_one = Pxx_full[: self.n // 2 + 1]
174
+
175
+ # 4) Double all interior bins (1 .. n//2-1) to account for negative frequencies
176
+ if self.n % 2 == 0:
177
+ # n even → Nyquist is index n/2 and should NOT be doubled
178
+ Pxx_one[1:-1] *= 2.0
179
+ else:
180
+ # n odd → last index is floor(n/2), which is still not doubled
181
+ Pxx_one[1:] *= 2.0
182
+
183
+ return Pxx_one
184
+
185
+ def plot(
186
+ self,
187
+ ax: Optional[plt.Axes] = None,
188
+ *,
189
+ show_legend: bool = True,
190
+ periodogram_kwargs: Optional[dict] = None,
191
+ theoretical_kwargs: Optional[dict] = None,
192
+ ) -> plt.Axes:
193
+ """
194
+ Plot the one‐sided raw periodogram and the theoretical PSD
195
+ on the same axes (log–log).
196
+
197
+ Parameters
198
+ ----------
199
+ ax : Optional[plt.Axes]
200
+ If provided, plot onto this Axes object. Otherwise, create a new figure/axes.
201
+ show_legend : bool, default=True
202
+ Whether to display a legend.
203
+ periodogram_kwargs : Optional[dict], default=None
204
+ Additional kwargs to pass to plt.semilogy when plotting the periodogram.
205
+ theoretical_kwargs : Optional[dict], default=None
206
+ Additional kwargs to pass to plt.semilogy when plotting the theoretical PSD.
207
+
208
+ Returns
209
+ -------
210
+ ax : plt.Axes
211
+ The Axes object containing the plot.
212
+ """
213
+ if ax is None:
214
+ fig, ax = plt.subplots(figsize=(8, 4))
215
+
216
+ # Default plotting styles
217
+ p_kwargs = {"label": "Raw Periodogram", "alpha": 0.6, "linewidth": 1.0}
218
+ t_kwargs = {
219
+ "label": "Theoretical PSD",
220
+ "linestyle": "--",
221
+ "color": "C1",
222
+ "linewidth": 2.0,
223
+ }
224
+
225
+ if periodogram_kwargs is not None:
226
+ p_kwargs.update(periodogram_kwargs)
227
+ if theoretical_kwargs is not None:
228
+ t_kwargs.update(theoretical_kwargs)
229
+
230
+ # Plot raw periodogram
231
+ ax.semilogy(self.freqs, self.periodogram, **p_kwargs)
232
+
233
+ # Plot theoretical PSD
234
+ ax.semilogy(self.freqs, self.psd_theoretical, **t_kwargs)
235
+
236
+ ax.set_xlabel("Frequency [Hz]")
237
+ ax.set_ylabel("PSD [power/Hz]")
238
+ ax.set_title(
239
+ f"AR({self.order}) Process: Periodogram vs Theoretical PSD"
240
+ )
241
+
242
+ if show_legend:
243
+ ax.legend()
244
+
245
+ ax.grid(True, which="both", ls=":", alpha=0.5)
246
+ return ax
247
+
248
+
249
+ # Example usage:
250
+ if __name__ == "__main__":
251
+ # --- Simulate AR(2) over 8 seconds at 1024 Hz ---
252
+ ar2 = ARData(
253
+ ar_coefs=[0.9, -0.5], duration=8.0, fs=1024.0, sigma=1.0, seed=42
254
+ )
255
+ fig, ax = plt.subplots(figsize=(8, 4))
256
+ ar2.plot(ax=ax)
257
+ plt.show()
258
+
259
+ # --- Simulate AR(4) over 4 seconds at 2048 Hz ---
260
+ # e.g. coefficients [0.5, -0.3, 0.1, -0.05]
261
+ ar4 = ARData(
262
+ ar_coefs=[0.5, -0.3, 0.1, -0.05], duration=4.0, fs=2048.0, sigma=1.0
263
+ )
264
+ fig2, ax2 = plt.subplots(figsize=(8, 4))
265
+ ar4.plot(
266
+ ax=ax2,
267
+ periodogram_kwargs={"color": "C2"},
268
+ theoretical_kwargs={"color": "k", "linestyle": "-."},
269
+ )
270
+ plt.show()
@@ -0,0 +1,177 @@
1
+ import os
2
+ from dataclasses import dataclass, field
3
+ from typing import Optional
4
+
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import scipy.signal as signal
8
+ from gwpy.timeseries import TimeSeries
9
+
10
+
11
+ @dataclass
12
+ class LVKData:
13
+ """
14
+ A dataclass for downloading, loading, and computing PSDs of gravitational-wave strain data.
15
+
16
+ Upon initialization, the PSDs for all overlapping segments are computed immediately.
17
+ """
18
+
19
+ strain: TimeSeries
20
+ duration: int
21
+ segment_duration: int
22
+ segment_overlap: float
23
+ min_freq: Optional[float] = None
24
+ max_freq: Optional[float] = None
25
+
26
+ # Fields computed in __post_init__; not passed at construction
27
+ fs: float = field(init=False)
28
+ n: int = field(init=False)
29
+ nperseg: int = field(init=False)
30
+ noverlap: int = field(init=False)
31
+ step: int = field(init=False)
32
+ n_segments: int = field(init=False)
33
+ freqs: np.ndarray = field(init=False)
34
+ psds: np.ndarray = field(init=False)
35
+ median_psd: np.ndarray = field(init=False)
36
+
37
+ def __post_init__(self):
38
+ # Sampling info
39
+ self.fs = float(self.strain.sample_rate.value)
40
+ self.n = len(self.strain)
41
+
42
+ # Number of samples per segment and overlap in samples
43
+ self.nperseg = int(self.fs * self.segment_duration)
44
+ self.noverlap = int(self.nperseg * self.segment_overlap)
45
+ self.step = self.nperseg - self.noverlap
46
+
47
+ # Compute number of segments
48
+ self.n_segments = (self.n - self.noverlap) // self.step
49
+
50
+ # Extract raw numpy array of strain values
51
+ data = self.strain.value
52
+
53
+ # Build strided array of shape (n_segments, nperseg)
54
+ shape = (self.n_segments, self.nperseg)
55
+ strides = (self.step * data.strides[-1], data.strides[-1])
56
+ segments = np.lib.stride_tricks.as_strided(
57
+ data, shape=shape, strides=strides
58
+ )
59
+
60
+ # Compute one-sided PSD for each segment
61
+ freqs_full, psd_full = signal.welch(
62
+ segments,
63
+ fs=self.fs,
64
+ nperseg=self.nperseg,
65
+ noverlap=self.noverlap,
66
+ axis=-1,
67
+ return_onesided=True,
68
+ scaling="density",
69
+ )
70
+
71
+ # Apply frequency mask if requested
72
+ freq_mask = np.ones_like(freqs_full, dtype=bool)
73
+ if self.min_freq is not None:
74
+ freq_mask &= freqs_full >= self.min_freq
75
+ if self.max_freq is not None:
76
+ freq_mask &= freqs_full <= self.max_freq
77
+
78
+ self.freqs = freqs_full[freq_mask]
79
+ self.psds = psd_full[:, freq_mask]
80
+ self.median_psd = np.median(self.psds, axis=0)
81
+
82
+ @classmethod
83
+ def download(
84
+ cls,
85
+ detector: str = "H1",
86
+ gps_start: int = 1126259462,
87
+ duration: int = 1024,
88
+ channel: Optional[str] = None,
89
+ ) -> TimeSeries:
90
+ """
91
+ Download open strain data from GWOSC for a given detector and GPS range.
92
+ """
93
+ print(
94
+ f"Downloading {detector} data [{gps_start} - {gps_start + duration}]..."
95
+ )
96
+ strain = TimeSeries.fetch_open_data(
97
+ detector, gps_start, gps_start + duration
98
+ )
99
+ if channel:
100
+ strain.channel = channel
101
+ return strain
102
+
103
+ @classmethod
104
+ def load(
105
+ cls,
106
+ detector: str = "H1",
107
+ gps_start: int = 1126259462,
108
+ duration: int = 1024,
109
+ segment_duration: int = 4,
110
+ segment_overlap: float = 0.5,
111
+ min_freq: Optional[float] = None,
112
+ max_freq: Optional[float] = None,
113
+ cache_file: str = "strain_cache.gwf",
114
+ channel: str = "H1:GWOSC-STRAIN",
115
+ ) -> "LVKData":
116
+ """
117
+ Load strain data from cache or download if needed, then compute PSDs.
118
+ """
119
+ if os.path.exists(cache_file):
120
+ try:
121
+ strain = TimeSeries.read(cache_file)
122
+ print(f"Loaded cached strain from '{cache_file}'")
123
+ except Exception as e:
124
+ print(
125
+ f"Failed to read cache '{cache_file}': {e}. Redownloading..."
126
+ )
127
+ os.remove(cache_file)
128
+ strain = cls.download(
129
+ detector, gps_start, duration, channel=channel
130
+ )
131
+ strain.write(cache_file)
132
+ print(f"Cached new strain to '{cache_file}'")
133
+ else:
134
+ strain = cls.download(
135
+ detector, gps_start, duration, channel=channel
136
+ )
137
+ strain.write(cache_file)
138
+ print(f"Cached strain to '{cache_file}'")
139
+
140
+ return cls(
141
+ strain=strain,
142
+ duration=duration,
143
+ segment_duration=segment_duration,
144
+ segment_overlap=segment_overlap,
145
+ min_freq=min_freq,
146
+ max_freq=max_freq,
147
+ )
148
+
149
+ def compute_median_psd(
150
+ self, n_segments: Optional[int] = None
151
+ ) -> np.ndarray:
152
+ """
153
+ Return the median PSD computed over the first `n_segments` segments.
154
+ """
155
+ if n_segments is None:
156
+ n_segments = self.n_segments
157
+ if n_segments > self.n_segments:
158
+ raise ValueError(
159
+ "n_segments exceeds available number of segments."
160
+ )
161
+ return np.median(self.psds[:n_segments, :], axis=0)
162
+
163
+ def plot_psd(self) -> plt.Figure:
164
+ """
165
+ Plot all individual-segment PSDs in gray and the median PSD in red.
166
+ """
167
+ fig, ax = plt.subplots(figsize=(8, 5))
168
+ ax.loglog(self.freqs, self.psds.T, color="gray", alpha=0.3)
169
+ ax.loglog(
170
+ self.freqs, self.median_psd, color="r", lw=2, label="Median PSD"
171
+ )
172
+ ax.set_xlabel("Frequency (Hz)")
173
+ ax.set_ylabel("PSD [strain^2/Hz]")
174
+ ax.set_title(f"PSD: {self.strain.channel}")
175
+ ax.grid(True, which="both", ls=":")
176
+ ax.legend()
177
+ return fig