ml4gw 0.6.3__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ml4gw might be problematic. Click here for more details.

@@ -1,10 +1,10 @@
1
1
  import torch
2
2
  from jaxtyping import Float
3
3
 
4
- from ml4gw.constants import MPC_SEC, MTSUN_SI, PI
5
- from ml4gw.constants import EulerGamma as GAMMA
6
- from ml4gw.types import BatchTensor, FrequencySeries1d
7
- from ml4gw.waveforms.conversion import chirp_mass_and_mass_ratio_to_components
4
+ from ...constants import MPC_SEC, MTSUN_SI, PI
5
+ from ...constants import EulerGamma as GAMMA
6
+ from ...types import BatchTensor, FrequencySeries1d
7
+ from ..conversion import chirp_mass_and_mass_ratio_to_components
8
8
 
9
9
 
10
10
  class TaylorF2(torch.nn.Module):
@@ -22,6 +22,7 @@ class TaylorF2(torch.nn.Module):
22
22
  phic: BatchTensor,
23
23
  inclination: BatchTensor,
24
24
  f_ref: float,
25
+ **kwargs
25
26
  ):
26
27
  """
27
28
  TaylorF2 up to 3.5 PN in phase. Newtonian SPA amplitude.
@@ -0,0 +1,111 @@
1
+ """
2
+ Utilities for conditioning waveforms
3
+ See https://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/lib/LALSimInspiral.c # noqa
4
+ """
5
+ import torch
6
+
7
+ from ml4gw.constants import MRSUN, MSUN, MTSUN_SI, C, G
8
+ from ml4gw.types import BatchTensor
9
+ from ml4gw.waveforms.cbc import coefficients
10
+
11
+
12
+ def chirp_time_bound(
13
+ fstart: BatchTensor,
14
+ mass_1: BatchTensor,
15
+ mass_2: BatchTensor,
16
+ s1: BatchTensor,
17
+ s2: BatchTensor,
18
+ ) -> BatchTensor:
19
+ """
20
+ https://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/lib/LALSimInspiral.c#L4969
21
+ """
22
+
23
+ total_mass = mass_1 + mass_2
24
+ reduced_mass = mass_1 * mass_2 / total_mass
25
+ eta = reduced_mass / total_mass
26
+ chi = torch.max(s1.abs(), s2.abs()).abs()
27
+
28
+ c0 = torch.abs(coefficients.taylor_t2_timing_0pn_coeff(total_mass, eta))
29
+
30
+ c2 = coefficients.taylor_t2_timing_2pn_coeff(eta)
31
+ c3 = (226.0 / 15.0) * chi
32
+ c4 = coefficients.taylor_t2_timing_4pn_coeff(eta)
33
+
34
+ v = (torch.pi * total_mass * fstart * G) ** (1.0 / 3.0)
35
+ v /= C
36
+
37
+ return (c0 * (v**-8) * (1 + (c2 + (c3 + c4 * v) * v) * v * v)).float()
38
+
39
+
40
+ def chirp_start_frequency_bound(
41
+ tchirp: BatchTensor,
42
+ mass_1: BatchTensor,
43
+ mass_2: BatchTensor,
44
+ ):
45
+ """
46
+ https://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/lib/LALSimInspiral.c#L5104
47
+ """
48
+ total_mass = mass_1 + mass_2
49
+ mu = mass_1 * mass_2 / total_mass
50
+
51
+ eta = mu / total_mass
52
+ c0 = coefficients.taylor_t3_frequency_0pn_coeff(total_mass)
53
+ return (
54
+ c0
55
+ * pow(5.0 * total_mass * (MTSUN_SI / MSUN) / (eta * tchirp), 3.0 / 8.0)
56
+ ).float()
57
+
58
+
59
+ def final_black_hole_spin_bound(
60
+ s1z: BatchTensor, s2z: BatchTensor
61
+ ) -> BatchTensor:
62
+ """
63
+ https://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/lib/LALSimInspiral.c#L5081
64
+ """
65
+ maximum_black_hole_spin = 0.998
66
+ s = 0.686 + 0.15 * (s1z + s2z)
67
+ s = torch.maximum(s, torch.abs(s1z)).maximum(torch.abs(s2z))
68
+ s = torch.clamp(s, max=maximum_black_hole_spin)
69
+ return s
70
+
71
+
72
+ def merge_time_bound(mass_1: BatchTensor, mass_2: BatchTensor) -> BatchTensor:
73
+ """
74
+ https://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/lib/LALSimInspiral.c#L5007
75
+ """
76
+
77
+ n_orbits = 1
78
+ total_mass = mass_1 + mass_2
79
+ r = 9.0 * total_mass * MRSUN / MSUN
80
+ v = C / 3.0
81
+ return (n_orbits * (2.0 * torch.pi * r / v)).float()
82
+
83
+
84
+ def ringdown_time_bound(
85
+ total_mass: BatchTensor, s: BatchTensor
86
+ ) -> BatchTensor:
87
+ """
88
+ https://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/lib/LALSimInspiral.c#L5032
89
+ """
90
+ n_efolds = 11
91
+
92
+ f1 = 1.5251
93
+ f2 = -1.1568
94
+ f3 = 0.1292
95
+ q1 = 0.7000
96
+ q2 = 1.4187
97
+ q3 = -0.4990
98
+
99
+ omega = (f1 + f2 * (1.0 - s) ** f3) / (total_mass * MTSUN_SI / MSUN)
100
+ Q = q1 + q2 * (1.0 - s) ** q3
101
+ tau = 2.0 * Q / omega
102
+ return (n_efolds * tau).float()
103
+
104
+
105
+ def frequency_isco(mass_1: BatchTensor, mass_2: BatchTensor):
106
+ return (
107
+ 1.0
108
+ / (
109
+ (9.0**1.5) * torch.pi * (mass_1 + mass_2) * MTSUN_SI / MSUN
110
+ ).float()
111
+ )
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
 
3
- from ml4gw.constants import MTSUN_SI, PI
4
- from ml4gw.types import BatchTensor
3
+ from ..constants import MTSUN_SI, PI
4
+ from ..types import BatchTensor
5
5
 
6
6
 
7
7
  def rotate_z(angle: BatchTensor, x, y, z):
@@ -1,43 +1,306 @@
1
- from typing import Callable, Dict, Tuple
1
+ import math
2
+ from typing import Callable, Tuple
2
3
 
4
+ import numpy as np
3
5
  import torch
4
6
  from jaxtyping import Float
5
7
  from torch import Tensor
6
8
 
9
+ from ..constants import MSUN
10
+ from ..transforms import IIRFilter
11
+ from ..types import BatchTensor
12
+ from .cbc import utils
7
13
 
8
- class ParameterSampler(torch.nn.Module):
9
- def __init__(self, **parameters: Callable) -> None:
10
- super().__init__()
11
- self.parameters = parameters
14
+ EXTRA_TIME_FRACTION = (
15
+ 0.1 # fraction of waveform duration to add as extra time for tapering
16
+ )
17
+ EXTRA_CYCLES = 3.0
12
18
 
13
- def forward(
14
- self,
15
- N: int,
16
- ) -> Dict[str, Float[Tensor, " {N}"]]:
17
- return {k: v.sample((N,)) for k, v in self.parameters.items()}
18
19
 
20
+ class TimeDomainCBCWaveformGenerator(torch.nn.Module):
21
+ """
22
+ Waveform generator that generates time-domain waveforms from frequency-domain approximants.
23
+
24
+ Frequency domain waveforms are conditioned as done by lalsimulation.
25
+ Specifically, waveforms are generated with a starting frequency `fstart`
26
+ slightly below the requested `f_min`, so that they can be tapered from
27
+ `fstart` to `f_min` using a cosine window.
28
+
29
+ Please see https://lscsoft.docs.ligo.org/lalsuite/lalsimulation/group___l_a_l_sim_inspiral__c.html#gac9f16dab2cbca5a431738ee7d2505969 # noqa
30
+ for more information
31
+
32
+ Args:
33
+ approximant:
34
+ A callable that returns hplus and hcross polarizations
35
+ given requested frequencies and relevant set of parameters.
36
+ See `ml4gw.waveforms.cbc` for implemented approximants.
37
+ sample_rate:
38
+ Rate at which returned time domain waveform will be
39
+ sampled in Hz. This also specifies `f_max` for generating
40
+ waveforms via the nyquist frequency: `f_max = sample_rate // 2`.
41
+ f_min:
42
+ Lower frequency bound for waveforms
43
+ duration:
44
+ Length of waveform in seconds.
45
+ Waveforms will be left padded with zeros
46
+ appropiately to fill the requested duration
47
+ right_pad:
48
+ How far from the right edge of the window
49
+ in seconds the returned waveform coalescence
50
+ will be placed.
51
+ f_ref:
52
+ Reference frequency for the waveform
53
+ """
19
54
 
20
- class WaveformGenerator(torch.nn.Module):
21
55
  def __init__(
22
- self, waveform: Callable, parameter_sampler: ParameterSampler
56
+ self,
57
+ approximant: Callable,
58
+ sample_rate: float,
59
+ duration: float,
60
+ f_min: float,
61
+ f_ref: float,
62
+ right_pad: float,
23
63
  ) -> None:
64
+
65
+ super().__init__()
66
+ self.approximant = approximant
67
+ self.f_min = f_min
68
+ self.sample_rate = sample_rate
69
+ self.duration = duration
70
+ self.right_pad = right_pad
71
+ self.f_ref = f_ref
72
+
73
+ self.highpass = self.build_highpass_filter()
74
+
75
+ @property
76
+ def delta_t(self):
77
+ return 1 / self.sample_rate
78
+
79
+ @property
80
+ def nyquist(self):
81
+ return int(self.sample_rate / 2)
82
+
83
+ @property
84
+ def size(self):
85
+ """Number of samples in the waveform"""
86
+ return int(self.duration * self.sample_rate)
87
+
88
+ @property
89
+ def delta_f(self):
90
+ return 1 / self.duration
91
+
92
+ def build_highpass_filter(self):
24
93
  """
25
- A torch module that generates waveforms from a given waveform function
26
- and a parameter sampler.
94
+ Builds highpass filter object.
95
+
96
+ See https://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/python/lalsimulation/gwsignal/core/conditioning_subroutines.py?ref_type=heads#L10 # noqa
97
+ """
98
+ order = 8.0
99
+ w1 = np.tan(np.pi * (self.f_min) / self.sample_rate)
100
+ attenuation = 0.99
101
+ wc = w1 * (1.0 / attenuation**0.5 - 1) ** (1.0 / (2.0 * order))
102
+ fc = self.sample_rate * np.arctan(wc) / np.pi
103
+
104
+ return IIRFilter(
105
+ order,
106
+ fc / (self.sample_rate / 2),
107
+ btype="highpass",
108
+ ftype="butterworth",
109
+ )
110
+
111
+ def get_frequencies(self, df: float):
112
+ """Get the frequencies from 0 to nyquist for corresponding df"""
113
+ num_freqs = int(self.nyquist / df) + 1
114
+ return torch.linspace(0, self.nyquist, num_freqs)
115
+
116
+ def generate_conditioned_fd_waveform(
117
+ self, **parameters: dict[str, BatchTensor]
118
+ ) -> Tuple[Float[Tensor, "{N} samples"], Float[Tensor, "{N} samples"]]:
119
+ """
120
+ Generate a conditioned frequency domain waveform from a frequency domain approximant.
121
+
122
+ Based on https://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/python/lalsimulation/gwsignal/core/waveform_conditioning.py?ref_type=heads#L248 # noqa
27
123
 
28
124
  Args:
29
- waveform:
30
- A callable that returns hplus and hcross polarizations
31
- given a set of parameters.
32
- parameter_sampler:
33
- A ParameterSampler object
125
+ **parameters:
126
+ Dictionary of parameters for waveform generation
127
+ where key is the parameter name and value is a tensor of parameters.
128
+ It is required that `parameters` contains `mass_1`, `mass_2`, `s1z`, and `s2z`
129
+ keys, which are used for determining parameters of data conditioning.
130
+ If the specified approximant takes other parameters for waveform generation,
131
+ like `chirp_mass` and `mass_ratio`, the utility functions in `ml4gw.waveforms.conversion`
132
+ may be useful for populating the parameters dictionary with these additional parameters.
133
+ Note that, if using an approximant from `ml4gw.waveforms.cbc`, any additional keys in `parameters`
134
+ not ingested by the approximant will be ignored.
34
135
  """
35
- super().__init__()
36
- self.waveform = waveform
37
- self.parameter_sampler = parameter_sampler
136
+ # convert masses to kg, make sure
137
+ # they are doubles so there is no
138
+ # overflow in the calculations
139
+ mass_1, mass_2 = (
140
+ parameters["mass_1"].double() * MSUN,
141
+ parameters["mass_2"].double() * MSUN,
142
+ )
143
+ total_mass = mass_1 + mass_2
144
+ s1z, s2z = parameters["s1z"], parameters["s2z"]
145
+ device = mass_1.device
146
+
147
+ f_isco = utils.frequency_isco(mass_1, mass_2)
148
+ f_min = torch.minimum(
149
+ f_isco,
150
+ torch.tensor(self.f_min, device=device),
151
+ )
152
+
153
+ # upper bound on chirp time
154
+ tchirp = utils.chirp_time_bound(f_min, mass_1, mass_2, s1z, s2z)
155
+
156
+ # upper bound on final black hole spin
157
+ s = utils.final_black_hole_spin_bound(s1z, s2z)
158
+
159
+ # upper bound on the final plunge, merger, and ringdown time
160
+ tmerge = utils.merge_time_bound(
161
+ mass_1, mass_2
162
+ ) + utils.ringdown_time_bound(total_mass, s)
163
+
164
+ # extra time to include for all waveforms to take care of situations
165
+ # where the frequency is close to merger (and is sweeping rapidly):
166
+ # this is a few cycles at the low frequency
167
+ textra = EXTRA_CYCLES / f_min
168
+
169
+ # lower bound on chirpt frequency start used for
170
+ # conditioning the frequency domain waveform
171
+ fstart = utils.chirp_start_frequency_bound(
172
+ (1.0 + EXTRA_TIME_FRACTION) * tchirp, mass_1, mass_2
173
+ )
174
+
175
+ # revised chirp time estimate based on fstart
176
+ tchirp_fstart = utils.chirp_time_bound(
177
+ fstart, mass_1, mass_2, s1z, s2z
178
+ )
179
+
180
+ # chirp length in samples
181
+ chirplen = torch.round(
182
+ (tchirp_fstart + tmerge + 2.0 * textra) * self.sample_rate
183
+ )
184
+
185
+ # pad to next power of 2
186
+ chirplen = 2 ** torch.ceil(torch.log(chirplen) / math.log(2))
187
+
188
+ # get smallest df corresponding to longest chirp length,
189
+ # which will make sure there is no wrap around effects.
190
+ df = 1.0 / (chirplen.max() / self.sample_rate)
191
+
192
+ # generate frequency array from 0 to nyquist based on df
193
+ frequencies = self.get_frequencies(df).to(mass_1.device)
194
+
195
+ # downselect to frequencies above fstart,
196
+ # and generate the waveform at the specified frequencies
197
+ freq_mask = frequencies >= fstart.min()
198
+ waveform_frequencies = frequencies[freq_mask]
199
+
200
+ # generate the waveform at specified frequencies
201
+ cross, plus = self.approximant(
202
+ waveform_frequencies, **parameters, f_ref=self.f_ref
203
+ )
204
+ batch_size = cross.size(0)
205
+
206
+ # create tensors to hold the full spectrum
207
+ # of frequencies from 0 to nyquist, and then
208
+ # fill in the requested frequencies with the waveform values
209
+ shape = (batch_size, frequencies.size(0))
210
+ hc_spectrum = torch.zeros(shape, dtype=cross.dtype, device=device)
211
+ hp_spectrum = torch.zeros(shape, dtype=plus.dtype, device=device)
212
+
213
+ hc_spectrum[:, freq_mask] = cross
214
+ hp_spectrum[:, freq_mask] = plus
215
+
216
+ # build a taper that is dependent on each
217
+ # individual waveforms fstart;
218
+ # since this means that the taper sizes
219
+ # will be different for each waveform,
220
+ # construct the tapers based on the maximum size
221
+ # and then set the values outside of the individual
222
+ # waveform taper regions to 1.0
223
+ k0s = torch.round(fstart / df)
224
+ k1s = torch.round(f_min / df)
225
+
226
+ num_freqs = frequencies.size(0)
227
+ frequency_indices = torch.arange(num_freqs, device=device)
228
+ taper_mask = frequency_indices <= k1s[:, None]
229
+ taper_mask &= frequency_indices >= k0s[:, None]
230
+
231
+ indices = frequency_indices.expand(batch_size, -1)
232
+
233
+ kvals = indices[taper_mask]
234
+ k0s_expanded = k0s.unsqueeze(1).expand(-1, num_freqs)[taper_mask]
235
+ k1s_expanded = k1s.unsqueeze(1).expand(-1, num_freqs)[taper_mask]
236
+
237
+ windows = 0.5 - 0.5 * torch.cos(
238
+ torch.pi * (kvals - k0s_expanded) / (k1s_expanded - k0s_expanded)
239
+ )
240
+
241
+ hc_spectrum[taper_mask] *= windows
242
+ hp_spectrum[taper_mask] *= windows
243
+
244
+ # zero out frequencies below fstart
245
+ zero_mask = frequencies < fstart[:, None]
246
+ hc_spectrum[zero_mask] = 0
247
+ hp_spectrum[zero_mask] = 0
248
+
249
+ # set nyquist frequency to zero
250
+ hc_spectrum[..., -1], hp_spectrum[..., -1] = 0.0, 0.0
251
+
252
+ # apply time translation in (i.e. phase shift in frequency domain)
253
+ # that will translate the coalescense time such that it is `right_pad`
254
+ # seconds from the right edge of the window
255
+ tshift = round(self.right_pad * self.sample_rate) / self.sample_rate
256
+ kvals = torch.arange(num_freqs, device=device)
257
+ phase_shift = torch.exp(1j * 2 * torch.pi * df * tshift * kvals)
258
+
259
+ hc_spectrum *= phase_shift
260
+ hp_spectrum *= phase_shift
261
+
262
+ return hc_spectrum, hp_spectrum
38
263
 
39
264
  def forward(
40
- self, N: int
41
- ) -> Tuple[Float[Tensor, "{N} samples"], Dict[str, Float[Tensor, " {N}"]]]:
42
- parameters = self.parameter_sampler(N)
43
- return self.waveform(**parameters), parameters
265
+ self,
266
+ **parameters,
267
+ ) -> Tuple[Float[Tensor, "{N} samples"], Float[Tensor, "{N} samples"]]:
268
+ """
269
+ Generates a time-domain waveform from a frequency domain approximant.
270
+ Conditioning is based onhttps://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/python/lalsimulation/gwsignal/core/waveform_conditioning.py?ref_type=heads#L248 # noqa
271
+
272
+ A frequency domain waveform is generated, conditioned (see `generate_conditioned_fd_waveform`)
273
+ and fftdd into the time-domain
274
+
275
+ **parameters:
276
+ Dictionary of parameters for waveform generation
277
+ where key is the parameter name and value is a tensor of parameters.
278
+ It is required that `parameters` contains `mass_1`, `mass_2`, `s1z`, and `s2z`
279
+ keys, which are used for determining parameters of data conditioning.
280
+ If the specified approximant takes other parameters for waveform generation,
281
+ like `chirp_mass` and `mass_ratio`, the utility functions in `ml4gw.waveforms.conversion`
282
+ may be useful for populating the parameters dictionary with these additional parameters.
283
+ Note that, if using an approximant from `ml4gw.waveforms.cbc`, any additional keys in `parameters`
284
+ not ingested by the approximant will be ignored.
285
+ """
286
+
287
+ hc, hp = self.generate_conditioned_fd_waveform(**parameters)
288
+
289
+ # fft to time domain and apply appropiate scaling
290
+ hc = torch.fft.irfft(hc) * self.sample_rate
291
+ hp = torch.fft.irfft(hp) * self.sample_rate
292
+
293
+ # TODO: some additional tapering in the time
294
+ # domain is performed in lalsimulation
295
+
296
+ # pad waveforms on left up to requested duration
297
+ pad = int((self.duration * self.sample_rate) - hp.shape[-1])
298
+ hc = torch.nn.functional.pad(hc, (pad, 0))
299
+ hp = torch.nn.functional.pad(hp, (pad, 0))
300
+
301
+ # finally, highpass the waveforms,
302
+ # going to double precision
303
+ hp = self.highpass(hp.double())
304
+ hc = self.highpass(hc.double())
305
+
306
+ return hc, hp