ml4gw 0.6.3__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ml4gw might be problematic. Click here for more details.
- ml4gw/__init__.py +1 -0
- ml4gw/dataloading/chunked_dataset.py +1 -1
- ml4gw/dataloading/hdf5_dataset.py +36 -6
- ml4gw/dataloading/in_memory_dataset.py +1 -1
- ml4gw/gw.py +4 -3
- ml4gw/nn/autoencoder/base.py +1 -1
- ml4gw/nn/autoencoder/convolutional.py +3 -3
- ml4gw/nn/autoencoder/skip_connection.py +1 -1
- ml4gw/nn/resnet/resnet_1d.py +1 -1
- ml4gw/nn/resnet/resnet_2d.py +1 -1
- ml4gw/nn/streaming/online_average.py +1 -1
- ml4gw/nn/streaming/snapshotter.py +1 -1
- ml4gw/spectral.py +24 -6
- ml4gw/transforms/__init__.py +1 -0
- ml4gw/transforms/iirfilter.py +100 -0
- ml4gw/transforms/pearson.py +2 -2
- ml4gw/transforms/qtransform.py +2 -2
- ml4gw/transforms/scaler.py +1 -1
- ml4gw/transforms/snr_rescaler.py +3 -3
- ml4gw/transforms/spectral.py +2 -2
- ml4gw/transforms/spectrogram.py +1 -1
- ml4gw/transforms/transform.py +2 -2
- ml4gw/transforms/waveforms.py +2 -2
- ml4gw/transforms/whitening.py +19 -4
- ml4gw/utils/slicing.py +1 -6
- ml4gw/waveforms/cbc/coefficients.py +35 -0
- ml4gw/waveforms/cbc/phenom_d.py +3 -3
- ml4gw/waveforms/cbc/phenom_p.py +1 -0
- ml4gw/waveforms/cbc/taylorf2.py +5 -4
- ml4gw/waveforms/cbc/utils.py +111 -0
- ml4gw/waveforms/conversion.py +2 -2
- ml4gw/waveforms/generator.py +289 -26
- ml4gw-0.7.0.dist-info/LICENSE +674 -0
- ml4gw-0.7.0.dist-info/METADATA +78 -0
- ml4gw-0.7.0.dist-info/RECORD +55 -0
- {ml4gw-0.6.3.dist-info → ml4gw-0.7.0.dist-info}/WHEEL +1 -1
- ml4gw-0.6.3.dist-info/METADATA +0 -154
- ml4gw-0.6.3.dist-info/RECORD +0 -51
ml4gw/waveforms/cbc/taylorf2.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
from jaxtyping import Float
|
|
3
3
|
|
|
4
|
-
from
|
|
5
|
-
from
|
|
6
|
-
from
|
|
7
|
-
from
|
|
4
|
+
from ...constants import MPC_SEC, MTSUN_SI, PI
|
|
5
|
+
from ...constants import EulerGamma as GAMMA
|
|
6
|
+
from ...types import BatchTensor, FrequencySeries1d
|
|
7
|
+
from ..conversion import chirp_mass_and_mass_ratio_to_components
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
class TaylorF2(torch.nn.Module):
|
|
@@ -22,6 +22,7 @@ class TaylorF2(torch.nn.Module):
|
|
|
22
22
|
phic: BatchTensor,
|
|
23
23
|
inclination: BatchTensor,
|
|
24
24
|
f_ref: float,
|
|
25
|
+
**kwargs
|
|
25
26
|
):
|
|
26
27
|
"""
|
|
27
28
|
TaylorF2 up to 3.5 PN in phase. Newtonian SPA amplitude.
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utilities for conditioning waveforms
|
|
3
|
+
See https://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/lib/LALSimInspiral.c # noqa
|
|
4
|
+
"""
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from ml4gw.constants import MRSUN, MSUN, MTSUN_SI, C, G
|
|
8
|
+
from ml4gw.types import BatchTensor
|
|
9
|
+
from ml4gw.waveforms.cbc import coefficients
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def chirp_time_bound(
|
|
13
|
+
fstart: BatchTensor,
|
|
14
|
+
mass_1: BatchTensor,
|
|
15
|
+
mass_2: BatchTensor,
|
|
16
|
+
s1: BatchTensor,
|
|
17
|
+
s2: BatchTensor,
|
|
18
|
+
) -> BatchTensor:
|
|
19
|
+
"""
|
|
20
|
+
https://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/lib/LALSimInspiral.c#L4969
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
total_mass = mass_1 + mass_2
|
|
24
|
+
reduced_mass = mass_1 * mass_2 / total_mass
|
|
25
|
+
eta = reduced_mass / total_mass
|
|
26
|
+
chi = torch.max(s1.abs(), s2.abs()).abs()
|
|
27
|
+
|
|
28
|
+
c0 = torch.abs(coefficients.taylor_t2_timing_0pn_coeff(total_mass, eta))
|
|
29
|
+
|
|
30
|
+
c2 = coefficients.taylor_t2_timing_2pn_coeff(eta)
|
|
31
|
+
c3 = (226.0 / 15.0) * chi
|
|
32
|
+
c4 = coefficients.taylor_t2_timing_4pn_coeff(eta)
|
|
33
|
+
|
|
34
|
+
v = (torch.pi * total_mass * fstart * G) ** (1.0 / 3.0)
|
|
35
|
+
v /= C
|
|
36
|
+
|
|
37
|
+
return (c0 * (v**-8) * (1 + (c2 + (c3 + c4 * v) * v) * v * v)).float()
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def chirp_start_frequency_bound(
|
|
41
|
+
tchirp: BatchTensor,
|
|
42
|
+
mass_1: BatchTensor,
|
|
43
|
+
mass_2: BatchTensor,
|
|
44
|
+
):
|
|
45
|
+
"""
|
|
46
|
+
https://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/lib/LALSimInspiral.c#L5104
|
|
47
|
+
"""
|
|
48
|
+
total_mass = mass_1 + mass_2
|
|
49
|
+
mu = mass_1 * mass_2 / total_mass
|
|
50
|
+
|
|
51
|
+
eta = mu / total_mass
|
|
52
|
+
c0 = coefficients.taylor_t3_frequency_0pn_coeff(total_mass)
|
|
53
|
+
return (
|
|
54
|
+
c0
|
|
55
|
+
* pow(5.0 * total_mass * (MTSUN_SI / MSUN) / (eta * tchirp), 3.0 / 8.0)
|
|
56
|
+
).float()
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def final_black_hole_spin_bound(
|
|
60
|
+
s1z: BatchTensor, s2z: BatchTensor
|
|
61
|
+
) -> BatchTensor:
|
|
62
|
+
"""
|
|
63
|
+
https://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/lib/LALSimInspiral.c#L5081
|
|
64
|
+
"""
|
|
65
|
+
maximum_black_hole_spin = 0.998
|
|
66
|
+
s = 0.686 + 0.15 * (s1z + s2z)
|
|
67
|
+
s = torch.maximum(s, torch.abs(s1z)).maximum(torch.abs(s2z))
|
|
68
|
+
s = torch.clamp(s, max=maximum_black_hole_spin)
|
|
69
|
+
return s
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def merge_time_bound(mass_1: BatchTensor, mass_2: BatchTensor) -> BatchTensor:
|
|
73
|
+
"""
|
|
74
|
+
https://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/lib/LALSimInspiral.c#L5007
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
n_orbits = 1
|
|
78
|
+
total_mass = mass_1 + mass_2
|
|
79
|
+
r = 9.0 * total_mass * MRSUN / MSUN
|
|
80
|
+
v = C / 3.0
|
|
81
|
+
return (n_orbits * (2.0 * torch.pi * r / v)).float()
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def ringdown_time_bound(
|
|
85
|
+
total_mass: BatchTensor, s: BatchTensor
|
|
86
|
+
) -> BatchTensor:
|
|
87
|
+
"""
|
|
88
|
+
https://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/lib/LALSimInspiral.c#L5032
|
|
89
|
+
"""
|
|
90
|
+
n_efolds = 11
|
|
91
|
+
|
|
92
|
+
f1 = 1.5251
|
|
93
|
+
f2 = -1.1568
|
|
94
|
+
f3 = 0.1292
|
|
95
|
+
q1 = 0.7000
|
|
96
|
+
q2 = 1.4187
|
|
97
|
+
q3 = -0.4990
|
|
98
|
+
|
|
99
|
+
omega = (f1 + f2 * (1.0 - s) ** f3) / (total_mass * MTSUN_SI / MSUN)
|
|
100
|
+
Q = q1 + q2 * (1.0 - s) ** q3
|
|
101
|
+
tau = 2.0 * Q / omega
|
|
102
|
+
return (n_efolds * tau).float()
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def frequency_isco(mass_1: BatchTensor, mass_2: BatchTensor):
|
|
106
|
+
return (
|
|
107
|
+
1.0
|
|
108
|
+
/ (
|
|
109
|
+
(9.0**1.5) * torch.pi * (mass_1 + mass_2) * MTSUN_SI / MSUN
|
|
110
|
+
).float()
|
|
111
|
+
)
|
ml4gw/waveforms/conversion.py
CHANGED
ml4gw/waveforms/generator.py
CHANGED
|
@@ -1,43 +1,306 @@
|
|
|
1
|
-
|
|
1
|
+
import math
|
|
2
|
+
from typing import Callable, Tuple
|
|
2
3
|
|
|
4
|
+
import numpy as np
|
|
3
5
|
import torch
|
|
4
6
|
from jaxtyping import Float
|
|
5
7
|
from torch import Tensor
|
|
6
8
|
|
|
9
|
+
from ..constants import MSUN
|
|
10
|
+
from ..transforms import IIRFilter
|
|
11
|
+
from ..types import BatchTensor
|
|
12
|
+
from .cbc import utils
|
|
7
13
|
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
14
|
+
EXTRA_TIME_FRACTION = (
|
|
15
|
+
0.1 # fraction of waveform duration to add as extra time for tapering
|
|
16
|
+
)
|
|
17
|
+
EXTRA_CYCLES = 3.0
|
|
12
18
|
|
|
13
|
-
def forward(
|
|
14
|
-
self,
|
|
15
|
-
N: int,
|
|
16
|
-
) -> Dict[str, Float[Tensor, " {N}"]]:
|
|
17
|
-
return {k: v.sample((N,)) for k, v in self.parameters.items()}
|
|
18
19
|
|
|
20
|
+
class TimeDomainCBCWaveformGenerator(torch.nn.Module):
|
|
21
|
+
"""
|
|
22
|
+
Waveform generator that generates time-domain waveforms from frequency-domain approximants.
|
|
23
|
+
|
|
24
|
+
Frequency domain waveforms are conditioned as done by lalsimulation.
|
|
25
|
+
Specifically, waveforms are generated with a starting frequency `fstart`
|
|
26
|
+
slightly below the requested `f_min`, so that they can be tapered from
|
|
27
|
+
`fstart` to `f_min` using a cosine window.
|
|
28
|
+
|
|
29
|
+
Please see https://lscsoft.docs.ligo.org/lalsuite/lalsimulation/group___l_a_l_sim_inspiral__c.html#gac9f16dab2cbca5a431738ee7d2505969 # noqa
|
|
30
|
+
for more information
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
approximant:
|
|
34
|
+
A callable that returns hplus and hcross polarizations
|
|
35
|
+
given requested frequencies and relevant set of parameters.
|
|
36
|
+
See `ml4gw.waveforms.cbc` for implemented approximants.
|
|
37
|
+
sample_rate:
|
|
38
|
+
Rate at which returned time domain waveform will be
|
|
39
|
+
sampled in Hz. This also specifies `f_max` for generating
|
|
40
|
+
waveforms via the nyquist frequency: `f_max = sample_rate // 2`.
|
|
41
|
+
f_min:
|
|
42
|
+
Lower frequency bound for waveforms
|
|
43
|
+
duration:
|
|
44
|
+
Length of waveform in seconds.
|
|
45
|
+
Waveforms will be left padded with zeros
|
|
46
|
+
appropiately to fill the requested duration
|
|
47
|
+
right_pad:
|
|
48
|
+
How far from the right edge of the window
|
|
49
|
+
in seconds the returned waveform coalescence
|
|
50
|
+
will be placed.
|
|
51
|
+
f_ref:
|
|
52
|
+
Reference frequency for the waveform
|
|
53
|
+
"""
|
|
19
54
|
|
|
20
|
-
class WaveformGenerator(torch.nn.Module):
|
|
21
55
|
def __init__(
|
|
22
|
-
self,
|
|
56
|
+
self,
|
|
57
|
+
approximant: Callable,
|
|
58
|
+
sample_rate: float,
|
|
59
|
+
duration: float,
|
|
60
|
+
f_min: float,
|
|
61
|
+
f_ref: float,
|
|
62
|
+
right_pad: float,
|
|
23
63
|
) -> None:
|
|
64
|
+
|
|
65
|
+
super().__init__()
|
|
66
|
+
self.approximant = approximant
|
|
67
|
+
self.f_min = f_min
|
|
68
|
+
self.sample_rate = sample_rate
|
|
69
|
+
self.duration = duration
|
|
70
|
+
self.right_pad = right_pad
|
|
71
|
+
self.f_ref = f_ref
|
|
72
|
+
|
|
73
|
+
self.highpass = self.build_highpass_filter()
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def delta_t(self):
|
|
77
|
+
return 1 / self.sample_rate
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def nyquist(self):
|
|
81
|
+
return int(self.sample_rate / 2)
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def size(self):
|
|
85
|
+
"""Number of samples in the waveform"""
|
|
86
|
+
return int(self.duration * self.sample_rate)
|
|
87
|
+
|
|
88
|
+
@property
|
|
89
|
+
def delta_f(self):
|
|
90
|
+
return 1 / self.duration
|
|
91
|
+
|
|
92
|
+
def build_highpass_filter(self):
|
|
24
93
|
"""
|
|
25
|
-
|
|
26
|
-
|
|
94
|
+
Builds highpass filter object.
|
|
95
|
+
|
|
96
|
+
See https://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/python/lalsimulation/gwsignal/core/conditioning_subroutines.py?ref_type=heads#L10 # noqa
|
|
97
|
+
"""
|
|
98
|
+
order = 8.0
|
|
99
|
+
w1 = np.tan(np.pi * (self.f_min) / self.sample_rate)
|
|
100
|
+
attenuation = 0.99
|
|
101
|
+
wc = w1 * (1.0 / attenuation**0.5 - 1) ** (1.0 / (2.0 * order))
|
|
102
|
+
fc = self.sample_rate * np.arctan(wc) / np.pi
|
|
103
|
+
|
|
104
|
+
return IIRFilter(
|
|
105
|
+
order,
|
|
106
|
+
fc / (self.sample_rate / 2),
|
|
107
|
+
btype="highpass",
|
|
108
|
+
ftype="butterworth",
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
def get_frequencies(self, df: float):
|
|
112
|
+
"""Get the frequencies from 0 to nyquist for corresponding df"""
|
|
113
|
+
num_freqs = int(self.nyquist / df) + 1
|
|
114
|
+
return torch.linspace(0, self.nyquist, num_freqs)
|
|
115
|
+
|
|
116
|
+
def generate_conditioned_fd_waveform(
|
|
117
|
+
self, **parameters: dict[str, BatchTensor]
|
|
118
|
+
) -> Tuple[Float[Tensor, "{N} samples"], Float[Tensor, "{N} samples"]]:
|
|
119
|
+
"""
|
|
120
|
+
Generate a conditioned frequency domain waveform from a frequency domain approximant.
|
|
121
|
+
|
|
122
|
+
Based on https://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/python/lalsimulation/gwsignal/core/waveform_conditioning.py?ref_type=heads#L248 # noqa
|
|
27
123
|
|
|
28
124
|
Args:
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
125
|
+
**parameters:
|
|
126
|
+
Dictionary of parameters for waveform generation
|
|
127
|
+
where key is the parameter name and value is a tensor of parameters.
|
|
128
|
+
It is required that `parameters` contains `mass_1`, `mass_2`, `s1z`, and `s2z`
|
|
129
|
+
keys, which are used for determining parameters of data conditioning.
|
|
130
|
+
If the specified approximant takes other parameters for waveform generation,
|
|
131
|
+
like `chirp_mass` and `mass_ratio`, the utility functions in `ml4gw.waveforms.conversion`
|
|
132
|
+
may be useful for populating the parameters dictionary with these additional parameters.
|
|
133
|
+
Note that, if using an approximant from `ml4gw.waveforms.cbc`, any additional keys in `parameters`
|
|
134
|
+
not ingested by the approximant will be ignored.
|
|
34
135
|
"""
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
136
|
+
# convert masses to kg, make sure
|
|
137
|
+
# they are doubles so there is no
|
|
138
|
+
# overflow in the calculations
|
|
139
|
+
mass_1, mass_2 = (
|
|
140
|
+
parameters["mass_1"].double() * MSUN,
|
|
141
|
+
parameters["mass_2"].double() * MSUN,
|
|
142
|
+
)
|
|
143
|
+
total_mass = mass_1 + mass_2
|
|
144
|
+
s1z, s2z = parameters["s1z"], parameters["s2z"]
|
|
145
|
+
device = mass_1.device
|
|
146
|
+
|
|
147
|
+
f_isco = utils.frequency_isco(mass_1, mass_2)
|
|
148
|
+
f_min = torch.minimum(
|
|
149
|
+
f_isco,
|
|
150
|
+
torch.tensor(self.f_min, device=device),
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# upper bound on chirp time
|
|
154
|
+
tchirp = utils.chirp_time_bound(f_min, mass_1, mass_2, s1z, s2z)
|
|
155
|
+
|
|
156
|
+
# upper bound on final black hole spin
|
|
157
|
+
s = utils.final_black_hole_spin_bound(s1z, s2z)
|
|
158
|
+
|
|
159
|
+
# upper bound on the final plunge, merger, and ringdown time
|
|
160
|
+
tmerge = utils.merge_time_bound(
|
|
161
|
+
mass_1, mass_2
|
|
162
|
+
) + utils.ringdown_time_bound(total_mass, s)
|
|
163
|
+
|
|
164
|
+
# extra time to include for all waveforms to take care of situations
|
|
165
|
+
# where the frequency is close to merger (and is sweeping rapidly):
|
|
166
|
+
# this is a few cycles at the low frequency
|
|
167
|
+
textra = EXTRA_CYCLES / f_min
|
|
168
|
+
|
|
169
|
+
# lower bound on chirpt frequency start used for
|
|
170
|
+
# conditioning the frequency domain waveform
|
|
171
|
+
fstart = utils.chirp_start_frequency_bound(
|
|
172
|
+
(1.0 + EXTRA_TIME_FRACTION) * tchirp, mass_1, mass_2
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# revised chirp time estimate based on fstart
|
|
176
|
+
tchirp_fstart = utils.chirp_time_bound(
|
|
177
|
+
fstart, mass_1, mass_2, s1z, s2z
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
# chirp length in samples
|
|
181
|
+
chirplen = torch.round(
|
|
182
|
+
(tchirp_fstart + tmerge + 2.0 * textra) * self.sample_rate
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
# pad to next power of 2
|
|
186
|
+
chirplen = 2 ** torch.ceil(torch.log(chirplen) / math.log(2))
|
|
187
|
+
|
|
188
|
+
# get smallest df corresponding to longest chirp length,
|
|
189
|
+
# which will make sure there is no wrap around effects.
|
|
190
|
+
df = 1.0 / (chirplen.max() / self.sample_rate)
|
|
191
|
+
|
|
192
|
+
# generate frequency array from 0 to nyquist based on df
|
|
193
|
+
frequencies = self.get_frequencies(df).to(mass_1.device)
|
|
194
|
+
|
|
195
|
+
# downselect to frequencies above fstart,
|
|
196
|
+
# and generate the waveform at the specified frequencies
|
|
197
|
+
freq_mask = frequencies >= fstart.min()
|
|
198
|
+
waveform_frequencies = frequencies[freq_mask]
|
|
199
|
+
|
|
200
|
+
# generate the waveform at specified frequencies
|
|
201
|
+
cross, plus = self.approximant(
|
|
202
|
+
waveform_frequencies, **parameters, f_ref=self.f_ref
|
|
203
|
+
)
|
|
204
|
+
batch_size = cross.size(0)
|
|
205
|
+
|
|
206
|
+
# create tensors to hold the full spectrum
|
|
207
|
+
# of frequencies from 0 to nyquist, and then
|
|
208
|
+
# fill in the requested frequencies with the waveform values
|
|
209
|
+
shape = (batch_size, frequencies.size(0))
|
|
210
|
+
hc_spectrum = torch.zeros(shape, dtype=cross.dtype, device=device)
|
|
211
|
+
hp_spectrum = torch.zeros(shape, dtype=plus.dtype, device=device)
|
|
212
|
+
|
|
213
|
+
hc_spectrum[:, freq_mask] = cross
|
|
214
|
+
hp_spectrum[:, freq_mask] = plus
|
|
215
|
+
|
|
216
|
+
# build a taper that is dependent on each
|
|
217
|
+
# individual waveforms fstart;
|
|
218
|
+
# since this means that the taper sizes
|
|
219
|
+
# will be different for each waveform,
|
|
220
|
+
# construct the tapers based on the maximum size
|
|
221
|
+
# and then set the values outside of the individual
|
|
222
|
+
# waveform taper regions to 1.0
|
|
223
|
+
k0s = torch.round(fstart / df)
|
|
224
|
+
k1s = torch.round(f_min / df)
|
|
225
|
+
|
|
226
|
+
num_freqs = frequencies.size(0)
|
|
227
|
+
frequency_indices = torch.arange(num_freqs)
|
|
228
|
+
taper_mask = frequency_indices <= k1s[:, None]
|
|
229
|
+
taper_mask &= frequency_indices >= k0s[:, None]
|
|
230
|
+
|
|
231
|
+
indices = frequency_indices.expand(batch_size, -1)
|
|
232
|
+
|
|
233
|
+
kvals = indices[taper_mask]
|
|
234
|
+
k0s_expanded = k0s.unsqueeze(1).expand(-1, num_freqs)[taper_mask]
|
|
235
|
+
k1s_expanded = k1s.unsqueeze(1).expand(-1, num_freqs)[taper_mask]
|
|
236
|
+
|
|
237
|
+
windows = 0.5 - 0.5 * torch.cos(
|
|
238
|
+
torch.pi * (kvals - k0s_expanded) / (k1s_expanded - k0s_expanded)
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
hc_spectrum[taper_mask] *= windows
|
|
242
|
+
hp_spectrum[taper_mask] *= windows
|
|
243
|
+
|
|
244
|
+
# zero out frequencies below fstart
|
|
245
|
+
zero_mask = frequencies < fstart[:, None]
|
|
246
|
+
hc_spectrum[zero_mask] = 0
|
|
247
|
+
hp_spectrum[zero_mask] = 0
|
|
248
|
+
|
|
249
|
+
# set nyquist frequency to zero
|
|
250
|
+
hc_spectrum[..., -1], hp_spectrum[..., -1] = 0.0, 0.0
|
|
251
|
+
|
|
252
|
+
# apply time translation in (i.e. phase shift in frequency domain)
|
|
253
|
+
# that will translate the coalescense time such that it is `right_pad`
|
|
254
|
+
# seconds from the right edge of the window
|
|
255
|
+
tshift = round(self.right_pad * self.sample_rate) / self.sample_rate
|
|
256
|
+
kvals = torch.arange(num_freqs)
|
|
257
|
+
phase_shift = torch.exp(1j * 2 * torch.pi * df * tshift * kvals)
|
|
258
|
+
|
|
259
|
+
hc_spectrum *= phase_shift
|
|
260
|
+
hp_spectrum *= phase_shift
|
|
261
|
+
|
|
262
|
+
return hc_spectrum, hp_spectrum
|
|
38
263
|
|
|
39
264
|
def forward(
|
|
40
|
-
self,
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
265
|
+
self,
|
|
266
|
+
**parameters,
|
|
267
|
+
) -> Tuple[Float[Tensor, "{N} samples"], Float[Tensor, "{N} samples"]]:
|
|
268
|
+
"""
|
|
269
|
+
Generates a time-domain waveform from a frequency domain approximant.
|
|
270
|
+
Conditioning is based onhttps://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/python/lalsimulation/gwsignal/core/waveform_conditioning.py?ref_type=heads#L248 # noqa
|
|
271
|
+
|
|
272
|
+
A frequency domain waveform is generated, conditioned (see `generate_conditioned_fd_waveform`)
|
|
273
|
+
and fftdd into the time-domain
|
|
274
|
+
|
|
275
|
+
**parameters:
|
|
276
|
+
Dictionary of parameters for waveform generation
|
|
277
|
+
where key is the parameter name and value is a tensor of parameters.
|
|
278
|
+
It is required that `parameters` contains `mass_1`, `mass_2`, `s1z`, and `s2z`
|
|
279
|
+
keys, which are used for determining parameters of data conditioning.
|
|
280
|
+
If the specified approximant takes other parameters for waveform generation,
|
|
281
|
+
like `chirp_mass` and `mass_ratio`, the utility functions in `ml4gw.waveforms.conversion`
|
|
282
|
+
may be useful for populating the parameters dictionary with these additional parameters.
|
|
283
|
+
Note that, if using an approximant from `ml4gw.waveforms.cbc`, any additional keys in `parameters`
|
|
284
|
+
not ingested by the approximant will be ignored.
|
|
285
|
+
"""
|
|
286
|
+
|
|
287
|
+
hc, hp = self.generate_conditioned_fd_waveform(**parameters)
|
|
288
|
+
|
|
289
|
+
# fft to time domain and apply appropiate scaling
|
|
290
|
+
hc = torch.fft.irfft(hc) * self.sample_rate
|
|
291
|
+
hp = torch.fft.irfft(hp) * self.sample_rate
|
|
292
|
+
|
|
293
|
+
# TODO: some additional tapering in the time
|
|
294
|
+
# domain is performed in lalsimulation
|
|
295
|
+
|
|
296
|
+
# pad waveforms on left up to requested duration
|
|
297
|
+
pad = int((self.duration * self.sample_rate) - hp.shape[-1])
|
|
298
|
+
hc = torch.nn.functional.pad(hc, (pad, 0))
|
|
299
|
+
hp = torch.nn.functional.pad(hp, (pad, 0))
|
|
300
|
+
|
|
301
|
+
# finally, highpass the waveforms,
|
|
302
|
+
# going to double precision
|
|
303
|
+
hp = self.highpass(hp.double())
|
|
304
|
+
hc = self.highpass(hc.double())
|
|
305
|
+
|
|
306
|
+
return hc, hp
|