ml4gw 0.6.2__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ml4gw might be problematic. Click here for more details.
- ml4gw/__init__.py +1 -0
- ml4gw/dataloading/chunked_dataset.py +1 -1
- ml4gw/dataloading/hdf5_dataset.py +36 -6
- ml4gw/dataloading/in_memory_dataset.py +1 -1
- ml4gw/gw.py +4 -3
- ml4gw/nn/autoencoder/base.py +1 -1
- ml4gw/nn/autoencoder/convolutional.py +3 -3
- ml4gw/nn/autoencoder/skip_connection.py +1 -1
- ml4gw/nn/resnet/resnet_1d.py +1 -1
- ml4gw/nn/resnet/resnet_2d.py +1 -1
- ml4gw/nn/streaming/online_average.py +1 -1
- ml4gw/nn/streaming/snapshotter.py +1 -1
- ml4gw/spectral.py +24 -6
- ml4gw/transforms/__init__.py +1 -0
- ml4gw/transforms/iirfilter.py +100 -0
- ml4gw/transforms/pearson.py +2 -2
- ml4gw/transforms/qtransform.py +2 -2
- ml4gw/transforms/scaler.py +1 -1
- ml4gw/transforms/snr_rescaler.py +3 -3
- ml4gw/transforms/spectral.py +2 -2
- ml4gw/transforms/spectrogram.py +1 -1
- ml4gw/transforms/transform.py +2 -2
- ml4gw/transforms/waveforms.py +2 -2
- ml4gw/transforms/whitening.py +19 -4
- ml4gw/utils/slicing.py +1 -6
- ml4gw/waveforms/cbc/coefficients.py +35 -0
- ml4gw/waveforms/cbc/phenom_d.py +3 -3
- ml4gw/waveforms/cbc/phenom_p.py +4 -1
- ml4gw/waveforms/cbc/taylorf2.py +5 -4
- ml4gw/waveforms/cbc/utils.py +111 -0
- ml4gw/waveforms/conversion.py +2 -2
- ml4gw/waveforms/generator.py +289 -26
- ml4gw-0.7.0.dist-info/LICENSE +674 -0
- ml4gw-0.7.0.dist-info/METADATA +78 -0
- ml4gw-0.7.0.dist-info/RECORD +55 -0
- {ml4gw-0.6.2.dist-info → ml4gw-0.7.0.dist-info}/WHEEL +1 -1
- ml4gw-0.6.2.dist-info/METADATA +0 -155
- ml4gw-0.6.2.dist-info/RECORD +0 -51
ml4gw/__init__.py
CHANGED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .constants import *
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
import warnings
|
|
2
|
-
from typing import Sequence, Union
|
|
2
|
+
from typing import Optional, Sequence, Union
|
|
3
3
|
|
|
4
4
|
import h5py
|
|
5
5
|
import numpy as np
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from
|
|
8
|
+
from ..types import WaveformTensor
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class ContiguousHdf5Warning(Warning):
|
|
@@ -50,6 +50,13 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
|
|
|
50
50
|
channel. The latter setting limits the amount of
|
|
51
51
|
entropy in the effective dataset, but can provide
|
|
52
52
|
over 2x improvement in total throughput.
|
|
53
|
+
num_files_per_batch:
|
|
54
|
+
The number of unique files from which to sample
|
|
55
|
+
batch elements each epoch. If left as `None`,
|
|
56
|
+
will use all available files. Useful when reading
|
|
57
|
+
from many files is bottlenecking dataloading.
|
|
58
|
+
|
|
59
|
+
|
|
53
60
|
"""
|
|
54
61
|
|
|
55
62
|
def __init__(
|
|
@@ -60,6 +67,7 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
|
|
|
60
67
|
batch_size: int,
|
|
61
68
|
batches_per_epoch: int,
|
|
62
69
|
coincident: Union[bool, str],
|
|
70
|
+
num_files_per_batch: Optional[int] = None,
|
|
63
71
|
) -> None:
|
|
64
72
|
if not isinstance(coincident, bool) and coincident != "files":
|
|
65
73
|
raise ValueError(
|
|
@@ -67,13 +75,21 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
|
|
|
67
75
|
"got unrecognized value {}".format(coincident)
|
|
68
76
|
)
|
|
69
77
|
|
|
70
|
-
self.fnames = fnames
|
|
78
|
+
self.fnames = np.array(fnames)
|
|
71
79
|
self.channels = channels
|
|
72
80
|
self.num_channels = len(channels)
|
|
73
81
|
self.kernel_size = kernel_size
|
|
74
82
|
self.batch_size = batch_size
|
|
75
83
|
self.batches_per_epoch = batches_per_epoch
|
|
76
84
|
self.coincident = coincident
|
|
85
|
+
self.num_files_per_batch = (
|
|
86
|
+
len(fnames) if num_files_per_batch is None else num_files_per_batch
|
|
87
|
+
)
|
|
88
|
+
if self.num_files_per_batch > len(fnames):
|
|
89
|
+
raise ValueError(
|
|
90
|
+
f"Number of files per batch ({self.num_files_per_batch}) "
|
|
91
|
+
f"cannot exceed number of files ({len(fnames)}) "
|
|
92
|
+
)
|
|
77
93
|
|
|
78
94
|
self.sizes = {}
|
|
79
95
|
for fname in self.fnames:
|
|
@@ -85,13 +101,14 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
|
|
|
85
101
|
"without using chunked storage. This can have "
|
|
86
102
|
"severe performance impacts at data loading time. "
|
|
87
103
|
"If you need faster loading, try re-generating "
|
|
88
|
-
"your
|
|
104
|
+
"your dataset with chunked storage turned on.".format(
|
|
89
105
|
fname
|
|
90
106
|
),
|
|
91
107
|
category=ContiguousHdf5Warning,
|
|
92
108
|
)
|
|
93
109
|
|
|
94
110
|
self.sizes[fname] = len(dset)
|
|
111
|
+
|
|
95
112
|
total = sum(self.sizes.values())
|
|
96
113
|
self.probs = np.array([i / total for i in self.sizes.values()])
|
|
97
114
|
|
|
@@ -99,9 +116,22 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
|
|
|
99
116
|
return self.batches_per_epoch
|
|
100
117
|
|
|
101
118
|
def sample_fnames(self, size) -> np.ndarray:
|
|
102
|
-
|
|
103
|
-
|
|
119
|
+
# first, randomly select `self.num_files_per_batch`
|
|
120
|
+
# file indices based on their probabilities
|
|
121
|
+
fname_indices = np.arange(len(self.fnames))
|
|
122
|
+
fname_indices = np.random.choice(
|
|
123
|
+
fname_indices,
|
|
104
124
|
p=self.probs,
|
|
125
|
+
size=(self.num_files_per_batch),
|
|
126
|
+
replace=False,
|
|
127
|
+
)
|
|
128
|
+
# now renormalize the probabilities, and sample
|
|
129
|
+
# the requested size from this subset of files
|
|
130
|
+
probs = self.probs[fname_indices]
|
|
131
|
+
probs /= probs.sum()
|
|
132
|
+
return np.random.choice(
|
|
133
|
+
self.fnames[fname_indices],
|
|
134
|
+
p=probs,
|
|
105
135
|
size=size,
|
|
106
136
|
replace=True,
|
|
107
137
|
)
|
ml4gw/gw.py
CHANGED
|
@@ -16,8 +16,10 @@ import torch
|
|
|
16
16
|
from jaxtyping import Float
|
|
17
17
|
from torch import Tensor
|
|
18
18
|
|
|
19
|
-
from ml4gw.
|
|
20
|
-
|
|
19
|
+
from ml4gw.utils.interferometer import InterferometerGeometry
|
|
20
|
+
|
|
21
|
+
from .constants import C
|
|
22
|
+
from .types import (
|
|
21
23
|
BatchTensor,
|
|
22
24
|
NetworkDetectorTensors,
|
|
23
25
|
NetworkVertices,
|
|
@@ -26,7 +28,6 @@ from ml4gw.types import (
|
|
|
26
28
|
VectorGeometry,
|
|
27
29
|
WaveformTensor,
|
|
28
30
|
)
|
|
29
|
-
from ml4gw.utils.interferometer import InterferometerGeometry
|
|
30
31
|
|
|
31
32
|
|
|
32
33
|
def outer(x: VectorGeometry, y: VectorGeometry) -> TensorGeometry:
|
ml4gw/nn/autoencoder/base.py
CHANGED
|
@@ -4,9 +4,9 @@ from typing import Optional
|
|
|
4
4
|
import torch
|
|
5
5
|
from torch import Tensor
|
|
6
6
|
|
|
7
|
-
from
|
|
8
|
-
from
|
|
9
|
-
from
|
|
7
|
+
from .base import Autoencoder
|
|
8
|
+
from .skip_connection import SkipConnection
|
|
9
|
+
from .utils import match_size
|
|
10
10
|
|
|
11
11
|
Module = Callable[[...], torch.nn.Module]
|
|
12
12
|
|
ml4gw/nn/resnet/resnet_1d.py
CHANGED
ml4gw/nn/resnet/resnet_2d.py
CHANGED
ml4gw/spectral.py
CHANGED
|
@@ -15,7 +15,7 @@ import torch
|
|
|
15
15
|
from jaxtyping import Float
|
|
16
16
|
from torch import Tensor
|
|
17
17
|
|
|
18
|
-
from
|
|
18
|
+
from .types import (
|
|
19
19
|
FrequencySeries1to3d,
|
|
20
20
|
PSDTensor,
|
|
21
21
|
TimeSeries1to3d,
|
|
@@ -343,6 +343,7 @@ def truncate_inverse_power_spectrum(
|
|
|
343
343
|
fduration: Union[Float[Tensor, " time"], float],
|
|
344
344
|
sample_rate: float,
|
|
345
345
|
highpass: Optional[float] = None,
|
|
346
|
+
lowpass: Optional[float] = None,
|
|
346
347
|
) -> PSDTensor:
|
|
347
348
|
"""
|
|
348
349
|
Truncate the length of the time domain response
|
|
@@ -375,6 +376,10 @@ def truncate_inverse_power_spectrum(
|
|
|
375
376
|
If specified, will zero out the frequency response
|
|
376
377
|
of all frequencies below this value in Hz. If left
|
|
377
378
|
as `None`, no highpass filtering will be applied.
|
|
379
|
+
lowpass:
|
|
380
|
+
If specified, will zero out the frequency response
|
|
381
|
+
of all frequencies above this value in Hz. If left
|
|
382
|
+
as `None`, no lowpass filtering will be applied.
|
|
378
383
|
Returns:
|
|
379
384
|
The PSD with its time domain response truncated
|
|
380
385
|
to `fduration` and any highpassed frequencies
|
|
@@ -388,12 +393,15 @@ def truncate_inverse_power_spectrum(
|
|
|
388
393
|
# impulse response function
|
|
389
394
|
inv_asd = 1 / psd**0.5
|
|
390
395
|
|
|
391
|
-
# zero
|
|
392
|
-
#
|
|
396
|
+
# zero out frequencies if we want the filter
|
|
397
|
+
# to perform highpass/lowpass filtering
|
|
398
|
+
df = sample_rate / N
|
|
393
399
|
if highpass is not None:
|
|
394
|
-
df = sample_rate / N
|
|
395
400
|
idx = int(highpass / df)
|
|
396
401
|
inv_asd[:, :, :idx] = 0
|
|
402
|
+
if lowpass is not None:
|
|
403
|
+
idx = int(lowpass / df)
|
|
404
|
+
inv_asd[:, :, idx:] = 0
|
|
397
405
|
|
|
398
406
|
if inv_asd.size(-1) % 2:
|
|
399
407
|
inv_asd[:, :, -1] = 0
|
|
@@ -455,12 +463,13 @@ def whiten(
|
|
|
455
463
|
fduration: Union[Float[Tensor, " time"], float],
|
|
456
464
|
sample_rate: float,
|
|
457
465
|
highpass: Optional[float] = None,
|
|
466
|
+
lowpass: Optional[float] = None,
|
|
458
467
|
) -> WaveformTensor:
|
|
459
468
|
"""
|
|
460
469
|
Whiten a batch of timeseries using the specified
|
|
461
470
|
background one-sided power spectral densities (PSDs),
|
|
462
471
|
modified to have the desired time domain response length
|
|
463
|
-
`fduration` and possibly to highpass filter.
|
|
472
|
+
`fduration` and possibly to highpass/lowpass filter.
|
|
464
473
|
|
|
465
474
|
Args:
|
|
466
475
|
X:
|
|
@@ -493,6 +502,11 @@ def whiten(
|
|
|
493
502
|
the data, setting the frequency response in the
|
|
494
503
|
whitening filter to 0. If left as `None`, no
|
|
495
504
|
highpass filtering will be applied.
|
|
505
|
+
lowpass:
|
|
506
|
+
The frequency in Hz at which to lowpass filter
|
|
507
|
+
the data, setting the frequency response in the
|
|
508
|
+
whitening filter to 0. If left as `None`, no
|
|
509
|
+
lowpass filtering will be applied.
|
|
496
510
|
Returns:
|
|
497
511
|
Batch of whitened multichannel timeseries with
|
|
498
512
|
`fduration / 2` seconds trimmed from each side.
|
|
@@ -529,7 +543,11 @@ def whiten(
|
|
|
529
543
|
# truncate it to have the desired
|
|
530
544
|
# time domain response length
|
|
531
545
|
psd = truncate_inverse_power_spectrum(
|
|
532
|
-
psd,
|
|
546
|
+
psd,
|
|
547
|
+
fduration,
|
|
548
|
+
sample_rate,
|
|
549
|
+
highpass,
|
|
550
|
+
lowpass,
|
|
533
551
|
)
|
|
534
552
|
|
|
535
553
|
return normalize_by_psd(X, psd, sample_rate, pad)
|
ml4gw/transforms/__init__.py
CHANGED
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
from typing import Union
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from scipy.signal import iirfilter
|
|
5
|
+
from torchaudio.functional import filtfilt
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class IIRFilter(torch.nn.Module):
|
|
9
|
+
r"""
|
|
10
|
+
IIR digital and analog filter design given order and critical points.
|
|
11
|
+
Design an Nth-order digital or analog filter and apply it to a signal.
|
|
12
|
+
Uses SciPy's `iirfilter` function to create the filter coefficients.
|
|
13
|
+
https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirfilter.html # noqa E501
|
|
14
|
+
|
|
15
|
+
The forward call of this module accepts a batch tensor of shape
|
|
16
|
+
(n_waveforms, n_samples) and returns the filtered waveforms.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
N:
|
|
20
|
+
The order of the filter.
|
|
21
|
+
Wn:
|
|
22
|
+
A scalar or length-2 sequence giving the critical frequencies.
|
|
23
|
+
For digital filters, Wn are in the same units as fs. By
|
|
24
|
+
default, fs is 2 half-cycles/sample, so these are normalized
|
|
25
|
+
from 0 to 1, where 1 is the Nyquist frequency. (Wn is thus in
|
|
26
|
+
half-cycles / sample). For analog filters, Wn is an angular
|
|
27
|
+
frequency (e.g., rad/s). When Wn is a length-2 sequence,`Wn[0]`
|
|
28
|
+
must be less than `Wn[1]`.
|
|
29
|
+
rp:
|
|
30
|
+
For Chebyshev and elliptic filters, provides the maximum ripple in
|
|
31
|
+
the passband. (dB)
|
|
32
|
+
rs:
|
|
33
|
+
For Chebyshev and elliptic filters, provides the minimum
|
|
34
|
+
attenuation in the stop band. (dB)
|
|
35
|
+
btype:
|
|
36
|
+
The type of filter. Default is 'bandpass'.
|
|
37
|
+
analog:
|
|
38
|
+
When True, return an analog filter, otherwise a digital filter
|
|
39
|
+
is returned.
|
|
40
|
+
ftype:
|
|
41
|
+
The type of IIR filter to design:
|
|
42
|
+
|
|
43
|
+
- Butterworth : 'butter'
|
|
44
|
+
- Chebyshev I : 'cheby1'
|
|
45
|
+
- Chebyshev II : 'cheby2'
|
|
46
|
+
- Cauer/elliptic: 'ellip'
|
|
47
|
+
- Bessel/Thomson: 'bessel's
|
|
48
|
+
fs:
|
|
49
|
+
The sampling frequency of the digital system.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
Filtered signal on the forward pass.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
N: int,
|
|
58
|
+
Wn: Union[float, torch.Tensor],
|
|
59
|
+
rs: Union[None, float, torch.Tensor] = None,
|
|
60
|
+
rp: Union[None, float, torch.Tensor] = None,
|
|
61
|
+
btype="band",
|
|
62
|
+
analog=False,
|
|
63
|
+
ftype="butter",
|
|
64
|
+
fs=None,
|
|
65
|
+
) -> None:
|
|
66
|
+
super().__init__()
|
|
67
|
+
|
|
68
|
+
if isinstance(Wn, torch.Tensor):
|
|
69
|
+
Wn = Wn.numpy()
|
|
70
|
+
if isinstance(rs, torch.Tensor):
|
|
71
|
+
rs = rs.numpy()
|
|
72
|
+
if isinstance(rp, torch.Tensor):
|
|
73
|
+
rp = rp.numpy()
|
|
74
|
+
|
|
75
|
+
b, a = iirfilter(
|
|
76
|
+
N,
|
|
77
|
+
Wn,
|
|
78
|
+
rs=rs,
|
|
79
|
+
rp=rp,
|
|
80
|
+
btype=btype,
|
|
81
|
+
analog=analog,
|
|
82
|
+
ftype=ftype,
|
|
83
|
+
output="ba",
|
|
84
|
+
fs=fs,
|
|
85
|
+
)
|
|
86
|
+
self.register_buffer("b", torch.tensor(b))
|
|
87
|
+
self.register_buffer("a", torch.tensor(a))
|
|
88
|
+
|
|
89
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
90
|
+
r"""
|
|
91
|
+
Apply the filter to the input signal.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
x:
|
|
95
|
+
The input signal to be filtered.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
The filtered signal.
|
|
99
|
+
"""
|
|
100
|
+
return filtfilt(x, self.a, self.b, clamp=False)
|
ml4gw/transforms/pearson.py
CHANGED
|
@@ -2,8 +2,8 @@ import torch
|
|
|
2
2
|
from jaxtyping import Float
|
|
3
3
|
from torch import Tensor
|
|
4
4
|
|
|
5
|
-
from
|
|
6
|
-
from
|
|
5
|
+
from ..types import TimeSeries1to3d
|
|
6
|
+
from ..utils.slicing import unfold_windows
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class ShiftedPearsonCorrelation(torch.nn.Module):
|
ml4gw/transforms/qtransform.py
CHANGED
|
@@ -7,8 +7,8 @@ import torch.nn.functional as F
|
|
|
7
7
|
from jaxtyping import Float, Int
|
|
8
8
|
from torch import Tensor
|
|
9
9
|
|
|
10
|
-
from
|
|
11
|
-
from
|
|
10
|
+
from ..types import FrequencySeries1to3d, TimeSeries1to3d, TimeSeries3d
|
|
11
|
+
from .spline_interpolation import SplineInterpolate
|
|
12
12
|
|
|
13
13
|
"""
|
|
14
14
|
All based on https://github.com/gwpy/gwpy/blob/v3.0.8/gwpy/signal/qtransform.py
|
ml4gw/transforms/scaler.py
CHANGED
ml4gw/transforms/snr_rescaler.py
CHANGED
|
@@ -2,9 +2,9 @@ from typing import Optional
|
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
-
from
|
|
6
|
-
from
|
|
7
|
-
from
|
|
5
|
+
from ..gw import compute_network_snr
|
|
6
|
+
from ..types import BatchTensor, TimeSeries2d, WaveformTensor
|
|
7
|
+
from .transform import FittableSpectralTransform
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
class SnrRescaler(FittableSpectralTransform):
|
ml4gw/transforms/spectral.py
CHANGED
|
@@ -4,8 +4,8 @@ import torch
|
|
|
4
4
|
from jaxtyping import Float
|
|
5
5
|
from torch import Tensor
|
|
6
6
|
|
|
7
|
-
from
|
|
8
|
-
from
|
|
7
|
+
from ..spectral import fast_spectral_density, spectral_density
|
|
8
|
+
from ..types import FrequencySeries1to3d, TimeSeries1to3d
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class SpectralDensity(torch.nn.Module):
|
ml4gw/transforms/spectrogram.py
CHANGED
ml4gw/transforms/transform.py
CHANGED
|
@@ -2,8 +2,8 @@ from typing import Optional
|
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
-
from
|
|
6
|
-
from
|
|
5
|
+
from ..spectral import spectral_density
|
|
6
|
+
from ..types import FrequencySeries1to3d, TimeSeries1to3d
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class FittableTransform(torch.nn.Module):
|
ml4gw/transforms/waveforms.py
CHANGED
ml4gw/transforms/whitening.py
CHANGED
|
@@ -2,14 +2,14 @@ from typing import Optional, Union
|
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
-
from
|
|
6
|
-
from
|
|
7
|
-
from ml4gw.types import (
|
|
5
|
+
from .. import spectral
|
|
6
|
+
from ..types import (
|
|
8
7
|
FrequencySeries1d,
|
|
9
8
|
FrequencySeries1to3d,
|
|
10
9
|
TimeSeries1d,
|
|
11
10
|
TimeSeries3d,
|
|
12
11
|
)
|
|
12
|
+
from .transform import FittableSpectralTransform
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
class Whiten(torch.nn.Module):
|
|
@@ -45,6 +45,10 @@ class Whiten(torch.nn.Module):
|
|
|
45
45
|
Cutoff frequency to apply highpass filtering
|
|
46
46
|
during whitening. If left as `None`, no highpass
|
|
47
47
|
filtering will be performed.
|
|
48
|
+
lowpass:
|
|
49
|
+
Cutoff frequency to apply lowpass filtering
|
|
50
|
+
during whitening. If left as `None`, no lowpass
|
|
51
|
+
filtering will be performed.
|
|
48
52
|
"""
|
|
49
53
|
|
|
50
54
|
def __init__(
|
|
@@ -52,11 +56,13 @@ class Whiten(torch.nn.Module):
|
|
|
52
56
|
fduration: float,
|
|
53
57
|
sample_rate: float,
|
|
54
58
|
highpass: Optional[float] = None,
|
|
59
|
+
lowpass: Optional[float] = None,
|
|
55
60
|
) -> None:
|
|
56
61
|
super().__init__()
|
|
57
62
|
self.fduration = fduration
|
|
58
63
|
self.sample_rate = sample_rate
|
|
59
64
|
self.highpass = highpass
|
|
65
|
+
self.lowpass = lowpass
|
|
60
66
|
|
|
61
67
|
# register a window up front to signify our
|
|
62
68
|
# fduration at inference time
|
|
@@ -104,6 +110,7 @@ class Whiten(torch.nn.Module):
|
|
|
104
110
|
fduration=self.window,
|
|
105
111
|
sample_rate=self.sample_rate,
|
|
106
112
|
highpass=self.highpass,
|
|
113
|
+
lowpass=self.lowpass,
|
|
107
114
|
)
|
|
108
115
|
|
|
109
116
|
|
|
@@ -153,6 +160,7 @@ class FixedWhiten(FittableSpectralTransform):
|
|
|
153
160
|
*background: Union[TimeSeries1d, FrequencySeries1d],
|
|
154
161
|
fftlength: Optional[float] = None,
|
|
155
162
|
highpass: Optional[float] = None,
|
|
163
|
+
lowpass: Optional[float] = None,
|
|
156
164
|
overlap: Optional[float] = None
|
|
157
165
|
) -> None:
|
|
158
166
|
"""
|
|
@@ -200,6 +208,13 @@ class FixedWhiten(FittableSpectralTransform):
|
|
|
200
208
|
in the frequency bins below this value to 0.
|
|
201
209
|
If left as `None`, the fit filter won't have any
|
|
202
210
|
highpass filtering properties.
|
|
211
|
+
lowpass:
|
|
212
|
+
Cutoff frequency, in Hz, used for lowpass filtering
|
|
213
|
+
with the fit whitening filter. This is achieved by
|
|
214
|
+
setting the frequency response of the fit PSDs
|
|
215
|
+
in the frequency bins above this value to 0.
|
|
216
|
+
If left as `None`, the fit filter won't have any
|
|
217
|
+
lowpass filtering properties.
|
|
203
218
|
overlap:
|
|
204
219
|
Overlap between FFT frames used to convert
|
|
205
220
|
time-domain data to the frequency domain via
|
|
@@ -224,7 +239,7 @@ class FixedWhiten(FittableSpectralTransform):
|
|
|
224
239
|
x = x.view(1, 1, -1)
|
|
225
240
|
|
|
226
241
|
psd = spectral.truncate_inverse_power_spectrum(
|
|
227
|
-
x, fduration, self.sample_rate, highpass
|
|
242
|
+
x, fduration, self.sample_rate, highpass, lowpass
|
|
228
243
|
)
|
|
229
244
|
psds.append(psd[0, 0])
|
|
230
245
|
psd = torch.stack(psds)
|
ml4gw/utils/slicing.py
CHANGED
|
@@ -5,12 +5,7 @@ from jaxtyping import Float, Int64
|
|
|
5
5
|
from torch import Tensor
|
|
6
6
|
from torch.nn.functional import unfold
|
|
7
7
|
|
|
8
|
-
from
|
|
9
|
-
TimeSeries1d,
|
|
10
|
-
TimeSeries1to3d,
|
|
11
|
-
TimeSeries2d,
|
|
12
|
-
TimeSeries3d,
|
|
13
|
-
)
|
|
8
|
+
from ..types import TimeSeries1d, TimeSeries1to3d, TimeSeries2d, TimeSeries3d
|
|
14
9
|
|
|
15
10
|
BatchTimeSeriesTensor = Union[Float[Tensor, "batch time"], TimeSeries3d]
|
|
16
11
|
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from ml4gw.constants import C, G
|
|
4
|
+
from ml4gw.types import BatchTensor
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def taylor_t2_timing_0pn_coeff(total_mass: BatchTensor, eta: BatchTensor):
|
|
8
|
+
"""
|
|
9
|
+
https://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/lib/LALSimInspiralPNCoefficients.c#L1528
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
output = total_mass * G / C**3
|
|
13
|
+
return -5.0 * output / (256.0 * eta)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def taylor_t2_timing_2pn_coeff(eta: BatchTensor):
|
|
17
|
+
"""
|
|
18
|
+
https://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/lib/LALSimInspiralPNCoefficients.c#L1545
|
|
19
|
+
"""
|
|
20
|
+
return 7.43 / 2.52 + 11.0 / 3.0 * eta
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def taylor_t2_timing_4pn_coeff(eta: BatchTensor):
|
|
24
|
+
"""
|
|
25
|
+
https://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/lib/LALSimInspiralPNCoefficients.c#L1560
|
|
26
|
+
"""
|
|
27
|
+
return 30.58673 / 5.08032 + 54.29 / 5.04 * eta + 61.7 / 7.2 * eta**2
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def taylor_t3_frequency_0pn_coeff(total_mass: BatchTensor):
|
|
31
|
+
"""
|
|
32
|
+
https://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/lib/LALSimInspiralPNCoefficients.c#L1723
|
|
33
|
+
"""
|
|
34
|
+
output = total_mass * G / C**3.0
|
|
35
|
+
return 1.0 / (8.0 * torch.pi * output)
|
ml4gw/waveforms/cbc/phenom_d.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
from jaxtyping import Float
|
|
3
3
|
|
|
4
|
-
from
|
|
5
|
-
from
|
|
6
|
-
|
|
4
|
+
from ...constants import MTSUN_SI, PI
|
|
5
|
+
from ...types import BatchTensor, FrequencySeries1d
|
|
7
6
|
from .phenom_d_data import QNMData_a, QNMData_fdamp, QNMData_fring
|
|
8
7
|
from .taylorf2 import TaylorF2
|
|
9
8
|
|
|
@@ -26,6 +25,7 @@ class IMRPhenomD(TaylorF2):
|
|
|
26
25
|
phic: BatchTensor,
|
|
27
26
|
inclination: BatchTensor,
|
|
28
27
|
f_ref: float,
|
|
28
|
+
**kwargs
|
|
29
29
|
):
|
|
30
30
|
"""
|
|
31
31
|
IMRPhenomD waveform
|
ml4gw/waveforms/cbc/phenom_p.py
CHANGED
|
@@ -39,6 +39,7 @@ class IMRPhenomPv2(IMRPhenomD):
|
|
|
39
39
|
inclination: BatchTensor,
|
|
40
40
|
f_ref: float,
|
|
41
41
|
tc: Optional[BatchTensor] = None,
|
|
42
|
+
**kwargs,
|
|
42
43
|
):
|
|
43
44
|
"""
|
|
44
45
|
IMRPhenomPv2 waveform
|
|
@@ -382,7 +383,9 @@ class IMRPhenomPv2(IMRPhenomD):
|
|
|
382
383
|
# reshape x to have same shape as diffRDphase
|
|
383
384
|
x = x[1:-1].unsqueeze(0).expand(diffRDphase.shape)
|
|
384
385
|
# interpolate at x = 1, as thats the same as f = fRD
|
|
385
|
-
diffRDphase = -self.interpolate(
|
|
386
|
+
diffRDphase = -self.interpolate(
|
|
387
|
+
torch.tensor([1], device=x.device), x, diffRDphase
|
|
388
|
+
)
|
|
386
389
|
return hPhenom, diffRDphase
|
|
387
390
|
|
|
388
391
|
# Utility functions
|