gwsim 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gwsim/__init__.py +11 -0
- gwsim/__main__.py +8 -0
- gwsim/cli/__init__.py +0 -0
- gwsim/cli/config.py +88 -0
- gwsim/cli/default_config.py +56 -0
- gwsim/cli/main.py +101 -0
- gwsim/cli/merge.py +150 -0
- gwsim/cli/repository/__init__.py +0 -0
- gwsim/cli/repository/create.py +91 -0
- gwsim/cli/repository/delete.py +51 -0
- gwsim/cli/repository/download.py +54 -0
- gwsim/cli/repository/list_depositions.py +63 -0
- gwsim/cli/repository/main.py +38 -0
- gwsim/cli/repository/metadata/__init__.py +0 -0
- gwsim/cli/repository/metadata/main.py +24 -0
- gwsim/cli/repository/metadata/update.py +58 -0
- gwsim/cli/repository/publish.py +52 -0
- gwsim/cli/repository/upload.py +74 -0
- gwsim/cli/repository/utils.py +47 -0
- gwsim/cli/repository/verify.py +61 -0
- gwsim/cli/simulate.py +220 -0
- gwsim/cli/simulate_utils.py +596 -0
- gwsim/cli/utils/__init__.py +85 -0
- gwsim/cli/utils/checkpoint.py +178 -0
- gwsim/cli/utils/config.py +347 -0
- gwsim/cli/utils/hash.py +23 -0
- gwsim/cli/utils/retry.py +62 -0
- gwsim/cli/utils/simulation_plan.py +439 -0
- gwsim/cli/utils/template.py +56 -0
- gwsim/cli/utils/utils.py +149 -0
- gwsim/cli/validate.py +255 -0
- gwsim/data/__init__.py +8 -0
- gwsim/data/serialize/__init__.py +9 -0
- gwsim/data/serialize/decoder.py +59 -0
- gwsim/data/serialize/encoder.py +44 -0
- gwsim/data/serialize/serializable.py +33 -0
- gwsim/data/time_series/__init__.py +3 -0
- gwsim/data/time_series/inject.py +104 -0
- gwsim/data/time_series/time_series.py +355 -0
- gwsim/data/time_series/time_series_list.py +182 -0
- gwsim/detector/__init__.py +8 -0
- gwsim/detector/base.py +156 -0
- gwsim/detector/detectors/E1_2L_Aligned_Sardinia.interferometer +22 -0
- gwsim/detector/detectors/E1_2L_Misaligned_Sardinia.interferometer +22 -0
- gwsim/detector/detectors/E1_Triangle_EMR.interferometer +19 -0
- gwsim/detector/detectors/E1_Triangle_Sardinia.interferometer +19 -0
- gwsim/detector/detectors/E2_2L_Aligned_EMR.interferometer +22 -0
- gwsim/detector/detectors/E2_2L_Misaligned_EMR.interferometer +22 -0
- gwsim/detector/detectors/E2_Triangle_EMR.interferometer +19 -0
- gwsim/detector/detectors/E2_Triangle_Sardinia.interferometer +19 -0
- gwsim/detector/detectors/E3_Triangle_EMR.interferometer +19 -0
- gwsim/detector/detectors/E3_Triangle_Sardinia.interferometer +19 -0
- gwsim/detector/noise_curves/ET_10_HF_psd.txt +3000 -0
- gwsim/detector/noise_curves/ET_10_full_cryo_psd.txt +3000 -0
- gwsim/detector/noise_curves/ET_15_HF_psd.txt +3000 -0
- gwsim/detector/noise_curves/ET_15_full_cryo_psd.txt +3000 -0
- gwsim/detector/noise_curves/ET_20_HF_psd.txt +3000 -0
- gwsim/detector/noise_curves/ET_20_full_cryo_psd.txt +3000 -0
- gwsim/detector/noise_curves/ET_D_psd.txt +3000 -0
- gwsim/detector/utils.py +90 -0
- gwsim/glitch/__init__.py +7 -0
- gwsim/glitch/base.py +69 -0
- gwsim/mixin/__init__.py +8 -0
- gwsim/mixin/detector.py +203 -0
- gwsim/mixin/gwf.py +192 -0
- gwsim/mixin/population_reader.py +175 -0
- gwsim/mixin/randomness.py +107 -0
- gwsim/mixin/time_series.py +295 -0
- gwsim/mixin/waveform.py +47 -0
- gwsim/noise/__init__.py +19 -0
- gwsim/noise/base.py +134 -0
- gwsim/noise/bilby_stationary_gaussian.py +117 -0
- gwsim/noise/colored_noise.py +275 -0
- gwsim/noise/correlated_noise.py +257 -0
- gwsim/noise/pycbc_stationary_gaussian.py +112 -0
- gwsim/noise/stationary_gaussian.py +44 -0
- gwsim/noise/white_noise.py +51 -0
- gwsim/repository/__init__.py +0 -0
- gwsim/repository/zenodo.py +269 -0
- gwsim/signal/__init__.py +11 -0
- gwsim/signal/base.py +137 -0
- gwsim/signal/cbc.py +61 -0
- gwsim/simulator/__init__.py +7 -0
- gwsim/simulator/base.py +315 -0
- gwsim/simulator/state.py +85 -0
- gwsim/utils/__init__.py +11 -0
- gwsim/utils/datetime_parser.py +44 -0
- gwsim/utils/et_2l_geometry.py +165 -0
- gwsim/utils/io.py +167 -0
- gwsim/utils/log.py +145 -0
- gwsim/utils/population.py +48 -0
- gwsim/utils/random.py +69 -0
- gwsim/utils/retry.py +75 -0
- gwsim/utils/triangular_et_geometry.py +164 -0
- gwsim/version.py +7 -0
- gwsim/waveform/__init__.py +7 -0
- gwsim/waveform/factory.py +83 -0
- gwsim/waveform/pycbc_wrapper.py +37 -0
- gwsim-0.1.0.dist-info/METADATA +157 -0
- gwsim-0.1.0.dist-info/RECORD +103 -0
- gwsim-0.1.0.dist-info/WHEEL +4 -0
- gwsim-0.1.0.dist-info/entry_points.txt +2 -0
- gwsim-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,596 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utilities for executing simulation plans via CLI.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import atexit
|
|
8
|
+
import copy
|
|
9
|
+
import logging
|
|
10
|
+
import signal
|
|
11
|
+
import time
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
import yaml
|
|
16
|
+
from tqdm import tqdm
|
|
17
|
+
|
|
18
|
+
from gwsim.cli.utils.checkpoint import CheckpointManager
|
|
19
|
+
from gwsim.cli.utils.config import SimulatorConfig, resolve_class_path
|
|
20
|
+
from gwsim.cli.utils.hash import compute_file_hash
|
|
21
|
+
from gwsim.cli.utils.simulation_plan import (
|
|
22
|
+
SimulationBatch,
|
|
23
|
+
SimulationPlan,
|
|
24
|
+
create_batch_metadata,
|
|
25
|
+
)
|
|
26
|
+
from gwsim.cli.utils.utils import handle_signal, import_attribute
|
|
27
|
+
from gwsim.simulator.base import Simulator
|
|
28
|
+
from gwsim.utils.io import get_file_name_from_template
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger("gwsim")
|
|
31
|
+
logger.setLevel(logging.DEBUG)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def retry_with_backoff(
|
|
35
|
+
func,
|
|
36
|
+
max_retries: int = 3,
|
|
37
|
+
initial_delay: float = 0.1,
|
|
38
|
+
backoff_factor: float = 2.0,
|
|
39
|
+
state_restore_func: Any = None,
|
|
40
|
+
) -> Any:
|
|
41
|
+
"""Retry a function with exponential backoff and optional state restoration.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
func: Callable to retry
|
|
45
|
+
max_retries: Maximum number of retries
|
|
46
|
+
initial_delay: Initial delay in seconds
|
|
47
|
+
backoff_factor: Multiplier for delay after each retry
|
|
48
|
+
state_restore_func: Optional callable to restore state before each retry.
|
|
49
|
+
Called before each retry attempt (not before first attempt).
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
Result of function call
|
|
53
|
+
|
|
54
|
+
Raises:
|
|
55
|
+
Exception: If all retries fail
|
|
56
|
+
"""
|
|
57
|
+
delay = initial_delay
|
|
58
|
+
last_exception: Exception | None = None
|
|
59
|
+
|
|
60
|
+
for attempt in range(max_retries + 1):
|
|
61
|
+
try:
|
|
62
|
+
return func()
|
|
63
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
|
64
|
+
last_exception = e
|
|
65
|
+
if attempt < max_retries:
|
|
66
|
+
logger.warning(
|
|
67
|
+
"Attempt %d/%d failed: %s. Retrying in %.2fs...",
|
|
68
|
+
attempt + 1,
|
|
69
|
+
max_retries + 1,
|
|
70
|
+
str(e),
|
|
71
|
+
delay,
|
|
72
|
+
exc_info=e,
|
|
73
|
+
)
|
|
74
|
+
time.sleep(delay)
|
|
75
|
+
delay *= backoff_factor
|
|
76
|
+
|
|
77
|
+
# Restore state before retry if function provided
|
|
78
|
+
if state_restore_func is not None:
|
|
79
|
+
try:
|
|
80
|
+
state_restore_func()
|
|
81
|
+
logger.debug("State restored before retry attempt %d", attempt + 2)
|
|
82
|
+
except Exception as restore_error:
|
|
83
|
+
logger.error("Failed to restore state before retry: %s", restore_error)
|
|
84
|
+
raise RuntimeError(f"Cannot retry: failed to restore state: {restore_error}") from restore_error
|
|
85
|
+
else:
|
|
86
|
+
logger.error("All %d attempts failed for batch: %s", max_retries + 1, str(e))
|
|
87
|
+
|
|
88
|
+
if last_exception is not None:
|
|
89
|
+
raise last_exception
|
|
90
|
+
raise RuntimeError("Unexpected retry failure")
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def update_metadata_index(
|
|
94
|
+
metadata_directory: Path,
|
|
95
|
+
output_files: list[Path],
|
|
96
|
+
metadata_file_name: str,
|
|
97
|
+
encoding: str = "utf-8",
|
|
98
|
+
) -> None:
|
|
99
|
+
"""Update the central metadata index file.
|
|
100
|
+
|
|
101
|
+
The index maps data file names to their corresponding metadata files,
|
|
102
|
+
enabling O(1) lookup to find metadata for a given data file.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
metadata_directory: Directory where metadata files are stored
|
|
106
|
+
output_files: List of output data file Paths
|
|
107
|
+
metadata_file_name: Name of the metadata file (e.g., "signal-0.metadata.yaml")
|
|
108
|
+
encoding: File encoding for reading/writing the index file
|
|
109
|
+
"""
|
|
110
|
+
index_file = metadata_directory / "index.yaml"
|
|
111
|
+
|
|
112
|
+
# Load existing index or create new
|
|
113
|
+
if index_file.exists():
|
|
114
|
+
try:
|
|
115
|
+
with index_file.open(encoding=encoding) as f:
|
|
116
|
+
index = yaml.safe_load(f) or {}
|
|
117
|
+
except (OSError, yaml.YAMLError) as e:
|
|
118
|
+
logger.warning("Failed to load metadata index: %s. Creating new index.", e)
|
|
119
|
+
index = {}
|
|
120
|
+
else:
|
|
121
|
+
index = {}
|
|
122
|
+
|
|
123
|
+
# Add entries for all output files
|
|
124
|
+
for output_file in output_files:
|
|
125
|
+
index[output_file.name] = metadata_file_name
|
|
126
|
+
logger.debug("Index entry: %s -> %s", output_file.name, metadata_file_name)
|
|
127
|
+
|
|
128
|
+
# Save updated index
|
|
129
|
+
try:
|
|
130
|
+
with index_file.open("w") as f:
|
|
131
|
+
yaml.safe_dump(index, f, default_flow_style=False, sort_keys=True)
|
|
132
|
+
logger.debug("Updated metadata index: %s", index_file)
|
|
133
|
+
except (OSError, yaml.YAMLError) as e:
|
|
134
|
+
logger.error("Failed to save metadata index: %s", e)
|
|
135
|
+
raise
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def instantiate_simulator(
|
|
139
|
+
simulator_config: SimulatorConfig,
|
|
140
|
+
simulator_name: str | None = None,
|
|
141
|
+
global_simulator_arguments: dict[str, Any] | None = None,
|
|
142
|
+
) -> Simulator:
|
|
143
|
+
"""Instantiate a simulator from configuration.
|
|
144
|
+
|
|
145
|
+
Creates a single simulator instance that will be reused across multiple batches.
|
|
146
|
+
The simulator maintains state (RNG, counters, etc.) across iterations.
|
|
147
|
+
|
|
148
|
+
Global simulator arguments are merged with simulator-specific arguments,
|
|
149
|
+
with simulator-specific arguments taking precedence.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
simulator_config: Configuration for this simulator
|
|
153
|
+
simulator_name: Name of the simulator (used for class path resolution)
|
|
154
|
+
global_simulator_arguments: Global fallback arguments for the simulator
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
Instantiated Simulator
|
|
158
|
+
|
|
159
|
+
Raises:
|
|
160
|
+
ImportError: If simulator class cannot be imported
|
|
161
|
+
TypeError: If simulator instantiation fails
|
|
162
|
+
"""
|
|
163
|
+
class_spec = simulator_config.class_
|
|
164
|
+
|
|
165
|
+
# Resolve short class names to full paths
|
|
166
|
+
class_spec = resolve_class_path(class_spec, simulator_name)
|
|
167
|
+
|
|
168
|
+
simulator_cls = import_attribute(class_spec)
|
|
169
|
+
|
|
170
|
+
# Merge global and simulator-specific arguments
|
|
171
|
+
# Simulator-specific arguments override global defaults
|
|
172
|
+
if global_simulator_arguments:
|
|
173
|
+
merged_arguments = {**global_simulator_arguments, **simulator_config.arguments}
|
|
174
|
+
else:
|
|
175
|
+
merged_arguments = simulator_config.arguments
|
|
176
|
+
|
|
177
|
+
# Normalize keys: convert hyphens to underscores (YAML uses hyphens, Python uses underscores)
|
|
178
|
+
normalized_arguments = {k.replace("-", "_"): v for k, v in merged_arguments.items()}
|
|
179
|
+
|
|
180
|
+
simulator = simulator_cls(**normalized_arguments)
|
|
181
|
+
|
|
182
|
+
logger.info("Instantiated simulator from class %s", class_spec)
|
|
183
|
+
return simulator
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def restore_batch_state(simulator: Simulator, batch: SimulationBatch) -> None:
|
|
187
|
+
"""Restore simulator state from batch metadata if available.
|
|
188
|
+
|
|
189
|
+
This is used when reproducing a specific batch. It restores the RNG state,
|
|
190
|
+
filter memory, and other stateful components that existed before this batch
|
|
191
|
+
was generated.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
simulator: Simulator instance
|
|
195
|
+
batch: SimulationBatch potentially containing state snapshot
|
|
196
|
+
|
|
197
|
+
Raises:
|
|
198
|
+
ValueError: If state restoration fails
|
|
199
|
+
"""
|
|
200
|
+
if batch.has_state_snapshot() and batch.pre_batch_state is not None:
|
|
201
|
+
logger.debug(
|
|
202
|
+
"[RESTORE] Batch %d: Restoring state from snapshot - state_keys=%s",
|
|
203
|
+
batch.batch_index,
|
|
204
|
+
list(batch.pre_batch_state.keys()),
|
|
205
|
+
)
|
|
206
|
+
try:
|
|
207
|
+
logger.debug(
|
|
208
|
+
"[RESTORE] Batch %d: Setting state dict - counter=%s",
|
|
209
|
+
batch.batch_index,
|
|
210
|
+
batch.pre_batch_state.get("counter"),
|
|
211
|
+
)
|
|
212
|
+
simulator.state = batch.pre_batch_state
|
|
213
|
+
logger.debug(
|
|
214
|
+
"[RESTORE] Batch %d: State restored successfully - new_counter=%s",
|
|
215
|
+
batch.batch_index,
|
|
216
|
+
simulator.counter,
|
|
217
|
+
)
|
|
218
|
+
except Exception as e:
|
|
219
|
+
logger.error("Failed to restore batch state: %s", e)
|
|
220
|
+
raise ValueError(f"Failed to restore state for batch {batch.batch_index}") from e
|
|
221
|
+
else:
|
|
222
|
+
logger.debug(
|
|
223
|
+
"[RESTORE] Batch %d: No pre-batch state snapshot available (fresh generation)",
|
|
224
|
+
batch.batch_index,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def save_batch_metadata(
|
|
229
|
+
simulator: Simulator,
|
|
230
|
+
batch: SimulationBatch,
|
|
231
|
+
metadata_directory: Path,
|
|
232
|
+
output_files: list[Path],
|
|
233
|
+
pre_batch_state: dict[str, Any] | None = None,
|
|
234
|
+
) -> None:
|
|
235
|
+
"""Save batch metadata including pre-batch simulator state and all output files.
|
|
236
|
+
|
|
237
|
+
The metadata file uses batch-indexed naming ({simulator_name}-{batch_index}.metadata.yaml)
|
|
238
|
+
to provide a single source of truth for all outputs from that batch. This handles
|
|
239
|
+
cases where a single batch generates multiple output files (e.g., one per detector).
|
|
240
|
+
|
|
241
|
+
An index file is also maintained to enable quick lookup of metadata for a given data file.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
simulator: Simulator instance
|
|
245
|
+
batch: SimulationBatch
|
|
246
|
+
metadata_directory: Directory to save metadata
|
|
247
|
+
output_files: List of Path objects for all output files generated by this batch
|
|
248
|
+
pre_batch_state: State of simulator before batch generation (for reproducibility).
|
|
249
|
+
If None, uses current simulator state.
|
|
250
|
+
"""
|
|
251
|
+
metadata_directory.mkdir(parents=True, exist_ok=True)
|
|
252
|
+
|
|
253
|
+
# Use provided pre_batch_state or current simulator state
|
|
254
|
+
state_to_save = pre_batch_state if pre_batch_state is not None else simulator.state
|
|
255
|
+
|
|
256
|
+
metadata = create_batch_metadata(
|
|
257
|
+
simulator_name=batch.simulator_name,
|
|
258
|
+
batch_index=batch.batch_index,
|
|
259
|
+
simulator_config=batch.simulator_config,
|
|
260
|
+
globals_config=batch.globals_config,
|
|
261
|
+
pre_batch_state=state_to_save,
|
|
262
|
+
source=batch.source,
|
|
263
|
+
author=batch.author,
|
|
264
|
+
email=batch.email,
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
# Add output files to metadata for easy discovery
|
|
268
|
+
# Store just the file names, not full paths
|
|
269
|
+
metadata["output_files"] = [f.name for f in output_files]
|
|
270
|
+
|
|
271
|
+
# Compute and add file hashes for integrity checking
|
|
272
|
+
file_hashes = {}
|
|
273
|
+
for output_file in output_files:
|
|
274
|
+
try:
|
|
275
|
+
file_hash = compute_file_hash(output_file)
|
|
276
|
+
file_hashes[output_file.name] = file_hash
|
|
277
|
+
logger.debug("Compute hash for %s: %s", output_file.name, file_hash)
|
|
278
|
+
except OSError as e:
|
|
279
|
+
logger.warning("Failed to compute hash for %s: %s", output_file.name, e)
|
|
280
|
+
# Continue without failing - metadata is still useful
|
|
281
|
+
|
|
282
|
+
metadata["file_hashes"] = file_hashes
|
|
283
|
+
|
|
284
|
+
metadata_file_name = f"{batch.simulator_name}-{batch.batch_index}.metadata.yaml"
|
|
285
|
+
metadata_file = metadata_directory / metadata_file_name
|
|
286
|
+
logger.debug("Saving batch metadata to %s with %d output files", metadata_file, len(output_files))
|
|
287
|
+
|
|
288
|
+
with metadata_file.open("w") as f:
|
|
289
|
+
yaml.safe_dump(metadata, f)
|
|
290
|
+
|
|
291
|
+
# Update the metadata index for quick lookup
|
|
292
|
+
update_metadata_index(metadata_directory, output_files, metadata_file_name)
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def process_batch(
|
|
296
|
+
simulator: Simulator,
|
|
297
|
+
batch_data: object,
|
|
298
|
+
batch: SimulationBatch,
|
|
299
|
+
output_directory: Path,
|
|
300
|
+
overwrite: bool,
|
|
301
|
+
) -> list[Path]:
|
|
302
|
+
"""Process and save a single batch of generated data.
|
|
303
|
+
|
|
304
|
+
A single batch may generate multiple output files (e.g., one per detector).
|
|
305
|
+
This function handles both single and multiple output files.
|
|
306
|
+
|
|
307
|
+
Args:
|
|
308
|
+
simulator: Simulator instance
|
|
309
|
+
batch_data: Generated batch data (may contain multiple outputs)
|
|
310
|
+
batch: SimulationBatch metadata
|
|
311
|
+
output_directory: Directory for output files
|
|
312
|
+
overwrite: Whether to overwrite existing files
|
|
313
|
+
|
|
314
|
+
Returns:
|
|
315
|
+
List of Path objects for all generated output files
|
|
316
|
+
"""
|
|
317
|
+
output_directory.mkdir(parents=True, exist_ok=True)
|
|
318
|
+
logger.debug(
|
|
319
|
+
"[PROCESS] Batch %s: Saving data - counter=%s, file_template=%s",
|
|
320
|
+
batch.batch_index,
|
|
321
|
+
simulator.counter,
|
|
322
|
+
batch.simulator_config.output.file_name,
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
# Build output configuration
|
|
326
|
+
output_config = batch.simulator_config.output
|
|
327
|
+
file_name_template = output_config.file_name
|
|
328
|
+
output_args = output_config.arguments.copy() if output_config.arguments else {}
|
|
329
|
+
|
|
330
|
+
# Save data with output directory
|
|
331
|
+
logger.debug(
|
|
332
|
+
"Saving batch data for %s batch %d",
|
|
333
|
+
batch.simulator_name,
|
|
334
|
+
batch.batch_index,
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
# Resolve the output file names (may be multiple if template contains arrays)
|
|
338
|
+
output_files = get_file_name_from_template(
|
|
339
|
+
template=file_name_template,
|
|
340
|
+
instance=simulator,
|
|
341
|
+
output_directory=output_directory,
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
# Normalize to list of Paths
|
|
345
|
+
if isinstance(output_files, Path):
|
|
346
|
+
output_files_list = [output_files]
|
|
347
|
+
else:
|
|
348
|
+
# If it's an array (multiple detectors), flatten it
|
|
349
|
+
output_files_list = [Path(str(f)) for f in output_files.flatten()]
|
|
350
|
+
|
|
351
|
+
logger.debug(
|
|
352
|
+
"[PROCESS] Batch %s: Resolved filenames - %s", batch.batch_index, [str(f.name) for f in output_files_list]
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
simulator.save_data(
|
|
356
|
+
data=batch_data,
|
|
357
|
+
file_name=file_name_template,
|
|
358
|
+
output_directory=output_directory,
|
|
359
|
+
overwrite=overwrite,
|
|
360
|
+
**output_args,
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
logger.debug("[PROCESS] Batch %s: Data saved - counter=%s", batch.batch_index, simulator.counter)
|
|
364
|
+
|
|
365
|
+
return output_files_list
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def setup_signal_handlers(checkpoint_dirs: list[Path]) -> None:
|
|
369
|
+
"""Set up signal handlers for graceful shutdown.
|
|
370
|
+
|
|
371
|
+
Args:
|
|
372
|
+
checkpoint_dirs: List of checkpoint directories to clean up
|
|
373
|
+
"""
|
|
374
|
+
|
|
375
|
+
def cleanup_checkpoints():
|
|
376
|
+
"""Clean up temporary checkpoint files."""
|
|
377
|
+
for checkpoint_dir in checkpoint_dirs:
|
|
378
|
+
for backup_file in checkpoint_dir.glob("*.bak"):
|
|
379
|
+
try:
|
|
380
|
+
backup_file.unlink()
|
|
381
|
+
logger.debug("Cleaned up backup file: %s", backup_file)
|
|
382
|
+
except OSError as e:
|
|
383
|
+
logger.warning("Failed to clean up backup file %s: %s", backup_file, e)
|
|
384
|
+
|
|
385
|
+
atexit.register(cleanup_checkpoints)
|
|
386
|
+
signal.signal(signal.SIGINT, handle_signal(cleanup_checkpoints))
|
|
387
|
+
signal.signal(signal.SIGTERM, handle_signal(cleanup_checkpoints))
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
def validate_plan(plan: SimulationPlan) -> None:
|
|
391
|
+
"""Validate simulation plan before execution.
|
|
392
|
+
|
|
393
|
+
Args:
|
|
394
|
+
plan: SimulationPlan to validate
|
|
395
|
+
|
|
396
|
+
Raises:
|
|
397
|
+
ValueError: If plan validation fails
|
|
398
|
+
"""
|
|
399
|
+
logger.info("Validating simulation plan with %d batches", plan.total_batches)
|
|
400
|
+
|
|
401
|
+
if plan.total_batches == 0:
|
|
402
|
+
raise ValueError("Simulation plan contains no batches")
|
|
403
|
+
|
|
404
|
+
# Validate each batch
|
|
405
|
+
for batch in plan.batches:
|
|
406
|
+
if not batch.simulator_name:
|
|
407
|
+
raise ValueError("Batch has empty simulator name")
|
|
408
|
+
if batch.batch_index < 0:
|
|
409
|
+
raise ValueError(f"Batch {batch.batch_index} has invalid index")
|
|
410
|
+
|
|
411
|
+
# Validate output configuration
|
|
412
|
+
output_config = batch.simulator_config.output
|
|
413
|
+
if not output_config.file_name:
|
|
414
|
+
raise ValueError(f"Batch {batch.simulator_name}-{batch.batch_index} missing file_name")
|
|
415
|
+
|
|
416
|
+
logger.info("Simulation plan validation completed successfully")
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
def execute_plan( # pylint: disable=too-many-locals
|
|
420
|
+
plan: SimulationPlan,
|
|
421
|
+
output_directory: Path,
|
|
422
|
+
metadata_directory: Path,
|
|
423
|
+
overwrite: bool,
|
|
424
|
+
max_retries: int = 3,
|
|
425
|
+
) -> None:
|
|
426
|
+
"""Execute a complete simulation plan.
|
|
427
|
+
|
|
428
|
+
The key insight: Simulators are stateful objects. Each simulator is instantiated
|
|
429
|
+
once and then generates multiple batches by calling next() repeatedly. State
|
|
430
|
+
(RNG, counters, filters) accumulates across batches.
|
|
431
|
+
|
|
432
|
+
Checkpoint behavior:
|
|
433
|
+
1. After each successfully completed batch, save checkpoint with updated state
|
|
434
|
+
2. Checkpoint contains: completed batch indices, simulator state
|
|
435
|
+
3. On next run, already-completed batches are skipped (resumption)
|
|
436
|
+
4. On successful completion of all batches, checkpoint is cleaned up
|
|
437
|
+
|
|
438
|
+
Workflow:
|
|
439
|
+
1. Group batches by simulator name
|
|
440
|
+
2. Load checkpoint to find already-completed batches
|
|
441
|
+
3. For each simulator:
|
|
442
|
+
a. Create ONE simulator instance
|
|
443
|
+
b. For each batch of that simulator:
|
|
444
|
+
- Skip if already completed (from checkpoint)
|
|
445
|
+
- Restore state if reproducing from metadata
|
|
446
|
+
- Call next(simulator) to generate batch (increments state)
|
|
447
|
+
- Save batch output and metadata
|
|
448
|
+
- Save checkpoint with updated state (for resumption)
|
|
449
|
+
|
|
450
|
+
Args:
|
|
451
|
+
plan: SimulationPlan to execute
|
|
452
|
+
output_directory: Directory for output files
|
|
453
|
+
metadata_directory: Directory for metadata files
|
|
454
|
+
overwrite: Whether to overwrite existing files
|
|
455
|
+
max_retries: Maximum retries per batch
|
|
456
|
+
"""
|
|
457
|
+
logger.info("Executing simulation plan: %d batches", plan.total_batches)
|
|
458
|
+
|
|
459
|
+
validate_plan(plan)
|
|
460
|
+
setup_signal_handlers([plan.checkpoint_directory] if plan.checkpoint_directory else [])
|
|
461
|
+
|
|
462
|
+
# Initialize checkpoint manager for resumption support
|
|
463
|
+
checkpoint_manager = CheckpointManager(plan.checkpoint_directory)
|
|
464
|
+
completed_batch_indices = checkpoint_manager.get_completed_batch_indices()
|
|
465
|
+
|
|
466
|
+
if completed_batch_indices:
|
|
467
|
+
logger.info("Loaded checkpoint: %d batches already completed", len(completed_batch_indices))
|
|
468
|
+
else:
|
|
469
|
+
logger.debug("No checkpoint found or no batches completed yet")
|
|
470
|
+
|
|
471
|
+
# Group batches by simulator name to execute sequentially per simulator
|
|
472
|
+
simulator_batches: dict[str, list[SimulationBatch]] = {}
|
|
473
|
+
for batch in plan.batches:
|
|
474
|
+
if batch.simulator_name not in simulator_batches:
|
|
475
|
+
simulator_batches[batch.simulator_name] = []
|
|
476
|
+
simulator_batches[batch.simulator_name].append(batch)
|
|
477
|
+
|
|
478
|
+
logger.info("Executing %d simulators", len(simulator_batches))
|
|
479
|
+
|
|
480
|
+
with tqdm(total=plan.total_batches, desc="Executing simulation plan") as p_bar:
|
|
481
|
+
for simulator_name, batches in simulator_batches.items():
|
|
482
|
+
logger.info("Starting simulator: %s with %d batches", simulator_name, len(batches))
|
|
483
|
+
|
|
484
|
+
# Create ONE simulator instance for all batches of this simulator
|
|
485
|
+
# Extract global simulator arguments from the first batch's global config
|
|
486
|
+
global_sim_args = batches[0].globals_config.simulator_arguments if batches else {}
|
|
487
|
+
simulator = instantiate_simulator(batches[0].simulator_config, simulator_name, global_sim_args)
|
|
488
|
+
|
|
489
|
+
# Process batches sequentially, maintaining state across them
|
|
490
|
+
for batch_idx, batch in enumerate(batches):
|
|
491
|
+
# Skip batches that were already completed (for resumption after interrupt)
|
|
492
|
+
if checkpoint_manager.should_skip_batch(batch.batch_index):
|
|
493
|
+
logger.info(
|
|
494
|
+
"Skipping batch %d (already completed from checkpoint)",
|
|
495
|
+
batch.batch_index,
|
|
496
|
+
)
|
|
497
|
+
continue
|
|
498
|
+
|
|
499
|
+
try:
|
|
500
|
+
logger.debug(
|
|
501
|
+
"Executing batch %d/%d for simulator %s",
|
|
502
|
+
batch_idx + 1,
|
|
503
|
+
len(batches),
|
|
504
|
+
simulator_name,
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
# Capture pre-batch state first for potential retries
|
|
508
|
+
logger.debug(
|
|
509
|
+
"[EXECUTE] Batch %s: Before restore - counter=%s, has_state_snapshot=%s",
|
|
510
|
+
batch.batch_index,
|
|
511
|
+
simulator.counter,
|
|
512
|
+
batch.has_state_snapshot(),
|
|
513
|
+
)
|
|
514
|
+
restore_batch_state(simulator, batch)
|
|
515
|
+
logger.debug("[EXECUTE] Batch %s: After restore - counter=%s", batch.batch_index, simulator.counter)
|
|
516
|
+
pre_batch_state = copy.deepcopy(simulator.state)
|
|
517
|
+
logger.debug(
|
|
518
|
+
"[EXECUTE] Batch %s: Captured pre_batch_state - keys=%s",
|
|
519
|
+
batch.batch_index,
|
|
520
|
+
list(pre_batch_state.keys()),
|
|
521
|
+
)
|
|
522
|
+
|
|
523
|
+
def execute_batch(
|
|
524
|
+
_simulator=simulator,
|
|
525
|
+
_batch=batch,
|
|
526
|
+
_output_directory=output_directory,
|
|
527
|
+
_pre_batch_state=pre_batch_state,
|
|
528
|
+
):
|
|
529
|
+
"""Execute a single batch with state management."""
|
|
530
|
+
# Generate data by calling next() - this advances simulator state
|
|
531
|
+
logger.debug("[BATCH] %s: Before next() - counter=%s", _batch.batch_index, _simulator.counter)
|
|
532
|
+
batch_data = _simulator.simulate()
|
|
533
|
+
logger.debug("[BATCH] %s: After next() - counter=%s", _batch.batch_index, _simulator.counter)
|
|
534
|
+
|
|
535
|
+
# Save the generated data and get all output file paths
|
|
536
|
+
output_files = process_batch(
|
|
537
|
+
simulator=_simulator,
|
|
538
|
+
batch_data=batch_data,
|
|
539
|
+
batch=_batch,
|
|
540
|
+
output_directory=_output_directory,
|
|
541
|
+
overwrite=overwrite,
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
# Only save metadata if data save succeeded
|
|
545
|
+
# This ensures metadata only exists for successfully saved data
|
|
546
|
+
save_batch_metadata(
|
|
547
|
+
_simulator,
|
|
548
|
+
_batch,
|
|
549
|
+
metadata_directory,
|
|
550
|
+
output_files,
|
|
551
|
+
pre_batch_state=_pre_batch_state,
|
|
552
|
+
)
|
|
553
|
+
# Update the state after successful save
|
|
554
|
+
_simulator.update_state()
|
|
555
|
+
|
|
556
|
+
def restore_state_for_retry(_simulator=simulator, _pre_batch_state=pre_batch_state):
|
|
557
|
+
"""Restore simulator state to pre-batch state before retry."""
|
|
558
|
+
_simulator.state = copy.deepcopy(_pre_batch_state)
|
|
559
|
+
|
|
560
|
+
# Execute batch with retry mechanism that restores state on failure
|
|
561
|
+
retry_with_backoff(
|
|
562
|
+
execute_batch,
|
|
563
|
+
max_retries=max_retries,
|
|
564
|
+
state_restore_func=restore_state_for_retry,
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
# After successful completion, save checkpoint with updated state
|
|
568
|
+
# At this point, state has been incremented by next() -> update_state()
|
|
569
|
+
# Save checkpoint to enable resumption if interrupted before next batch
|
|
570
|
+
completed_batch_indices.add(batch.batch_index)
|
|
571
|
+
checkpoint_manager.save_checkpoint(
|
|
572
|
+
completed_batch_indices=sorted(completed_batch_indices),
|
|
573
|
+
last_simulator_name=simulator_name,
|
|
574
|
+
last_completed_batch_index=batch.batch_index,
|
|
575
|
+
last_simulator_state=copy.deepcopy(simulator.state),
|
|
576
|
+
)
|
|
577
|
+
logger.debug(
|
|
578
|
+
"Checkpoint saved after batch %d - state counter=%s",
|
|
579
|
+
batch.batch_index,
|
|
580
|
+
simulator.counter,
|
|
581
|
+
)
|
|
582
|
+
p_bar.update(1)
|
|
583
|
+
|
|
584
|
+
except Exception as e:
|
|
585
|
+
logger.error(
|
|
586
|
+
"Failed to execute batch %d for simulator %s after %d retries: %s",
|
|
587
|
+
batch.batch_index,
|
|
588
|
+
simulator_name,
|
|
589
|
+
max_retries,
|
|
590
|
+
e,
|
|
591
|
+
)
|
|
592
|
+
raise
|
|
593
|
+
|
|
594
|
+
# All batches completed successfully - clean up checkpoint files
|
|
595
|
+
checkpoint_manager.cleanup()
|
|
596
|
+
logger.info("All batches completed successfully. Checkpoint files cleaned up.")
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
"""Utility functions for the gwsim CLI."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import base64
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import yaml
|
|
9
|
+
from astropy.units import Quantity
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def represent_quantity(dumper: yaml.SafeDumper, obj: Quantity) -> yaml.nodes.MappingNode:
|
|
13
|
+
"""Represent Quantity for YAML serialization.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
dumper: YAML dumper.
|
|
17
|
+
obj: Quantity object to represent.
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
YAML node representing the Quantity.
|
|
21
|
+
"""
|
|
22
|
+
return dumper.represent_mapping("!Quantity", {"value": float(obj.value), "unit": str(obj.unit)})
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def construct_quantity(loader: yaml.Loader, node: yaml.MappingNode) -> Quantity:
|
|
26
|
+
"""Construct Quantity from YAML representation.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
loader: YAML loader.
|
|
30
|
+
node: YAML node to construct from.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Quantity object.
|
|
34
|
+
"""
|
|
35
|
+
data = loader.construct_mapping(node)
|
|
36
|
+
return Quantity(data["value"], data["unit"])
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
yaml.SafeDumper.add_multi_representer(Quantity, represent_quantity)
|
|
40
|
+
yaml.SafeLoader.add_constructor("!Quantity", construct_quantity)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def represent_numpy_array(dumper: yaml.SafeDumper, obj: np.ndarray) -> yaml.nodes.MappingNode:
|
|
44
|
+
"""Represent numpy array for YAML serialization.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
dumper: YAML dumper.
|
|
48
|
+
obj: Numpy array to represent.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
YAML node representing the numpy array.
|
|
52
|
+
"""
|
|
53
|
+
bytes_data = obj.tobytes()
|
|
54
|
+
encoded_data = base64.b64encode(bytes_data).decode("ascii")
|
|
55
|
+
data = {
|
|
56
|
+
"data": encoded_data,
|
|
57
|
+
"dtype": str(obj.dtype),
|
|
58
|
+
"shape": list(obj.shape),
|
|
59
|
+
"encoding": "base64",
|
|
60
|
+
}
|
|
61
|
+
return dumper.represent_mapping("!ndarray", data)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def construct_numpy_array(loader: yaml.Loader, node: yaml.MappingNode) -> np.ndarray:
|
|
65
|
+
"""Construct numpy array from YAML representation.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
loader: YAML loader.
|
|
69
|
+
node: YAML node to construct from.
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
Numpy array.
|
|
73
|
+
"""
|
|
74
|
+
data = loader.construct_mapping(node)
|
|
75
|
+
if data.get("encoding") != "base64":
|
|
76
|
+
raise ValueError("Expected base64 encoding in YAML data")
|
|
77
|
+
dtype = np.dtype(data["dtype"])
|
|
78
|
+
shape = tuple(data["shape"])
|
|
79
|
+
decoded_bytes = base64.b64decode(data["data"])
|
|
80
|
+
array = np.frombuffer(decoded_bytes, dtype=dtype).reshape(shape)
|
|
81
|
+
return array
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
yaml.SafeDumper.add_multi_representer(np.ndarray, represent_numpy_array)
|
|
85
|
+
yaml.SafeLoader.add_constructor("!ndarray", construct_numpy_array)
|