deepextractor 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepextractor/__init__.py +28 -0
- deepextractor/_version.py +24 -0
- deepextractor/api.py +60 -0
- deepextractor/generation/__init__.py +1 -0
- deepextractor/generation/generate_spectrograms.py +158 -0
- deepextractor/generation/generate_timeseries.py +162 -0
- deepextractor/generation/glitch_functions.py +145 -0
- deepextractor/model.py +155 -0
- deepextractor/models/__init__.py +19 -0
- deepextractor/models/architectures.py +314 -0
- deepextractor/py.typed +0 -0
- deepextractor/training/__init__.py +1 -0
- deepextractor/training/train_fn.py +60 -0
- deepextractor/training/trainer.py +265 -0
- deepextractor/utils/__init__.py +45 -0
- deepextractor/utils/checkpoints.py +102 -0
- deepextractor/utils/io.py +165 -0
- deepextractor/utils/metrics.py +63 -0
- deepextractor/utils/signal.py +85 -0
- deepextractor/utils/stft.py +64 -0
- deepextractor/utils/visualization.py +121 -0
- deepextractor-0.1.0.dist-info/METADATA +200 -0
- deepextractor-0.1.0.dist-info/RECORD +27 -0
- deepextractor-0.1.0.dist-info/WHEEL +5 -0
- deepextractor-0.1.0.dist-info/entry_points.txt +5 -0
- deepextractor-0.1.0.dist-info/licenses/LICENSE +21 -0
- deepextractor-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DeepExtractor: Deep learning for gravitational-wave glitch reconstruction.
|
|
3
|
+
|
|
4
|
+
Quick usage::
|
|
5
|
+
|
|
6
|
+
import deepextractor
|
|
7
|
+
reconstructed = deepextractor.reconstruct(noisy_strain)
|
|
8
|
+
|
|
9
|
+
# Or with explicit model control:
|
|
10
|
+
model = deepextractor.DeepExtractorModel(checkpoint="DeepExtractor_257")
|
|
11
|
+
signal = model.reconstruct(noisy_strain)
|
|
12
|
+
background = model.background(noisy_strain)
|
|
13
|
+
|
|
14
|
+
Paper: https://arxiv.org/abs/2501.18423
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from importlib.metadata import PackageNotFoundError, version
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
__version__ = version("deepextractor")
|
|
21
|
+
except PackageNotFoundError:
|
|
22
|
+
# Running from source without pip install -e .
|
|
23
|
+
__version__ = "0.0.0.dev"
|
|
24
|
+
|
|
25
|
+
from deepextractor.model import DeepExtractorModel
|
|
26
|
+
from deepextractor.api import extract, reconstruct
|
|
27
|
+
|
|
28
|
+
__all__ = ["__version__", "DeepExtractorModel", "reconstruct", "extract"]
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# file generated by vcs-versioning
|
|
2
|
+
# don't change, don't track in version control
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"__version__",
|
|
7
|
+
"__version_tuple__",
|
|
8
|
+
"version",
|
|
9
|
+
"version_tuple",
|
|
10
|
+
"__commit_id__",
|
|
11
|
+
"commit_id",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
version: str
|
|
15
|
+
__version__: str
|
|
16
|
+
__version_tuple__: tuple[int | str, ...]
|
|
17
|
+
version_tuple: tuple[int | str, ...]
|
|
18
|
+
commit_id: str | None
|
|
19
|
+
__commit_id__: str | None
|
|
20
|
+
|
|
21
|
+
__version__ = version = '0.1.0'
|
|
22
|
+
__version_tuple__ = version_tuple = (0, 1, 0)
|
|
23
|
+
|
|
24
|
+
__commit_id__ = commit_id = None
|
deepextractor/api.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Top-level convenience functions for DeepExtractor inference.
|
|
3
|
+
|
|
4
|
+
For one-shot use. For repeated inference on many signals, instantiate
|
|
5
|
+
:class:`DeepExtractorModel` directly to amortise the model load cost.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
from deepextractor.model import DeepExtractorModel
|
|
11
|
+
from deepextractor.utils.checkpoints import CHECKPOINT_BILBY
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def reconstruct(
|
|
15
|
+
noisy_input: np.ndarray,
|
|
16
|
+
checkpoint: str = "DeepExtractor_257",
|
|
17
|
+
checkpoint_filename: str = CHECKPOINT_BILBY,
|
|
18
|
+
checkpoint_dir: str | None = None,
|
|
19
|
+
device: str | None = None,
|
|
20
|
+
scaler_path: str | None = None,
|
|
21
|
+
) -> np.ndarray:
|
|
22
|
+
"""
|
|
23
|
+
Extract the transient signal from a noisy gravitational-wave strain.
|
|
24
|
+
|
|
25
|
+
Loads a DeepExtractor model, runs inference, and returns the reconstructed
|
|
26
|
+
signal. For repeated calls, prefer instantiating :class:`DeepExtractorModel`
|
|
27
|
+
directly to avoid reloading weights on each call.
|
|
28
|
+
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
noisy_input : np.ndarray
|
|
32
|
+
1-D array of shape ``(T,)`` or 2-D batch of shape ``(N, T)``.
|
|
33
|
+
checkpoint : str
|
|
34
|
+
Model name. Default ``"DeepExtractor_257"``.
|
|
35
|
+
checkpoint_filename : str
|
|
36
|
+
Checkpoint filename. Defaults to the bilby-noise checkpoint.
|
|
37
|
+
checkpoint_dir : str | None
|
|
38
|
+
Local checkpoint directory. Falls back to HuggingFace Hub if None.
|
|
39
|
+
device : str | None
|
|
40
|
+
Torch device string. Auto-detected if None.
|
|
41
|
+
scaler_path : str | None
|
|
42
|
+
Path to scaler .pkl. Uses bundled asset if None.
|
|
43
|
+
|
|
44
|
+
Returns
|
|
45
|
+
-------
|
|
46
|
+
np.ndarray
|
|
47
|
+
Reconstructed signal, same shape as ``noisy_input``.
|
|
48
|
+
"""
|
|
49
|
+
model = DeepExtractorModel(
|
|
50
|
+
checkpoint=checkpoint,
|
|
51
|
+
checkpoint_filename=checkpoint_filename,
|
|
52
|
+
checkpoint_dir=checkpoint_dir,
|
|
53
|
+
device=device,
|
|
54
|
+
scaler_path=scaler_path,
|
|
55
|
+
)
|
|
56
|
+
return model.reconstruct(noisy_input)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
# `extract` and `reconstruct` are synonyms at the API level.
|
|
60
|
+
extract = reconstruct
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Synthetic glitch signal generators and data generation scripts."""
|
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Convert time-domain .npy arrays to STFT spectrograms (magnitude + phase).
|
|
3
|
+
|
|
4
|
+
Also provides a utility to concatenate chunked spectrogram files.
|
|
5
|
+
|
|
6
|
+
Usage::
|
|
7
|
+
|
|
8
|
+
deepextractor-specgen --input-dir data/pycbc_noise/time_domain/ --output-dir data/pycbc_noise/spectrogram_domain/
|
|
9
|
+
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import argparse
|
|
13
|
+
import os
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
import torch
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# Default STFT parameters (257x257 output shape)
|
|
20
|
+
DEFAULT_N_FFT = 256 * 2
|
|
21
|
+
DEFAULT_WIN_LENGTH = DEFAULT_N_FFT // 8
|
|
22
|
+
DEFAULT_HOP_LENGTH = DEFAULT_WIN_LENGTH // 2
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def apply_stft_and_save(
|
|
26
|
+
array_path, save_path, n_fft, hop_length, win_length, window, chunk_size=5000
|
|
27
|
+
):
|
|
28
|
+
"""Apply STFT to a .npy array in chunks and save the result."""
|
|
29
|
+
array = np.load(array_path)
|
|
30
|
+
print(f"Loaded {array_path}, shape: {array.shape}")
|
|
31
|
+
|
|
32
|
+
total_chunks = array.shape[0] // chunk_size
|
|
33
|
+
stft_list = []
|
|
34
|
+
|
|
35
|
+
for i in range(0, array.shape[0], chunk_size):
|
|
36
|
+
chunk = array[i : i + chunk_size]
|
|
37
|
+
tensor = torch.tensor(chunk, dtype=torch.float32)
|
|
38
|
+
stft_result = torch.stft(
|
|
39
|
+
tensor,
|
|
40
|
+
n_fft=n_fft,
|
|
41
|
+
hop_length=hop_length,
|
|
42
|
+
win_length=win_length,
|
|
43
|
+
window=window,
|
|
44
|
+
return_complex=True,
|
|
45
|
+
)
|
|
46
|
+
magnitude = torch.abs(stft_result)
|
|
47
|
+
phase = torch.angle(stft_result)
|
|
48
|
+
stft_mag_phase = torch.stack([magnitude, phase], dim=1)
|
|
49
|
+
stft_list.append(stft_mag_phase)
|
|
50
|
+
|
|
51
|
+
del tensor, stft_result, magnitude, phase
|
|
52
|
+
torch.cuda.empty_cache()
|
|
53
|
+
|
|
54
|
+
print(f"Processed chunk {i // chunk_size + 1}/{max(total_chunks, 1)}")
|
|
55
|
+
|
|
56
|
+
stft_final = torch.cat(stft_list, dim=0)
|
|
57
|
+
stft_numpy = stft_final.cpu().numpy()
|
|
58
|
+
np.save(save_path, stft_numpy)
|
|
59
|
+
print(f"STFT saved to {save_path}.npy, final shape: {stft_numpy.shape}")
|
|
60
|
+
|
|
61
|
+
del array, stft_list, stft_final, stft_numpy
|
|
62
|
+
torch.cuda.empty_cache()
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def load_and_concatenate_chunks(data_dir, base_filename, total_chunks):
|
|
66
|
+
"""Load and concatenate chunked numpy arrays saved as ``{base}_chunk_{i}.npy``."""
|
|
67
|
+
stft_list = []
|
|
68
|
+
for i in range(total_chunks):
|
|
69
|
+
chunk_filename = f"{base_filename}_chunk_{i}.npy"
|
|
70
|
+
chunk_path = os.path.join(data_dir, chunk_filename)
|
|
71
|
+
if os.path.exists(chunk_path):
|
|
72
|
+
print(f"Loading {chunk_filename}...")
|
|
73
|
+
stft_list.append(np.load(chunk_path))
|
|
74
|
+
else:
|
|
75
|
+
print(f"Chunk {chunk_filename} not found. Skipping.")
|
|
76
|
+
print("Concatenating chunks...")
|
|
77
|
+
return np.concatenate(stft_list, axis=0)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def main():
|
|
81
|
+
parser = argparse.ArgumentParser(
|
|
82
|
+
description="Convert time-domain .npy arrays to STFT spectrogram arrays",
|
|
83
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
84
|
+
)
|
|
85
|
+
parser.add_argument(
|
|
86
|
+
"--input-dir",
|
|
87
|
+
type=str,
|
|
88
|
+
required=True,
|
|
89
|
+
help="Directory containing the time-domain .npy files.",
|
|
90
|
+
)
|
|
91
|
+
parser.add_argument(
|
|
92
|
+
"--output-dir",
|
|
93
|
+
type=str,
|
|
94
|
+
required=True,
|
|
95
|
+
help="Directory to save the spectrogram .npy files.",
|
|
96
|
+
)
|
|
97
|
+
parser.add_argument("--n-fft", type=int, default=DEFAULT_N_FFT)
|
|
98
|
+
parser.add_argument("--win-length", type=int, default=DEFAULT_WIN_LENGTH)
|
|
99
|
+
parser.add_argument("--hop-length", type=int, default=DEFAULT_HOP_LENGTH)
|
|
100
|
+
parser.add_argument("--chunk-size", type=int, default=5000)
|
|
101
|
+
parser.add_argument(
|
|
102
|
+
"--combine-chunks",
|
|
103
|
+
action="store_true",
|
|
104
|
+
help="Combine pre-existing chunk files instead of generating new spectrograms.",
|
|
105
|
+
)
|
|
106
|
+
parser.add_argument(
|
|
107
|
+
"--chunks-glitch-train", type=int, default=16,
|
|
108
|
+
help="Number of chunks for glitch_train (used with --combine-chunks).",
|
|
109
|
+
)
|
|
110
|
+
parser.add_argument(
|
|
111
|
+
"--chunks-background-train", type=int, default=16,
|
|
112
|
+
help="Number of chunks for background_train (used with --combine-chunks).",
|
|
113
|
+
)
|
|
114
|
+
parser.add_argument(
|
|
115
|
+
"--chunks-glitch-val", type=int, default=2,
|
|
116
|
+
help="Number of chunks for glitch_val (used with --combine-chunks).",
|
|
117
|
+
)
|
|
118
|
+
parser.add_argument(
|
|
119
|
+
"--chunks-background-val", type=int, default=2,
|
|
120
|
+
help="Number of chunks for background_val (used with --combine-chunks).",
|
|
121
|
+
)
|
|
122
|
+
args = parser.parse_args()
|
|
123
|
+
|
|
124
|
+
os.makedirs(args.output_dir, exist_ok=True)
|
|
125
|
+
window = torch.hann_window(args.win_length)
|
|
126
|
+
|
|
127
|
+
if args.combine_chunks:
|
|
128
|
+
for base, n_chunks in [
|
|
129
|
+
("glitch_train_scaled_mag_phase", args.chunks_glitch_train),
|
|
130
|
+
("background_train_scaled_mag_phase", args.chunks_background_train),
|
|
131
|
+
("glitch_val_scaled_mag_phase", args.chunks_glitch_val),
|
|
132
|
+
("background_val_scaled_mag_phase", args.chunks_background_val),
|
|
133
|
+
]:
|
|
134
|
+
combined = load_and_concatenate_chunks(args.output_dir, base, n_chunks)
|
|
135
|
+
out_path = os.path.join(args.output_dir, f"{base}_combined.npy")
|
|
136
|
+
np.save(out_path, combined)
|
|
137
|
+
print(f"Saved combined {base} to {out_path}")
|
|
138
|
+
print("All combined datasets saved.")
|
|
139
|
+
else:
|
|
140
|
+
datasets = [
|
|
141
|
+
("glitch_train_scaled.npy", "glitch_train_scaled_mag_phase"),
|
|
142
|
+
("background_train_scaled.npy", "background_train_scaled_mag_phase"),
|
|
143
|
+
("glitch_val_scaled.npy", "glitch_val_scaled_mag_phase"),
|
|
144
|
+
("background_val_scaled.npy", "background_val_scaled_mag_phase"),
|
|
145
|
+
]
|
|
146
|
+
for in_name, out_name in datasets:
|
|
147
|
+
in_path = os.path.join(args.input_dir, in_name)
|
|
148
|
+
out_path = os.path.join(args.output_dir, out_name)
|
|
149
|
+
apply_stft_and_save(
|
|
150
|
+
in_path, out_path,
|
|
151
|
+
args.n_fft, args.hop_length, args.win_length, window,
|
|
152
|
+
args.chunk_size,
|
|
153
|
+
)
|
|
154
|
+
print("All STFT results saved.")
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
if __name__ == "__main__":
|
|
158
|
+
main()
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Generate synthetic time-domain training data.
|
|
3
|
+
|
|
4
|
+
Usage::
|
|
5
|
+
|
|
6
|
+
deepextractor-generate --output-dir data/ --num-train 250000 --bilby-noise
|
|
7
|
+
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import argparse
|
|
11
|
+
import os
|
|
12
|
+
import pickle
|
|
13
|
+
import random
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
from sklearn.preprocessing import StandardScaler
|
|
17
|
+
from tqdm import tqdm
|
|
18
|
+
|
|
19
|
+
from deepextractor.generation.glitch_functions import (
|
|
20
|
+
generate_chirp,
|
|
21
|
+
generate_gaussian_pulse,
|
|
22
|
+
generate_sine,
|
|
23
|
+
generate_sine_gaussian,
|
|
24
|
+
ringdown,
|
|
25
|
+
)
|
|
26
|
+
from deepextractor.utils.signal import whitened_snr_scaling
|
|
27
|
+
|
|
28
|
+
SAMPLE_RATE = 4096
|
|
29
|
+
T = 2.0
|
|
30
|
+
T_INJ = T / 2
|
|
31
|
+
LENGTH = int(T * SAMPLE_RATE)
|
|
32
|
+
SNR_MIN, SNR_MAX = 1, 250
|
|
33
|
+
MINIMUM_FREQUENCY = 20.0
|
|
34
|
+
SNR_SCALING_FACTOR_BILBY = 31.970149253731343
|
|
35
|
+
|
|
36
|
+
SIGNAL_TYPES = ["chirp", "sine", "sine_gaussian", "gaussian_pulse", "ringdown"]
|
|
37
|
+
SIGNAL_FUNCTION_MAP = {
|
|
38
|
+
"chirp": generate_chirp,
|
|
39
|
+
"sine": generate_sine,
|
|
40
|
+
"sine_gaussian": generate_sine_gaussian,
|
|
41
|
+
"gaussian_pulse": generate_gaussian_pulse,
|
|
42
|
+
"ringdown": ringdown,
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def generate_gaussian_noise(mean, std_dev, num_samples, sample_shape, bilby_noise=False,
|
|
47
|
+
sample_rate=SAMPLE_RATE, duration=T,
|
|
48
|
+
minimum_frequency=MINIMUM_FREQUENCY):
|
|
49
|
+
"""Generate Gaussian noise samples (pycbc or bilby)."""
|
|
50
|
+
if bilby_noise:
|
|
51
|
+
try:
|
|
52
|
+
import bilby
|
|
53
|
+
except ImportError as e:
|
|
54
|
+
raise ImportError(
|
|
55
|
+
"bilby is required for bilby noise generation. "
|
|
56
|
+
"Install it with: pip install deepextractor[generative]"
|
|
57
|
+
) from e
|
|
58
|
+
gaussian_noise_samples = []
|
|
59
|
+
for i in tqdm(range(num_samples), desc="Generating bilby noise..."):
|
|
60
|
+
ifos = bilby.gw.detector.InterferometerList(["L1"])
|
|
61
|
+
for ifo in ifos:
|
|
62
|
+
ifo.minimum_frequency = minimum_frequency
|
|
63
|
+
ifos.set_strain_data_from_power_spectral_densities(
|
|
64
|
+
sampling_frequency=sample_rate,
|
|
65
|
+
duration=duration,
|
|
66
|
+
start_time=0,
|
|
67
|
+
)
|
|
68
|
+
white_time_domain_strain = list(ifos[0].whitened_time_domain_strain)
|
|
69
|
+
gaussian_noise_samples.append(white_time_domain_strain)
|
|
70
|
+
return np.asarray(gaussian_noise_samples)
|
|
71
|
+
else:
|
|
72
|
+
print("Generating pycbc noise...")
|
|
73
|
+
return np.random.normal(loc=mean, scale=std_dev, size=(num_samples, *sample_shape))
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def generate_synthetic_data(gaussian_noise_samples, bilby_noise=False, phase="train",
|
|
77
|
+
t_min=0.125, t_max=2.0, snr_min=SNR_MIN, snr_max=SNR_MAX):
|
|
78
|
+
"""Generate synthetic noisy glitch and background data arrays."""
|
|
79
|
+
noisy_glitch_ts = []
|
|
80
|
+
pure_noise_ts = []
|
|
81
|
+
|
|
82
|
+
for i in tqdm(range(len(gaussian_noise_samples)),
|
|
83
|
+
desc=f"Generating Synthetic {phase.capitalize()} Data"):
|
|
84
|
+
background = gaussian_noise_samples[i]
|
|
85
|
+
noisy_glitch = background.copy()
|
|
86
|
+
n_injs = np.random.randint(1, 30)
|
|
87
|
+
for _ in range(n_injs):
|
|
88
|
+
snr_to_scale = np.random.uniform(snr_min, snr_max)
|
|
89
|
+
if bilby_noise:
|
|
90
|
+
snr_to_scale = snr_to_scale / SNR_SCALING_FACTOR_BILBY
|
|
91
|
+
duration = np.random.uniform(t_min, t_max)
|
|
92
|
+
s_type = random.choice(SIGNAL_TYPES)
|
|
93
|
+
_, signal_injection = SIGNAL_FUNCTION_MAP[s_type](duration)
|
|
94
|
+
len_glitch = len(signal_injection)
|
|
95
|
+
id_start = int((T_INJ * SAMPLE_RATE / LENGTH) * len(background)) - len_glitch // 2
|
|
96
|
+
glitch = signal_injection - np.mean(signal_injection)
|
|
97
|
+
glitch = whitened_snr_scaling(glitch, snr=snr_to_scale)
|
|
98
|
+
shift_int = np.random.randint(-id_start, len(background) - id_start - len_glitch)
|
|
99
|
+
noisy_glitch[id_start + shift_int:id_start + len_glitch + shift_int] += glitch
|
|
100
|
+
|
|
101
|
+
noisy_glitch_ts.append(noisy_glitch)
|
|
102
|
+
pure_noise_ts.append(background)
|
|
103
|
+
|
|
104
|
+
noisy_glitch_ts = np.asarray(noisy_glitch_ts)
|
|
105
|
+
pure_noise_ts = np.asarray(pure_noise_ts)
|
|
106
|
+
|
|
107
|
+
mask = ~np.any(
|
|
108
|
+
np.isnan(noisy_glitch_ts) | np.isinf(noisy_glitch_ts)
|
|
109
|
+
| (np.abs(noisy_glitch_ts) > np.finfo(np.float64).max),
|
|
110
|
+
axis=1,
|
|
111
|
+
)
|
|
112
|
+
return noisy_glitch_ts[mask], pure_noise_ts[mask]
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def main():
|
|
116
|
+
parser = argparse.ArgumentParser(
|
|
117
|
+
description="Generate synthetic time-domain training data",
|
|
118
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
119
|
+
)
|
|
120
|
+
parser.add_argument("--output-dir", type=str, default="data/", help="Root output directory.")
|
|
121
|
+
parser.add_argument("--num-train", type=int, default=250000)
|
|
122
|
+
parser.add_argument("--num-val", type=int, default=25000)
|
|
123
|
+
parser.add_argument(
|
|
124
|
+
"--bilby-noise", action="store_true", help="Use bilby noise instead of pycbc."
|
|
125
|
+
)
|
|
126
|
+
args = parser.parse_args()
|
|
127
|
+
|
|
128
|
+
bilby_noise = args.bilby_noise
|
|
129
|
+
noise_ext = "bilby_noise/" if bilby_noise else "pycbc_noise/"
|
|
130
|
+
ext = "bilby" if bilby_noise else "pycbc"
|
|
131
|
+
noise_type_path = os.path.join(args.output_dir, noise_ext)
|
|
132
|
+
domain_path = os.path.join(noise_type_path, "time_domain")
|
|
133
|
+
os.makedirs(domain_path, exist_ok=True)
|
|
134
|
+
|
|
135
|
+
mean = 0
|
|
136
|
+
std_dev = np.sqrt(SAMPLE_RATE)
|
|
137
|
+
|
|
138
|
+
train_noise = generate_gaussian_noise(mean, std_dev, args.num_train, (LENGTH,), bilby_noise)
|
|
139
|
+
val_noise = generate_gaussian_noise(mean, std_dev, args.num_val, (LENGTH,), bilby_noise)
|
|
140
|
+
|
|
141
|
+
glitch_train, bg_train = generate_synthetic_data(train_noise, bilby_noise, "train")
|
|
142
|
+
glitch_val, bg_val = generate_synthetic_data(val_noise, bilby_noise, "val")
|
|
143
|
+
|
|
144
|
+
scaler = StandardScaler()
|
|
145
|
+
glitch_train_scaled = scaler.fit_transform(glitch_train.reshape(-1, 1)).reshape(glitch_train.shape)
|
|
146
|
+
bg_train_scaled = scaler.transform(bg_train.reshape(-1, 1)).reshape(bg_train.shape)
|
|
147
|
+
glitch_val_scaled = scaler.transform(glitch_val.reshape(-1, 1)).reshape(glitch_val.shape)
|
|
148
|
+
bg_val_scaled = scaler.transform(bg_val.reshape(-1, 1)).reshape(bg_val.shape)
|
|
149
|
+
|
|
150
|
+
with open(os.path.join(noise_type_path, f"scaler_{ext}.pkl"), "wb") as f:
|
|
151
|
+
pickle.dump(scaler, f)
|
|
152
|
+
|
|
153
|
+
np.save(os.path.join(domain_path, "glitch_train_scaled"), glitch_train_scaled)
|
|
154
|
+
np.save(os.path.join(domain_path, "background_train_scaled"), bg_train_scaled)
|
|
155
|
+
np.save(os.path.join(domain_path, "glitch_val_scaled"), glitch_val_scaled)
|
|
156
|
+
np.save(os.path.join(domain_path, "background_val_scaled"), bg_val_scaled)
|
|
157
|
+
|
|
158
|
+
print("Done. Data saved to", domain_path)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
if __name__ == "__main__":
|
|
162
|
+
main()
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Synthetic glitch signal generators.
|
|
3
|
+
|
|
4
|
+
The CDVGAN and gengli generators require optional dependencies:
|
|
5
|
+
pip install deepextractor[generative]
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from scipy.signal import chirp, gausspulse
|
|
10
|
+
|
|
11
|
+
from deepextractor.utils.signal import quality_factor_conversion, rescale
|
|
12
|
+
|
|
13
|
+
SRATE = 4096
|
|
14
|
+
NYQUIST_FREQ = SRATE // 2
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def generate_chirp(duration, sample_rate=4096, f0_min=1, f0_max=NYQUIST_FREQ,
|
|
18
|
+
f1_min=1, f1_max=NYQUIST_FREQ):
|
|
19
|
+
f0 = np.random.uniform(f0_min, f0_max)
|
|
20
|
+
f1 = np.random.uniform(f1_min, f1_max)
|
|
21
|
+
t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
|
|
22
|
+
signal = chirp(t, f0=f0, f1=f1, t1=duration, method="linear")
|
|
23
|
+
return t, signal
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def generate_sine(duration, sample_rate=4096, freq_min=1, freq_max=NYQUIST_FREQ):
|
|
27
|
+
frequency = np.random.uniform(freq_min, freq_max)
|
|
28
|
+
t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
|
|
29
|
+
signal = np.sin(2 * np.pi * frequency * t)
|
|
30
|
+
return t, signal
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def generate_sine_gaussian(duration, sample_rate=4096, freq_min=1, freq_max=NYQUIST_FREQ):
|
|
34
|
+
tau = np.random.uniform(duration / 200, duration / 4)
|
|
35
|
+
frequency = np.random.uniform(freq_min, freq_max)
|
|
36
|
+
t = np.linspace(-duration / 2, duration / 2, int(sample_rate * duration), endpoint=False)
|
|
37
|
+
signal = np.sin(2 * np.pi * frequency * t) * np.exp(-(t**2) / (2 * tau**2))
|
|
38
|
+
return t, signal
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def generate_gaussian_pulse(duration, sample_rate=4096, fc_min=1, fc_max=NYQUIST_FREQ,
|
|
42
|
+
bw_min=0.1, bw_max=1.0, bwr_min=-10, bwr_max=0,
|
|
43
|
+
tpr_min=0.5, tpr_max=2.0):
|
|
44
|
+
"""
|
|
45
|
+
Generate a Gaussian pulse with random parameters.
|
|
46
|
+
|
|
47
|
+
Parameters
|
|
48
|
+
----------
|
|
49
|
+
duration : float
|
|
50
|
+
Duration in seconds.
|
|
51
|
+
sample_rate : int
|
|
52
|
+
Sampling rate in Hz.
|
|
53
|
+
fc_min, fc_max : float
|
|
54
|
+
Range for the center frequency (Hz).
|
|
55
|
+
bw_min, bw_max : float
|
|
56
|
+
Range for the fractional bandwidth.
|
|
57
|
+
bwr_min, bwr_max : float
|
|
58
|
+
Range for the bandwidth reference level (dB).
|
|
59
|
+
tpr_min, tpr_max : float
|
|
60
|
+
Range for the taper reference level (dB).
|
|
61
|
+
"""
|
|
62
|
+
bw = np.random.uniform(bw_min, bw_max)
|
|
63
|
+
fc = np.random.uniform(fc_min, fc_max)
|
|
64
|
+
bwr = np.random.uniform(bwr_min, bwr_max)
|
|
65
|
+
tpr = np.random.uniform(tpr_min, tpr_max)
|
|
66
|
+
t = np.linspace(-duration / 2, duration / 2, int(sample_rate * duration), endpoint=False)
|
|
67
|
+
signal = gausspulse(t, fc=fc, bw=bw, bwr=bwr, tpr=tpr)
|
|
68
|
+
return t, signal
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def ringdown(duration, sample_rate=4096, n_signals=1):
|
|
72
|
+
t = np.linspace(0, duration, int(sample_rate * duration))
|
|
73
|
+
phi = np.random.uniform(0, 2 * np.pi)
|
|
74
|
+
A = 1.0
|
|
75
|
+
f_0 = np.random.uniform(10, NYQUIST_FREQ, n_signals)
|
|
76
|
+
t_0 = np.random.uniform(t[-1] / 4, 3 * t[-1] / 4, n_signals)
|
|
77
|
+
Q = np.random.uniform(5, 150, n_signals)
|
|
78
|
+
tau = np.maximum(quality_factor_conversion(Q, f_0), 0.01)
|
|
79
|
+
f_0 = np.expand_dims(f_0, axis=1)
|
|
80
|
+
t_0 = np.expand_dims(t_0, axis=1)
|
|
81
|
+
tau = np.expand_dims(tau, axis=1)
|
|
82
|
+
h_1 = A * np.exp(-1.0 * ((t - t_0) / (tau))) * np.sin(2 * np.pi * f_0 * (t - t_0) + phi)
|
|
83
|
+
h_1 = ((t - t_0) > 0) * h_1
|
|
84
|
+
h_1 = rescale(h_1)
|
|
85
|
+
if np.random.rand() < 0.5:
|
|
86
|
+
h_1 = np.flip(h_1)
|
|
87
|
+
return t, h_1[0]
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def generate_gengli_glitch(ifo):
|
|
91
|
+
"""
|
|
92
|
+
Generate a glitch sample using the gengli library.
|
|
93
|
+
|
|
94
|
+
Requires the ``[generative]`` optional dependencies:
|
|
95
|
+
``pip install deepextractor[generative]``
|
|
96
|
+
"""
|
|
97
|
+
try:
|
|
98
|
+
import gengli
|
|
99
|
+
except ImportError as e:
|
|
100
|
+
raise ImportError(
|
|
101
|
+
"gengli is required for this function. "
|
|
102
|
+
"Install it with: pip install deepextractor[generative]"
|
|
103
|
+
) from e
|
|
104
|
+
g = gengli.glitch_generator(ifo)
|
|
105
|
+
glitch = g.get_glitch(1, srate=4096, snr=10, alpha=0.2, fhigh=1024)
|
|
106
|
+
return None, glitch
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def generate_cdvgan_glitch(gtype, cdvgan_generator):
|
|
110
|
+
"""
|
|
111
|
+
Generate a glitch sample using a pretrained CDVGAN TensorFlow model.
|
|
112
|
+
|
|
113
|
+
Parameters
|
|
114
|
+
----------
|
|
115
|
+
gtype : str
|
|
116
|
+
Glitch type: one of ``'blip'``, ``'tomte'``, ``'bbh'``,
|
|
117
|
+
``'simplex'``, ``'uniform'``.
|
|
118
|
+
cdvgan_generator : tf.keras.Model
|
|
119
|
+
The loaded CDVGAN generator model.
|
|
120
|
+
"""
|
|
121
|
+
try:
|
|
122
|
+
import tensorflow as tf
|
|
123
|
+
except ImportError as e:
|
|
124
|
+
raise ImportError(
|
|
125
|
+
"TensorFlow is required for CDVGAN glitches. "
|
|
126
|
+
"Install it with: pip install deepextractor[generative]"
|
|
127
|
+
) from e
|
|
128
|
+
|
|
129
|
+
latent_dim = 100
|
|
130
|
+
random_ints = np.random.randint(0, 100, size=(1, 3))
|
|
131
|
+
simplex_classes = random_ints / np.sum(random_ints, axis=1).reshape(1, 1)
|
|
132
|
+
uniform_classes = np.random.uniform(low=0.0, high=1.0, size=(1, 3))
|
|
133
|
+
|
|
134
|
+
class_vector_map = {
|
|
135
|
+
"blip": [[1, 0, 0]],
|
|
136
|
+
"tomte": [[0, 1, 0]],
|
|
137
|
+
"bbh": [[0, 0, 1]],
|
|
138
|
+
"simplex": simplex_classes,
|
|
139
|
+
"uniform": uniform_classes,
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
class_vector = np.array(class_vector_map[gtype])
|
|
143
|
+
latent_vector = tf.random.normal(shape=(1, latent_dim))
|
|
144
|
+
generated_glitch = cdvgan_generator([latent_vector, class_vector]).numpy()
|
|
145
|
+
return None, generated_glitch
|