deepextractor 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,28 @@
1
+ """
2
+ DeepExtractor: Deep learning for gravitational-wave glitch reconstruction.
3
+
4
+ Quick usage::
5
+
6
+ import deepextractor
7
+ reconstructed = deepextractor.reconstruct(noisy_strain)
8
+
9
+ # Or with explicit model control:
10
+ model = deepextractor.DeepExtractorModel(checkpoint="DeepExtractor_257")
11
+ signal = model.reconstruct(noisy_strain)
12
+ background = model.background(noisy_strain)
13
+
14
+ Paper: https://arxiv.org/abs/2501.18423
15
+ """
16
+
17
+ from importlib.metadata import PackageNotFoundError, version
18
+
19
+ try:
20
+ __version__ = version("deepextractor")
21
+ except PackageNotFoundError:
22
+ # Running from source without pip install -e .
23
+ __version__ = "0.0.0.dev"
24
+
25
+ from deepextractor.model import DeepExtractorModel
26
+ from deepextractor.api import extract, reconstruct
27
+
28
+ __all__ = ["__version__", "DeepExtractorModel", "reconstruct", "extract"]
@@ -0,0 +1,24 @@
1
+ # file generated by vcs-versioning
2
+ # don't change, don't track in version control
3
+ from __future__ import annotations
4
+
5
+ __all__ = [
6
+ "__version__",
7
+ "__version_tuple__",
8
+ "version",
9
+ "version_tuple",
10
+ "__commit_id__",
11
+ "commit_id",
12
+ ]
13
+
14
+ version: str
15
+ __version__: str
16
+ __version_tuple__: tuple[int | str, ...]
17
+ version_tuple: tuple[int | str, ...]
18
+ commit_id: str | None
19
+ __commit_id__: str | None
20
+
21
+ __version__ = version = '0.1.0'
22
+ __version_tuple__ = version_tuple = (0, 1, 0)
23
+
24
+ __commit_id__ = commit_id = None
deepextractor/api.py ADDED
@@ -0,0 +1,60 @@
1
+ """
2
+ Top-level convenience functions for DeepExtractor inference.
3
+
4
+ For one-shot use. For repeated inference on many signals, instantiate
5
+ :class:`DeepExtractorModel` directly to amortise the model load cost.
6
+ """
7
+
8
+ import numpy as np
9
+
10
+ from deepextractor.model import DeepExtractorModel
11
+ from deepextractor.utils.checkpoints import CHECKPOINT_BILBY
12
+
13
+
14
+ def reconstruct(
15
+ noisy_input: np.ndarray,
16
+ checkpoint: str = "DeepExtractor_257",
17
+ checkpoint_filename: str = CHECKPOINT_BILBY,
18
+ checkpoint_dir: str | None = None,
19
+ device: str | None = None,
20
+ scaler_path: str | None = None,
21
+ ) -> np.ndarray:
22
+ """
23
+ Extract the transient signal from a noisy gravitational-wave strain.
24
+
25
+ Loads a DeepExtractor model, runs inference, and returns the reconstructed
26
+ signal. For repeated calls, prefer instantiating :class:`DeepExtractorModel`
27
+ directly to avoid reloading weights on each call.
28
+
29
+ Parameters
30
+ ----------
31
+ noisy_input : np.ndarray
32
+ 1-D array of shape ``(T,)`` or 2-D batch of shape ``(N, T)``.
33
+ checkpoint : str
34
+ Model name. Default ``"DeepExtractor_257"``.
35
+ checkpoint_filename : str
36
+ Checkpoint filename. Defaults to the bilby-noise checkpoint.
37
+ checkpoint_dir : str | None
38
+ Local checkpoint directory. Falls back to HuggingFace Hub if None.
39
+ device : str | None
40
+ Torch device string. Auto-detected if None.
41
+ scaler_path : str | None
42
+ Path to scaler .pkl. Uses bundled asset if None.
43
+
44
+ Returns
45
+ -------
46
+ np.ndarray
47
+ Reconstructed signal, same shape as ``noisy_input``.
48
+ """
49
+ model = DeepExtractorModel(
50
+ checkpoint=checkpoint,
51
+ checkpoint_filename=checkpoint_filename,
52
+ checkpoint_dir=checkpoint_dir,
53
+ device=device,
54
+ scaler_path=scaler_path,
55
+ )
56
+ return model.reconstruct(noisy_input)
57
+
58
+
59
+ # `extract` and `reconstruct` are synonyms at the API level.
60
+ extract = reconstruct
@@ -0,0 +1 @@
1
+ """Synthetic glitch signal generators and data generation scripts."""
@@ -0,0 +1,158 @@
1
+ """
2
+ Convert time-domain .npy arrays to STFT spectrograms (magnitude + phase).
3
+
4
+ Also provides a utility to concatenate chunked spectrogram files.
5
+
6
+ Usage::
7
+
8
+ deepextractor-specgen --input-dir data/pycbc_noise/time_domain/ --output-dir data/pycbc_noise/spectrogram_domain/
9
+
10
+ """
11
+
12
+ import argparse
13
+ import os
14
+
15
+ import numpy as np
16
+ import torch
17
+
18
+
19
+ # Default STFT parameters (257x257 output shape)
20
+ DEFAULT_N_FFT = 256 * 2
21
+ DEFAULT_WIN_LENGTH = DEFAULT_N_FFT // 8
22
+ DEFAULT_HOP_LENGTH = DEFAULT_WIN_LENGTH // 2
23
+
24
+
25
+ def apply_stft_and_save(
26
+ array_path, save_path, n_fft, hop_length, win_length, window, chunk_size=5000
27
+ ):
28
+ """Apply STFT to a .npy array in chunks and save the result."""
29
+ array = np.load(array_path)
30
+ print(f"Loaded {array_path}, shape: {array.shape}")
31
+
32
+ total_chunks = array.shape[0] // chunk_size
33
+ stft_list = []
34
+
35
+ for i in range(0, array.shape[0], chunk_size):
36
+ chunk = array[i : i + chunk_size]
37
+ tensor = torch.tensor(chunk, dtype=torch.float32)
38
+ stft_result = torch.stft(
39
+ tensor,
40
+ n_fft=n_fft,
41
+ hop_length=hop_length,
42
+ win_length=win_length,
43
+ window=window,
44
+ return_complex=True,
45
+ )
46
+ magnitude = torch.abs(stft_result)
47
+ phase = torch.angle(stft_result)
48
+ stft_mag_phase = torch.stack([magnitude, phase], dim=1)
49
+ stft_list.append(stft_mag_phase)
50
+
51
+ del tensor, stft_result, magnitude, phase
52
+ torch.cuda.empty_cache()
53
+
54
+ print(f"Processed chunk {i // chunk_size + 1}/{max(total_chunks, 1)}")
55
+
56
+ stft_final = torch.cat(stft_list, dim=0)
57
+ stft_numpy = stft_final.cpu().numpy()
58
+ np.save(save_path, stft_numpy)
59
+ print(f"STFT saved to {save_path}.npy, final shape: {stft_numpy.shape}")
60
+
61
+ del array, stft_list, stft_final, stft_numpy
62
+ torch.cuda.empty_cache()
63
+
64
+
65
+ def load_and_concatenate_chunks(data_dir, base_filename, total_chunks):
66
+ """Load and concatenate chunked numpy arrays saved as ``{base}_chunk_{i}.npy``."""
67
+ stft_list = []
68
+ for i in range(total_chunks):
69
+ chunk_filename = f"{base_filename}_chunk_{i}.npy"
70
+ chunk_path = os.path.join(data_dir, chunk_filename)
71
+ if os.path.exists(chunk_path):
72
+ print(f"Loading {chunk_filename}...")
73
+ stft_list.append(np.load(chunk_path))
74
+ else:
75
+ print(f"Chunk {chunk_filename} not found. Skipping.")
76
+ print("Concatenating chunks...")
77
+ return np.concatenate(stft_list, axis=0)
78
+
79
+
80
+ def main():
81
+ parser = argparse.ArgumentParser(
82
+ description="Convert time-domain .npy arrays to STFT spectrogram arrays",
83
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
84
+ )
85
+ parser.add_argument(
86
+ "--input-dir",
87
+ type=str,
88
+ required=True,
89
+ help="Directory containing the time-domain .npy files.",
90
+ )
91
+ parser.add_argument(
92
+ "--output-dir",
93
+ type=str,
94
+ required=True,
95
+ help="Directory to save the spectrogram .npy files.",
96
+ )
97
+ parser.add_argument("--n-fft", type=int, default=DEFAULT_N_FFT)
98
+ parser.add_argument("--win-length", type=int, default=DEFAULT_WIN_LENGTH)
99
+ parser.add_argument("--hop-length", type=int, default=DEFAULT_HOP_LENGTH)
100
+ parser.add_argument("--chunk-size", type=int, default=5000)
101
+ parser.add_argument(
102
+ "--combine-chunks",
103
+ action="store_true",
104
+ help="Combine pre-existing chunk files instead of generating new spectrograms.",
105
+ )
106
+ parser.add_argument(
107
+ "--chunks-glitch-train", type=int, default=16,
108
+ help="Number of chunks for glitch_train (used with --combine-chunks).",
109
+ )
110
+ parser.add_argument(
111
+ "--chunks-background-train", type=int, default=16,
112
+ help="Number of chunks for background_train (used with --combine-chunks).",
113
+ )
114
+ parser.add_argument(
115
+ "--chunks-glitch-val", type=int, default=2,
116
+ help="Number of chunks for glitch_val (used with --combine-chunks).",
117
+ )
118
+ parser.add_argument(
119
+ "--chunks-background-val", type=int, default=2,
120
+ help="Number of chunks for background_val (used with --combine-chunks).",
121
+ )
122
+ args = parser.parse_args()
123
+
124
+ os.makedirs(args.output_dir, exist_ok=True)
125
+ window = torch.hann_window(args.win_length)
126
+
127
+ if args.combine_chunks:
128
+ for base, n_chunks in [
129
+ ("glitch_train_scaled_mag_phase", args.chunks_glitch_train),
130
+ ("background_train_scaled_mag_phase", args.chunks_background_train),
131
+ ("glitch_val_scaled_mag_phase", args.chunks_glitch_val),
132
+ ("background_val_scaled_mag_phase", args.chunks_background_val),
133
+ ]:
134
+ combined = load_and_concatenate_chunks(args.output_dir, base, n_chunks)
135
+ out_path = os.path.join(args.output_dir, f"{base}_combined.npy")
136
+ np.save(out_path, combined)
137
+ print(f"Saved combined {base} to {out_path}")
138
+ print("All combined datasets saved.")
139
+ else:
140
+ datasets = [
141
+ ("glitch_train_scaled.npy", "glitch_train_scaled_mag_phase"),
142
+ ("background_train_scaled.npy", "background_train_scaled_mag_phase"),
143
+ ("glitch_val_scaled.npy", "glitch_val_scaled_mag_phase"),
144
+ ("background_val_scaled.npy", "background_val_scaled_mag_phase"),
145
+ ]
146
+ for in_name, out_name in datasets:
147
+ in_path = os.path.join(args.input_dir, in_name)
148
+ out_path = os.path.join(args.output_dir, out_name)
149
+ apply_stft_and_save(
150
+ in_path, out_path,
151
+ args.n_fft, args.hop_length, args.win_length, window,
152
+ args.chunk_size,
153
+ )
154
+ print("All STFT results saved.")
155
+
156
+
157
+ if __name__ == "__main__":
158
+ main()
@@ -0,0 +1,162 @@
1
+ """
2
+ Generate synthetic time-domain training data.
3
+
4
+ Usage::
5
+
6
+ deepextractor-generate --output-dir data/ --num-train 250000 --bilby-noise
7
+
8
+ """
9
+
10
+ import argparse
11
+ import os
12
+ import pickle
13
+ import random
14
+
15
+ import numpy as np
16
+ from sklearn.preprocessing import StandardScaler
17
+ from tqdm import tqdm
18
+
19
+ from deepextractor.generation.glitch_functions import (
20
+ generate_chirp,
21
+ generate_gaussian_pulse,
22
+ generate_sine,
23
+ generate_sine_gaussian,
24
+ ringdown,
25
+ )
26
+ from deepextractor.utils.signal import whitened_snr_scaling
27
+
28
+ SAMPLE_RATE = 4096
29
+ T = 2.0
30
+ T_INJ = T / 2
31
+ LENGTH = int(T * SAMPLE_RATE)
32
+ SNR_MIN, SNR_MAX = 1, 250
33
+ MINIMUM_FREQUENCY = 20.0
34
+ SNR_SCALING_FACTOR_BILBY = 31.970149253731343
35
+
36
+ SIGNAL_TYPES = ["chirp", "sine", "sine_gaussian", "gaussian_pulse", "ringdown"]
37
+ SIGNAL_FUNCTION_MAP = {
38
+ "chirp": generate_chirp,
39
+ "sine": generate_sine,
40
+ "sine_gaussian": generate_sine_gaussian,
41
+ "gaussian_pulse": generate_gaussian_pulse,
42
+ "ringdown": ringdown,
43
+ }
44
+
45
+
46
+ def generate_gaussian_noise(mean, std_dev, num_samples, sample_shape, bilby_noise=False,
47
+ sample_rate=SAMPLE_RATE, duration=T,
48
+ minimum_frequency=MINIMUM_FREQUENCY):
49
+ """Generate Gaussian noise samples (pycbc or bilby)."""
50
+ if bilby_noise:
51
+ try:
52
+ import bilby
53
+ except ImportError as e:
54
+ raise ImportError(
55
+ "bilby is required for bilby noise generation. "
56
+ "Install it with: pip install deepextractor[generative]"
57
+ ) from e
58
+ gaussian_noise_samples = []
59
+ for i in tqdm(range(num_samples), desc="Generating bilby noise..."):
60
+ ifos = bilby.gw.detector.InterferometerList(["L1"])
61
+ for ifo in ifos:
62
+ ifo.minimum_frequency = minimum_frequency
63
+ ifos.set_strain_data_from_power_spectral_densities(
64
+ sampling_frequency=sample_rate,
65
+ duration=duration,
66
+ start_time=0,
67
+ )
68
+ white_time_domain_strain = list(ifos[0].whitened_time_domain_strain)
69
+ gaussian_noise_samples.append(white_time_domain_strain)
70
+ return np.asarray(gaussian_noise_samples)
71
+ else:
72
+ print("Generating pycbc noise...")
73
+ return np.random.normal(loc=mean, scale=std_dev, size=(num_samples, *sample_shape))
74
+
75
+
76
+ def generate_synthetic_data(gaussian_noise_samples, bilby_noise=False, phase="train",
77
+ t_min=0.125, t_max=2.0, snr_min=SNR_MIN, snr_max=SNR_MAX):
78
+ """Generate synthetic noisy glitch and background data arrays."""
79
+ noisy_glitch_ts = []
80
+ pure_noise_ts = []
81
+
82
+ for i in tqdm(range(len(gaussian_noise_samples)),
83
+ desc=f"Generating Synthetic {phase.capitalize()} Data"):
84
+ background = gaussian_noise_samples[i]
85
+ noisy_glitch = background.copy()
86
+ n_injs = np.random.randint(1, 30)
87
+ for _ in range(n_injs):
88
+ snr_to_scale = np.random.uniform(snr_min, snr_max)
89
+ if bilby_noise:
90
+ snr_to_scale = snr_to_scale / SNR_SCALING_FACTOR_BILBY
91
+ duration = np.random.uniform(t_min, t_max)
92
+ s_type = random.choice(SIGNAL_TYPES)
93
+ _, signal_injection = SIGNAL_FUNCTION_MAP[s_type](duration)
94
+ len_glitch = len(signal_injection)
95
+ id_start = int((T_INJ * SAMPLE_RATE / LENGTH) * len(background)) - len_glitch // 2
96
+ glitch = signal_injection - np.mean(signal_injection)
97
+ glitch = whitened_snr_scaling(glitch, snr=snr_to_scale)
98
+ shift_int = np.random.randint(-id_start, len(background) - id_start - len_glitch)
99
+ noisy_glitch[id_start + shift_int:id_start + len_glitch + shift_int] += glitch
100
+
101
+ noisy_glitch_ts.append(noisy_glitch)
102
+ pure_noise_ts.append(background)
103
+
104
+ noisy_glitch_ts = np.asarray(noisy_glitch_ts)
105
+ pure_noise_ts = np.asarray(pure_noise_ts)
106
+
107
+ mask = ~np.any(
108
+ np.isnan(noisy_glitch_ts) | np.isinf(noisy_glitch_ts)
109
+ | (np.abs(noisy_glitch_ts) > np.finfo(np.float64).max),
110
+ axis=1,
111
+ )
112
+ return noisy_glitch_ts[mask], pure_noise_ts[mask]
113
+
114
+
115
+ def main():
116
+ parser = argparse.ArgumentParser(
117
+ description="Generate synthetic time-domain training data",
118
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
119
+ )
120
+ parser.add_argument("--output-dir", type=str, default="data/", help="Root output directory.")
121
+ parser.add_argument("--num-train", type=int, default=250000)
122
+ parser.add_argument("--num-val", type=int, default=25000)
123
+ parser.add_argument(
124
+ "--bilby-noise", action="store_true", help="Use bilby noise instead of pycbc."
125
+ )
126
+ args = parser.parse_args()
127
+
128
+ bilby_noise = args.bilby_noise
129
+ noise_ext = "bilby_noise/" if bilby_noise else "pycbc_noise/"
130
+ ext = "bilby" if bilby_noise else "pycbc"
131
+ noise_type_path = os.path.join(args.output_dir, noise_ext)
132
+ domain_path = os.path.join(noise_type_path, "time_domain")
133
+ os.makedirs(domain_path, exist_ok=True)
134
+
135
+ mean = 0
136
+ std_dev = np.sqrt(SAMPLE_RATE)
137
+
138
+ train_noise = generate_gaussian_noise(mean, std_dev, args.num_train, (LENGTH,), bilby_noise)
139
+ val_noise = generate_gaussian_noise(mean, std_dev, args.num_val, (LENGTH,), bilby_noise)
140
+
141
+ glitch_train, bg_train = generate_synthetic_data(train_noise, bilby_noise, "train")
142
+ glitch_val, bg_val = generate_synthetic_data(val_noise, bilby_noise, "val")
143
+
144
+ scaler = StandardScaler()
145
+ glitch_train_scaled = scaler.fit_transform(glitch_train.reshape(-1, 1)).reshape(glitch_train.shape)
146
+ bg_train_scaled = scaler.transform(bg_train.reshape(-1, 1)).reshape(bg_train.shape)
147
+ glitch_val_scaled = scaler.transform(glitch_val.reshape(-1, 1)).reshape(glitch_val.shape)
148
+ bg_val_scaled = scaler.transform(bg_val.reshape(-1, 1)).reshape(bg_val.shape)
149
+
150
+ with open(os.path.join(noise_type_path, f"scaler_{ext}.pkl"), "wb") as f:
151
+ pickle.dump(scaler, f)
152
+
153
+ np.save(os.path.join(domain_path, "glitch_train_scaled"), glitch_train_scaled)
154
+ np.save(os.path.join(domain_path, "background_train_scaled"), bg_train_scaled)
155
+ np.save(os.path.join(domain_path, "glitch_val_scaled"), glitch_val_scaled)
156
+ np.save(os.path.join(domain_path, "background_val_scaled"), bg_val_scaled)
157
+
158
+ print("Done. Data saved to", domain_path)
159
+
160
+
161
+ if __name__ == "__main__":
162
+ main()
@@ -0,0 +1,145 @@
1
+ """
2
+ Synthetic glitch signal generators.
3
+
4
+ The CDVGAN and gengli generators require optional dependencies:
5
+ pip install deepextractor[generative]
6
+ """
7
+
8
+ import numpy as np
9
+ from scipy.signal import chirp, gausspulse
10
+
11
+ from deepextractor.utils.signal import quality_factor_conversion, rescale
12
+
13
+ SRATE = 4096
14
+ NYQUIST_FREQ = SRATE // 2
15
+
16
+
17
+ def generate_chirp(duration, sample_rate=4096, f0_min=1, f0_max=NYQUIST_FREQ,
18
+ f1_min=1, f1_max=NYQUIST_FREQ):
19
+ f0 = np.random.uniform(f0_min, f0_max)
20
+ f1 = np.random.uniform(f1_min, f1_max)
21
+ t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
22
+ signal = chirp(t, f0=f0, f1=f1, t1=duration, method="linear")
23
+ return t, signal
24
+
25
+
26
+ def generate_sine(duration, sample_rate=4096, freq_min=1, freq_max=NYQUIST_FREQ):
27
+ frequency = np.random.uniform(freq_min, freq_max)
28
+ t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
29
+ signal = np.sin(2 * np.pi * frequency * t)
30
+ return t, signal
31
+
32
+
33
+ def generate_sine_gaussian(duration, sample_rate=4096, freq_min=1, freq_max=NYQUIST_FREQ):
34
+ tau = np.random.uniform(duration / 200, duration / 4)
35
+ frequency = np.random.uniform(freq_min, freq_max)
36
+ t = np.linspace(-duration / 2, duration / 2, int(sample_rate * duration), endpoint=False)
37
+ signal = np.sin(2 * np.pi * frequency * t) * np.exp(-(t**2) / (2 * tau**2))
38
+ return t, signal
39
+
40
+
41
+ def generate_gaussian_pulse(duration, sample_rate=4096, fc_min=1, fc_max=NYQUIST_FREQ,
42
+ bw_min=0.1, bw_max=1.0, bwr_min=-10, bwr_max=0,
43
+ tpr_min=0.5, tpr_max=2.0):
44
+ """
45
+ Generate a Gaussian pulse with random parameters.
46
+
47
+ Parameters
48
+ ----------
49
+ duration : float
50
+ Duration in seconds.
51
+ sample_rate : int
52
+ Sampling rate in Hz.
53
+ fc_min, fc_max : float
54
+ Range for the center frequency (Hz).
55
+ bw_min, bw_max : float
56
+ Range for the fractional bandwidth.
57
+ bwr_min, bwr_max : float
58
+ Range for the bandwidth reference level (dB).
59
+ tpr_min, tpr_max : float
60
+ Range for the taper reference level (dB).
61
+ """
62
+ bw = np.random.uniform(bw_min, bw_max)
63
+ fc = np.random.uniform(fc_min, fc_max)
64
+ bwr = np.random.uniform(bwr_min, bwr_max)
65
+ tpr = np.random.uniform(tpr_min, tpr_max)
66
+ t = np.linspace(-duration / 2, duration / 2, int(sample_rate * duration), endpoint=False)
67
+ signal = gausspulse(t, fc=fc, bw=bw, bwr=bwr, tpr=tpr)
68
+ return t, signal
69
+
70
+
71
+ def ringdown(duration, sample_rate=4096, n_signals=1):
72
+ t = np.linspace(0, duration, int(sample_rate * duration))
73
+ phi = np.random.uniform(0, 2 * np.pi)
74
+ A = 1.0
75
+ f_0 = np.random.uniform(10, NYQUIST_FREQ, n_signals)
76
+ t_0 = np.random.uniform(t[-1] / 4, 3 * t[-1] / 4, n_signals)
77
+ Q = np.random.uniform(5, 150, n_signals)
78
+ tau = np.maximum(quality_factor_conversion(Q, f_0), 0.01)
79
+ f_0 = np.expand_dims(f_0, axis=1)
80
+ t_0 = np.expand_dims(t_0, axis=1)
81
+ tau = np.expand_dims(tau, axis=1)
82
+ h_1 = A * np.exp(-1.0 * ((t - t_0) / (tau))) * np.sin(2 * np.pi * f_0 * (t - t_0) + phi)
83
+ h_1 = ((t - t_0) > 0) * h_1
84
+ h_1 = rescale(h_1)
85
+ if np.random.rand() < 0.5:
86
+ h_1 = np.flip(h_1)
87
+ return t, h_1[0]
88
+
89
+
90
+ def generate_gengli_glitch(ifo):
91
+ """
92
+ Generate a glitch sample using the gengli library.
93
+
94
+ Requires the ``[generative]`` optional dependencies:
95
+ ``pip install deepextractor[generative]``
96
+ """
97
+ try:
98
+ import gengli
99
+ except ImportError as e:
100
+ raise ImportError(
101
+ "gengli is required for this function. "
102
+ "Install it with: pip install deepextractor[generative]"
103
+ ) from e
104
+ g = gengli.glitch_generator(ifo)
105
+ glitch = g.get_glitch(1, srate=4096, snr=10, alpha=0.2, fhigh=1024)
106
+ return None, glitch
107
+
108
+
109
+ def generate_cdvgan_glitch(gtype, cdvgan_generator):
110
+ """
111
+ Generate a glitch sample using a pretrained CDVGAN TensorFlow model.
112
+
113
+ Parameters
114
+ ----------
115
+ gtype : str
116
+ Glitch type: one of ``'blip'``, ``'tomte'``, ``'bbh'``,
117
+ ``'simplex'``, ``'uniform'``.
118
+ cdvgan_generator : tf.keras.Model
119
+ The loaded CDVGAN generator model.
120
+ """
121
+ try:
122
+ import tensorflow as tf
123
+ except ImportError as e:
124
+ raise ImportError(
125
+ "TensorFlow is required for CDVGAN glitches. "
126
+ "Install it with: pip install deepextractor[generative]"
127
+ ) from e
128
+
129
+ latent_dim = 100
130
+ random_ints = np.random.randint(0, 100, size=(1, 3))
131
+ simplex_classes = random_ints / np.sum(random_ints, axis=1).reshape(1, 1)
132
+ uniform_classes = np.random.uniform(low=0.0, high=1.0, size=(1, 3))
133
+
134
+ class_vector_map = {
135
+ "blip": [[1, 0, 0]],
136
+ "tomte": [[0, 1, 0]],
137
+ "bbh": [[0, 0, 1]],
138
+ "simplex": simplex_classes,
139
+ "uniform": uniform_classes,
140
+ }
141
+
142
+ class_vector = np.array(class_vector_map[gtype])
143
+ latent_vector = tf.random.normal(shape=(1, latent_dim))
144
+ generated_glitch = cdvgan_generator([latent_vector, class_vector]).numpy()
145
+ return None, generated_glitch