gwsim 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gwsim/__init__.py +11 -0
- gwsim/__main__.py +8 -0
- gwsim/cli/__init__.py +0 -0
- gwsim/cli/config.py +88 -0
- gwsim/cli/default_config.py +56 -0
- gwsim/cli/main.py +101 -0
- gwsim/cli/merge.py +150 -0
- gwsim/cli/repository/__init__.py +0 -0
- gwsim/cli/repository/create.py +91 -0
- gwsim/cli/repository/delete.py +51 -0
- gwsim/cli/repository/download.py +54 -0
- gwsim/cli/repository/list_depositions.py +63 -0
- gwsim/cli/repository/main.py +38 -0
- gwsim/cli/repository/metadata/__init__.py +0 -0
- gwsim/cli/repository/metadata/main.py +24 -0
- gwsim/cli/repository/metadata/update.py +58 -0
- gwsim/cli/repository/publish.py +52 -0
- gwsim/cli/repository/upload.py +74 -0
- gwsim/cli/repository/utils.py +47 -0
- gwsim/cli/repository/verify.py +61 -0
- gwsim/cli/simulate.py +220 -0
- gwsim/cli/simulate_utils.py +596 -0
- gwsim/cli/utils/__init__.py +85 -0
- gwsim/cli/utils/checkpoint.py +178 -0
- gwsim/cli/utils/config.py +347 -0
- gwsim/cli/utils/hash.py +23 -0
- gwsim/cli/utils/retry.py +62 -0
- gwsim/cli/utils/simulation_plan.py +439 -0
- gwsim/cli/utils/template.py +56 -0
- gwsim/cli/utils/utils.py +149 -0
- gwsim/cli/validate.py +255 -0
- gwsim/data/__init__.py +8 -0
- gwsim/data/serialize/__init__.py +9 -0
- gwsim/data/serialize/decoder.py +59 -0
- gwsim/data/serialize/encoder.py +44 -0
- gwsim/data/serialize/serializable.py +33 -0
- gwsim/data/time_series/__init__.py +3 -0
- gwsim/data/time_series/inject.py +104 -0
- gwsim/data/time_series/time_series.py +355 -0
- gwsim/data/time_series/time_series_list.py +182 -0
- gwsim/detector/__init__.py +8 -0
- gwsim/detector/base.py +156 -0
- gwsim/detector/detectors/E1_2L_Aligned_Sardinia.interferometer +22 -0
- gwsim/detector/detectors/E1_2L_Misaligned_Sardinia.interferometer +22 -0
- gwsim/detector/detectors/E1_Triangle_EMR.interferometer +19 -0
- gwsim/detector/detectors/E1_Triangle_Sardinia.interferometer +19 -0
- gwsim/detector/detectors/E2_2L_Aligned_EMR.interferometer +22 -0
- gwsim/detector/detectors/E2_2L_Misaligned_EMR.interferometer +22 -0
- gwsim/detector/detectors/E2_Triangle_EMR.interferometer +19 -0
- gwsim/detector/detectors/E2_Triangle_Sardinia.interferometer +19 -0
- gwsim/detector/detectors/E3_Triangle_EMR.interferometer +19 -0
- gwsim/detector/detectors/E3_Triangle_Sardinia.interferometer +19 -0
- gwsim/detector/noise_curves/ET_10_HF_psd.txt +3000 -0
- gwsim/detector/noise_curves/ET_10_full_cryo_psd.txt +3000 -0
- gwsim/detector/noise_curves/ET_15_HF_psd.txt +3000 -0
- gwsim/detector/noise_curves/ET_15_full_cryo_psd.txt +3000 -0
- gwsim/detector/noise_curves/ET_20_HF_psd.txt +3000 -0
- gwsim/detector/noise_curves/ET_20_full_cryo_psd.txt +3000 -0
- gwsim/detector/noise_curves/ET_D_psd.txt +3000 -0
- gwsim/detector/utils.py +90 -0
- gwsim/glitch/__init__.py +7 -0
- gwsim/glitch/base.py +69 -0
- gwsim/mixin/__init__.py +8 -0
- gwsim/mixin/detector.py +203 -0
- gwsim/mixin/gwf.py +192 -0
- gwsim/mixin/population_reader.py +175 -0
- gwsim/mixin/randomness.py +107 -0
- gwsim/mixin/time_series.py +295 -0
- gwsim/mixin/waveform.py +47 -0
- gwsim/noise/__init__.py +19 -0
- gwsim/noise/base.py +134 -0
- gwsim/noise/bilby_stationary_gaussian.py +117 -0
- gwsim/noise/colored_noise.py +275 -0
- gwsim/noise/correlated_noise.py +257 -0
- gwsim/noise/pycbc_stationary_gaussian.py +112 -0
- gwsim/noise/stationary_gaussian.py +44 -0
- gwsim/noise/white_noise.py +51 -0
- gwsim/repository/__init__.py +0 -0
- gwsim/repository/zenodo.py +269 -0
- gwsim/signal/__init__.py +11 -0
- gwsim/signal/base.py +137 -0
- gwsim/signal/cbc.py +61 -0
- gwsim/simulator/__init__.py +7 -0
- gwsim/simulator/base.py +315 -0
- gwsim/simulator/state.py +85 -0
- gwsim/utils/__init__.py +11 -0
- gwsim/utils/datetime_parser.py +44 -0
- gwsim/utils/et_2l_geometry.py +165 -0
- gwsim/utils/io.py +167 -0
- gwsim/utils/log.py +145 -0
- gwsim/utils/population.py +48 -0
- gwsim/utils/random.py +69 -0
- gwsim/utils/retry.py +75 -0
- gwsim/utils/triangular_et_geometry.py +164 -0
- gwsim/version.py +7 -0
- gwsim/waveform/__init__.py +7 -0
- gwsim/waveform/factory.py +83 -0
- gwsim/waveform/pycbc_wrapper.py +37 -0
- gwsim-0.1.0.dist-info/METADATA +157 -0
- gwsim-0.1.0.dist-info/RECORD +103 -0
- gwsim-0.1.0.dist-info/WHEEL +4 -0
- gwsim-0.1.0.dist-info/entry_points.txt +2 -0
- gwsim-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
"""Checkpoint management for simulation recovery."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
from gwsim.data.serialize.decoder import Decoder
|
|
11
|
+
from gwsim.data.serialize.encoder import Encoder
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger("gwsim")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class CheckpointManager:
|
|
17
|
+
"""Manages checkpoint files for simulation recovery.
|
|
18
|
+
|
|
19
|
+
A checkpoint is created after each successfully completed batch,
|
|
20
|
+
allowing resumption from that point if the simulation is interrupted.
|
|
21
|
+
|
|
22
|
+
Checkpoint file format:
|
|
23
|
+
{
|
|
24
|
+
"completed_batch_indices": [0, 1, 2, ...],
|
|
25
|
+
"last_simulator_name": "signal",
|
|
26
|
+
"last_completed_batch_index": 2,
|
|
27
|
+
"last_simulator_state": {...}
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
The checkpoint is written atomically:
|
|
31
|
+
1. Write to .tmp file
|
|
32
|
+
2. Backup existing checkpoint to .bak
|
|
33
|
+
3. Rename .tmp to checkpoint file
|
|
34
|
+
This ensures we never have a corrupted checkpoint.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self, checkpoint_directory: Path):
|
|
38
|
+
"""Initialize checkpoint manager.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
checkpoint_directory: Directory to store checkpoint files
|
|
42
|
+
"""
|
|
43
|
+
self.checkpoint_directory = Path(checkpoint_directory)
|
|
44
|
+
self.checkpoint_directory.mkdir(parents=True, exist_ok=True)
|
|
45
|
+
self.checkpoint_file = self.checkpoint_directory / "simulation.checkpoint.json"
|
|
46
|
+
self.checkpoint_tmp = self.checkpoint_directory / "simulation.checkpoint.json.tmp"
|
|
47
|
+
self.checkpoint_backup = self.checkpoint_directory / "simulation.checkpoint.json.bak"
|
|
48
|
+
|
|
49
|
+
def load_checkpoint(self) -> dict[str, Any] | None:
|
|
50
|
+
"""Load checkpoint from file if it exists.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
Checkpoint dict with keys:
|
|
54
|
+
- completed_batch_indices: List of completed batch indices
|
|
55
|
+
- last_simulator_name: Name of last simulator
|
|
56
|
+
- last_completed_batch_index: Index of last completed batch
|
|
57
|
+
- last_simulator_state: State dict of last simulator
|
|
58
|
+
None if no checkpoint exists or checkpoint is corrupted
|
|
59
|
+
"""
|
|
60
|
+
# Try to restore from backup if checkpoint doesn't exist but backup does
|
|
61
|
+
if not self.checkpoint_file.exists() and self.checkpoint_backup.exists():
|
|
62
|
+
logger.warning("Checkpoint file missing but backup exists. Restoring from backup...")
|
|
63
|
+
try:
|
|
64
|
+
self.checkpoint_backup.rename(self.checkpoint_file)
|
|
65
|
+
logger.info("Checkpoint restored from backup")
|
|
66
|
+
except OSError as e:
|
|
67
|
+
logger.error("Failed to restore checkpoint from backup: %s", e)
|
|
68
|
+
return None
|
|
69
|
+
|
|
70
|
+
if not self.checkpoint_file.exists():
|
|
71
|
+
logger.debug("No checkpoint file found")
|
|
72
|
+
return None
|
|
73
|
+
|
|
74
|
+
try:
|
|
75
|
+
with self.checkpoint_file.open("r") as f:
|
|
76
|
+
checkpoint = json.load(f, cls=Decoder)
|
|
77
|
+
logger.debug(
|
|
78
|
+
"Loaded checkpoint: last_batch=%s, completed=%d batches",
|
|
79
|
+
checkpoint.get("last_completed_batch_index"),
|
|
80
|
+
len(checkpoint.get("completed_batch_indices", [])),
|
|
81
|
+
)
|
|
82
|
+
return checkpoint
|
|
83
|
+
except (OSError, json.JSONDecodeError) as e:
|
|
84
|
+
logger.error("Failed to load checkpoint: %s", e)
|
|
85
|
+
return None
|
|
86
|
+
|
|
87
|
+
def save_checkpoint(
|
|
88
|
+
self,
|
|
89
|
+
completed_batch_indices: list[int],
|
|
90
|
+
last_simulator_name: str,
|
|
91
|
+
last_completed_batch_index: int,
|
|
92
|
+
last_simulator_state: dict[str, Any],
|
|
93
|
+
) -> None:
|
|
94
|
+
"""Save checkpoint after completing a batch.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
completed_batch_indices: List of all completed batch indices so far
|
|
98
|
+
last_simulator_name: Name of the simulator that completed the batch
|
|
99
|
+
last_completed_batch_index: Index of the batch that just completed
|
|
100
|
+
last_simulator_state: State dict of the simulator after completion
|
|
101
|
+
|
|
102
|
+
Raises:
|
|
103
|
+
OSError: If checkpoint cannot be written
|
|
104
|
+
"""
|
|
105
|
+
checkpoint = {
|
|
106
|
+
"completed_batch_indices": completed_batch_indices,
|
|
107
|
+
"last_simulator_name": last_simulator_name,
|
|
108
|
+
"last_completed_batch_index": last_completed_batch_index,
|
|
109
|
+
"last_simulator_state": last_simulator_state,
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
# Write to temp file first (atomic write pattern)
|
|
113
|
+
try:
|
|
114
|
+
with self.checkpoint_tmp.open("w") as f:
|
|
115
|
+
json.dump(checkpoint, f, indent=2, cls=Encoder)
|
|
116
|
+
|
|
117
|
+
# Backup existing checkpoint if it exists
|
|
118
|
+
if self.checkpoint_file.exists():
|
|
119
|
+
try:
|
|
120
|
+
# Remove old backup if it exists (to avoid rename conflicts)
|
|
121
|
+
if self.checkpoint_backup.exists():
|
|
122
|
+
self.checkpoint_backup.unlink()
|
|
123
|
+
self.checkpoint_file.rename(self.checkpoint_backup)
|
|
124
|
+
except OSError as e:
|
|
125
|
+
logger.warning("Failed to backup previous checkpoint: %s", e)
|
|
126
|
+
|
|
127
|
+
# Move temp to final checkpoint
|
|
128
|
+
self.checkpoint_tmp.rename(self.checkpoint_file)
|
|
129
|
+
|
|
130
|
+
logger.debug(
|
|
131
|
+
"Checkpoint saved: batch_index=%d, completed=%d batches",
|
|
132
|
+
last_completed_batch_index,
|
|
133
|
+
len(completed_batch_indices),
|
|
134
|
+
)
|
|
135
|
+
except OSError as e:
|
|
136
|
+
logger.error("Failed to save checkpoint: %s", e)
|
|
137
|
+
# Clean up temp file if it exists
|
|
138
|
+
if self.checkpoint_tmp.exists():
|
|
139
|
+
try:
|
|
140
|
+
self.checkpoint_tmp.unlink()
|
|
141
|
+
except OSError:
|
|
142
|
+
pass
|
|
143
|
+
raise
|
|
144
|
+
|
|
145
|
+
def cleanup(self) -> None:
|
|
146
|
+
"""Clean up checkpoint files after successful completion."""
|
|
147
|
+
# Remove both checkpoint and backup after successful completion
|
|
148
|
+
try:
|
|
149
|
+
if self.checkpoint_file.exists():
|
|
150
|
+
self.checkpoint_file.unlink()
|
|
151
|
+
logger.debug("Cleaned up checkpoint file")
|
|
152
|
+
if self.checkpoint_backup.exists():
|
|
153
|
+
self.checkpoint_backup.unlink()
|
|
154
|
+
logger.debug("Cleaned up checkpoint backup file")
|
|
155
|
+
except OSError as e:
|
|
156
|
+
logger.warning("Failed to clean up checkpoint files: %s", e)
|
|
157
|
+
|
|
158
|
+
def get_completed_batch_indices(self) -> set[int]:
|
|
159
|
+
"""Get set of completed batch indices from checkpoint.
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
Set of batch indices that have already been completed
|
|
163
|
+
"""
|
|
164
|
+
checkpoint = self.load_checkpoint()
|
|
165
|
+
if checkpoint is None:
|
|
166
|
+
return set()
|
|
167
|
+
return set(checkpoint.get("completed_batch_indices", []))
|
|
168
|
+
|
|
169
|
+
def should_skip_batch(self, batch_index: int) -> bool:
|
|
170
|
+
"""Check if a batch has already been completed.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
batch_index: Index of batch to check
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
True if batch was already completed, False otherwise
|
|
177
|
+
"""
|
|
178
|
+
return batch_index in self.get_completed_batch_indices()
|
|
@@ -0,0 +1,347 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utility functions to load and save configuration files.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
import yaml
|
|
12
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger("gwsim")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SimulatorOutputConfig(BaseModel):
|
|
18
|
+
"""Configuration for simulator output handling."""
|
|
19
|
+
|
|
20
|
+
file_name: str = Field(..., description="Output file name template (supports {{ variable }} placeholders)")
|
|
21
|
+
arguments: dict[str, Any] = Field(
|
|
22
|
+
default_factory=dict, description="Output-specific arguments (e.g., channel name)"
|
|
23
|
+
)
|
|
24
|
+
output_directory: str | None = Field(
|
|
25
|
+
default=None, description="Optional directory override for this simulator's output"
|
|
26
|
+
)
|
|
27
|
+
metadata_directory: str | None = Field(
|
|
28
|
+
default=None, description="Optional directory override for this simulator's metadata"
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
# Allow unknown fields
|
|
32
|
+
model_config = ConfigDict(extra="allow")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class SimulatorConfig(BaseModel):
|
|
36
|
+
"""Configuration for a single simulator."""
|
|
37
|
+
|
|
38
|
+
class_: str = Field(alias="class", description="Simulator class name or full import path")
|
|
39
|
+
arguments: dict[str, Any] = Field(default_factory=dict, description="Arguments passed to simulator constructor")
|
|
40
|
+
output: SimulatorOutputConfig = Field(
|
|
41
|
+
default_factory=lambda: SimulatorOutputConfig(file_name="output-{{counter}}.hdf5"),
|
|
42
|
+
description="Output configuration for this simulator",
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
model_config = ConfigDict(extra="allow", populate_by_name=True)
|
|
46
|
+
|
|
47
|
+
@field_validator("class_", mode="before")
|
|
48
|
+
@classmethod
|
|
49
|
+
def validate_class_name(cls, v: str) -> str:
|
|
50
|
+
"""Validate class specification is non-empty."""
|
|
51
|
+
if not isinstance(v, str) or not v.strip():
|
|
52
|
+
raise ValueError("'class' must be a non-empty string")
|
|
53
|
+
return v
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class GlobalsConfig(BaseModel):
|
|
57
|
+
"""Global configuration applying to all simulators.
|
|
58
|
+
|
|
59
|
+
This configuration provides universal directory settings and fallback arguments
|
|
60
|
+
for simulators and output handlers. The simulator_arguments and output_arguments
|
|
61
|
+
are agnostic to simulator type, supporting both time-series and population simulators.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
working_directory: str = Field(
|
|
65
|
+
default=".", alias="working-directory", description="Base working directory for all output"
|
|
66
|
+
)
|
|
67
|
+
output_directory: str | None = Field(
|
|
68
|
+
default=None, alias="output-directory", description="Default output directory (can be overridden per simulator)"
|
|
69
|
+
)
|
|
70
|
+
metadata_directory: str | None = Field(
|
|
71
|
+
default=None,
|
|
72
|
+
alias="metadata-directory",
|
|
73
|
+
description="Default metadata directory (can be overridden per simulator)",
|
|
74
|
+
)
|
|
75
|
+
simulator_arguments: dict[str, Any] = Field(
|
|
76
|
+
default_factory=dict,
|
|
77
|
+
alias="simulator-arguments",
|
|
78
|
+
description="Global default arguments for simulators (e.g., sampling-frequency, duration, seed). "
|
|
79
|
+
"Simulator-specific arguments override these.",
|
|
80
|
+
)
|
|
81
|
+
output_arguments: dict[str, Any] = Field(
|
|
82
|
+
default_factory=dict,
|
|
83
|
+
alias="output-arguments",
|
|
84
|
+
description="Global default arguments for output handlers (e.g., channel names). "
|
|
85
|
+
"Simulator-specific output arguments override these.",
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
model_config = ConfigDict(extra="allow", populate_by_name=True)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class Config(BaseModel):
|
|
92
|
+
"""Top-level configuration model."""
|
|
93
|
+
|
|
94
|
+
globals: GlobalsConfig = Field(default_factory=GlobalsConfig, description="Global configuration")
|
|
95
|
+
simulators: dict[str, SimulatorConfig] = Field(..., description="Dictionary of simulators")
|
|
96
|
+
|
|
97
|
+
model_config = ConfigDict(extra="allow", populate_by_name=True)
|
|
98
|
+
|
|
99
|
+
@field_validator("simulators", mode="before")
|
|
100
|
+
@classmethod
|
|
101
|
+
def validate_simulators_not_empty(cls, v: dict[str, Any]) -> dict[str, Any]:
|
|
102
|
+
"""Ensure simulators section is not empty."""
|
|
103
|
+
if not v:
|
|
104
|
+
raise ValueError("'simulators' section cannot be empty")
|
|
105
|
+
return v
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def load_config(file_name: Path, encoding: str = "utf-8") -> Config:
|
|
109
|
+
"""Load configuration file with validation.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
file_name (Path): File name.
|
|
113
|
+
encoding (str, optional): File encoding. Defaults to "utf-8".
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
Config: Validated configuration dataclass.
|
|
117
|
+
|
|
118
|
+
Raises:
|
|
119
|
+
FileNotFoundError: If the configuration file does not exist.
|
|
120
|
+
ValueError: If the configuration is invalid or cannot be parsed.
|
|
121
|
+
"""
|
|
122
|
+
if not file_name.exists():
|
|
123
|
+
raise FileNotFoundError(f"Configuration file not found: {file_name}")
|
|
124
|
+
try:
|
|
125
|
+
with file_name.open(encoding=encoding) as f:
|
|
126
|
+
raw_config = yaml.safe_load(f)
|
|
127
|
+
except yaml.YAMLError as e:
|
|
128
|
+
raise ValueError(f"Failed to parse YAML configuration: {e}") from e
|
|
129
|
+
|
|
130
|
+
if not isinstance(raw_config, dict):
|
|
131
|
+
raise ValueError("Configuration must be a YAML dictionary")
|
|
132
|
+
|
|
133
|
+
# Validate and convert to Config dataclass
|
|
134
|
+
try:
|
|
135
|
+
config = Config(**raw_config)
|
|
136
|
+
logger.info("Configuration loaded and validated: %s simulators", len(config.simulators))
|
|
137
|
+
return config
|
|
138
|
+
except ValueError as e:
|
|
139
|
+
raise ValueError(f"Configuration validation failed: {e}") from e
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def save_config(
|
|
143
|
+
file_name: Path, config: Config, overwrite: bool = False, encoding: str = "utf-8", backup: bool = True
|
|
144
|
+
) -> None:
|
|
145
|
+
"""Save configuration to YAML file safely.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
file_name: Path to save configuration to
|
|
149
|
+
config: Config dataclass instance
|
|
150
|
+
overwrite: If True, overwrite existing file
|
|
151
|
+
encoding: File encoding (default: utf-8)
|
|
152
|
+
backup: If True and overwriting, create backup
|
|
153
|
+
|
|
154
|
+
Raises:
|
|
155
|
+
FileExistsError: If file exists and overwrite=False
|
|
156
|
+
"""
|
|
157
|
+
if file_name.exists() and not overwrite:
|
|
158
|
+
raise FileExistsError(f"File already exists: {file_name}. Use overwrite=True to overwrite.")
|
|
159
|
+
|
|
160
|
+
# Create backup if needed
|
|
161
|
+
if file_name.exists() and overwrite and backup:
|
|
162
|
+
backup_path = file_name.with_suffix(f"{file_name.suffix}.backup")
|
|
163
|
+
logger.info("Creating backup: %s", backup_path)
|
|
164
|
+
backup_path.write_text(file_name.read_text(encoding=encoding), encoding=encoding)
|
|
165
|
+
|
|
166
|
+
# Atomic write
|
|
167
|
+
temp_file = file_name.with_suffix(f"{file_name.suffix}.tmp")
|
|
168
|
+
try:
|
|
169
|
+
# Convert to dict, excluding internal fields
|
|
170
|
+
config_dict = config.model_dump(by_alias=True, exclude_none=False)
|
|
171
|
+
|
|
172
|
+
with temp_file.open("w", encoding=encoding) as f:
|
|
173
|
+
yaml.safe_dump(config_dict, f, default_flow_style=False, sort_keys=False)
|
|
174
|
+
|
|
175
|
+
temp_file.replace(file_name)
|
|
176
|
+
logger.info("Configuration saved to: %s", file_name)
|
|
177
|
+
|
|
178
|
+
except Exception as e:
|
|
179
|
+
if temp_file.exists():
|
|
180
|
+
temp_file.unlink()
|
|
181
|
+
raise ValueError(f"Failed to save configuration: {e}") from e
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def validate_config(config: dict) -> None:
|
|
185
|
+
"""Validate configuration structure and provide helpful error messages.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
config (dict): Configuration dictionary to validate
|
|
189
|
+
|
|
190
|
+
Raises:
|
|
191
|
+
ValueError: If configuration is invalid with detailed error message
|
|
192
|
+
"""
|
|
193
|
+
# Check for required top-level structure
|
|
194
|
+
if "simulators" not in config:
|
|
195
|
+
raise ValueError("Invalid configuration: Must contain 'simulators' section with simulator definitions")
|
|
196
|
+
|
|
197
|
+
simulators = config["simulators"]
|
|
198
|
+
|
|
199
|
+
if not isinstance(simulators, dict):
|
|
200
|
+
raise ValueError("'simulators' must be a dictionary")
|
|
201
|
+
|
|
202
|
+
if not simulators:
|
|
203
|
+
raise ValueError("'simulators' section cannot be empty")
|
|
204
|
+
|
|
205
|
+
for name, sim_config in simulators.items():
|
|
206
|
+
if not isinstance(sim_config, dict):
|
|
207
|
+
raise ValueError(f"Simulator '{name}' configuration must be a dictionary")
|
|
208
|
+
|
|
209
|
+
# Check required fields
|
|
210
|
+
if "class" not in sim_config:
|
|
211
|
+
raise ValueError(f"Simulator '{name}' missing required 'class' field")
|
|
212
|
+
|
|
213
|
+
# Validate class specification
|
|
214
|
+
class_spec = sim_config["class"]
|
|
215
|
+
if not isinstance(class_spec, str) or not class_spec.strip():
|
|
216
|
+
raise ValueError(f"Simulator '{name}' 'class' must be a non-empty string")
|
|
217
|
+
|
|
218
|
+
# Validate arguments if present
|
|
219
|
+
if "arguments" in sim_config and not isinstance(sim_config["arguments"], dict):
|
|
220
|
+
raise ValueError(f"Simulator '{name}' 'arguments' must be a dictionary")
|
|
221
|
+
|
|
222
|
+
# Validate output configuration if present
|
|
223
|
+
if "output" in sim_config:
|
|
224
|
+
output_config = sim_config["output"]
|
|
225
|
+
if not isinstance(output_config, dict):
|
|
226
|
+
raise ValueError(f"Simulator '{name}' 'output' must be a dictionary")
|
|
227
|
+
|
|
228
|
+
# Validate globals section if present
|
|
229
|
+
if "globals" in config:
|
|
230
|
+
globals_config = config["globals"]
|
|
231
|
+
if not isinstance(globals_config, dict):
|
|
232
|
+
raise ValueError("'globals' must be a dictionary")
|
|
233
|
+
|
|
234
|
+
logger.info("Configuration validation passed")
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def resolve_class_path(class_spec: str, section_name: str | None) -> str:
|
|
238
|
+
"""Resolve class specification to full module path.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
class_spec: Either 'ClassName' or 'third_party.module.ClassName'
|
|
242
|
+
section_name: Section name (e.g., 'noise', 'signal', 'glitch')
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
Full path like 'gwsim.noise.ClassName' or 'third_party.module.ClassName'
|
|
246
|
+
|
|
247
|
+
Examples:
|
|
248
|
+
resolve_class_path("WhiteNoise", "noise") -> "gwsim.noise.WhiteNoise"
|
|
249
|
+
resolve_class_path("numpy.random.Generator", "noise") -> "numpy.random.Generator"
|
|
250
|
+
"""
|
|
251
|
+
if "." not in class_spec and section_name:
|
|
252
|
+
# Just a class name - use section_name as submodule, class imported in __init__.py
|
|
253
|
+
return f"gwsim.{section_name}.{class_spec}"
|
|
254
|
+
# Contains dots - assume it's a third-party package, use as-is
|
|
255
|
+
return class_spec
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def merge_parameters(globals_config: GlobalsConfig, simulator_args: dict[str, Any]) -> dict[str, Any]:
|
|
259
|
+
"""Merge global and simulator-specific parameters.
|
|
260
|
+
|
|
261
|
+
Flattens simulator_arguments from globals into the result, then applies
|
|
262
|
+
simulator-specific overrides.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
globals_config: GlobalsConfig dataclass instance
|
|
266
|
+
simulator_args: Simulator-specific arguments dict
|
|
267
|
+
|
|
268
|
+
Returns:
|
|
269
|
+
Merged parameters with simulator args taking precedence
|
|
270
|
+
|
|
271
|
+
Note:
|
|
272
|
+
Simulator_arguments from globals_config are flattened into the result.
|
|
273
|
+
Directory settings (working-directory, output-directory, metadata-directory)
|
|
274
|
+
are included. Output_arguments are not included (handled separately).
|
|
275
|
+
"""
|
|
276
|
+
# Start with directory settings from globals
|
|
277
|
+
merged = {}
|
|
278
|
+
if globals_config.working_directory:
|
|
279
|
+
merged["working-directory"] = globals_config.working_directory
|
|
280
|
+
if globals_config.output_directory:
|
|
281
|
+
merged["output-directory"] = globals_config.output_directory
|
|
282
|
+
if globals_config.metadata_directory:
|
|
283
|
+
merged["metadata-directory"] = globals_config.metadata_directory
|
|
284
|
+
|
|
285
|
+
# Flatten simulator_arguments from globals
|
|
286
|
+
merged.update(globals_config.simulator_arguments)
|
|
287
|
+
|
|
288
|
+
# Override with simulator-specific arguments (takes precedence)
|
|
289
|
+
merged.update(simulator_args)
|
|
290
|
+
|
|
291
|
+
return merged
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def get_output_directories(
|
|
295
|
+
globals_config: GlobalsConfig,
|
|
296
|
+
simulator_config: SimulatorConfig,
|
|
297
|
+
simulator_name: str,
|
|
298
|
+
working_directory: Path | None = None,
|
|
299
|
+
) -> tuple[Path, Path]:
|
|
300
|
+
"""Get output and metadata directories for a simulator.
|
|
301
|
+
|
|
302
|
+
Args:
|
|
303
|
+
globals_config: Global configuration
|
|
304
|
+
simulator_config: Simulator-specific configuration
|
|
305
|
+
simulator_name: Name of the simulator
|
|
306
|
+
working_directory: Override working directory (for testing)
|
|
307
|
+
|
|
308
|
+
Returns:
|
|
309
|
+
Tuple of (output_directory, metadata_directory)
|
|
310
|
+
|
|
311
|
+
Priority (highest to lowest):
|
|
312
|
+
1. Simulator output.output_directory / output.metadata_directory
|
|
313
|
+
2. Global output-directory / metadata-directory
|
|
314
|
+
3. working-directory / output / {simulator_name}
|
|
315
|
+
|
|
316
|
+
Examples:
|
|
317
|
+
>>> globals_cfg = GlobalsConfig(working_directory="/data")
|
|
318
|
+
>>> sim_cfg = SimulatorConfig(class_="Noise")
|
|
319
|
+
>>> get_output_directories(globals_cfg, sim_cfg, "noise")
|
|
320
|
+
(Path("/data/output/noise"), Path("/data/output/noise"))
|
|
321
|
+
"""
|
|
322
|
+
working_dir = working_directory or Path(globals_config.working_directory)
|
|
323
|
+
|
|
324
|
+
# Simulator-specific overrides
|
|
325
|
+
if simulator_config.output.output_directory:
|
|
326
|
+
output_path = Path(simulator_config.output.output_directory)
|
|
327
|
+
# Prepend working_dir if path is relative
|
|
328
|
+
output_directory = output_path if output_path.is_absolute() else working_dir / output_path
|
|
329
|
+
elif globals_config.output_directory:
|
|
330
|
+
output_path = Path(globals_config.output_directory)
|
|
331
|
+
# Prepend working_dir if path is relative
|
|
332
|
+
output_directory = output_path if output_path.is_absolute() else working_dir / output_path
|
|
333
|
+
else:
|
|
334
|
+
output_directory = working_dir / "output" / simulator_name
|
|
335
|
+
|
|
336
|
+
if simulator_config.output.metadata_directory:
|
|
337
|
+
metadata_path = Path(simulator_config.output.metadata_directory)
|
|
338
|
+
# Prepend working_dir if path is relative
|
|
339
|
+
metadata_directory = metadata_path if metadata_path.is_absolute() else working_dir / metadata_path
|
|
340
|
+
elif globals_config.metadata_directory:
|
|
341
|
+
metadata_path = Path(globals_config.metadata_directory)
|
|
342
|
+
# Prepend working_dir if path is relative
|
|
343
|
+
metadata_directory = metadata_path if metadata_path.is_absolute() else working_dir / metadata_path
|
|
344
|
+
else:
|
|
345
|
+
metadata_directory = working_dir / "metadata" / simulator_name
|
|
346
|
+
|
|
347
|
+
return output_directory, metadata_directory
|
gwsim/cli/utils/hash.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""Contains utility functions for hashing operations."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import hashlib
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def compute_file_hash(file_path: str | Path, algorithm: str = "sha256") -> str:
|
|
10
|
+
"""Compute the hash of a file using the specified algorithm.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
file_path: Path to the file.
|
|
14
|
+
algorithm: Hashing algorithm to use (default is 'sha256').
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
The computed hash as a hexadecimal string.
|
|
18
|
+
"""
|
|
19
|
+
hash_func = hashlib.new(algorithm)
|
|
20
|
+
with open(file_path, "rb") as f:
|
|
21
|
+
for chunk in iter(lambda: f.read(8192), b""):
|
|
22
|
+
hash_func.update(chunk)
|
|
23
|
+
return f"{algorithm}:{hash_func.hexdigest()}"
|
gwsim/cli/utils/retry.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
"""Utility functions for retrying operations with exponential backoff."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import time
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger("gwsim")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class RetryManager:
|
|
13
|
+
"""Manages retry logic with exponential backoff."""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
max_retries: int = 3,
|
|
18
|
+
base_delay: float = 1.0,
|
|
19
|
+
retryable_exceptions: tuple[type[Exception], ...] = (
|
|
20
|
+
OSError,
|
|
21
|
+
PermissionError,
|
|
22
|
+
FileNotFoundError,
|
|
23
|
+
RuntimeError,
|
|
24
|
+
ValueError,
|
|
25
|
+
),
|
|
26
|
+
):
|
|
27
|
+
"""Initialize the RetryManager.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
max_retries: Maximum number of retries.
|
|
31
|
+
base_delay: Base delay in seconds for exponential backoff.
|
|
32
|
+
retryable_exceptions: Tuple of exception types that are considered retryable.
|
|
33
|
+
"""
|
|
34
|
+
self.max_retries = max_retries
|
|
35
|
+
self.base_delay = base_delay
|
|
36
|
+
self.retryable_exception = retryable_exceptions
|
|
37
|
+
|
|
38
|
+
def retry_with_backoff(self, operation, *args, **kwargs) -> Any | None:
|
|
39
|
+
"""Retry operation with exponential backoff."""
|
|
40
|
+
for attempt in range(self.max_retries + 1):
|
|
41
|
+
try:
|
|
42
|
+
return operation(*args, **kwargs)
|
|
43
|
+
except self.retryable_exception as e:
|
|
44
|
+
if attempt == self.max_retries:
|
|
45
|
+
logger.error("Operation failed after %s retries: %s", self.max_retries, e)
|
|
46
|
+
raise
|
|
47
|
+
|
|
48
|
+
delay = self.base_delay * (2**attempt)
|
|
49
|
+
logger.warning("Attempt %s failed: %s. Retrying in %ss...", attempt + 1, e, delay)
|
|
50
|
+
time.sleep(delay)
|
|
51
|
+
return None
|
|
52
|
+
|
|
53
|
+
def is_retryable_exception(self, exception: Exception) -> bool:
|
|
54
|
+
"""Check if an exception is retryable.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
exception: The exception to check
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
bool: True if the exception is retryable, False otherwise
|
|
61
|
+
"""
|
|
62
|
+
return isinstance(exception, self.retryable_exception)
|