ml4gw 0.6.3__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ml4gw might be problematic. Click here for more details.

ml4gw/__init__.py CHANGED
@@ -0,0 +1 @@
1
+ from .constants import *
@@ -2,7 +2,7 @@ from collections.abc import Iterable
2
2
 
3
3
  import torch
4
4
 
5
- from ml4gw.types import WaveformTensor
5
+ from ..types import WaveformTensor
6
6
 
7
7
 
8
8
  class ChunkedTimeSeriesDataset(torch.utils.data.IterableDataset):
@@ -1,11 +1,11 @@
1
1
  import warnings
2
- from typing import Sequence, Union
2
+ from typing import Optional, Sequence, Union
3
3
 
4
4
  import h5py
5
5
  import numpy as np
6
6
  import torch
7
7
 
8
- from ml4gw.types import WaveformTensor
8
+ from ..types import WaveformTensor
9
9
 
10
10
 
11
11
  class ContiguousHdf5Warning(Warning):
@@ -50,6 +50,13 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
50
50
  channel. The latter setting limits the amount of
51
51
  entropy in the effective dataset, but can provide
52
52
  over 2x improvement in total throughput.
53
+ num_files_per_batch:
54
+ The number of unique files from which to sample
55
+ batch elements each epoch. If left as `None`,
56
+ will use all available files. Useful when reading
57
+ from many files is bottlenecking dataloading.
58
+
59
+
53
60
  """
54
61
 
55
62
  def __init__(
@@ -60,6 +67,7 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
60
67
  batch_size: int,
61
68
  batches_per_epoch: int,
62
69
  coincident: Union[bool, str],
70
+ num_files_per_batch: Optional[int] = None,
63
71
  ) -> None:
64
72
  if not isinstance(coincident, bool) and coincident != "files":
65
73
  raise ValueError(
@@ -67,13 +75,21 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
67
75
  "got unrecognized value {}".format(coincident)
68
76
  )
69
77
 
70
- self.fnames = fnames
78
+ self.fnames = np.array(fnames)
71
79
  self.channels = channels
72
80
  self.num_channels = len(channels)
73
81
  self.kernel_size = kernel_size
74
82
  self.batch_size = batch_size
75
83
  self.batches_per_epoch = batches_per_epoch
76
84
  self.coincident = coincident
85
+ self.num_files_per_batch = (
86
+ len(fnames) if num_files_per_batch is None else num_files_per_batch
87
+ )
88
+ if self.num_files_per_batch > len(fnames):
89
+ raise ValueError(
90
+ f"Number of files per batch ({self.num_files_per_batch}) "
91
+ f"cannot exceed number of files ({len(fnames)}) "
92
+ )
77
93
 
78
94
  self.sizes = {}
79
95
  for fname in self.fnames:
@@ -85,13 +101,14 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
85
101
  "without using chunked storage. This can have "
86
102
  "severe performance impacts at data loading time. "
87
103
  "If you need faster loading, try re-generating "
88
- "your datset with chunked storage turned on.".format(
104
+ "your dataset with chunked storage turned on.".format(
89
105
  fname
90
106
  ),
91
107
  category=ContiguousHdf5Warning,
92
108
  )
93
109
 
94
110
  self.sizes[fname] = len(dset)
111
+
95
112
  total = sum(self.sizes.values())
96
113
  self.probs = np.array([i / total for i in self.sizes.values()])
97
114
 
@@ -99,9 +116,22 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
99
116
  return self.batches_per_epoch
100
117
 
101
118
  def sample_fnames(self, size) -> np.ndarray:
102
- return np.random.choice(
103
- self.fnames,
119
+ # first, randomly select `self.num_files_per_batch`
120
+ # file indices based on their probabilities
121
+ fname_indices = np.arange(len(self.fnames))
122
+ fname_indices = np.random.choice(
123
+ fname_indices,
104
124
  p=self.probs,
125
+ size=(self.num_files_per_batch),
126
+ replace=False,
127
+ )
128
+ # now renormalize the probabilities, and sample
129
+ # the requested size from this subset of files
130
+ probs = self.probs[fname_indices]
131
+ probs /= probs.sum()
132
+ return np.random.choice(
133
+ self.fnames[fname_indices],
134
+ p=probs,
105
135
  size=size,
106
136
  replace=True,
107
137
  )
@@ -5,7 +5,7 @@ import torch
5
5
  from jaxtyping import Float
6
6
  from torch import Tensor
7
7
 
8
- from ml4gw.utils.slicing import slice_kernels
8
+ from ..utils.slicing import slice_kernels
9
9
 
10
10
 
11
11
  class InMemoryDataset(torch.utils.data.IterableDataset):
ml4gw/gw.py CHANGED
@@ -16,8 +16,10 @@ import torch
16
16
  from jaxtyping import Float
17
17
  from torch import Tensor
18
18
 
19
- from ml4gw.constants import C
20
- from ml4gw.types import (
19
+ from ml4gw.utils.interferometer import InterferometerGeometry
20
+
21
+ from .constants import C
22
+ from .types import (
21
23
  BatchTensor,
22
24
  NetworkDetectorTensors,
23
25
  NetworkVertices,
@@ -26,7 +28,6 @@ from ml4gw.types import (
26
28
  VectorGeometry,
27
29
  WaveformTensor,
28
30
  )
29
- from ml4gw.utils.interferometer import InterferometerGeometry
30
31
 
31
32
 
32
33
  def outer(x: VectorGeometry, y: VectorGeometry) -> TensorGeometry:
@@ -4,7 +4,7 @@ from typing import Optional, Tuple, Union
4
4
  import torch
5
5
  from torch import Tensor
6
6
 
7
- from ml4gw.nn.autoencoder.skip_connection import SkipConnection
7
+ from .skip_connection import SkipConnection
8
8
 
9
9
 
10
10
  class Autoencoder(torch.nn.Module):
@@ -4,9 +4,9 @@ from typing import Optional
4
4
  import torch
5
5
  from torch import Tensor
6
6
 
7
- from ml4gw.nn.autoencoder.base import Autoencoder
8
- from ml4gw.nn.autoencoder.skip_connection import SkipConnection
9
- from ml4gw.nn.autoencoder.utils import match_size
7
+ from .base import Autoencoder
8
+ from .skip_connection import SkipConnection
9
+ from .utils import match_size
10
10
 
11
11
  Module = Callable[[...], torch.nn.Module]
12
12
 
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
  from torch import Tensor
3
3
 
4
- from ml4gw.nn.autoencoder.utils import match_size
4
+ from .utils import match_size
5
5
 
6
6
 
7
7
  class SkipConnection(torch.nn.Module):
@@ -13,7 +13,7 @@ import torch
13
13
  import torch.nn as nn
14
14
  from torch import Tensor
15
15
 
16
- from ml4gw.nn.norm import GroupNorm1DGetter, NormLayer
16
+ from ..norm import GroupNorm1DGetter, NormLayer
17
17
 
18
18
 
19
19
  def convN(
@@ -10,7 +10,7 @@ import torch
10
10
  import torch.nn as nn
11
11
  from torch import Tensor
12
12
 
13
- from ml4gw.nn.norm import GroupNorm2DGetter, NormLayer
13
+ from ..norm import GroupNorm2DGetter, NormLayer
14
14
 
15
15
 
16
16
  def convN(
@@ -4,7 +4,7 @@ import torch
4
4
  from jaxtyping import Float
5
5
  from torch import Tensor
6
6
 
7
- from ml4gw.utils.slicing import unfold_windows
7
+ from ...utils.slicing import unfold_windows
8
8
 
9
9
 
10
10
  class OnlineAverager(torch.nn.Module):
@@ -4,7 +4,7 @@ import torch
4
4
  from jaxtyping import Float
5
5
  from torch import Tensor
6
6
 
7
- from ml4gw.utils.slicing import unfold_windows
7
+ from ...utils.slicing import unfold_windows
8
8
 
9
9
 
10
10
  class Snapshotter(torch.nn.Module):
ml4gw/spectral.py CHANGED
@@ -15,7 +15,7 @@ import torch
15
15
  from jaxtyping import Float
16
16
  from torch import Tensor
17
17
 
18
- from ml4gw.types import (
18
+ from .types import (
19
19
  FrequencySeries1to3d,
20
20
  PSDTensor,
21
21
  TimeSeries1to3d,
@@ -343,6 +343,7 @@ def truncate_inverse_power_spectrum(
343
343
  fduration: Union[Float[Tensor, " time"], float],
344
344
  sample_rate: float,
345
345
  highpass: Optional[float] = None,
346
+ lowpass: Optional[float] = None,
346
347
  ) -> PSDTensor:
347
348
  """
348
349
  Truncate the length of the time domain response
@@ -375,6 +376,10 @@ def truncate_inverse_power_spectrum(
375
376
  If specified, will zero out the frequency response
376
377
  of all frequencies below this value in Hz. If left
377
378
  as `None`, no highpass filtering will be applied.
379
+ lowpass:
380
+ If specified, will zero out the frequency response
381
+ of all frequencies above this value in Hz. If left
382
+ as `None`, no lowpass filtering will be applied.
378
383
  Returns:
379
384
  The PSD with its time domain response truncated
380
385
  to `fduration` and any highpassed frequencies
@@ -388,12 +393,15 @@ def truncate_inverse_power_spectrum(
388
393
  # impulse response function
389
394
  inv_asd = 1 / psd**0.5
390
395
 
391
- # zero our leading frequencies if we want the
392
- # filter to perform highpass filtering
396
+ # zero out frequencies if we want the filter
397
+ # to perform highpass/lowpass filtering
398
+ df = sample_rate / N
393
399
  if highpass is not None:
394
- df = sample_rate / N
395
400
  idx = int(highpass / df)
396
401
  inv_asd[:, :, :idx] = 0
402
+ if lowpass is not None:
403
+ idx = int(lowpass / df)
404
+ inv_asd[:, :, idx:] = 0
397
405
 
398
406
  if inv_asd.size(-1) % 2:
399
407
  inv_asd[:, :, -1] = 0
@@ -455,12 +463,13 @@ def whiten(
455
463
  fduration: Union[Float[Tensor, " time"], float],
456
464
  sample_rate: float,
457
465
  highpass: Optional[float] = None,
466
+ lowpass: Optional[float] = None,
458
467
  ) -> WaveformTensor:
459
468
  """
460
469
  Whiten a batch of timeseries using the specified
461
470
  background one-sided power spectral densities (PSDs),
462
471
  modified to have the desired time domain response length
463
- `fduration` and possibly to highpass filter.
472
+ `fduration` and possibly to highpass/lowpass filter.
464
473
 
465
474
  Args:
466
475
  X:
@@ -493,6 +502,11 @@ def whiten(
493
502
  the data, setting the frequency response in the
494
503
  whitening filter to 0. If left as `None`, no
495
504
  highpass filtering will be applied.
505
+ lowpass:
506
+ The frequency in Hz at which to lowpass filter
507
+ the data, setting the frequency response in the
508
+ whitening filter to 0. If left as `None`, no
509
+ lowpass filtering will be applied.
496
510
  Returns:
497
511
  Batch of whitened multichannel timeseries with
498
512
  `fduration / 2` seconds trimmed from each side.
@@ -529,7 +543,11 @@ def whiten(
529
543
  # truncate it to have the desired
530
544
  # time domain response length
531
545
  psd = truncate_inverse_power_spectrum(
532
- psd, fduration, sample_rate, highpass
546
+ psd,
547
+ fduration,
548
+ sample_rate,
549
+ highpass,
550
+ lowpass,
533
551
  )
534
552
 
535
553
  return normalize_by_psd(X, psd, sample_rate, pad)
@@ -1,3 +1,4 @@
1
+ from .iirfilter import IIRFilter
1
2
  from .pearson import ShiftedPearsonCorrelation
2
3
  from .qtransform import QScan, SingleQTransform
3
4
  from .scaler import ChannelWiseScaler
@@ -0,0 +1,100 @@
1
+ from typing import Union
2
+
3
+ import torch
4
+ from scipy.signal import iirfilter
5
+ from torchaudio.functional import filtfilt
6
+
7
+
8
+ class IIRFilter(torch.nn.Module):
9
+ r"""
10
+ IIR digital and analog filter design given order and critical points.
11
+ Design an Nth-order digital or analog filter and apply it to a signal.
12
+ Uses SciPy's `iirfilter` function to create the filter coefficients.
13
+ https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirfilter.html # noqa E501
14
+
15
+ The forward call of this module accepts a batch tensor of shape
16
+ (n_waveforms, n_samples) and returns the filtered waveforms.
17
+
18
+ Args:
19
+ N:
20
+ The order of the filter.
21
+ Wn:
22
+ A scalar or length-2 sequence giving the critical frequencies.
23
+ For digital filters, Wn are in the same units as fs. By
24
+ default, fs is 2 half-cycles/sample, so these are normalized
25
+ from 0 to 1, where 1 is the Nyquist frequency. (Wn is thus in
26
+ half-cycles / sample). For analog filters, Wn is an angular
27
+ frequency (e.g., rad/s). When Wn is a length-2 sequence,`Wn[0]`
28
+ must be less than `Wn[1]`.
29
+ rp:
30
+ For Chebyshev and elliptic filters, provides the maximum ripple in
31
+ the passband. (dB)
32
+ rs:
33
+ For Chebyshev and elliptic filters, provides the minimum
34
+ attenuation in the stop band. (dB)
35
+ btype:
36
+ The type of filter. Default is 'bandpass'.
37
+ analog:
38
+ When True, return an analog filter, otherwise a digital filter
39
+ is returned.
40
+ ftype:
41
+ The type of IIR filter to design:
42
+
43
+ - Butterworth : 'butter'
44
+ - Chebyshev I : 'cheby1'
45
+ - Chebyshev II : 'cheby2'
46
+ - Cauer/elliptic: 'ellip'
47
+ - Bessel/Thomson: 'bessel's
48
+ fs:
49
+ The sampling frequency of the digital system.
50
+
51
+ Returns:
52
+ Filtered signal on the forward pass.
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ N: int,
58
+ Wn: Union[float, torch.Tensor],
59
+ rs: Union[None, float, torch.Tensor] = None,
60
+ rp: Union[None, float, torch.Tensor] = None,
61
+ btype="band",
62
+ analog=False,
63
+ ftype="butter",
64
+ fs=None,
65
+ ) -> None:
66
+ super().__init__()
67
+
68
+ if isinstance(Wn, torch.Tensor):
69
+ Wn = Wn.numpy()
70
+ if isinstance(rs, torch.Tensor):
71
+ rs = rs.numpy()
72
+ if isinstance(rp, torch.Tensor):
73
+ rp = rp.numpy()
74
+
75
+ b, a = iirfilter(
76
+ N,
77
+ Wn,
78
+ rs=rs,
79
+ rp=rp,
80
+ btype=btype,
81
+ analog=analog,
82
+ ftype=ftype,
83
+ output="ba",
84
+ fs=fs,
85
+ )
86
+ self.register_buffer("b", torch.tensor(b))
87
+ self.register_buffer("a", torch.tensor(a))
88
+
89
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
90
+ r"""
91
+ Apply the filter to the input signal.
92
+
93
+ Args:
94
+ x:
95
+ The input signal to be filtered.
96
+
97
+ Returns:
98
+ The filtered signal.
99
+ """
100
+ return filtfilt(x, self.a, self.b, clamp=False)
@@ -2,8 +2,8 @@ import torch
2
2
  from jaxtyping import Float
3
3
  from torch import Tensor
4
4
 
5
- from ml4gw.types import TimeSeries1to3d
6
- from ml4gw.utils.slicing import unfold_windows
5
+ from ..types import TimeSeries1to3d
6
+ from ..utils.slicing import unfold_windows
7
7
 
8
8
 
9
9
  class ShiftedPearsonCorrelation(torch.nn.Module):
@@ -7,8 +7,8 @@ import torch.nn.functional as F
7
7
  from jaxtyping import Float, Int
8
8
  from torch import Tensor
9
9
 
10
- from ml4gw.transforms.spline_interpolation import SplineInterpolate
11
- from ml4gw.types import FrequencySeries1to3d, TimeSeries1to3d, TimeSeries3d
10
+ from ..types import FrequencySeries1to3d, TimeSeries1to3d, TimeSeries3d
11
+ from .spline_interpolation import SplineInterpolate
12
12
 
13
13
  """
14
14
  All based on https://github.com/gwpy/gwpy/blob/v3.0.8/gwpy/signal/qtransform.py
@@ -4,7 +4,7 @@ import torch
4
4
  from jaxtyping import Float
5
5
  from torch import Tensor
6
6
 
7
- from ml4gw.transforms.transform import FittableTransform
7
+ from .transform import FittableTransform
8
8
 
9
9
 
10
10
  class ChannelWiseScaler(FittableTransform):
@@ -2,9 +2,9 @@ from typing import Optional
2
2
 
3
3
  import torch
4
4
 
5
- from ml4gw.gw import compute_network_snr
6
- from ml4gw.transforms.transform import FittableSpectralTransform
7
- from ml4gw.types import BatchTensor, TimeSeries2d, WaveformTensor
5
+ from ..gw import compute_network_snr
6
+ from ..types import BatchTensor, TimeSeries2d, WaveformTensor
7
+ from .transform import FittableSpectralTransform
8
8
 
9
9
 
10
10
  class SnrRescaler(FittableSpectralTransform):
@@ -4,8 +4,8 @@ import torch
4
4
  from jaxtyping import Float
5
5
  from torch import Tensor
6
6
 
7
- from ml4gw.spectral import fast_spectral_density, spectral_density
8
- from ml4gw.types import FrequencySeries1to3d, TimeSeries1to3d
7
+ from ..spectral import fast_spectral_density, spectral_density
8
+ from ..types import FrequencySeries1to3d, TimeSeries1to3d
9
9
 
10
10
 
11
11
  class SpectralDensity(torch.nn.Module):
@@ -7,7 +7,7 @@ from jaxtyping import Float
7
7
  from torch import Tensor
8
8
  from torchaudio.transforms import Spectrogram
9
9
 
10
- from ml4gw.types import TimeSeries3d
10
+ from ..types import TimeSeries3d
11
11
 
12
12
 
13
13
  class MultiResolutionSpectrogram(torch.nn.Module):
@@ -2,8 +2,8 @@ from typing import Optional
2
2
 
3
3
  import torch
4
4
 
5
- from ml4gw.spectral import spectral_density
6
- from ml4gw.types import FrequencySeries1to3d, TimeSeries1to3d
5
+ from ..spectral import spectral_density
6
+ from ..types import FrequencySeries1to3d, TimeSeries1to3d
7
7
 
8
8
 
9
9
  class FittableTransform(torch.nn.Module):
@@ -4,8 +4,8 @@ import torch
4
4
  from jaxtyping import Float
5
5
  from torch import Tensor
6
6
 
7
- from ml4gw import gw
8
- from ml4gw.types import BatchTensor
7
+ from .. import gw
8
+ from ..types import BatchTensor
9
9
 
10
10
 
11
11
  # TODO: should these live in ml4gw.waveforms submodule?
@@ -2,14 +2,14 @@ from typing import Optional, Union
2
2
 
3
3
  import torch
4
4
 
5
- from ml4gw import spectral
6
- from ml4gw.transforms.transform import FittableSpectralTransform
7
- from ml4gw.types import (
5
+ from .. import spectral
6
+ from ..types import (
8
7
  FrequencySeries1d,
9
8
  FrequencySeries1to3d,
10
9
  TimeSeries1d,
11
10
  TimeSeries3d,
12
11
  )
12
+ from .transform import FittableSpectralTransform
13
13
 
14
14
 
15
15
  class Whiten(torch.nn.Module):
@@ -45,6 +45,10 @@ class Whiten(torch.nn.Module):
45
45
  Cutoff frequency to apply highpass filtering
46
46
  during whitening. If left as `None`, no highpass
47
47
  filtering will be performed.
48
+ lowpass:
49
+ Cutoff frequency to apply lowpass filtering
50
+ during whitening. If left as `None`, no lowpass
51
+ filtering will be performed.
48
52
  """
49
53
 
50
54
  def __init__(
@@ -52,11 +56,13 @@ class Whiten(torch.nn.Module):
52
56
  fduration: float,
53
57
  sample_rate: float,
54
58
  highpass: Optional[float] = None,
59
+ lowpass: Optional[float] = None,
55
60
  ) -> None:
56
61
  super().__init__()
57
62
  self.fduration = fduration
58
63
  self.sample_rate = sample_rate
59
64
  self.highpass = highpass
65
+ self.lowpass = lowpass
60
66
 
61
67
  # register a window up front to signify our
62
68
  # fduration at inference time
@@ -104,6 +110,7 @@ class Whiten(torch.nn.Module):
104
110
  fduration=self.window,
105
111
  sample_rate=self.sample_rate,
106
112
  highpass=self.highpass,
113
+ lowpass=self.lowpass,
107
114
  )
108
115
 
109
116
 
@@ -153,6 +160,7 @@ class FixedWhiten(FittableSpectralTransform):
153
160
  *background: Union[TimeSeries1d, FrequencySeries1d],
154
161
  fftlength: Optional[float] = None,
155
162
  highpass: Optional[float] = None,
163
+ lowpass: Optional[float] = None,
156
164
  overlap: Optional[float] = None
157
165
  ) -> None:
158
166
  """
@@ -200,6 +208,13 @@ class FixedWhiten(FittableSpectralTransform):
200
208
  in the frequency bins below this value to 0.
201
209
  If left as `None`, the fit filter won't have any
202
210
  highpass filtering properties.
211
+ lowpass:
212
+ Cutoff frequency, in Hz, used for lowpass filtering
213
+ with the fit whitening filter. This is achieved by
214
+ setting the frequency response of the fit PSDs
215
+ in the frequency bins above this value to 0.
216
+ If left as `None`, the fit filter won't have any
217
+ lowpass filtering properties.
203
218
  overlap:
204
219
  Overlap between FFT frames used to convert
205
220
  time-domain data to the frequency domain via
@@ -224,7 +239,7 @@ class FixedWhiten(FittableSpectralTransform):
224
239
  x = x.view(1, 1, -1)
225
240
 
226
241
  psd = spectral.truncate_inverse_power_spectrum(
227
- x, fduration, self.sample_rate, highpass
242
+ x, fduration, self.sample_rate, highpass, lowpass
228
243
  )
229
244
  psds.append(psd[0, 0])
230
245
  psd = torch.stack(psds)
ml4gw/utils/slicing.py CHANGED
@@ -5,12 +5,7 @@ from jaxtyping import Float, Int64
5
5
  from torch import Tensor
6
6
  from torch.nn.functional import unfold
7
7
 
8
- from ml4gw.types import (
9
- TimeSeries1d,
10
- TimeSeries1to3d,
11
- TimeSeries2d,
12
- TimeSeries3d,
13
- )
8
+ from ..types import TimeSeries1d, TimeSeries1to3d, TimeSeries2d, TimeSeries3d
14
9
 
15
10
  BatchTimeSeriesTensor = Union[Float[Tensor, "batch time"], TimeSeries3d]
16
11
 
@@ -0,0 +1,35 @@
1
+ import torch
2
+
3
+ from ml4gw.constants import C, G
4
+ from ml4gw.types import BatchTensor
5
+
6
+
7
+ def taylor_t2_timing_0pn_coeff(total_mass: BatchTensor, eta: BatchTensor):
8
+ """
9
+ https://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/lib/LALSimInspiralPNCoefficients.c#L1528
10
+ """
11
+
12
+ output = total_mass * G / C**3
13
+ return -5.0 * output / (256.0 * eta)
14
+
15
+
16
+ def taylor_t2_timing_2pn_coeff(eta: BatchTensor):
17
+ """
18
+ https://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/lib/LALSimInspiralPNCoefficients.c#L1545
19
+ """
20
+ return 7.43 / 2.52 + 11.0 / 3.0 * eta
21
+
22
+
23
+ def taylor_t2_timing_4pn_coeff(eta: BatchTensor):
24
+ """
25
+ https://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/lib/LALSimInspiralPNCoefficients.c#L1560
26
+ """
27
+ return 30.58673 / 5.08032 + 54.29 / 5.04 * eta + 61.7 / 7.2 * eta**2
28
+
29
+
30
+ def taylor_t3_frequency_0pn_coeff(total_mass: BatchTensor):
31
+ """
32
+ https://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/lib/LALSimInspiralPNCoefficients.c#L1723
33
+ """
34
+ output = total_mass * G / C**3.0
35
+ return 1.0 / (8.0 * torch.pi * output)
@@ -1,9 +1,8 @@
1
1
  import torch
2
2
  from jaxtyping import Float
3
3
 
4
- from ml4gw.constants import MTSUN_SI, PI
5
- from ml4gw.types import BatchTensor, FrequencySeries1d
6
-
4
+ from ...constants import MTSUN_SI, PI
5
+ from ...types import BatchTensor, FrequencySeries1d
7
6
  from .phenom_d_data import QNMData_a, QNMData_fdamp, QNMData_fring
8
7
  from .taylorf2 import TaylorF2
9
8
 
@@ -26,6 +25,7 @@ class IMRPhenomD(TaylorF2):
26
25
  phic: BatchTensor,
27
26
  inclination: BatchTensor,
28
27
  f_ref: float,
28
+ **kwargs
29
29
  ):
30
30
  """
31
31
  IMRPhenomD waveform
@@ -39,6 +39,7 @@ class IMRPhenomPv2(IMRPhenomD):
39
39
  inclination: BatchTensor,
40
40
  f_ref: float,
41
41
  tc: Optional[BatchTensor] = None,
42
+ **kwargs,
42
43
  ):
43
44
  """
44
45
  IMRPhenomPv2 waveform