gwsim 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gwsim/__init__.py +11 -0
- gwsim/__main__.py +8 -0
- gwsim/cli/__init__.py +0 -0
- gwsim/cli/config.py +88 -0
- gwsim/cli/default_config.py +56 -0
- gwsim/cli/main.py +101 -0
- gwsim/cli/merge.py +150 -0
- gwsim/cli/repository/__init__.py +0 -0
- gwsim/cli/repository/create.py +91 -0
- gwsim/cli/repository/delete.py +51 -0
- gwsim/cli/repository/download.py +54 -0
- gwsim/cli/repository/list_depositions.py +63 -0
- gwsim/cli/repository/main.py +38 -0
- gwsim/cli/repository/metadata/__init__.py +0 -0
- gwsim/cli/repository/metadata/main.py +24 -0
- gwsim/cli/repository/metadata/update.py +58 -0
- gwsim/cli/repository/publish.py +52 -0
- gwsim/cli/repository/upload.py +74 -0
- gwsim/cli/repository/utils.py +47 -0
- gwsim/cli/repository/verify.py +61 -0
- gwsim/cli/simulate.py +220 -0
- gwsim/cli/simulate_utils.py +596 -0
- gwsim/cli/utils/__init__.py +85 -0
- gwsim/cli/utils/checkpoint.py +178 -0
- gwsim/cli/utils/config.py +347 -0
- gwsim/cli/utils/hash.py +23 -0
- gwsim/cli/utils/retry.py +62 -0
- gwsim/cli/utils/simulation_plan.py +439 -0
- gwsim/cli/utils/template.py +56 -0
- gwsim/cli/utils/utils.py +149 -0
- gwsim/cli/validate.py +255 -0
- gwsim/data/__init__.py +8 -0
- gwsim/data/serialize/__init__.py +9 -0
- gwsim/data/serialize/decoder.py +59 -0
- gwsim/data/serialize/encoder.py +44 -0
- gwsim/data/serialize/serializable.py +33 -0
- gwsim/data/time_series/__init__.py +3 -0
- gwsim/data/time_series/inject.py +104 -0
- gwsim/data/time_series/time_series.py +355 -0
- gwsim/data/time_series/time_series_list.py +182 -0
- gwsim/detector/__init__.py +8 -0
- gwsim/detector/base.py +156 -0
- gwsim/detector/detectors/E1_2L_Aligned_Sardinia.interferometer +22 -0
- gwsim/detector/detectors/E1_2L_Misaligned_Sardinia.interferometer +22 -0
- gwsim/detector/detectors/E1_Triangle_EMR.interferometer +19 -0
- gwsim/detector/detectors/E1_Triangle_Sardinia.interferometer +19 -0
- gwsim/detector/detectors/E2_2L_Aligned_EMR.interferometer +22 -0
- gwsim/detector/detectors/E2_2L_Misaligned_EMR.interferometer +22 -0
- gwsim/detector/detectors/E2_Triangle_EMR.interferometer +19 -0
- gwsim/detector/detectors/E2_Triangle_Sardinia.interferometer +19 -0
- gwsim/detector/detectors/E3_Triangle_EMR.interferometer +19 -0
- gwsim/detector/detectors/E3_Triangle_Sardinia.interferometer +19 -0
- gwsim/detector/noise_curves/ET_10_HF_psd.txt +3000 -0
- gwsim/detector/noise_curves/ET_10_full_cryo_psd.txt +3000 -0
- gwsim/detector/noise_curves/ET_15_HF_psd.txt +3000 -0
- gwsim/detector/noise_curves/ET_15_full_cryo_psd.txt +3000 -0
- gwsim/detector/noise_curves/ET_20_HF_psd.txt +3000 -0
- gwsim/detector/noise_curves/ET_20_full_cryo_psd.txt +3000 -0
- gwsim/detector/noise_curves/ET_D_psd.txt +3000 -0
- gwsim/detector/utils.py +90 -0
- gwsim/glitch/__init__.py +7 -0
- gwsim/glitch/base.py +69 -0
- gwsim/mixin/__init__.py +8 -0
- gwsim/mixin/detector.py +203 -0
- gwsim/mixin/gwf.py +192 -0
- gwsim/mixin/population_reader.py +175 -0
- gwsim/mixin/randomness.py +107 -0
- gwsim/mixin/time_series.py +295 -0
- gwsim/mixin/waveform.py +47 -0
- gwsim/noise/__init__.py +19 -0
- gwsim/noise/base.py +134 -0
- gwsim/noise/bilby_stationary_gaussian.py +117 -0
- gwsim/noise/colored_noise.py +275 -0
- gwsim/noise/correlated_noise.py +257 -0
- gwsim/noise/pycbc_stationary_gaussian.py +112 -0
- gwsim/noise/stationary_gaussian.py +44 -0
- gwsim/noise/white_noise.py +51 -0
- gwsim/repository/__init__.py +0 -0
- gwsim/repository/zenodo.py +269 -0
- gwsim/signal/__init__.py +11 -0
- gwsim/signal/base.py +137 -0
- gwsim/signal/cbc.py +61 -0
- gwsim/simulator/__init__.py +7 -0
- gwsim/simulator/base.py +315 -0
- gwsim/simulator/state.py +85 -0
- gwsim/utils/__init__.py +11 -0
- gwsim/utils/datetime_parser.py +44 -0
- gwsim/utils/et_2l_geometry.py +165 -0
- gwsim/utils/io.py +167 -0
- gwsim/utils/log.py +145 -0
- gwsim/utils/population.py +48 -0
- gwsim/utils/random.py +69 -0
- gwsim/utils/retry.py +75 -0
- gwsim/utils/triangular_et_geometry.py +164 -0
- gwsim/version.py +7 -0
- gwsim/waveform/__init__.py +7 -0
- gwsim/waveform/factory.py +83 -0
- gwsim/waveform/pycbc_wrapper.py +37 -0
- gwsim-0.1.0.dist-info/METADATA +157 -0
- gwsim-0.1.0.dist-info/RECORD +103 -0
- gwsim-0.1.0.dist-info/WHEEL +4 -0
- gwsim-0.1.0.dist-info/entry_points.txt +2 -0
- gwsim-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,439 @@
|
|
|
1
|
+
"""Utility functions for creating and managing simulation plans."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import datetime
|
|
6
|
+
import getpass
|
|
7
|
+
import logging
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
import yaml
|
|
13
|
+
|
|
14
|
+
from gwsim.cli.utils.config import Config, GlobalsConfig, SimulatorConfig
|
|
15
|
+
from gwsim.utils.log import get_dependency_versions
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger("gwsim")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class SimulationBatch:
|
|
22
|
+
"""Data class representing a single simulation batch.
|
|
23
|
+
|
|
24
|
+
A batch is a unit of work for a particular simulator. For example:
|
|
25
|
+
- A noise simulator might generate multiple batches (segments) of noise data
|
|
26
|
+
- A signal simulator might generate multiple batches of gravitational wave signals
|
|
27
|
+
|
|
28
|
+
The batch_index is per-simulator, so batch 0 from noise simulator and batch 0 from
|
|
29
|
+
signal simulator are different batches.
|
|
30
|
+
|
|
31
|
+
Metadata can contain two types of information:
|
|
32
|
+
- Configuration metadata: Full config + max_samples (for reproducibility with fresh state)
|
|
33
|
+
- State metadata: Pre-batch state (RNG state, etc.) for exact reproduction of a specific batch
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
simulator_name: str
|
|
37
|
+
"""Name of the simulator (e.g., 'noise', 'signal', 'glitch')"""
|
|
38
|
+
|
|
39
|
+
simulator_config: SimulatorConfig
|
|
40
|
+
"""Configuration for this simulator"""
|
|
41
|
+
|
|
42
|
+
globals_config: GlobalsConfig
|
|
43
|
+
"""Global configuration (shared across all simulators)"""
|
|
44
|
+
|
|
45
|
+
batch_index: int
|
|
46
|
+
"""Index of this batch within the simulator (0-based, per-simulator)"""
|
|
47
|
+
|
|
48
|
+
# Optional: For metadata-based reproduction
|
|
49
|
+
metadata_file: Path | None = None
|
|
50
|
+
"""If reproducing from metadata, path to the metadata file"""
|
|
51
|
+
|
|
52
|
+
batch_metadata: dict[str, Any] | None = None
|
|
53
|
+
"""Parsed metadata content (if metadata_file is provided)"""
|
|
54
|
+
|
|
55
|
+
# State snapshot for exact reproduction
|
|
56
|
+
pre_batch_state: dict[str, Any] | None = None
|
|
57
|
+
"""State snapshot taken before this batch was generated.
|
|
58
|
+
|
|
59
|
+
Contains simulator-specific state that cannot be known a priori:
|
|
60
|
+
- RNG state (numpy.random.RandomState or similar)
|
|
61
|
+
- Simulator internal state (e.g., filter memory for colored noise)
|
|
62
|
+
- Other stateful components
|
|
63
|
+
|
|
64
|
+
If present, use this for exact reproduction. Otherwise, reconstruct from config.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
# For tracking
|
|
68
|
+
source: str = "config"
|
|
69
|
+
"""Source of this batch: 'config' (fresh), 'metadata_config' (from saved config),
|
|
70
|
+
or 'metadata_state' (from saved state snapshot)"""
|
|
71
|
+
|
|
72
|
+
author: str | None = None
|
|
73
|
+
"""Author of this batch (from metadata)"""
|
|
74
|
+
|
|
75
|
+
email: str | None = None
|
|
76
|
+
"""Email of the author (from metadata)"""
|
|
77
|
+
|
|
78
|
+
def __post_init__(self):
|
|
79
|
+
"""Post-initialization checks.
|
|
80
|
+
|
|
81
|
+
Raises:
|
|
82
|
+
ValueError: If simulator_name is empty or batch_index is negative.
|
|
83
|
+
"""
|
|
84
|
+
if not self.simulator_name:
|
|
85
|
+
raise ValueError("simulator_name must not be empty")
|
|
86
|
+
if self.batch_index < 0:
|
|
87
|
+
raise ValueError("batch_index must be non-negative")
|
|
88
|
+
|
|
89
|
+
def is_metadata_based(self) -> bool:
|
|
90
|
+
"""Check if this batch is based on saved metadata.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
True if the batch is based on metadata, False otherwise.
|
|
94
|
+
"""
|
|
95
|
+
return self.source in ("metadata_config", "metadata_state")
|
|
96
|
+
|
|
97
|
+
def has_state_snapshot(self) -> bool:
|
|
98
|
+
"""Check if this batch has a pre-batch state snapshot for exact reproduction.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
True if pre_batch_state is available, False otherwise.
|
|
102
|
+
"""
|
|
103
|
+
return self.pre_batch_state is not None
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@dataclass
|
|
107
|
+
class SimulationPlan:
|
|
108
|
+
"""Data class representing a simulation plan."""
|
|
109
|
+
|
|
110
|
+
batches: list[SimulationBatch] = field(default_factory=list)
|
|
111
|
+
"""List of batches to simulate"""
|
|
112
|
+
|
|
113
|
+
source_config: Config | None = None
|
|
114
|
+
"""Original Config object (if config-based)"""
|
|
115
|
+
|
|
116
|
+
checkpoint_directory: Path = Path("checkpoints")
|
|
117
|
+
"""Directory for checkpoint files"""
|
|
118
|
+
|
|
119
|
+
total_batches: int = 0
|
|
120
|
+
"""Total number of batches"""
|
|
121
|
+
|
|
122
|
+
def add_batch(self, batch: SimulationBatch) -> None:
|
|
123
|
+
"""Add a batch to the plan.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
batch: SimulationBatch to add
|
|
127
|
+
"""
|
|
128
|
+
self.batches.append(batch)
|
|
129
|
+
self.total_batches = len(self.batches)
|
|
130
|
+
logger.debug(
|
|
131
|
+
"Added batch %d: simulator=%s, source=%s",
|
|
132
|
+
batch.batch_index,
|
|
133
|
+
batch.simulator_name,
|
|
134
|
+
batch.source,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
def get_batches_for_simulator(self, simulator_name: str) -> list[SimulationBatch]:
|
|
138
|
+
"""Get all batches for a specific simulator.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
simulator_name: Name of the simulator
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
List of batches for that simulator, in order
|
|
145
|
+
"""
|
|
146
|
+
return [b for b in self.batches if b.simulator_name == simulator_name]
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def parse_batch_metadata(metadata_file: Path) -> dict[str, Any]:
|
|
150
|
+
"""Parse a batch metadata file.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
metadata_file: Path to BATCH-*.metadata.yaml file
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
Parsed metadata dictionary
|
|
157
|
+
|
|
158
|
+
Raises:
|
|
159
|
+
FileNotFoundError: If file doesn't exist
|
|
160
|
+
ValueError: If YAML is invalid
|
|
161
|
+
"""
|
|
162
|
+
if not metadata_file.exists():
|
|
163
|
+
raise FileNotFoundError(f"Metadata file not found: {metadata_file}")
|
|
164
|
+
|
|
165
|
+
try:
|
|
166
|
+
with metadata_file.open(encoding="utf-8") as f:
|
|
167
|
+
metadata = yaml.safe_load(f)
|
|
168
|
+
except yaml.YAMLError as e:
|
|
169
|
+
raise ValueError(f"Failed to parse metadata YAML: {e}") from e
|
|
170
|
+
|
|
171
|
+
if not isinstance(metadata, dict):
|
|
172
|
+
raise ValueError(f"Metadata must be a dictionary, got {type(metadata)}")
|
|
173
|
+
|
|
174
|
+
return metadata
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def create_batch_metadata(
|
|
178
|
+
simulator_name: str,
|
|
179
|
+
batch_index: int,
|
|
180
|
+
simulator_config: SimulatorConfig,
|
|
181
|
+
globals_config: GlobalsConfig,
|
|
182
|
+
pre_batch_state: dict[str, Any] | None = None,
|
|
183
|
+
source: str = "config",
|
|
184
|
+
author: str | None = None,
|
|
185
|
+
email: str | None = None,
|
|
186
|
+
timestamp: datetime.datetime | None = None,
|
|
187
|
+
) -> dict[str, Any]:
|
|
188
|
+
"""Create metadata for a simulation batch.
|
|
189
|
+
|
|
190
|
+
This metadata can be used to reproduce a specific batch. It includes:
|
|
191
|
+
1. Configuration: Simulator and global configs for reproducibility
|
|
192
|
+
2. State snapshot: Pre-batch state (RNG, etc.) for exact reproduction
|
|
193
|
+
3. Version information: gwsim and key dependency versions for traceability
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
simulator_name: Name of the simulator
|
|
197
|
+
batch_index: Index of the batch within the simulator
|
|
198
|
+
simulator_config: Configuration for this simulator
|
|
199
|
+
globals_config: Global configuration
|
|
200
|
+
pre_batch_state: Optional state snapshot taken before batch generation
|
|
201
|
+
source: Source of this batch: 'config', 'metadata_config', or 'metadata_state'
|
|
202
|
+
author: Optional author name
|
|
203
|
+
email: Optional author email
|
|
204
|
+
timestamp: Optional timestamp for when the batch was created
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
Metadata dictionary suitable for YAML serialization
|
|
208
|
+
"""
|
|
209
|
+
if author is None:
|
|
210
|
+
author = getpass.getuser()
|
|
211
|
+
|
|
212
|
+
if timestamp is None:
|
|
213
|
+
timestamp = datetime.datetime.now(datetime.timezone.utc)
|
|
214
|
+
|
|
215
|
+
metadata: dict[str, Any] = {
|
|
216
|
+
"simulator_name": simulator_name,
|
|
217
|
+
"batch_index": batch_index,
|
|
218
|
+
"simulator_config": simulator_config.model_dump(mode="python"),
|
|
219
|
+
"globals_config": globals_config.model_dump(mode="python"),
|
|
220
|
+
"author": author,
|
|
221
|
+
"email": email,
|
|
222
|
+
"timestamp": timestamp.isoformat(),
|
|
223
|
+
"versions": get_dependency_versions(),
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
if pre_batch_state is not None:
|
|
227
|
+
metadata["pre_batch_state"] = pre_batch_state
|
|
228
|
+
|
|
229
|
+
metadata["source"] = source
|
|
230
|
+
|
|
231
|
+
return metadata
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def create_plan_from_config(
|
|
235
|
+
config: Config, checkpoint_dir: Path, author: str | None = None, email: str | None = None
|
|
236
|
+
) -> SimulationPlan:
|
|
237
|
+
"""Create a simulation plan from a configuration file.
|
|
238
|
+
|
|
239
|
+
This is the standard workflow: start fresh with a config.
|
|
240
|
+
Each simulator can generate multiple batches (e.g., segments or samples).
|
|
241
|
+
|
|
242
|
+
Note: State (like RNG state) will be captured during simulation and stored in
|
|
243
|
+
metadata for exact reproduction of individual batches.
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
config: Parsed Config object
|
|
247
|
+
checkpoint_dir: Directory for checkpoints
|
|
248
|
+
author: Optional author name for metadata
|
|
249
|
+
email: Optional author email for metadata
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
SimulationPlan with all batches defined across all simulators
|
|
253
|
+
|
|
254
|
+
Example:
|
|
255
|
+
>>> from gwsim.cli.utils.config import load_config
|
|
256
|
+
>>> cfg = load_config(Path("config.yaml"))
|
|
257
|
+
>>> plan = create_plan_from_config(cfg, Path("checkpoints"))
|
|
258
|
+
>>> print(f"Total batches: {plan.total_batches}")
|
|
259
|
+
>>> # Get all batches from a specific simulator
|
|
260
|
+
>>> noise_batches = plan.get_batches_for_simulator("noise")
|
|
261
|
+
"""
|
|
262
|
+
plan = SimulationPlan(
|
|
263
|
+
source_config=config,
|
|
264
|
+
checkpoint_directory=checkpoint_dir,
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
# For each simulator, create batches (each simulator can generate multiple batches)
|
|
268
|
+
global_batch_index = 0
|
|
269
|
+
for simulator_name, simulator_config in config.simulators.items():
|
|
270
|
+
# Determine number of batches for this simulator
|
|
271
|
+
# This comes from simulator_arguments in globals_config (max_samples parameter)
|
|
272
|
+
# First check simulator-specific arguments, then fall back to global simulator_arguments
|
|
273
|
+
# Note: Keys in simulator_arguments may have hyphens (YAML style), so normalize them
|
|
274
|
+
global_sim_args = {k.replace("-", "_"): v for k, v in config.globals.simulator_arguments.items()}
|
|
275
|
+
local_sim_args = {k.replace("-", "_"): v for k, v in simulator_config.arguments.items()}
|
|
276
|
+
|
|
277
|
+
max_samples = local_sim_args.get("max_samples", global_sim_args.get("max_samples", 1))
|
|
278
|
+
|
|
279
|
+
for _ in range(max_samples):
|
|
280
|
+
batch = SimulationBatch(
|
|
281
|
+
simulator_name=simulator_name,
|
|
282
|
+
simulator_config=simulator_config,
|
|
283
|
+
globals_config=config.globals,
|
|
284
|
+
batch_index=global_batch_index,
|
|
285
|
+
source="config",
|
|
286
|
+
author=author,
|
|
287
|
+
email=email,
|
|
288
|
+
)
|
|
289
|
+
plan.add_batch(batch)
|
|
290
|
+
global_batch_index += 1
|
|
291
|
+
|
|
292
|
+
logger.info("Created simulation plan from config: %d batches", plan.total_batches)
|
|
293
|
+
return plan
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def create_plan_from_metadata_files(
|
|
297
|
+
metadata_files: list[Path],
|
|
298
|
+
checkpoint_dir: Path,
|
|
299
|
+
author: str | None = None,
|
|
300
|
+
email: str | None = None,
|
|
301
|
+
) -> SimulationPlan:
|
|
302
|
+
"""Create a simulation plan from individual metadata files.
|
|
303
|
+
|
|
304
|
+
This allows exact reproduction of specific batches by restoring their pre-batch state.
|
|
305
|
+
Metadata files should follow the naming pattern: SIMULATOR-BATCH_INDEX.metadata.yaml
|
|
306
|
+
|
|
307
|
+
Args:
|
|
308
|
+
metadata_files: List of paths to individual metadata YAML files
|
|
309
|
+
checkpoint_dir: Directory for checkpoints
|
|
310
|
+
author: Optional author name for metadata
|
|
311
|
+
email: Optional author email for metadata
|
|
312
|
+
|
|
313
|
+
Returns:
|
|
314
|
+
SimulationPlan with batches reconstructed from metadata
|
|
315
|
+
|
|
316
|
+
Raises:
|
|
317
|
+
FileNotFoundError: If any metadata file doesn't exist
|
|
318
|
+
ValueError: If metadata files are malformed
|
|
319
|
+
|
|
320
|
+
Example:
|
|
321
|
+
>>> files = [Path("signal-0.metadata.yaml"), Path("signal-1.metadata.yaml")]
|
|
322
|
+
>>> plan = create_plan_from_metadata_files(files, Path("checkpoints"))
|
|
323
|
+
>>> # Reproduces specific batches with exact state snapshots
|
|
324
|
+
"""
|
|
325
|
+
plan = SimulationPlan(checkpoint_directory=checkpoint_dir)
|
|
326
|
+
|
|
327
|
+
for metadata_file in sorted(metadata_files):
|
|
328
|
+
if not metadata_file.exists():
|
|
329
|
+
raise FileNotFoundError(f"Metadata file not found: {metadata_file}")
|
|
330
|
+
|
|
331
|
+
metadata = parse_batch_metadata(metadata_file)
|
|
332
|
+
|
|
333
|
+
# Reconstruct configs from metadata
|
|
334
|
+
try:
|
|
335
|
+
globals_config = GlobalsConfig(**metadata["globals_config"])
|
|
336
|
+
simulator_config = SimulatorConfig(**metadata["simulator_config"])
|
|
337
|
+
except (KeyError, TypeError) as e:
|
|
338
|
+
raise ValueError(f"Invalid metadata in {metadata_file}: missing or malformed config: {e}") from e
|
|
339
|
+
|
|
340
|
+
simulator_name = metadata.get("simulator_name")
|
|
341
|
+
batch_index = metadata.get("batch_index")
|
|
342
|
+
pre_batch_state = metadata.get("pre_batch_state")
|
|
343
|
+
|
|
344
|
+
if not simulator_name or batch_index is None:
|
|
345
|
+
raise ValueError(f"Invalid metadata in {metadata_file}: missing simulator_name or batch_index")
|
|
346
|
+
|
|
347
|
+
batch = SimulationBatch(
|
|
348
|
+
simulator_name=simulator_name,
|
|
349
|
+
simulator_config=simulator_config,
|
|
350
|
+
globals_config=globals_config,
|
|
351
|
+
batch_index=batch_index,
|
|
352
|
+
metadata_file=metadata_file,
|
|
353
|
+
batch_metadata=metadata,
|
|
354
|
+
pre_batch_state=pre_batch_state,
|
|
355
|
+
source="metadata_state" if pre_batch_state else "metadata_config",
|
|
356
|
+
author=author,
|
|
357
|
+
email=email,
|
|
358
|
+
)
|
|
359
|
+
plan.add_batch(batch)
|
|
360
|
+
|
|
361
|
+
logger.info(
|
|
362
|
+
"Created simulation plan from %d metadata files",
|
|
363
|
+
len(metadata_files),
|
|
364
|
+
)
|
|
365
|
+
return plan
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def create_plan_from_metadata(
|
|
369
|
+
metadata_dir: Path,
|
|
370
|
+
checkpoint_dir: Path,
|
|
371
|
+
author: str | None = None,
|
|
372
|
+
email: str | None = None,
|
|
373
|
+
) -> SimulationPlan:
|
|
374
|
+
"""Create a simulation plan from a directory of metadata files.
|
|
375
|
+
|
|
376
|
+
This allows exact reproduction of specific batches by restoring their pre-batch state.
|
|
377
|
+
Metadata files should follow the naming pattern: SIMULATOR-BATCH_INDEX.metadata.yaml
|
|
378
|
+
|
|
379
|
+
Args:
|
|
380
|
+
metadata_dir: Directory containing metadata YAML files
|
|
381
|
+
checkpoint_dir: Directory for checkpoints
|
|
382
|
+
author: Optional author name for metadata
|
|
383
|
+
email: Optional author email for metadata
|
|
384
|
+
|
|
385
|
+
Returns:
|
|
386
|
+
SimulationPlan with batches reconstructed from metadata
|
|
387
|
+
|
|
388
|
+
Raises:
|
|
389
|
+
FileNotFoundError: If metadata_dir doesn't exist
|
|
390
|
+
ValueError: If metadata files are malformed
|
|
391
|
+
|
|
392
|
+
Example:
|
|
393
|
+
>>> plan = create_plan_from_metadata(Path("metadata"), Path("checkpoints"))
|
|
394
|
+
>>> # Reproduces batches with exact state snapshots
|
|
395
|
+
"""
|
|
396
|
+
if not metadata_dir.exists():
|
|
397
|
+
raise FileNotFoundError(f"Metadata directory not found: {metadata_dir}")
|
|
398
|
+
|
|
399
|
+
# Find all metadata files in directory
|
|
400
|
+
metadata_files = list(metadata_dir.glob("*.metadata.yaml"))
|
|
401
|
+
|
|
402
|
+
plan = create_plan_from_metadata_files(metadata_files, checkpoint_dir, author=author, email=email)
|
|
403
|
+
logger.info(
|
|
404
|
+
"Created simulation plan from metadata directory: %d batches from %s",
|
|
405
|
+
plan.total_batches,
|
|
406
|
+
metadata_dir,
|
|
407
|
+
)
|
|
408
|
+
return plan
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
def merge_plans(*plans: SimulationPlan) -> SimulationPlan:
|
|
412
|
+
"""Merge multiple simulation plans into one.
|
|
413
|
+
|
|
414
|
+
Useful for combining config-based and metadata-based workflows.
|
|
415
|
+
|
|
416
|
+
Args:
|
|
417
|
+
*plans: SimulationPlan objects to merge
|
|
418
|
+
|
|
419
|
+
Returns:
|
|
420
|
+
Merged SimulationPlan
|
|
421
|
+
|
|
422
|
+
Example:
|
|
423
|
+
>>> plan_config = create_plan_from_config(cfg, Path("checkpoints"))
|
|
424
|
+
>>> plan_metadata = create_plan_from_metadata_files([meta1, meta2], Path("checkpoints"))
|
|
425
|
+
>>> combined_plan = merge_plans(plan_config, plan_metadata)
|
|
426
|
+
"""
|
|
427
|
+
merged = SimulationPlan()
|
|
428
|
+
batch_index = 0
|
|
429
|
+
|
|
430
|
+
for plan in plans:
|
|
431
|
+
for batch in plan.batches:
|
|
432
|
+
# Reassign batch indices to maintain order
|
|
433
|
+
batch.batch_index = batch_index
|
|
434
|
+
merged.add_batch(batch)
|
|
435
|
+
batch_index += 1
|
|
436
|
+
|
|
437
|
+
merged.checkpoint_directory = plans[0].checkpoint_directory if plans else Path("checkpoints")
|
|
438
|
+
logger.info("Merged %d plans into one: %d total batches", len(plans), merged.total_batches)
|
|
439
|
+
return merged
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""Template validation utilities for gwsim CLI."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import re
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger("gwsim")
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TemplateValidator:
|
|
12
|
+
"""Validate template strings for simulators."""
|
|
13
|
+
|
|
14
|
+
@staticmethod
|
|
15
|
+
def validate_template(template: str, simulator_name: str) -> tuple[bool, list[str]]:
|
|
16
|
+
"""Validate template and return (is_valid, errors)."""
|
|
17
|
+
errors = []
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
# Extract all placeholder fields from template
|
|
21
|
+
# template_fields = TemplateValidator._extract_template_fields(template)
|
|
22
|
+
|
|
23
|
+
# Try to format with dummy data to catch syntax errors
|
|
24
|
+
dummy_state = TemplateValidator._create_dummy_state()
|
|
25
|
+
template.format(**dummy_state)
|
|
26
|
+
|
|
27
|
+
logger.debug("Template validation passed for %s: %s", simulator_name, template)
|
|
28
|
+
|
|
29
|
+
except KeyError as e:
|
|
30
|
+
errors.append(f"Missing template field: {e}")
|
|
31
|
+
except ValueError as e:
|
|
32
|
+
errors.append(f"Template formatting error: {e}")
|
|
33
|
+
except (AttributeError, TypeError) as e:
|
|
34
|
+
errors.append(f"Template validation error: {e}")
|
|
35
|
+
|
|
36
|
+
return len(errors) == 0, errors
|
|
37
|
+
|
|
38
|
+
@staticmethod
|
|
39
|
+
def extract_template_fields(template: str) -> set[str]:
|
|
40
|
+
"""Extract field names from template string."""
|
|
41
|
+
# Find all {field_name} patterns, excluding format specs
|
|
42
|
+
fields = re.findall(r"\{([^}:]+)", template)
|
|
43
|
+
return set(fields)
|
|
44
|
+
|
|
45
|
+
@staticmethod
|
|
46
|
+
def _create_dummy_state() -> dict:
|
|
47
|
+
"""Create dummy state data for validation."""
|
|
48
|
+
return {
|
|
49
|
+
"counter": 1,
|
|
50
|
+
"start_time": 1696291200,
|
|
51
|
+
"duration": 4096,
|
|
52
|
+
"detector": "H1",
|
|
53
|
+
"batch_id": "test",
|
|
54
|
+
"sample_rate": 4096,
|
|
55
|
+
"end_time": 1696295296,
|
|
56
|
+
}
|
gwsim/cli/utils/utils.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utility functions used in the command line tools.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import importlib
|
|
8
|
+
import logging
|
|
9
|
+
import re
|
|
10
|
+
import sys
|
|
11
|
+
from collections.abc import Callable
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger("gwsim")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def import_attribute(full_path: str) -> Any:
|
|
19
|
+
"""
|
|
20
|
+
Import an attribute from a full dotted path.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
full_path (str): Dotted path to the class, e.g., 'my_package.my_module.my_attribute'.
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
Any: The attribute.
|
|
27
|
+
"""
|
|
28
|
+
module_path, class_name = full_path.rsplit(".", 1)
|
|
29
|
+
module = importlib.import_module(module_path)
|
|
30
|
+
return getattr(module, class_name)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def get_file_name_from_template_with_dict(
|
|
34
|
+
template: str, values: dict[str, Any], exclude: set[str] | None = None
|
|
35
|
+
) -> str:
|
|
36
|
+
"""Get the file name from a template string.
|
|
37
|
+
The template string should use a double curly bracket to indicate the placeholder.
|
|
38
|
+
For example, in '{{ x }}-{{ y }}.txt', x and y are interpreted as placeholders,
|
|
39
|
+
and the values are retrieved from the values dictionary.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
template (str): A template string.
|
|
43
|
+
values (dict[str, Any]): A dictionary of values.
|
|
44
|
+
exclude (set[str] | None): Set of attribute names to exclude from expansion. Defaults to None.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
str: The file name with the placeholders substituted by the values from the dictionary.
|
|
48
|
+
"""
|
|
49
|
+
if exclude is None:
|
|
50
|
+
exclude = set()
|
|
51
|
+
|
|
52
|
+
def replace(matched):
|
|
53
|
+
label = matched.group(1).strip()
|
|
54
|
+
if label in exclude:
|
|
55
|
+
return matched.group(0) # Return the original placeholder unchanged
|
|
56
|
+
try:
|
|
57
|
+
return str(values[label])
|
|
58
|
+
except KeyError as e:
|
|
59
|
+
raise ValueError(f"Key '{label}' not found in values dictionary") from e
|
|
60
|
+
|
|
61
|
+
return re.sub(r"\{\{\s*(\w+)\s*\}\}", replace, template)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def get_file_name_from_template(template: str, instance: object, exclude: set[str] | None = None) -> str:
|
|
65
|
+
"""Get the file name from a template string.
|
|
66
|
+
The template string should use a double curly bracket to indicate the placeholder.
|
|
67
|
+
For example, in '{{ x }}-{{ y }}.txt', x and y are interpreted as placeholders,
|
|
68
|
+
and the values are retrieved from the instance.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
template (str): A template string.
|
|
72
|
+
instance (object): An instance.
|
|
73
|
+
exclude (set[str] | None): Set of attribute names to exclude from expansion. Defaults to None.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
str: The file name with the placeholders substituted by the values of the attributes of the instance.
|
|
77
|
+
"""
|
|
78
|
+
if exclude is None:
|
|
79
|
+
exclude = set()
|
|
80
|
+
|
|
81
|
+
def replace(matched):
|
|
82
|
+
label = matched.group(1).strip()
|
|
83
|
+
if label in exclude:
|
|
84
|
+
return matched.group(0) # Return the original placeholder unchanged
|
|
85
|
+
try:
|
|
86
|
+
return str(getattr(instance, label))
|
|
87
|
+
except AttributeError as e:
|
|
88
|
+
raise ValueError(f"Attribute '{label}' not found in instance of type {type(instance).__name__}") from e
|
|
89
|
+
|
|
90
|
+
return re.sub(r"\{\{\s*(\w+)\s*\}\}", replace, template)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def handle_signal(cleanup_fn: Callable) -> Callable:
|
|
94
|
+
"""A factory to create a signal handler from a clean-up function.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
cleanup_fn (Callable): A clean-up function to be called when the signal is received.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
Callable: A signal handler.
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
def handler(sig_num, _frame):
|
|
104
|
+
logger.error("Received signal %s, exiting...", sig_num)
|
|
105
|
+
cleanup_fn()
|
|
106
|
+
sys.exit(1)
|
|
107
|
+
|
|
108
|
+
return handler
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def save_file_safely(file_name: str | Path, backup_file_name: str | Path, save_function: Callable, **kwargs) -> None:
|
|
112
|
+
"""A helper function to save file safely by first creating a backup.
|
|
113
|
+
|
|
114
|
+
This function is designed for saving a checkpoint file that has a fixed file name.
|
|
115
|
+
If an existing `file_name` is detected, it is first renamed to `backup_file_name`
|
|
116
|
+
before calling `save_function`.
|
|
117
|
+
|
|
118
|
+
`save_function` needs to have an argument `file_name` to define the name of the output file.
|
|
119
|
+
Additional arguments can be provided through **kwargs.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
file_name (str | Path): File name of the output.
|
|
123
|
+
backup_file_name (str | Path): File name of the backup.
|
|
124
|
+
save_function (Callable): A callable to perform the saving.
|
|
125
|
+
"""
|
|
126
|
+
file_name = Path(file_name)
|
|
127
|
+
backup_file_name = Path(backup_file_name)
|
|
128
|
+
|
|
129
|
+
if file_name.is_file():
|
|
130
|
+
file_name.rename(backup_file_name)
|
|
131
|
+
logger.debug("Existing file backed up to: %s", backup_file_name)
|
|
132
|
+
|
|
133
|
+
# Try to call save_function to save to file.
|
|
134
|
+
try:
|
|
135
|
+
save_function(file_name=file_name, **kwargs)
|
|
136
|
+
|
|
137
|
+
if backup_file_name.is_file():
|
|
138
|
+
backup_file_name.unlink()
|
|
139
|
+
logger.debug("Backup file deleted after successful save.")
|
|
140
|
+
except (OSError, PermissionError, ValueError) as e:
|
|
141
|
+
logger.error("Failed to save file: %s", e)
|
|
142
|
+
|
|
143
|
+
if backup_file_name.is_file():
|
|
144
|
+
try:
|
|
145
|
+
backup_file_name.rename(file_name)
|
|
146
|
+
logger.warning("Restored file from backup due to a failure.")
|
|
147
|
+
except (OSError, PermissionError) as restore_error:
|
|
148
|
+
logger.error("Failed to restore backup file: %s", restore_error)
|
|
149
|
+
raise
|