gwsim 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. gwsim/__init__.py +11 -0
  2. gwsim/__main__.py +8 -0
  3. gwsim/cli/__init__.py +0 -0
  4. gwsim/cli/config.py +88 -0
  5. gwsim/cli/default_config.py +56 -0
  6. gwsim/cli/main.py +101 -0
  7. gwsim/cli/merge.py +150 -0
  8. gwsim/cli/repository/__init__.py +0 -0
  9. gwsim/cli/repository/create.py +91 -0
  10. gwsim/cli/repository/delete.py +51 -0
  11. gwsim/cli/repository/download.py +54 -0
  12. gwsim/cli/repository/list_depositions.py +63 -0
  13. gwsim/cli/repository/main.py +38 -0
  14. gwsim/cli/repository/metadata/__init__.py +0 -0
  15. gwsim/cli/repository/metadata/main.py +24 -0
  16. gwsim/cli/repository/metadata/update.py +58 -0
  17. gwsim/cli/repository/publish.py +52 -0
  18. gwsim/cli/repository/upload.py +74 -0
  19. gwsim/cli/repository/utils.py +47 -0
  20. gwsim/cli/repository/verify.py +61 -0
  21. gwsim/cli/simulate.py +220 -0
  22. gwsim/cli/simulate_utils.py +596 -0
  23. gwsim/cli/utils/__init__.py +85 -0
  24. gwsim/cli/utils/checkpoint.py +178 -0
  25. gwsim/cli/utils/config.py +347 -0
  26. gwsim/cli/utils/hash.py +23 -0
  27. gwsim/cli/utils/retry.py +62 -0
  28. gwsim/cli/utils/simulation_plan.py +439 -0
  29. gwsim/cli/utils/template.py +56 -0
  30. gwsim/cli/utils/utils.py +149 -0
  31. gwsim/cli/validate.py +255 -0
  32. gwsim/data/__init__.py +8 -0
  33. gwsim/data/serialize/__init__.py +9 -0
  34. gwsim/data/serialize/decoder.py +59 -0
  35. gwsim/data/serialize/encoder.py +44 -0
  36. gwsim/data/serialize/serializable.py +33 -0
  37. gwsim/data/time_series/__init__.py +3 -0
  38. gwsim/data/time_series/inject.py +104 -0
  39. gwsim/data/time_series/time_series.py +355 -0
  40. gwsim/data/time_series/time_series_list.py +182 -0
  41. gwsim/detector/__init__.py +8 -0
  42. gwsim/detector/base.py +156 -0
  43. gwsim/detector/detectors/E1_2L_Aligned_Sardinia.interferometer +22 -0
  44. gwsim/detector/detectors/E1_2L_Misaligned_Sardinia.interferometer +22 -0
  45. gwsim/detector/detectors/E1_Triangle_EMR.interferometer +19 -0
  46. gwsim/detector/detectors/E1_Triangle_Sardinia.interferometer +19 -0
  47. gwsim/detector/detectors/E2_2L_Aligned_EMR.interferometer +22 -0
  48. gwsim/detector/detectors/E2_2L_Misaligned_EMR.interferometer +22 -0
  49. gwsim/detector/detectors/E2_Triangle_EMR.interferometer +19 -0
  50. gwsim/detector/detectors/E2_Triangle_Sardinia.interferometer +19 -0
  51. gwsim/detector/detectors/E3_Triangle_EMR.interferometer +19 -0
  52. gwsim/detector/detectors/E3_Triangle_Sardinia.interferometer +19 -0
  53. gwsim/detector/noise_curves/ET_10_HF_psd.txt +3000 -0
  54. gwsim/detector/noise_curves/ET_10_full_cryo_psd.txt +3000 -0
  55. gwsim/detector/noise_curves/ET_15_HF_psd.txt +3000 -0
  56. gwsim/detector/noise_curves/ET_15_full_cryo_psd.txt +3000 -0
  57. gwsim/detector/noise_curves/ET_20_HF_psd.txt +3000 -0
  58. gwsim/detector/noise_curves/ET_20_full_cryo_psd.txt +3000 -0
  59. gwsim/detector/noise_curves/ET_D_psd.txt +3000 -0
  60. gwsim/detector/utils.py +90 -0
  61. gwsim/glitch/__init__.py +7 -0
  62. gwsim/glitch/base.py +69 -0
  63. gwsim/mixin/__init__.py +8 -0
  64. gwsim/mixin/detector.py +203 -0
  65. gwsim/mixin/gwf.py +192 -0
  66. gwsim/mixin/population_reader.py +175 -0
  67. gwsim/mixin/randomness.py +107 -0
  68. gwsim/mixin/time_series.py +295 -0
  69. gwsim/mixin/waveform.py +47 -0
  70. gwsim/noise/__init__.py +19 -0
  71. gwsim/noise/base.py +134 -0
  72. gwsim/noise/bilby_stationary_gaussian.py +117 -0
  73. gwsim/noise/colored_noise.py +275 -0
  74. gwsim/noise/correlated_noise.py +257 -0
  75. gwsim/noise/pycbc_stationary_gaussian.py +112 -0
  76. gwsim/noise/stationary_gaussian.py +44 -0
  77. gwsim/noise/white_noise.py +51 -0
  78. gwsim/repository/__init__.py +0 -0
  79. gwsim/repository/zenodo.py +269 -0
  80. gwsim/signal/__init__.py +11 -0
  81. gwsim/signal/base.py +137 -0
  82. gwsim/signal/cbc.py +61 -0
  83. gwsim/simulator/__init__.py +7 -0
  84. gwsim/simulator/base.py +315 -0
  85. gwsim/simulator/state.py +85 -0
  86. gwsim/utils/__init__.py +11 -0
  87. gwsim/utils/datetime_parser.py +44 -0
  88. gwsim/utils/et_2l_geometry.py +165 -0
  89. gwsim/utils/io.py +167 -0
  90. gwsim/utils/log.py +145 -0
  91. gwsim/utils/population.py +48 -0
  92. gwsim/utils/random.py +69 -0
  93. gwsim/utils/retry.py +75 -0
  94. gwsim/utils/triangular_et_geometry.py +164 -0
  95. gwsim/version.py +7 -0
  96. gwsim/waveform/__init__.py +7 -0
  97. gwsim/waveform/factory.py +83 -0
  98. gwsim/waveform/pycbc_wrapper.py +37 -0
  99. gwsim-0.1.0.dist-info/METADATA +157 -0
  100. gwsim-0.1.0.dist-info/RECORD +103 -0
  101. gwsim-0.1.0.dist-info/WHEEL +4 -0
  102. gwsim-0.1.0.dist-info/entry_points.txt +2 -0
  103. gwsim-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,596 @@
1
+ """
2
+ Utilities for executing simulation plans via CLI.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import atexit
8
+ import copy
9
+ import logging
10
+ import signal
11
+ import time
12
+ from pathlib import Path
13
+ from typing import Any
14
+
15
+ import yaml
16
+ from tqdm import tqdm
17
+
18
+ from gwsim.cli.utils.checkpoint import CheckpointManager
19
+ from gwsim.cli.utils.config import SimulatorConfig, resolve_class_path
20
+ from gwsim.cli.utils.hash import compute_file_hash
21
+ from gwsim.cli.utils.simulation_plan import (
22
+ SimulationBatch,
23
+ SimulationPlan,
24
+ create_batch_metadata,
25
+ )
26
+ from gwsim.cli.utils.utils import handle_signal, import_attribute
27
+ from gwsim.simulator.base import Simulator
28
+ from gwsim.utils.io import get_file_name_from_template
29
+
30
+ logger = logging.getLogger("gwsim")
31
+ logger.setLevel(logging.DEBUG)
32
+
33
+
34
+ def retry_with_backoff(
35
+ func,
36
+ max_retries: int = 3,
37
+ initial_delay: float = 0.1,
38
+ backoff_factor: float = 2.0,
39
+ state_restore_func: Any = None,
40
+ ) -> Any:
41
+ """Retry a function with exponential backoff and optional state restoration.
42
+
43
+ Args:
44
+ func: Callable to retry
45
+ max_retries: Maximum number of retries
46
+ initial_delay: Initial delay in seconds
47
+ backoff_factor: Multiplier for delay after each retry
48
+ state_restore_func: Optional callable to restore state before each retry.
49
+ Called before each retry attempt (not before first attempt).
50
+
51
+ Returns:
52
+ Result of function call
53
+
54
+ Raises:
55
+ Exception: If all retries fail
56
+ """
57
+ delay = initial_delay
58
+ last_exception: Exception | None = None
59
+
60
+ for attempt in range(max_retries + 1):
61
+ try:
62
+ return func()
63
+ except Exception as e: # pylint: disable=broad-exception-caught
64
+ last_exception = e
65
+ if attempt < max_retries:
66
+ logger.warning(
67
+ "Attempt %d/%d failed: %s. Retrying in %.2fs...",
68
+ attempt + 1,
69
+ max_retries + 1,
70
+ str(e),
71
+ delay,
72
+ exc_info=e,
73
+ )
74
+ time.sleep(delay)
75
+ delay *= backoff_factor
76
+
77
+ # Restore state before retry if function provided
78
+ if state_restore_func is not None:
79
+ try:
80
+ state_restore_func()
81
+ logger.debug("State restored before retry attempt %d", attempt + 2)
82
+ except Exception as restore_error:
83
+ logger.error("Failed to restore state before retry: %s", restore_error)
84
+ raise RuntimeError(f"Cannot retry: failed to restore state: {restore_error}") from restore_error
85
+ else:
86
+ logger.error("All %d attempts failed for batch: %s", max_retries + 1, str(e))
87
+
88
+ if last_exception is not None:
89
+ raise last_exception
90
+ raise RuntimeError("Unexpected retry failure")
91
+
92
+
93
+ def update_metadata_index(
94
+ metadata_directory: Path,
95
+ output_files: list[Path],
96
+ metadata_file_name: str,
97
+ encoding: str = "utf-8",
98
+ ) -> None:
99
+ """Update the central metadata index file.
100
+
101
+ The index maps data file names to their corresponding metadata files,
102
+ enabling O(1) lookup to find metadata for a given data file.
103
+
104
+ Args:
105
+ metadata_directory: Directory where metadata files are stored
106
+ output_files: List of output data file Paths
107
+ metadata_file_name: Name of the metadata file (e.g., "signal-0.metadata.yaml")
108
+ encoding: File encoding for reading/writing the index file
109
+ """
110
+ index_file = metadata_directory / "index.yaml"
111
+
112
+ # Load existing index or create new
113
+ if index_file.exists():
114
+ try:
115
+ with index_file.open(encoding=encoding) as f:
116
+ index = yaml.safe_load(f) or {}
117
+ except (OSError, yaml.YAMLError) as e:
118
+ logger.warning("Failed to load metadata index: %s. Creating new index.", e)
119
+ index = {}
120
+ else:
121
+ index = {}
122
+
123
+ # Add entries for all output files
124
+ for output_file in output_files:
125
+ index[output_file.name] = metadata_file_name
126
+ logger.debug("Index entry: %s -> %s", output_file.name, metadata_file_name)
127
+
128
+ # Save updated index
129
+ try:
130
+ with index_file.open("w") as f:
131
+ yaml.safe_dump(index, f, default_flow_style=False, sort_keys=True)
132
+ logger.debug("Updated metadata index: %s", index_file)
133
+ except (OSError, yaml.YAMLError) as e:
134
+ logger.error("Failed to save metadata index: %s", e)
135
+ raise
136
+
137
+
138
+ def instantiate_simulator(
139
+ simulator_config: SimulatorConfig,
140
+ simulator_name: str | None = None,
141
+ global_simulator_arguments: dict[str, Any] | None = None,
142
+ ) -> Simulator:
143
+ """Instantiate a simulator from configuration.
144
+
145
+ Creates a single simulator instance that will be reused across multiple batches.
146
+ The simulator maintains state (RNG, counters, etc.) across iterations.
147
+
148
+ Global simulator arguments are merged with simulator-specific arguments,
149
+ with simulator-specific arguments taking precedence.
150
+
151
+ Args:
152
+ simulator_config: Configuration for this simulator
153
+ simulator_name: Name of the simulator (used for class path resolution)
154
+ global_simulator_arguments: Global fallback arguments for the simulator
155
+
156
+ Returns:
157
+ Instantiated Simulator
158
+
159
+ Raises:
160
+ ImportError: If simulator class cannot be imported
161
+ TypeError: If simulator instantiation fails
162
+ """
163
+ class_spec = simulator_config.class_
164
+
165
+ # Resolve short class names to full paths
166
+ class_spec = resolve_class_path(class_spec, simulator_name)
167
+
168
+ simulator_cls = import_attribute(class_spec)
169
+
170
+ # Merge global and simulator-specific arguments
171
+ # Simulator-specific arguments override global defaults
172
+ if global_simulator_arguments:
173
+ merged_arguments = {**global_simulator_arguments, **simulator_config.arguments}
174
+ else:
175
+ merged_arguments = simulator_config.arguments
176
+
177
+ # Normalize keys: convert hyphens to underscores (YAML uses hyphens, Python uses underscores)
178
+ normalized_arguments = {k.replace("-", "_"): v for k, v in merged_arguments.items()}
179
+
180
+ simulator = simulator_cls(**normalized_arguments)
181
+
182
+ logger.info("Instantiated simulator from class %s", class_spec)
183
+ return simulator
184
+
185
+
186
+ def restore_batch_state(simulator: Simulator, batch: SimulationBatch) -> None:
187
+ """Restore simulator state from batch metadata if available.
188
+
189
+ This is used when reproducing a specific batch. It restores the RNG state,
190
+ filter memory, and other stateful components that existed before this batch
191
+ was generated.
192
+
193
+ Args:
194
+ simulator: Simulator instance
195
+ batch: SimulationBatch potentially containing state snapshot
196
+
197
+ Raises:
198
+ ValueError: If state restoration fails
199
+ """
200
+ if batch.has_state_snapshot() and batch.pre_batch_state is not None:
201
+ logger.debug(
202
+ "[RESTORE] Batch %d: Restoring state from snapshot - state_keys=%s",
203
+ batch.batch_index,
204
+ list(batch.pre_batch_state.keys()),
205
+ )
206
+ try:
207
+ logger.debug(
208
+ "[RESTORE] Batch %d: Setting state dict - counter=%s",
209
+ batch.batch_index,
210
+ batch.pre_batch_state.get("counter"),
211
+ )
212
+ simulator.state = batch.pre_batch_state
213
+ logger.debug(
214
+ "[RESTORE] Batch %d: State restored successfully - new_counter=%s",
215
+ batch.batch_index,
216
+ simulator.counter,
217
+ )
218
+ except Exception as e:
219
+ logger.error("Failed to restore batch state: %s", e)
220
+ raise ValueError(f"Failed to restore state for batch {batch.batch_index}") from e
221
+ else:
222
+ logger.debug(
223
+ "[RESTORE] Batch %d: No pre-batch state snapshot available (fresh generation)",
224
+ batch.batch_index,
225
+ )
226
+
227
+
228
+ def save_batch_metadata(
229
+ simulator: Simulator,
230
+ batch: SimulationBatch,
231
+ metadata_directory: Path,
232
+ output_files: list[Path],
233
+ pre_batch_state: dict[str, Any] | None = None,
234
+ ) -> None:
235
+ """Save batch metadata including pre-batch simulator state and all output files.
236
+
237
+ The metadata file uses batch-indexed naming ({simulator_name}-{batch_index}.metadata.yaml)
238
+ to provide a single source of truth for all outputs from that batch. This handles
239
+ cases where a single batch generates multiple output files (e.g., one per detector).
240
+
241
+ An index file is also maintained to enable quick lookup of metadata for a given data file.
242
+
243
+ Args:
244
+ simulator: Simulator instance
245
+ batch: SimulationBatch
246
+ metadata_directory: Directory to save metadata
247
+ output_files: List of Path objects for all output files generated by this batch
248
+ pre_batch_state: State of simulator before batch generation (for reproducibility).
249
+ If None, uses current simulator state.
250
+ """
251
+ metadata_directory.mkdir(parents=True, exist_ok=True)
252
+
253
+ # Use provided pre_batch_state or current simulator state
254
+ state_to_save = pre_batch_state if pre_batch_state is not None else simulator.state
255
+
256
+ metadata = create_batch_metadata(
257
+ simulator_name=batch.simulator_name,
258
+ batch_index=batch.batch_index,
259
+ simulator_config=batch.simulator_config,
260
+ globals_config=batch.globals_config,
261
+ pre_batch_state=state_to_save,
262
+ source=batch.source,
263
+ author=batch.author,
264
+ email=batch.email,
265
+ )
266
+
267
+ # Add output files to metadata for easy discovery
268
+ # Store just the file names, not full paths
269
+ metadata["output_files"] = [f.name for f in output_files]
270
+
271
+ # Compute and add file hashes for integrity checking
272
+ file_hashes = {}
273
+ for output_file in output_files:
274
+ try:
275
+ file_hash = compute_file_hash(output_file)
276
+ file_hashes[output_file.name] = file_hash
277
+ logger.debug("Compute hash for %s: %s", output_file.name, file_hash)
278
+ except OSError as e:
279
+ logger.warning("Failed to compute hash for %s: %s", output_file.name, e)
280
+ # Continue without failing - metadata is still useful
281
+
282
+ metadata["file_hashes"] = file_hashes
283
+
284
+ metadata_file_name = f"{batch.simulator_name}-{batch.batch_index}.metadata.yaml"
285
+ metadata_file = metadata_directory / metadata_file_name
286
+ logger.debug("Saving batch metadata to %s with %d output files", metadata_file, len(output_files))
287
+
288
+ with metadata_file.open("w") as f:
289
+ yaml.safe_dump(metadata, f)
290
+
291
+ # Update the metadata index for quick lookup
292
+ update_metadata_index(metadata_directory, output_files, metadata_file_name)
293
+
294
+
295
+ def process_batch(
296
+ simulator: Simulator,
297
+ batch_data: object,
298
+ batch: SimulationBatch,
299
+ output_directory: Path,
300
+ overwrite: bool,
301
+ ) -> list[Path]:
302
+ """Process and save a single batch of generated data.
303
+
304
+ A single batch may generate multiple output files (e.g., one per detector).
305
+ This function handles both single and multiple output files.
306
+
307
+ Args:
308
+ simulator: Simulator instance
309
+ batch_data: Generated batch data (may contain multiple outputs)
310
+ batch: SimulationBatch metadata
311
+ output_directory: Directory for output files
312
+ overwrite: Whether to overwrite existing files
313
+
314
+ Returns:
315
+ List of Path objects for all generated output files
316
+ """
317
+ output_directory.mkdir(parents=True, exist_ok=True)
318
+ logger.debug(
319
+ "[PROCESS] Batch %s: Saving data - counter=%s, file_template=%s",
320
+ batch.batch_index,
321
+ simulator.counter,
322
+ batch.simulator_config.output.file_name,
323
+ )
324
+
325
+ # Build output configuration
326
+ output_config = batch.simulator_config.output
327
+ file_name_template = output_config.file_name
328
+ output_args = output_config.arguments.copy() if output_config.arguments else {}
329
+
330
+ # Save data with output directory
331
+ logger.debug(
332
+ "Saving batch data for %s batch %d",
333
+ batch.simulator_name,
334
+ batch.batch_index,
335
+ )
336
+
337
+ # Resolve the output file names (may be multiple if template contains arrays)
338
+ output_files = get_file_name_from_template(
339
+ template=file_name_template,
340
+ instance=simulator,
341
+ output_directory=output_directory,
342
+ )
343
+
344
+ # Normalize to list of Paths
345
+ if isinstance(output_files, Path):
346
+ output_files_list = [output_files]
347
+ else:
348
+ # If it's an array (multiple detectors), flatten it
349
+ output_files_list = [Path(str(f)) for f in output_files.flatten()]
350
+
351
+ logger.debug(
352
+ "[PROCESS] Batch %s: Resolved filenames - %s", batch.batch_index, [str(f.name) for f in output_files_list]
353
+ )
354
+
355
+ simulator.save_data(
356
+ data=batch_data,
357
+ file_name=file_name_template,
358
+ output_directory=output_directory,
359
+ overwrite=overwrite,
360
+ **output_args,
361
+ )
362
+
363
+ logger.debug("[PROCESS] Batch %s: Data saved - counter=%s", batch.batch_index, simulator.counter)
364
+
365
+ return output_files_list
366
+
367
+
368
+ def setup_signal_handlers(checkpoint_dirs: list[Path]) -> None:
369
+ """Set up signal handlers for graceful shutdown.
370
+
371
+ Args:
372
+ checkpoint_dirs: List of checkpoint directories to clean up
373
+ """
374
+
375
+ def cleanup_checkpoints():
376
+ """Clean up temporary checkpoint files."""
377
+ for checkpoint_dir in checkpoint_dirs:
378
+ for backup_file in checkpoint_dir.glob("*.bak"):
379
+ try:
380
+ backup_file.unlink()
381
+ logger.debug("Cleaned up backup file: %s", backup_file)
382
+ except OSError as e:
383
+ logger.warning("Failed to clean up backup file %s: %s", backup_file, e)
384
+
385
+ atexit.register(cleanup_checkpoints)
386
+ signal.signal(signal.SIGINT, handle_signal(cleanup_checkpoints))
387
+ signal.signal(signal.SIGTERM, handle_signal(cleanup_checkpoints))
388
+
389
+
390
+ def validate_plan(plan: SimulationPlan) -> None:
391
+ """Validate simulation plan before execution.
392
+
393
+ Args:
394
+ plan: SimulationPlan to validate
395
+
396
+ Raises:
397
+ ValueError: If plan validation fails
398
+ """
399
+ logger.info("Validating simulation plan with %d batches", plan.total_batches)
400
+
401
+ if plan.total_batches == 0:
402
+ raise ValueError("Simulation plan contains no batches")
403
+
404
+ # Validate each batch
405
+ for batch in plan.batches:
406
+ if not batch.simulator_name:
407
+ raise ValueError("Batch has empty simulator name")
408
+ if batch.batch_index < 0:
409
+ raise ValueError(f"Batch {batch.batch_index} has invalid index")
410
+
411
+ # Validate output configuration
412
+ output_config = batch.simulator_config.output
413
+ if not output_config.file_name:
414
+ raise ValueError(f"Batch {batch.simulator_name}-{batch.batch_index} missing file_name")
415
+
416
+ logger.info("Simulation plan validation completed successfully")
417
+
418
+
419
+ def execute_plan( # pylint: disable=too-many-locals
420
+ plan: SimulationPlan,
421
+ output_directory: Path,
422
+ metadata_directory: Path,
423
+ overwrite: bool,
424
+ max_retries: int = 3,
425
+ ) -> None:
426
+ """Execute a complete simulation plan.
427
+
428
+ The key insight: Simulators are stateful objects. Each simulator is instantiated
429
+ once and then generates multiple batches by calling next() repeatedly. State
430
+ (RNG, counters, filters) accumulates across batches.
431
+
432
+ Checkpoint behavior:
433
+ 1. After each successfully completed batch, save checkpoint with updated state
434
+ 2. Checkpoint contains: completed batch indices, simulator state
435
+ 3. On next run, already-completed batches are skipped (resumption)
436
+ 4. On successful completion of all batches, checkpoint is cleaned up
437
+
438
+ Workflow:
439
+ 1. Group batches by simulator name
440
+ 2. Load checkpoint to find already-completed batches
441
+ 3. For each simulator:
442
+ a. Create ONE simulator instance
443
+ b. For each batch of that simulator:
444
+ - Skip if already completed (from checkpoint)
445
+ - Restore state if reproducing from metadata
446
+ - Call next(simulator) to generate batch (increments state)
447
+ - Save batch output and metadata
448
+ - Save checkpoint with updated state (for resumption)
449
+
450
+ Args:
451
+ plan: SimulationPlan to execute
452
+ output_directory: Directory for output files
453
+ metadata_directory: Directory for metadata files
454
+ overwrite: Whether to overwrite existing files
455
+ max_retries: Maximum retries per batch
456
+ """
457
+ logger.info("Executing simulation plan: %d batches", plan.total_batches)
458
+
459
+ validate_plan(plan)
460
+ setup_signal_handlers([plan.checkpoint_directory] if plan.checkpoint_directory else [])
461
+
462
+ # Initialize checkpoint manager for resumption support
463
+ checkpoint_manager = CheckpointManager(plan.checkpoint_directory)
464
+ completed_batch_indices = checkpoint_manager.get_completed_batch_indices()
465
+
466
+ if completed_batch_indices:
467
+ logger.info("Loaded checkpoint: %d batches already completed", len(completed_batch_indices))
468
+ else:
469
+ logger.debug("No checkpoint found or no batches completed yet")
470
+
471
+ # Group batches by simulator name to execute sequentially per simulator
472
+ simulator_batches: dict[str, list[SimulationBatch]] = {}
473
+ for batch in plan.batches:
474
+ if batch.simulator_name not in simulator_batches:
475
+ simulator_batches[batch.simulator_name] = []
476
+ simulator_batches[batch.simulator_name].append(batch)
477
+
478
+ logger.info("Executing %d simulators", len(simulator_batches))
479
+
480
+ with tqdm(total=plan.total_batches, desc="Executing simulation plan") as p_bar:
481
+ for simulator_name, batches in simulator_batches.items():
482
+ logger.info("Starting simulator: %s with %d batches", simulator_name, len(batches))
483
+
484
+ # Create ONE simulator instance for all batches of this simulator
485
+ # Extract global simulator arguments from the first batch's global config
486
+ global_sim_args = batches[0].globals_config.simulator_arguments if batches else {}
487
+ simulator = instantiate_simulator(batches[0].simulator_config, simulator_name, global_sim_args)
488
+
489
+ # Process batches sequentially, maintaining state across them
490
+ for batch_idx, batch in enumerate(batches):
491
+ # Skip batches that were already completed (for resumption after interrupt)
492
+ if checkpoint_manager.should_skip_batch(batch.batch_index):
493
+ logger.info(
494
+ "Skipping batch %d (already completed from checkpoint)",
495
+ batch.batch_index,
496
+ )
497
+ continue
498
+
499
+ try:
500
+ logger.debug(
501
+ "Executing batch %d/%d for simulator %s",
502
+ batch_idx + 1,
503
+ len(batches),
504
+ simulator_name,
505
+ )
506
+
507
+ # Capture pre-batch state first for potential retries
508
+ logger.debug(
509
+ "[EXECUTE] Batch %s: Before restore - counter=%s, has_state_snapshot=%s",
510
+ batch.batch_index,
511
+ simulator.counter,
512
+ batch.has_state_snapshot(),
513
+ )
514
+ restore_batch_state(simulator, batch)
515
+ logger.debug("[EXECUTE] Batch %s: After restore - counter=%s", batch.batch_index, simulator.counter)
516
+ pre_batch_state = copy.deepcopy(simulator.state)
517
+ logger.debug(
518
+ "[EXECUTE] Batch %s: Captured pre_batch_state - keys=%s",
519
+ batch.batch_index,
520
+ list(pre_batch_state.keys()),
521
+ )
522
+
523
+ def execute_batch(
524
+ _simulator=simulator,
525
+ _batch=batch,
526
+ _output_directory=output_directory,
527
+ _pre_batch_state=pre_batch_state,
528
+ ):
529
+ """Execute a single batch with state management."""
530
+ # Generate data by calling next() - this advances simulator state
531
+ logger.debug("[BATCH] %s: Before next() - counter=%s", _batch.batch_index, _simulator.counter)
532
+ batch_data = _simulator.simulate()
533
+ logger.debug("[BATCH] %s: After next() - counter=%s", _batch.batch_index, _simulator.counter)
534
+
535
+ # Save the generated data and get all output file paths
536
+ output_files = process_batch(
537
+ simulator=_simulator,
538
+ batch_data=batch_data,
539
+ batch=_batch,
540
+ output_directory=_output_directory,
541
+ overwrite=overwrite,
542
+ )
543
+
544
+ # Only save metadata if data save succeeded
545
+ # This ensures metadata only exists for successfully saved data
546
+ save_batch_metadata(
547
+ _simulator,
548
+ _batch,
549
+ metadata_directory,
550
+ output_files,
551
+ pre_batch_state=_pre_batch_state,
552
+ )
553
+ # Update the state after successful save
554
+ _simulator.update_state()
555
+
556
+ def restore_state_for_retry(_simulator=simulator, _pre_batch_state=pre_batch_state):
557
+ """Restore simulator state to pre-batch state before retry."""
558
+ _simulator.state = copy.deepcopy(_pre_batch_state)
559
+
560
+ # Execute batch with retry mechanism that restores state on failure
561
+ retry_with_backoff(
562
+ execute_batch,
563
+ max_retries=max_retries,
564
+ state_restore_func=restore_state_for_retry,
565
+ )
566
+
567
+ # After successful completion, save checkpoint with updated state
568
+ # At this point, state has been incremented by next() -> update_state()
569
+ # Save checkpoint to enable resumption if interrupted before next batch
570
+ completed_batch_indices.add(batch.batch_index)
571
+ checkpoint_manager.save_checkpoint(
572
+ completed_batch_indices=sorted(completed_batch_indices),
573
+ last_simulator_name=simulator_name,
574
+ last_completed_batch_index=batch.batch_index,
575
+ last_simulator_state=copy.deepcopy(simulator.state),
576
+ )
577
+ logger.debug(
578
+ "Checkpoint saved after batch %d - state counter=%s",
579
+ batch.batch_index,
580
+ simulator.counter,
581
+ )
582
+ p_bar.update(1)
583
+
584
+ except Exception as e:
585
+ logger.error(
586
+ "Failed to execute batch %d for simulator %s after %d retries: %s",
587
+ batch.batch_index,
588
+ simulator_name,
589
+ max_retries,
590
+ e,
591
+ )
592
+ raise
593
+
594
+ # All batches completed successfully - clean up checkpoint files
595
+ checkpoint_manager.cleanup()
596
+ logger.info("All batches completed successfully. Checkpoint files cleaned up.")
@@ -0,0 +1,85 @@
1
+ """Utility functions for the gwsim CLI."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import base64
6
+
7
+ import numpy as np
8
+ import yaml
9
+ from astropy.units import Quantity
10
+
11
+
12
+ def represent_quantity(dumper: yaml.SafeDumper, obj: Quantity) -> yaml.nodes.MappingNode:
13
+ """Represent Quantity for YAML serialization.
14
+
15
+ Args:
16
+ dumper: YAML dumper.
17
+ obj: Quantity object to represent.
18
+
19
+ Returns:
20
+ YAML node representing the Quantity.
21
+ """
22
+ return dumper.represent_mapping("!Quantity", {"value": float(obj.value), "unit": str(obj.unit)})
23
+
24
+
25
+ def construct_quantity(loader: yaml.Loader, node: yaml.MappingNode) -> Quantity:
26
+ """Construct Quantity from YAML representation.
27
+
28
+ Args:
29
+ loader: YAML loader.
30
+ node: YAML node to construct from.
31
+
32
+ Returns:
33
+ Quantity object.
34
+ """
35
+ data = loader.construct_mapping(node)
36
+ return Quantity(data["value"], data["unit"])
37
+
38
+
39
+ yaml.SafeDumper.add_multi_representer(Quantity, represent_quantity)
40
+ yaml.SafeLoader.add_constructor("!Quantity", construct_quantity)
41
+
42
+
43
+ def represent_numpy_array(dumper: yaml.SafeDumper, obj: np.ndarray) -> yaml.nodes.MappingNode:
44
+ """Represent numpy array for YAML serialization.
45
+
46
+ Args:
47
+ dumper: YAML dumper.
48
+ obj: Numpy array to represent.
49
+
50
+ Returns:
51
+ YAML node representing the numpy array.
52
+ """
53
+ bytes_data = obj.tobytes()
54
+ encoded_data = base64.b64encode(bytes_data).decode("ascii")
55
+ data = {
56
+ "data": encoded_data,
57
+ "dtype": str(obj.dtype),
58
+ "shape": list(obj.shape),
59
+ "encoding": "base64",
60
+ }
61
+ return dumper.represent_mapping("!ndarray", data)
62
+
63
+
64
+ def construct_numpy_array(loader: yaml.Loader, node: yaml.MappingNode) -> np.ndarray:
65
+ """Construct numpy array from YAML representation.
66
+
67
+ Args:
68
+ loader: YAML loader.
69
+ node: YAML node to construct from.
70
+
71
+ Returns:
72
+ Numpy array.
73
+ """
74
+ data = loader.construct_mapping(node)
75
+ if data.get("encoding") != "base64":
76
+ raise ValueError("Expected base64 encoding in YAML data")
77
+ dtype = np.dtype(data["dtype"])
78
+ shape = tuple(data["shape"])
79
+ decoded_bytes = base64.b64decode(data["data"])
80
+ array = np.frombuffer(decoded_bytes, dtype=dtype).reshape(shape)
81
+ return array
82
+
83
+
84
+ yaml.SafeDumper.add_multi_representer(np.ndarray, represent_numpy_array)
85
+ yaml.SafeLoader.add_constructor("!ndarray", construct_numpy_array)