gwsim 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. gwsim/__init__.py +11 -0
  2. gwsim/__main__.py +8 -0
  3. gwsim/cli/__init__.py +0 -0
  4. gwsim/cli/config.py +88 -0
  5. gwsim/cli/default_config.py +56 -0
  6. gwsim/cli/main.py +101 -0
  7. gwsim/cli/merge.py +150 -0
  8. gwsim/cli/repository/__init__.py +0 -0
  9. gwsim/cli/repository/create.py +91 -0
  10. gwsim/cli/repository/delete.py +51 -0
  11. gwsim/cli/repository/download.py +54 -0
  12. gwsim/cli/repository/list_depositions.py +63 -0
  13. gwsim/cli/repository/main.py +38 -0
  14. gwsim/cli/repository/metadata/__init__.py +0 -0
  15. gwsim/cli/repository/metadata/main.py +24 -0
  16. gwsim/cli/repository/metadata/update.py +58 -0
  17. gwsim/cli/repository/publish.py +52 -0
  18. gwsim/cli/repository/upload.py +74 -0
  19. gwsim/cli/repository/utils.py +47 -0
  20. gwsim/cli/repository/verify.py +61 -0
  21. gwsim/cli/simulate.py +220 -0
  22. gwsim/cli/simulate_utils.py +596 -0
  23. gwsim/cli/utils/__init__.py +85 -0
  24. gwsim/cli/utils/checkpoint.py +178 -0
  25. gwsim/cli/utils/config.py +347 -0
  26. gwsim/cli/utils/hash.py +23 -0
  27. gwsim/cli/utils/retry.py +62 -0
  28. gwsim/cli/utils/simulation_plan.py +439 -0
  29. gwsim/cli/utils/template.py +56 -0
  30. gwsim/cli/utils/utils.py +149 -0
  31. gwsim/cli/validate.py +255 -0
  32. gwsim/data/__init__.py +8 -0
  33. gwsim/data/serialize/__init__.py +9 -0
  34. gwsim/data/serialize/decoder.py +59 -0
  35. gwsim/data/serialize/encoder.py +44 -0
  36. gwsim/data/serialize/serializable.py +33 -0
  37. gwsim/data/time_series/__init__.py +3 -0
  38. gwsim/data/time_series/inject.py +104 -0
  39. gwsim/data/time_series/time_series.py +355 -0
  40. gwsim/data/time_series/time_series_list.py +182 -0
  41. gwsim/detector/__init__.py +8 -0
  42. gwsim/detector/base.py +156 -0
  43. gwsim/detector/detectors/E1_2L_Aligned_Sardinia.interferometer +22 -0
  44. gwsim/detector/detectors/E1_2L_Misaligned_Sardinia.interferometer +22 -0
  45. gwsim/detector/detectors/E1_Triangle_EMR.interferometer +19 -0
  46. gwsim/detector/detectors/E1_Triangle_Sardinia.interferometer +19 -0
  47. gwsim/detector/detectors/E2_2L_Aligned_EMR.interferometer +22 -0
  48. gwsim/detector/detectors/E2_2L_Misaligned_EMR.interferometer +22 -0
  49. gwsim/detector/detectors/E2_Triangle_EMR.interferometer +19 -0
  50. gwsim/detector/detectors/E2_Triangle_Sardinia.interferometer +19 -0
  51. gwsim/detector/detectors/E3_Triangle_EMR.interferometer +19 -0
  52. gwsim/detector/detectors/E3_Triangle_Sardinia.interferometer +19 -0
  53. gwsim/detector/noise_curves/ET_10_HF_psd.txt +3000 -0
  54. gwsim/detector/noise_curves/ET_10_full_cryo_psd.txt +3000 -0
  55. gwsim/detector/noise_curves/ET_15_HF_psd.txt +3000 -0
  56. gwsim/detector/noise_curves/ET_15_full_cryo_psd.txt +3000 -0
  57. gwsim/detector/noise_curves/ET_20_HF_psd.txt +3000 -0
  58. gwsim/detector/noise_curves/ET_20_full_cryo_psd.txt +3000 -0
  59. gwsim/detector/noise_curves/ET_D_psd.txt +3000 -0
  60. gwsim/detector/utils.py +90 -0
  61. gwsim/glitch/__init__.py +7 -0
  62. gwsim/glitch/base.py +69 -0
  63. gwsim/mixin/__init__.py +8 -0
  64. gwsim/mixin/detector.py +203 -0
  65. gwsim/mixin/gwf.py +192 -0
  66. gwsim/mixin/population_reader.py +175 -0
  67. gwsim/mixin/randomness.py +107 -0
  68. gwsim/mixin/time_series.py +295 -0
  69. gwsim/mixin/waveform.py +47 -0
  70. gwsim/noise/__init__.py +19 -0
  71. gwsim/noise/base.py +134 -0
  72. gwsim/noise/bilby_stationary_gaussian.py +117 -0
  73. gwsim/noise/colored_noise.py +275 -0
  74. gwsim/noise/correlated_noise.py +257 -0
  75. gwsim/noise/pycbc_stationary_gaussian.py +112 -0
  76. gwsim/noise/stationary_gaussian.py +44 -0
  77. gwsim/noise/white_noise.py +51 -0
  78. gwsim/repository/__init__.py +0 -0
  79. gwsim/repository/zenodo.py +269 -0
  80. gwsim/signal/__init__.py +11 -0
  81. gwsim/signal/base.py +137 -0
  82. gwsim/signal/cbc.py +61 -0
  83. gwsim/simulator/__init__.py +7 -0
  84. gwsim/simulator/base.py +315 -0
  85. gwsim/simulator/state.py +85 -0
  86. gwsim/utils/__init__.py +11 -0
  87. gwsim/utils/datetime_parser.py +44 -0
  88. gwsim/utils/et_2l_geometry.py +165 -0
  89. gwsim/utils/io.py +167 -0
  90. gwsim/utils/log.py +145 -0
  91. gwsim/utils/population.py +48 -0
  92. gwsim/utils/random.py +69 -0
  93. gwsim/utils/retry.py +75 -0
  94. gwsim/utils/triangular_et_geometry.py +164 -0
  95. gwsim/version.py +7 -0
  96. gwsim/waveform/__init__.py +7 -0
  97. gwsim/waveform/factory.py +83 -0
  98. gwsim/waveform/pycbc_wrapper.py +37 -0
  99. gwsim-0.1.0.dist-info/METADATA +157 -0
  100. gwsim-0.1.0.dist-info/RECORD +103 -0
  101. gwsim-0.1.0.dist-info/WHEEL +4 -0
  102. gwsim-0.1.0.dist-info/entry_points.txt +2 -0
  103. gwsim-0.1.0.dist-info/licenses/LICENSE +21 -0
gwsim/noise/base.py ADDED
@@ -0,0 +1,134 @@
1
+ """Base class for noise simulators."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import cast
7
+
8
+ import numpy as np
9
+
10
+ from gwsim.cli.utils.utils import get_file_name_from_template_with_dict
11
+ from gwsim.mixin.detector import DetectorMixin
12
+ from gwsim.mixin.gwf import GWFOutputMixin
13
+ from gwsim.mixin.randomness import RandomnessMixin
14
+ from gwsim.mixin.time_series import TimeSeriesMixin
15
+ from gwsim.simulator.base import Simulator
16
+ from gwsim.simulator.state import StateAttribute
17
+ from gwsim.utils.random import get_state
18
+
19
+
20
+ class NoiseSimulator(
21
+ RandomnessMixin, DetectorMixin, TimeSeriesMixin, GWFOutputMixin, Simulator
22
+ ): # pylint: disable=duplicate-code
23
+ """Base class for noise simulators."""
24
+
25
+ start_time = StateAttribute(0)
26
+
27
+ def __init__(
28
+ self,
29
+ sampling_frequency: float,
30
+ duration: float,
31
+ start_time: float = 0,
32
+ max_samples: int | None = None,
33
+ seed: int | None = None,
34
+ detectors: list[str] | None = None,
35
+ **kwargs,
36
+ ) -> None:
37
+ """Initialize the base noise simulator.
38
+
39
+ Args:
40
+ sampling_frequency: Sampling frequency of the noise in Hz.
41
+ duration: Duration of each noise segment in seconds.
42
+ start_time: Start time of the first noise segment in GPS seconds. Default is 0
43
+ max_samples: Maximum number of samples to generate. None means infinite.
44
+ seed: Seed for the random number generator. If None, the RNG is not initialized.
45
+ detectors: List of detector names. Default is None.
46
+ **kwargs: Additional arguments absorbed by subclasses and mixins.
47
+ """
48
+ super().__init__(
49
+ sampling_frequency=sampling_frequency,
50
+ duration=duration,
51
+ start_time=start_time,
52
+ max_samples=max_samples,
53
+ seed=seed,
54
+ detectors=detectors,
55
+ **kwargs,
56
+ )
57
+
58
+ def save_batch(self, batch: np.ndarray, file_name: str | Path, overwrite: bool = False, **kwargs) -> None:
59
+ """Save a batch of noise data to a file.
60
+
61
+ Args:
62
+ batch: Batch of noise data to save.
63
+ file_name: Name of the output file.
64
+ overwrite: Whether to overwrite existing files. Default is False.
65
+ **kwargs: Additional arguments for the output mixin.
66
+
67
+ Raises:
68
+ NotImplementedError: If the output mixin does not implement this method.
69
+ """
70
+ suffix = Path(file_name).suffix.lower()
71
+ if suffix == ".gwf":
72
+ save_function = self.save_batch_to_gwf
73
+ else:
74
+ raise NotImplementedError(f"Output format {suffix} not supported by the output mixin.")
75
+
76
+ # Check whether the file_name contains the {detector} placeholder
77
+ if "{detector}" in str(file_name).replace(" ", ""):
78
+ # Check whether self.detectors is set
79
+ if self.detectors is None:
80
+ raise ValueError(
81
+ "The file_name contains the {detector} placeholder, but the simulator does not have detectors set."
82
+ )
83
+ # Check whether the dimension of batch matches number of detectors
84
+ if len(batch.shape) == 1:
85
+ batch = batch[None, :]
86
+ # Check whether the length of batch matches number of detectors
87
+ if batch.shape[0] != len(self.detectors):
88
+ raise ValueError(
89
+ f"The batch has {batch.shape[0]} channels, but the simulator has {len(self.detectors)} detectors."
90
+ )
91
+ # Save each detector's data separately
92
+ for i, detector in enumerate(self.detectors):
93
+ detector_file_name = get_file_name_from_template_with_dict(
94
+ template=str(file_name),
95
+ values={
96
+ "detector": detector,
97
+ },
98
+ )
99
+ self.save_batch_to_gwf(
100
+ batch=batch[i, :],
101
+ file_path=detector_file_name,
102
+ overwrite=overwrite,
103
+ **kwargs,
104
+ )
105
+ else:
106
+ save_function(
107
+ batch=batch,
108
+ file_path=file_name,
109
+ overwrite=overwrite,
110
+ **kwargs,
111
+ )
112
+
113
+ @property
114
+ def metadata(self) -> dict:
115
+ """Get a dictionary of metadata.
116
+ This can be overridden by the subclass.
117
+
118
+ Returns:
119
+ dict: A dictionary of metadata.
120
+ """
121
+ # Get metadata from all parent classes using cooperative inheritance
122
+ metadata = super().metadata
123
+
124
+ return metadata
125
+
126
+ def update_state(self) -> None:
127
+ """Update internal state after each sample generation.
128
+
129
+ This method can be overridden by subclasses to update any internal state
130
+ after generating a sample. The default implementation does nothing.
131
+ """
132
+ self.counter = cast(int, self.counter) + 1
133
+ self.start_time += self.duration
134
+ self.rng_state = get_state()
@@ -0,0 +1,117 @@
1
+ """Stationary Gaussian noise simulator using Bilby."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ import numpy as np
8
+ from bilby.core.utils.random import Generator as BilbyGenerator
9
+ from bilby.core.utils.series import infft
10
+ from bilby.gw.detector.psd import PowerSpectralDensity
11
+ from numpy.random import Generator
12
+
13
+ from gwsim.noise.stationary_gaussian import StationaryGaussianNoiseSimulator
14
+
15
+
16
+ class BilbyStationaryGaussianNoiseSimulator(
17
+ StationaryGaussianNoiseSimulator
18
+ ): # pylint: disable=too-many-ancestors, duplicate-code
19
+ """Stationary Gaussian noise simulator using Bilby."""
20
+
21
+ def __init__(
22
+ self,
23
+ frequency_array: np.ndarray[Any, np.dtype[Any]] | None = None,
24
+ psd_array: np.ndarray[Any, np.dtype[Any]] | None = None,
25
+ psd_file: str | None = None,
26
+ sampling_frequency: float = 4096,
27
+ duration: float = 4,
28
+ start_time: float = 0,
29
+ max_samples: int | None = None,
30
+ seed: int | None = None,
31
+ detectors: list[str] | None = None,
32
+ **kwargs,
33
+ ):
34
+ """Initialize Bilby stationary Gaussian noise simulator.
35
+
36
+ Args:
37
+ frequency_array: Frequency array for the PSD.
38
+ psd_array: PSD values corresponding to the frequency array.
39
+ psd_file: Path to a file containing the PSD.
40
+ psd: Path to PSD file or numpy array with PSD values, or label of PSD
41
+ sampling_frequency: Sampling frequency in Hz. Default is 4096.
42
+ duration: Duration of each segment in seconds. Default is 4.
43
+ start_time: Start time in GPS seconds. Default is 0.
44
+ max_samples: Maximum number of samples. None means infinite.
45
+ seed: Random seed. If None, RNG is not initialized.
46
+ detectors: List of detector names. Default is None.
47
+ **kwargs: Additional arguments.
48
+ """
49
+ super().__init__(
50
+ sampling_frequency=sampling_frequency,
51
+ duration=duration,
52
+ start_time=start_time,
53
+ max_samples=max_samples,
54
+ seed=seed,
55
+ detectors=detectors,
56
+ **kwargs,
57
+ )
58
+ self.frequency_array = frequency_array
59
+ self.psd_array = psd_array
60
+ self.psd_file = psd_file
61
+ self._setup_psd()
62
+
63
+ @property
64
+ def frequency_array(self) -> np.ndarray[Any, np.dtype[Any]] | None:
65
+ """Get the frequency array."""
66
+ return self._frequency_array
67
+
68
+ @frequency_array.setter
69
+ def frequency_array(self, value: np.ndarray[Any, np.dtype[Any]] | None) -> None:
70
+ """Set the frequency array."""
71
+ if value is None:
72
+ self._frequency_array = np.arange(int(self.sampling_frequency * self.duration // 2) + 1) / self.duration
73
+ else:
74
+ self._frequency_array = value
75
+
76
+ def _setup_psd(self) -> None:
77
+ if self.frequency_array is not None and self.psd_array is not None:
78
+ self.psd = PowerSpectralDensity(frequency_array=self.frequency_array, psd_array=self.psd_array)
79
+ elif self.psd_file is not None:
80
+ self.psd = PowerSpectralDensity.from_power_spectral_density_file(self.psd_file)
81
+ else:
82
+ raise ValueError("Either frequency_array and psd_array or psd_file must be provided.")
83
+
84
+ @property
85
+ def rng(self) -> Generator | None:
86
+ """Get the random number generator.
87
+
88
+ Returns:
89
+ Random number generator instance or None if no seed was set.
90
+ """
91
+ return self._rng
92
+
93
+ @rng.setter
94
+ def rng(self, value: Generator | None) -> None:
95
+ """Set the random number generator.
96
+
97
+ Args:
98
+ value: Random number generator instance.
99
+ """
100
+ self._rng = value
101
+ # Override the bilby RNG
102
+ if value is not None:
103
+ BilbyGenerator.rng = value
104
+
105
+ def simulate(self, *args, **kwargs) -> np.ndarray:
106
+ """Simulate a noise segment.
107
+
108
+ Returns:
109
+ np.ndarray: Simulated noise segment as a numpy array.
110
+ """
111
+ if self.rng is None:
112
+ raise RuntimeError("Random number generator not initialized. Set seed in constructor.")
113
+ # Placeholder implementation; replace with actual Bilby PSD-based noise generation
114
+ frequency_domain_strain, _frequencies = self.psd.get_noise_realisation(
115
+ sampling_frequency=self.sampling_frequency, duration=self.duration
116
+ )
117
+ return infft(frequency_domain_strain=frequency_domain_strain, sampling_frequency=self.sampling_frequency)
@@ -0,0 +1,275 @@
1
+ """Colored noise simulator for gravitational wave detectors."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from pathlib import Path
7
+
8
+ import numpy as np
9
+ from scipy.interpolate import interp1d
10
+ from scipy.signal.windows import tukey
11
+
12
+ from gwsim.data.time_series.time_series import TimeSeries
13
+ from gwsim.data.time_series.time_series_list import TimeSeriesList
14
+ from gwsim.noise.base import NoiseSimulator
15
+ from gwsim.simulator.state import StateAttribute
16
+
17
+ logger = logging.getLogger("gwsim")
18
+
19
+ # The default base path for PSD files
20
+ DEFAULT_PSD_PATH = Path(__file__).parent.parent / "detector/noise_curves"
21
+
22
+
23
+ class ColoredNoiseSimulator(NoiseSimulator): # pylint: disable=too-many-instance-attributes
24
+ """Colored noise simulator for gravitational wave detectors.
25
+
26
+ This class generates noise time series with a specified power spectral density (PSD).
27
+ It uses an overlap-add method with windowing to produce smooth, continuous time series
28
+ across segment boundaries.
29
+
30
+ The simulator maintains state between batches to ensure continuity of the noise
31
+ time series across multiple calls to simulate().
32
+ """
33
+
34
+ # State attribute to track the previous strain buffer for continuity
35
+ previous_strain = StateAttribute(default=None)
36
+
37
+ def __init__( # pylint: disable=too-many-arguments,too-many-positional-arguments,duplicate-code
38
+ self,
39
+ psd_file: str | Path,
40
+ detectors: list[str],
41
+ sampling_frequency: float = 4096,
42
+ duration: float = 4,
43
+ start_time: float = 0,
44
+ max_samples: int | None = None,
45
+ seed: int | None = None,
46
+ low_frequency_cutoff: float = 2.0,
47
+ high_frequency_cutoff: float | None = None,
48
+ **kwargs,
49
+ ):
50
+ """Initialize the colored noise simulator.
51
+
52
+ Args:
53
+ psd_file: Path to file containing Power Spectral Density array with shape (N, 2),
54
+ where the first column is frequency (Hz) and the second is PSD values.
55
+ detectors: List of detector names (e.g., ['H1', 'L1']).
56
+ sampling_frequency: Sampling frequency in Hz. Default is 4096.
57
+ duration: Duration of each noise segment in seconds. Default is 4.
58
+ start_time: GPS start time for the time series. Default is 0.
59
+ max_samples: Maximum number of samples to generate. None means infinite.
60
+ seed: Seed for random number generation. If None, RNG is not initialized.
61
+ low_frequency_cutoff: Lower frequency cutoff in Hz. Default is 2.0.
62
+ high_frequency_cutoff: Upper frequency cutoff in Hz. Default is Nyquist frequency.
63
+ **kwargs: Additional arguments passed to parent classes.
64
+
65
+ Raises:
66
+ ValueError: If detectors list is empty.
67
+ ValueError: If duration is too short for proper noise generation.
68
+ """
69
+ if not detectors or len(detectors) == 0:
70
+ raise ValueError("detectors must contain at least one detector.")
71
+
72
+ super().__init__( # pylint: disable=duplicate-code
73
+ sampling_frequency=sampling_frequency,
74
+ duration=duration,
75
+ start_time=start_time,
76
+ max_samples=max_samples,
77
+ seed=seed,
78
+ detectors=detectors,
79
+ **kwargs,
80
+ )
81
+
82
+ self.psd_file = psd_file
83
+ self.low_frequency_cutoff = low_frequency_cutoff
84
+ self.high_frequency_cutoff = (
85
+ high_frequency_cutoff
86
+ if (high_frequency_cutoff is not None and high_frequency_cutoff <= sampling_frequency / 2)
87
+ else sampling_frequency // 2
88
+ )
89
+
90
+ # Initialize noise generation properties
91
+ self._n_det = len(detectors)
92
+ self._initialize_window_properties()
93
+ self._initialize_frequency_properties()
94
+ self._initialize_psd()
95
+
96
+ # Initialize the previous strain buffer (will be populated on first simulate call)
97
+ self.previous_strain = np.zeros((self._n_det, self._n_chunk))
98
+ self._temp_strain_buffer: np.ndarray | None = None
99
+
100
+ def _initialize_window_properties(self) -> None:
101
+ """Initialize window properties for connecting noise realizations.
102
+
103
+ Raises:
104
+ ValueError: If the duration is too short for proper noise generation.
105
+ """
106
+ self._t_window = 2048
107
+ self._f_window = 1.0 / self._t_window
108
+ self._t_overlap = self._t_window / 2.0
109
+ self._n_overlap = int(self._t_overlap * self.sampling_frequency.value)
110
+
111
+ # Create overlap windows for smooth transitions
112
+ t_overlap_array = np.linspace(0, self._t_overlap, self._n_overlap)
113
+ self._w0 = 0.5 + np.cos(2 * np.pi * self._f_window * t_overlap_array) / 2
114
+ self._w1 = 0.5 + np.sin(2 * np.pi * self._f_window * t_overlap_array - np.pi / 2) / 2
115
+
116
+ # Safety check to ensure proper noise generation
117
+ if self.duration.value < self._t_window / 2:
118
+ raise ValueError(
119
+ f"Duration ({self.duration.value:.1f} seconds) must be at least "
120
+ f"{self._t_window / 2:.1f} seconds to ensure noise continuity."
121
+ )
122
+
123
+ def _initialize_frequency_properties(self) -> None:
124
+ """Initialize frequency and time properties for noise generation."""
125
+ self._t_chunk = self._t_window
126
+ self._df_chunk = 1.0 / self._t_chunk
127
+ self._n_chunk = int(self._t_chunk * self.sampling_frequency.value)
128
+ self._k_min_chunk = int(self.low_frequency_cutoff / self._df_chunk)
129
+ self._k_max_chunk = int(self.high_frequency_cutoff / self._df_chunk) + 1
130
+ self._frequency_chunk = np.arange(0.0, self._n_chunk / 2.0 + 1) * self._df_chunk
131
+ self._n_freq_chunk = len(self._frequency_chunk[self._k_min_chunk : self._k_max_chunk])
132
+ self._dt = 1.0 / self.sampling_frequency.value
133
+
134
+ def _load_spectral_data(self, file_path: str | Path) -> np.ndarray: # pylint: disable=duplicate-code
135
+ """Load spectral data from file.
136
+
137
+ Args:
138
+ file_path: Path to file containing spectral data.
139
+
140
+ Returns:
141
+ Loaded array.
142
+
143
+ Raises:
144
+ ValueError: If file format is not supported.
145
+ TypeError: If file_path is not a string or Path.
146
+ """
147
+ if not isinstance(file_path, (str, Path)):
148
+ raise TypeError("file_path must be a string or Path.")
149
+
150
+ path = Path(file_path)
151
+ if not path.exists():
152
+ psd_dir = DEFAULT_PSD_PATH
153
+ path = next(iter(psd_dir.rglob(path.name)))
154
+
155
+ if path.suffix == ".npy":
156
+ return np.load(path)
157
+ if path.suffix == ".txt":
158
+ return np.loadtxt(path)
159
+ if path.suffix == ".csv":
160
+ return np.loadtxt(path, delimiter=",")
161
+ raise ValueError(f"Unsupported file format: {path.suffix}. Use .npy, .txt, or .csv.")
162
+
163
+ def _initialize_psd(self) -> None:
164
+ """Initialize PSD interpolation for the frequency range.
165
+
166
+ Raises:
167
+ ValueError: If PSD array doesn't have shape (N, 2).
168
+ """
169
+ psd_data = self._load_spectral_data(self.psd_file)
170
+
171
+ if psd_data.shape[1] != 2:
172
+ raise ValueError("PSD file must have shape (N, 2).")
173
+
174
+ # Interpolate the PSD to the relevant frequencies
175
+ freqs = self._frequency_chunk[self._k_min_chunk : self._k_max_chunk]
176
+ psd_interp = interp1d(psd_data[:, 0], psd_data[:, 1], bounds_error=False, fill_value="extrapolate")(freqs)
177
+
178
+ # Add a roll-off at the edges using a Tukey window
179
+ window = tukey(self._n_freq_chunk, alpha=1e-3)
180
+ self._psd = psd_interp * window
181
+
182
+ def _generate_single_realization(self) -> np.ndarray:
183
+ """Generate a single noise realization in the time domain.
184
+
185
+ Returns:
186
+ Time series array with shape (n_detectors, n_samples).
187
+ """
188
+ if self.rng is None:
189
+ raise RuntimeError("Random number generator not initialized. Set seed in constructor.")
190
+
191
+ freq_series = np.zeros((self._n_det, self._frequency_chunk.size), dtype=np.complex128)
192
+
193
+ # Generate white noise and color it with the PSD
194
+ white_strain = (
195
+ self.rng.standard_normal((self._n_det, self._n_freq_chunk))
196
+ + 1j * self.rng.standard_normal((self._n_det, self._n_freq_chunk))
197
+ ) / np.sqrt(2)
198
+ colored_strain = white_strain * np.sqrt(self._psd * 0.5 / self._df_chunk)
199
+ freq_series[:, self._k_min_chunk : self._k_max_chunk] += colored_strain
200
+
201
+ # Transform to time domain
202
+ time_series = np.fft.irfft(freq_series, n=self._n_chunk, axis=1) * self._df_chunk * self._n_chunk
203
+
204
+ return time_series
205
+
206
+ def _simulate(self, *args, **kwargs) -> TimeSeriesList:
207
+ """Simulate colored noise for all detectors.
208
+
209
+ Returns:
210
+ TimeSeriesList containing a single TimeSeries with shape (n_detectors, n_samples).
211
+ """
212
+ n_frame = int(self.duration.value * self.sampling_frequency.value)
213
+
214
+ # Load previous strain, or generate new if all zeros
215
+ if self.previous_strain.shape[-1] < self._n_overlap:
216
+ raise ValueError(
217
+ f"previous_strain has only {self.previous_strain.shape[-1]} samples per detector, "
218
+ f"but expected at least {self._n_overlap}."
219
+ )
220
+
221
+ strain_buffer = self.previous_strain[:, -self._n_chunk :]
222
+ if np.all(strain_buffer == 0):
223
+ strain_buffer = self._generate_single_realization()
224
+
225
+ # Apply the final part of the window
226
+ strain_buffer[:, -self._n_overlap :] *= self._w0
227
+
228
+ # Extend the strain buffer until it has more valid data than a single frame
229
+ while strain_buffer.shape[-1] - self._n_chunk - self._n_overlap < n_frame:
230
+ new_strain = self._generate_single_realization()
231
+ new_strain[:, : self._n_overlap] *= self._w1
232
+ new_strain[:, -self._n_overlap :] *= self._w0
233
+ strain_buffer[:, -self._n_overlap :] += new_strain[:, : self._n_overlap]
234
+ strain_buffer[:, -self._n_overlap :] *= 1 / np.sqrt(self._w0**2 + self._w1**2)
235
+ strain_buffer = np.concatenate((strain_buffer, new_strain[:, self._n_overlap :]), axis=1)
236
+
237
+ # Extract the frame data
238
+ output_strain = strain_buffer[:, self._n_chunk : (self._n_chunk + n_frame)]
239
+
240
+ # Store the output strain temporarily for state update
241
+ self._temp_strain_buffer = output_strain
242
+
243
+ return TimeSeriesList(
244
+ [TimeSeries(data=output_strain, start_time=self.start_time, sampling_frequency=self.sampling_frequency)]
245
+ )
246
+
247
+ def update_state(self) -> None:
248
+ """Update internal state after each sample generation.
249
+
250
+ Updates the previous_strain buffer to ensure continuity across batches.
251
+ """
252
+ # Call parent's update_state first (increments counter, advances start_time, saves rng_state)
253
+ super().update_state()
254
+
255
+ # Update the previous strain buffer for continuity
256
+ if self._temp_strain_buffer is not None:
257
+ self.previous_strain = self._temp_strain_buffer
258
+ self._temp_strain_buffer = None
259
+
260
+ @property
261
+ def metadata(self) -> dict:
262
+ """Get metadata including colored noise configuration.
263
+
264
+ Returns:
265
+ Dictionary containing metadata.
266
+ """
267
+ meta = super().metadata
268
+ meta["colored_noise"] = {
269
+ "arguments": {
270
+ "psd_file": str(self.psd_file),
271
+ "low_frequency_cutoff": self.low_frequency_cutoff,
272
+ "high_frequency_cutoff": self.high_frequency_cutoff,
273
+ }
274
+ }
275
+ return meta