ml4gw 0.4.0__tar.gz → 0.4.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ml4gw might be problematic. Click here for more details.
- {ml4gw-0.4.0 → ml4gw-0.4.2}/PKG-INFO +3 -5
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/dataloading/in_memory_dataset.py +18 -32
- ml4gw-0.4.2/ml4gw/distributions.py +166 -0
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/transforms/__init__.py +1 -0
- ml4gw-0.4.2/ml4gw/transforms/qtransform.py +463 -0
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/waveforms/generator.py +1 -1
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/waveforms/phenom_d.py +51 -49
- {ml4gw-0.4.0 → ml4gw-0.4.2}/pyproject.toml +4 -9
- ml4gw-0.4.0/ml4gw/distributions.py +0 -122
- {ml4gw-0.4.0 → ml4gw-0.4.2}/README.md +0 -0
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/__init__.py +0 -0
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/augmentations.py +0 -0
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/dataloading/__init__.py +0 -0
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/dataloading/chunked_dataset.py +0 -0
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/dataloading/hdf5_dataset.py +0 -0
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/gw.py +0 -0
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/nn/__init__.py +0 -0
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/nn/autoencoder/__init__.py +0 -0
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/nn/autoencoder/base.py +0 -0
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/nn/autoencoder/convolutional.py +0 -0
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/nn/autoencoder/skip_connection.py +0 -0
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/nn/autoencoder/utils.py +0 -0
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/nn/norm.py +0 -0
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/nn/resnet/__init__.py +0 -0
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/nn/resnet/resnet_1d.py +0 -0
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/nn/resnet/resnet_2d.py +0 -0
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/nn/streaming/__init__.py +0 -0
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/nn/streaming/online_average.py +0 -0
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/nn/streaming/snapshotter.py +0 -0
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/spectral.py +0 -0
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/transforms/pearson.py +0 -0
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/transforms/scaler.py +0 -0
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/transforms/snr_rescaler.py +0 -0
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/transforms/spectral.py +0 -0
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/transforms/spectrogram.py +0 -0
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/transforms/transform.py +0 -0
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/transforms/waveforms.py +0 -0
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/transforms/whitening.py +0 -0
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/types.py +0 -0
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/utils/interferometer.py +0 -0
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/utils/slicing.py +0 -0
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/waveforms/__init__.py +0 -0
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/waveforms/phenom_d_data.py +0 -0
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/waveforms/sine_gaussian.py +0 -0
- {ml4gw-0.4.0 → ml4gw-0.4.2}/ml4gw/waveforms/taylorf2.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ml4gw
|
|
3
|
-
Version: 0.4.
|
|
3
|
+
Version: 0.4.2
|
|
4
4
|
Summary: Tools for training torch models on gravitational wave data
|
|
5
5
|
Author: Alec Gunny
|
|
6
6
|
Author-email: alec.gunny@ligo.org
|
|
@@ -10,10 +10,8 @@ Classifier: Programming Language :: Python :: 3.8
|
|
|
10
10
|
Classifier: Programming Language :: Python :: 3.9
|
|
11
11
|
Classifier: Programming Language :: Python :: 3.10
|
|
12
12
|
Classifier: Programming Language :: Python :: 3.11
|
|
13
|
-
Requires-Dist: torch (>=
|
|
14
|
-
Requires-Dist:
|
|
15
|
-
Requires-Dist: torchaudio (>=0.13,<0.14) ; python_version >= "3.8" and python_version < "3.11"
|
|
16
|
-
Requires-Dist: torchaudio (>=2.0,<3.0) ; python_version >= "3.11"
|
|
13
|
+
Requires-Dist: torch (>=2.0,<3.0)
|
|
14
|
+
Requires-Dist: torchaudio (>=2.0,<3.0)
|
|
17
15
|
Requires-Dist: torchtyping (>=0.1,<0.2)
|
|
18
16
|
Description-Content-Type: text/markdown
|
|
19
17
|
|
|
@@ -7,7 +7,7 @@ from ml4gw import types
|
|
|
7
7
|
from ml4gw.utils.slicing import slice_kernels
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
class InMemoryDataset:
|
|
10
|
+
class InMemoryDataset(torch.utils.data.IterableDataset):
|
|
11
11
|
"""Dataset for iterating through in-memory multi-channel timeseries
|
|
12
12
|
|
|
13
13
|
Dataset for arrays of timeseries data which can be stored
|
|
@@ -131,7 +131,6 @@ class InMemoryDataset:
|
|
|
131
131
|
self.batches_per_epoch = batches_per_epoch
|
|
132
132
|
self.shuffle = shuffle
|
|
133
133
|
self.coincident = coincident
|
|
134
|
-
self._i = self._idx = None
|
|
135
134
|
|
|
136
135
|
@property
|
|
137
136
|
def num_kernels(self) -> int:
|
|
@@ -157,7 +156,7 @@ class InMemoryDataset:
|
|
|
157
156
|
num_kernels = self.num_kernels ** len(self.X)
|
|
158
157
|
return (num_kernels - 1) // self.batch_size + 1
|
|
159
158
|
|
|
160
|
-
def
|
|
159
|
+
def init_indices(self):
|
|
161
160
|
"""
|
|
162
161
|
Initialize arrays of indices we'll use to slice
|
|
163
162
|
through X and y at iteration time. This helps by
|
|
@@ -204,36 +203,23 @@ class InMemoryDataset:
|
|
|
204
203
|
# the simplest case: deteriminstic and coincident
|
|
205
204
|
idx = torch.arange(num_kernels, device=device)
|
|
206
205
|
|
|
207
|
-
|
|
208
|
-
self._i = 0
|
|
209
|
-
return self
|
|
206
|
+
return idx
|
|
210
207
|
|
|
211
|
-
def
|
|
208
|
+
def __iter__(
|
|
212
209
|
self,
|
|
213
210
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
214
|
-
if self._i is None or self._idx is None:
|
|
215
|
-
raise TypeError(
|
|
216
|
-
"Must initialize InMemoryDataset iteration "
|
|
217
|
-
"before calling __next__"
|
|
218
|
-
)
|
|
219
211
|
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
y = slice_kernels(self.y, idx, self.kernel_size)
|
|
235
|
-
|
|
236
|
-
self._i += 1
|
|
237
|
-
if self.y is not None:
|
|
238
|
-
return X, y
|
|
239
|
-
return X
|
|
212
|
+
indices = self.init_indices()
|
|
213
|
+
for i in range(len(self)):
|
|
214
|
+
# slice the array of _indices_ we'll be using to
|
|
215
|
+
# slice our timeseries, and scale them by the stride
|
|
216
|
+
slc = slice(i * self.batch_size, (i + 1) * self.batch_size)
|
|
217
|
+
idx = indices[slc] * self.stride
|
|
218
|
+
|
|
219
|
+
# slice our timeseries
|
|
220
|
+
X = slice_kernels(self.X, idx, self.kernel_size)
|
|
221
|
+
if self.y is not None:
|
|
222
|
+
y = slice_kernels(self.y, idx, self.kernel_size)
|
|
223
|
+
yield X, y
|
|
224
|
+
else:
|
|
225
|
+
yield X
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Module containing callables classes for generating samples
|
|
3
|
+
from specified distributions. Each callable should map from
|
|
4
|
+
an integer `N` to a 1D torch `Tensor` containing `N` samples
|
|
5
|
+
from the corresponding distribution.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Optional
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
import torch.distributions as dist
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Cosine(dist.Distribution):
|
|
15
|
+
"""
|
|
16
|
+
Cosine distribution based on
|
|
17
|
+
``torch.distributions.TransformedDistribution``.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
arg_constraints = {}
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
low: float = torch.as_tensor(-torch.pi / 2),
|
|
25
|
+
high: float = torch.as_tensor(torch.pi / 2),
|
|
26
|
+
validate_args=None,
|
|
27
|
+
):
|
|
28
|
+
batch_shape = torch.Size()
|
|
29
|
+
super().__init__(batch_shape, validate_args=validate_args)
|
|
30
|
+
self.low = low
|
|
31
|
+
self.norm = 1 / (torch.sin(high) - torch.sin(low))
|
|
32
|
+
|
|
33
|
+
def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
|
|
34
|
+
u = torch.rand(sample_shape, device=self.low.device)
|
|
35
|
+
return torch.arcsin(u / self.norm + torch.sin(self.low))
|
|
36
|
+
|
|
37
|
+
def log_prob(self, value):
|
|
38
|
+
value = torch.as_tensor(value)
|
|
39
|
+
inside_range = (value >= self.low) & (value <= self.high)
|
|
40
|
+
return value.cos().log() * inside_range
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class Sine(dist.TransformedDistribution):
|
|
44
|
+
"""
|
|
45
|
+
Sine distribution based on
|
|
46
|
+
``torch.distributions.TransformedDistribution``.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
low: float = torch.as_tensor(0),
|
|
52
|
+
high: float = torch.as_tensor(torch.pi),
|
|
53
|
+
validate_args=None,
|
|
54
|
+
):
|
|
55
|
+
base_dist = Cosine(
|
|
56
|
+
low - torch.pi / 2, high - torch.pi / 2, validate_args
|
|
57
|
+
)
|
|
58
|
+
super().__init__(
|
|
59
|
+
base_dist,
|
|
60
|
+
[
|
|
61
|
+
dist.AffineTransform(
|
|
62
|
+
loc=torch.pi / 2,
|
|
63
|
+
scale=1,
|
|
64
|
+
)
|
|
65
|
+
],
|
|
66
|
+
validate_args=validate_args,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class LogUniform(dist.TransformedDistribution):
|
|
71
|
+
"""
|
|
72
|
+
Sample from a log uniform distribution
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
def __init__(self, low: float, high: float, validate_args=None):
|
|
76
|
+
base_dist = dist.Uniform(
|
|
77
|
+
torch.as_tensor(low).log(),
|
|
78
|
+
torch.as_tensor(high).log(),
|
|
79
|
+
validate_args,
|
|
80
|
+
)
|
|
81
|
+
super().__init__(
|
|
82
|
+
base_dist,
|
|
83
|
+
[dist.ExpTransform()],
|
|
84
|
+
validate_args=validate_args,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class LogNormal(dist.LogNormal):
|
|
89
|
+
def __init__(
|
|
90
|
+
self,
|
|
91
|
+
mean: float,
|
|
92
|
+
std: float,
|
|
93
|
+
low: Optional[float] = None,
|
|
94
|
+
validate_args=None,
|
|
95
|
+
):
|
|
96
|
+
self.low = low
|
|
97
|
+
super().__init__(loc=mean, scale=std, validate_args=validate_args)
|
|
98
|
+
|
|
99
|
+
def support(self):
|
|
100
|
+
if self.low is not None:
|
|
101
|
+
return dist.constraints.greater_than(self.low)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class PowerLaw(dist.TransformedDistribution):
|
|
105
|
+
"""
|
|
106
|
+
Sample from a power law distribution,
|
|
107
|
+
.. math::
|
|
108
|
+
p(x) \approx x^{\alpha}.
|
|
109
|
+
|
|
110
|
+
Index alpha cannot be 0, since it is equivalent to a Uniform distribution.
|
|
111
|
+
This could be used, for example, as a universal distribution of
|
|
112
|
+
signal-to-noise ratios (SNRs) from uniformly volume distributed
|
|
113
|
+
sources
|
|
114
|
+
.. math::
|
|
115
|
+
|
|
116
|
+
p(\rho) = 3*\rho_0^3 / \rho^4
|
|
117
|
+
|
|
118
|
+
where :math:`\rho_0` is a representative minimum SNR
|
|
119
|
+
considered for detection. See, for example,
|
|
120
|
+
`Schutz (2011) <https://arxiv.org/abs/1102.5421>`_.
|
|
121
|
+
Or, for example, ``index=2`` for uniform in Euclidean volume.
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
support = dist.constraints.nonnegative
|
|
125
|
+
|
|
126
|
+
def __init__(
|
|
127
|
+
self, minimum: float, maximum: float, index: int, validate_args=None
|
|
128
|
+
):
|
|
129
|
+
if index == 0:
|
|
130
|
+
raise RuntimeError("Index of 0 is the same as Uniform")
|
|
131
|
+
elif index == -1:
|
|
132
|
+
base_min = torch.as_tensor(minimum).log()
|
|
133
|
+
base_max = torch.as_tensor(maximum).log()
|
|
134
|
+
transforms = [dist.ExpTransform()]
|
|
135
|
+
else:
|
|
136
|
+
index_plus = index + 1
|
|
137
|
+
base_min = minimum**index_plus / index_plus
|
|
138
|
+
base_max = maximum**index_plus / index_plus
|
|
139
|
+
transforms = [
|
|
140
|
+
dist.AffineTransform(loc=0, scale=index_plus),
|
|
141
|
+
dist.PowerTransform(1 / index_plus),
|
|
142
|
+
]
|
|
143
|
+
base_dist = dist.Uniform(base_min, base_max, validate_args=False)
|
|
144
|
+
super().__init__(
|
|
145
|
+
base_dist,
|
|
146
|
+
transforms,
|
|
147
|
+
validate_args=validate_args,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class DeltaFunction(dist.Distribution):
|
|
152
|
+
arg_constraints = {}
|
|
153
|
+
|
|
154
|
+
def __init__(
|
|
155
|
+
self,
|
|
156
|
+
peak: float = torch.as_tensor(0.0),
|
|
157
|
+
validate_args=None,
|
|
158
|
+
):
|
|
159
|
+
batch_shape = torch.Size()
|
|
160
|
+
super().__init__(batch_shape, validate_args=validate_args)
|
|
161
|
+
self.peak = peak
|
|
162
|
+
|
|
163
|
+
def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
|
|
164
|
+
return self.peak * torch.ones(
|
|
165
|
+
sample_shape, device=self.peak.device, dtype=torch.float32
|
|
166
|
+
)
|
|
@@ -0,0 +1,463 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import List, Optional, Tuple
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
All based on https://github.com/gwpy/gwpy/blob/v3.0.8/gwpy/signal/qtransform.py
|
|
9
|
+
The methods, names, and descriptions come almost entirely from GWpy.
|
|
10
|
+
This code allows the Q-transform to be performed on batches of multi-channel
|
|
11
|
+
input on GPU.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class QTile(torch.nn.Module):
|
|
16
|
+
"""
|
|
17
|
+
Compute the row of Q-tiles for a single Q value and a single
|
|
18
|
+
frequency for a batch of multi-channel frequency series data.
|
|
19
|
+
Should really be called `QRow`, but I want to match GWpy.
|
|
20
|
+
Input data should have three dimensions or fewer.
|
|
21
|
+
If fewer, dimensions will be added until the input is
|
|
22
|
+
three-dimensional.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
q:
|
|
26
|
+
The Q value to use in computing the Q tile
|
|
27
|
+
frequency:
|
|
28
|
+
The frequency for which to compute the Q tile in Hz
|
|
29
|
+
duration:
|
|
30
|
+
The length of time in seconds that the input frequency
|
|
31
|
+
series represents
|
|
32
|
+
sample_rate:
|
|
33
|
+
The sample rate of the original time series in Hz
|
|
34
|
+
mismatch:
|
|
35
|
+
The maximum fractional mismatch between neighboring tiles
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
q: float,
|
|
43
|
+
frequency: float,
|
|
44
|
+
duration: float,
|
|
45
|
+
sample_rate: float,
|
|
46
|
+
mismatch: float,
|
|
47
|
+
):
|
|
48
|
+
super().__init__()
|
|
49
|
+
self.mismatch = mismatch
|
|
50
|
+
self.q = q
|
|
51
|
+
self.deltam = torch.tensor(2 * (self.mismatch / 3.0) ** (1 / 2.0))
|
|
52
|
+
self.qprime = self.q / 11 ** (1 / 2.0)
|
|
53
|
+
self.frequency = frequency
|
|
54
|
+
self.duration = duration
|
|
55
|
+
self.sample_rate = sample_rate
|
|
56
|
+
|
|
57
|
+
self.windowsize = (
|
|
58
|
+
2 * int(self.frequency / self.qprime * self.duration) + 1
|
|
59
|
+
)
|
|
60
|
+
pad = self.ntiles() - self.windowsize
|
|
61
|
+
padding = torch.Tensor((int((pad - 1) / 2.0), int((pad + 1) / 2.0)))
|
|
62
|
+
self.register_buffer("padding", padding)
|
|
63
|
+
self.register_buffer("indices", self.get_data_indices())
|
|
64
|
+
self.register_buffer("window", self.get_window())
|
|
65
|
+
|
|
66
|
+
def ntiles(self):
|
|
67
|
+
"""
|
|
68
|
+
Number of tiles in this frequency row
|
|
69
|
+
"""
|
|
70
|
+
tcum_mismatch = self.duration * 2 * torch.pi * self.frequency / self.q
|
|
71
|
+
return int(2 ** torch.ceil(torch.log2(tcum_mismatch / self.deltam)))
|
|
72
|
+
|
|
73
|
+
def _get_indices(self):
|
|
74
|
+
half = int((self.windowsize - 1) / 2)
|
|
75
|
+
return torch.arange(-half, half + 1)
|
|
76
|
+
|
|
77
|
+
def get_window(self):
|
|
78
|
+
"""
|
|
79
|
+
Generate the bi-square window for this row
|
|
80
|
+
"""
|
|
81
|
+
wfrequencies = self._get_indices() / self.duration
|
|
82
|
+
xfrequencies = wfrequencies * self.qprime / self.frequency
|
|
83
|
+
norm = (
|
|
84
|
+
self.ntiles()
|
|
85
|
+
/ (self.duration * self.sample_rate)
|
|
86
|
+
* (315 * self.qprime / (128 * self.frequency)) ** (1 / 2.0)
|
|
87
|
+
)
|
|
88
|
+
return torch.Tensor((1 - xfrequencies**2) ** 2 * norm)
|
|
89
|
+
|
|
90
|
+
def get_data_indices(self):
|
|
91
|
+
"""
|
|
92
|
+
Get the index array of relevant frequencies for this row
|
|
93
|
+
"""
|
|
94
|
+
return torch.round(
|
|
95
|
+
self._get_indices() + 1 + self.frequency * self.duration,
|
|
96
|
+
).type(torch.long)
|
|
97
|
+
|
|
98
|
+
def forward(self, fseries: torch.Tensor, norm: str = "median"):
|
|
99
|
+
"""
|
|
100
|
+
Compute the transform for this row
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
fseries:
|
|
104
|
+
Frequency series of data. Should correspond to data with
|
|
105
|
+
the duration and sample rate used to initialize this object.
|
|
106
|
+
Expected input shape is `(B, C, F)`, where F is the number
|
|
107
|
+
of samples, C is the number of channels, and B is the number
|
|
108
|
+
of batches. If less than three-dimensional, axes will be
|
|
109
|
+
added.
|
|
110
|
+
norm:
|
|
111
|
+
The method of normalization. Options are "median", "mean", or
|
|
112
|
+
`None`.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
The row of Q-tiles for the given Q and frequency. Output is
|
|
116
|
+
three-dimensional: `(B, C, T)`
|
|
117
|
+
"""
|
|
118
|
+
if len(fseries.shape) > 3:
|
|
119
|
+
raise ValueError("Input data has more than 3 dimensions")
|
|
120
|
+
|
|
121
|
+
while len(fseries.shape) < 3:
|
|
122
|
+
fseries = fseries[None]
|
|
123
|
+
|
|
124
|
+
windowed = fseries[..., self.indices] * self.window
|
|
125
|
+
left, right = self.padding
|
|
126
|
+
padded = F.pad(windowed, (int(left), int(right)), mode="constant")
|
|
127
|
+
wenergy = torch.fft.ifftshift(padded, dim=-1)
|
|
128
|
+
|
|
129
|
+
tdenergy = torch.fft.ifft(wenergy)
|
|
130
|
+
energy = tdenergy.real**2.0 + tdenergy.imag**2.0
|
|
131
|
+
if norm:
|
|
132
|
+
norm = norm.lower() if isinstance(norm, str) else norm
|
|
133
|
+
if norm == "median":
|
|
134
|
+
medians = torch.quantile(energy, q=0.5, dim=-1, keepdim=True)
|
|
135
|
+
energy /= medians
|
|
136
|
+
elif norm == "mean":
|
|
137
|
+
means = torch.mean(energy, dim=-1, keepdim=True)
|
|
138
|
+
energy /= means
|
|
139
|
+
else:
|
|
140
|
+
raise ValueError("Invalid normalisation %r" % norm)
|
|
141
|
+
return energy.type(torch.float32)
|
|
142
|
+
return energy
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class SingleQTransform(torch.nn.Module):
|
|
146
|
+
"""
|
|
147
|
+
Compute the Q-transform for a single Q value for a batch of
|
|
148
|
+
multi-channel time series data. Input data should have
|
|
149
|
+
three dimensions or fewer.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
duration:
|
|
153
|
+
Length of the time series data in seconds
|
|
154
|
+
sample_rate:
|
|
155
|
+
Sample rate of the data in Hz
|
|
156
|
+
spectrogram_shape:
|
|
157
|
+
The shape of the interpolated spectrogram, specified as
|
|
158
|
+
`(num_f_bins, num_t_bins)`. Because the
|
|
159
|
+
frequency spacing of the Q-tiles is in log-space, the frequency
|
|
160
|
+
interpolation is log-spaced as well.
|
|
161
|
+
q:
|
|
162
|
+
The Q value to use for the Q transform
|
|
163
|
+
frange:
|
|
164
|
+
The lower and upper frequency limit to consider for
|
|
165
|
+
the transform. If unspecified, default values will
|
|
166
|
+
be chosen based on q, sample_rate, and duration
|
|
167
|
+
mismatch:
|
|
168
|
+
The maximum fractional mismatch between neighboring tiles
|
|
169
|
+
"""
|
|
170
|
+
|
|
171
|
+
def __init__(
|
|
172
|
+
self,
|
|
173
|
+
duration: float,
|
|
174
|
+
sample_rate: float,
|
|
175
|
+
spectrogram_shape: Tuple[int, int],
|
|
176
|
+
q: float = 12,
|
|
177
|
+
frange: List[float] = [0, torch.inf],
|
|
178
|
+
mismatch: float = 0.2,
|
|
179
|
+
):
|
|
180
|
+
super().__init__()
|
|
181
|
+
self.q = q
|
|
182
|
+
self.spectrogram_shape = spectrogram_shape
|
|
183
|
+
self.frange = frange
|
|
184
|
+
self.duration = duration
|
|
185
|
+
self.mismatch = mismatch
|
|
186
|
+
|
|
187
|
+
qprime = self.q / 11 ** (1 / 2.0)
|
|
188
|
+
if self.frange[0] <= 0: # set non-zero lower frequency
|
|
189
|
+
self.frange[0] = 50 * self.q / (2 * torch.pi * duration)
|
|
190
|
+
if math.isinf(self.frange[1]): # set non-infinite upper frequency
|
|
191
|
+
self.frange[1] = sample_rate / 2 / (1 + 1 / qprime)
|
|
192
|
+
self.freqs = self.get_freqs()
|
|
193
|
+
self.qtile_transforms = torch.nn.ModuleList(
|
|
194
|
+
[
|
|
195
|
+
QTile(self.q, freq, self.duration, sample_rate, self.mismatch)
|
|
196
|
+
for freq in self.freqs
|
|
197
|
+
]
|
|
198
|
+
)
|
|
199
|
+
self.qtiles = None
|
|
200
|
+
|
|
201
|
+
def get_freqs(self):
|
|
202
|
+
"""
|
|
203
|
+
Calculate the frequencies that will be used in this transform.
|
|
204
|
+
For each frequency, a `QTile` is created.
|
|
205
|
+
"""
|
|
206
|
+
minf, maxf = self.frange
|
|
207
|
+
fcum_mismatch = (
|
|
208
|
+
math.log(maxf / minf) * (2 + self.q**2) ** (1 / 2.0) / 2.0
|
|
209
|
+
)
|
|
210
|
+
deltam = 2 * (self.mismatch / 3.0) ** (1 / 2.0)
|
|
211
|
+
nfreq = int(max(1, math.ceil(fcum_mismatch / deltam)))
|
|
212
|
+
fstep = fcum_mismatch / nfreq
|
|
213
|
+
fstepmin = 1 / self.duration
|
|
214
|
+
|
|
215
|
+
freq_base = math.exp(2 / ((2 + self.q**2) ** (1 / 2.0)) * fstep)
|
|
216
|
+
freqs = torch.Tensor([freq_base ** (i + 0.5) for i in range(nfreq)])
|
|
217
|
+
freqs = (minf * freqs // fstepmin) * fstepmin
|
|
218
|
+
return torch.unique(freqs)
|
|
219
|
+
|
|
220
|
+
def get_max_energy(
|
|
221
|
+
self, fsearch_range: List[float] = None, dimension: str = "both"
|
|
222
|
+
):
|
|
223
|
+
"""
|
|
224
|
+
Gets the maximum energy value among the QTiles. The maximum can
|
|
225
|
+
be computed across all batches and channels, across all channels,
|
|
226
|
+
across all batches, or individually for each channel/batch
|
|
227
|
+
combination. This could be useful for allowing the use of different
|
|
228
|
+
Q values for different channels and batches, but the slicing would
|
|
229
|
+
be slow, so this isn't used yet.
|
|
230
|
+
|
|
231
|
+
Optionally, a pair of frequency values can be specified for
|
|
232
|
+
`fsearch_range` to restrict the frequencies in which the maximum
|
|
233
|
+
energy value is sought.
|
|
234
|
+
"""
|
|
235
|
+
allowed_dimensions = ["both", "neither", "channel", "batch"]
|
|
236
|
+
if dimension not in allowed_dimensions:
|
|
237
|
+
raise ValueError(f"Dimension must be one of {allowed_dimensions}")
|
|
238
|
+
|
|
239
|
+
if self.qtiles is None:
|
|
240
|
+
raise RuntimeError(
|
|
241
|
+
"Q-tiles must first be computed with .compute_qtiles()"
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
if fsearch_range is not None:
|
|
245
|
+
start = min(torch.argwhere(self.freqs > fsearch_range[0]))
|
|
246
|
+
stop = min(torch.argwhere(self.freqs > fsearch_range[1]))
|
|
247
|
+
qtiles = self.qtiles[start:stop]
|
|
248
|
+
else:
|
|
249
|
+
qtiles = self.qtiles
|
|
250
|
+
|
|
251
|
+
if dimension == "both":
|
|
252
|
+
return max([torch.max(qtile) for qtile in qtiles])
|
|
253
|
+
|
|
254
|
+
max_across_t = [torch.max(qtile, dim=-1).values for qtile in qtiles]
|
|
255
|
+
max_across_t = torch.stack(max_across_t, dim=-1)
|
|
256
|
+
max_across_ft = torch.max(max_across_t, dim=-1).values
|
|
257
|
+
|
|
258
|
+
if dimension == "neither":
|
|
259
|
+
return max_across_ft
|
|
260
|
+
if dimension == "channel":
|
|
261
|
+
return torch.max(max_across_ft, dim=-2).values
|
|
262
|
+
if dimension == "batch":
|
|
263
|
+
return torch.max(max_across_ft, dim=-1).values
|
|
264
|
+
|
|
265
|
+
def compute_qtiles(self, X: torch.Tensor, norm: str = "median"):
|
|
266
|
+
"""
|
|
267
|
+
Take the FFT of the input timeseries and calculate the transform
|
|
268
|
+
for each `QTile`
|
|
269
|
+
"""
|
|
270
|
+
# Computing the FFT with the same normalization and scaling as GWpy
|
|
271
|
+
X = torch.fft.rfft(X, norm="forward")
|
|
272
|
+
X[..., 1:] *= 2
|
|
273
|
+
self.qtiles = [qtile(X, norm) for qtile in self.qtile_transforms]
|
|
274
|
+
|
|
275
|
+
def interpolate(self, num_f_bins: int, num_t_bins: int):
|
|
276
|
+
"""
|
|
277
|
+
Interpolate each `QTile` to the specified number of time and
|
|
278
|
+
frequency bins. Note that PyTorch does not have the same
|
|
279
|
+
interpolation methods that GWpy uses, and so the interpolated
|
|
280
|
+
spectrograms will be different even though the uninterpolated
|
|
281
|
+
values match. The `bicubic` interpolation method is used as
|
|
282
|
+
it seems to match GWpy most closely.
|
|
283
|
+
"""
|
|
284
|
+
if self.qtiles is None:
|
|
285
|
+
raise RuntimeError(
|
|
286
|
+
"Q-tiles must first be computed with .compute_qtiles()"
|
|
287
|
+
)
|
|
288
|
+
resampled = [
|
|
289
|
+
F.interpolate(
|
|
290
|
+
qtile[None], (qtile.shape[-2], num_t_bins), mode="bicubic"
|
|
291
|
+
)
|
|
292
|
+
for qtile in self.qtiles
|
|
293
|
+
]
|
|
294
|
+
resampled = torch.stack(resampled, dim=-2)
|
|
295
|
+
resampled = F.interpolate(
|
|
296
|
+
resampled[0], (num_f_bins, num_t_bins), mode="bicubic"
|
|
297
|
+
)
|
|
298
|
+
return torch.squeeze(resampled)
|
|
299
|
+
|
|
300
|
+
def forward(
|
|
301
|
+
self,
|
|
302
|
+
X: torch.Tensor,
|
|
303
|
+
norm: str = "median",
|
|
304
|
+
spectrogram_shape: Optional[Tuple[int, int]] = None,
|
|
305
|
+
):
|
|
306
|
+
"""
|
|
307
|
+
Compute the Q-tiles and interpolate
|
|
308
|
+
|
|
309
|
+
Args:
|
|
310
|
+
X:
|
|
311
|
+
Time series of data. Should have the duration and sample rate
|
|
312
|
+
used to initialize this object. Expected input shape is
|
|
313
|
+
`(B, C, T)`, where T is the number of samples, C is the number
|
|
314
|
+
of channels, and B is the number of batches. If less than
|
|
315
|
+
three-dimensional, axes will be added during Q-tile
|
|
316
|
+
computation.
|
|
317
|
+
norm:
|
|
318
|
+
The method of interpolation used by each QTile
|
|
319
|
+
spectrogram_shape:
|
|
320
|
+
The shape of the interpolated spectrogram, specified as
|
|
321
|
+
`(num_f_bins, num_t_bins)`. Because the
|
|
322
|
+
frequency spacing of the Q-tiles is in log-space, the frequency
|
|
323
|
+
interpolation is log-spaced as well. If not given, the shape
|
|
324
|
+
used to initialize the transform will be used.
|
|
325
|
+
|
|
326
|
+
Returns:
|
|
327
|
+
The interpolated Q-transform for the batch of data. Output will
|
|
328
|
+
have one more dimension than the input
|
|
329
|
+
"""
|
|
330
|
+
|
|
331
|
+
if spectrogram_shape is None:
|
|
332
|
+
spectrogram_shape = self.spectrogram_shape
|
|
333
|
+
num_f_bins, num_t_bins = spectrogram_shape
|
|
334
|
+
self.compute_qtiles(X, norm)
|
|
335
|
+
return self.interpolate(num_f_bins, num_t_bins)
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
class QScan(torch.nn.Module):
|
|
339
|
+
"""
|
|
340
|
+
Calculate the Q-transform of a batch of multi-channel
|
|
341
|
+
time series data for a range of Q values and return
|
|
342
|
+
the interpolated Q-transform with the highest energy.
|
|
343
|
+
|
|
344
|
+
Args:
|
|
345
|
+
duration:
|
|
346
|
+
Length of the time series data in seconds
|
|
347
|
+
sample_rate:
|
|
348
|
+
Sample rate of the data in Hz
|
|
349
|
+
spectrogram_shape:
|
|
350
|
+
The shape of the interpolated spectrogram, specified as
|
|
351
|
+
`(num_f_bins, num_t_bins)`. Because the
|
|
352
|
+
frequency spacing of the Q-tiles is in log-space, the frequency
|
|
353
|
+
interpolation is log-spaced as well.
|
|
354
|
+
qrange:
|
|
355
|
+
The lower and upper values of Q to consider. The
|
|
356
|
+
actual values of Q used for the transforms are
|
|
357
|
+
determined by the `get_qs` method
|
|
358
|
+
frange:
|
|
359
|
+
The lower and upper frequency limit to consider for
|
|
360
|
+
the transform. If unspecified, default values will
|
|
361
|
+
be chosen based on q, sample_rate, and duration
|
|
362
|
+
mismatch:
|
|
363
|
+
The maximum fractional mismatch between neighboring tiles
|
|
364
|
+
"""
|
|
365
|
+
|
|
366
|
+
def __init__(
|
|
367
|
+
self,
|
|
368
|
+
duration: float,
|
|
369
|
+
sample_rate: float,
|
|
370
|
+
spectrogram_shape: Tuple[int, int],
|
|
371
|
+
qrange: List[float] = [4, 64],
|
|
372
|
+
frange: List[float] = [0, torch.inf],
|
|
373
|
+
mismatch: float = 0.2,
|
|
374
|
+
):
|
|
375
|
+
super().__init__()
|
|
376
|
+
self.qrange = qrange
|
|
377
|
+
self.mismatch = mismatch
|
|
378
|
+
self.qs = self.get_qs()
|
|
379
|
+
self.frange = frange
|
|
380
|
+
self.spectrogram_shape = spectrogram_shape
|
|
381
|
+
|
|
382
|
+
# Deliberately doing something different from GWpy here.
|
|
383
|
+
# Their final frange is the intersection of the frange
|
|
384
|
+
# from each q. This implementation uses the frange of
|
|
385
|
+
# the chosen q.
|
|
386
|
+
self.q_transforms = torch.nn.ModuleList(
|
|
387
|
+
[
|
|
388
|
+
SingleQTransform(
|
|
389
|
+
duration=duration,
|
|
390
|
+
sample_rate=sample_rate,
|
|
391
|
+
spectrogram_shape=spectrogram_shape,
|
|
392
|
+
q=q,
|
|
393
|
+
frange=self.frange.copy(),
|
|
394
|
+
mismatch=self.mismatch,
|
|
395
|
+
)
|
|
396
|
+
for q in self.qs
|
|
397
|
+
]
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
def get_qs(self):
|
|
401
|
+
"""
|
|
402
|
+
Determine the values of Q to try for the set of Q-transforms
|
|
403
|
+
"""
|
|
404
|
+
deltam = 2 * (self.mismatch / 3.0) ** (1 / 2.0)
|
|
405
|
+
cumum = math.log(self.qrange[1] / self.qrange[0]) / 2 ** (1 / 2.0)
|
|
406
|
+
nplanes = int(max(math.ceil(cumum / deltam), 1))
|
|
407
|
+
dq = cumum / nplanes
|
|
408
|
+
qs = [
|
|
409
|
+
self.qrange[0] * math.exp(2 ** (1 / 2.0) * dq * (i + 0.5))
|
|
410
|
+
for i in range(nplanes)
|
|
411
|
+
]
|
|
412
|
+
return qs
|
|
413
|
+
|
|
414
|
+
def forward(
|
|
415
|
+
self,
|
|
416
|
+
X: torch.Tensor,
|
|
417
|
+
fsearch_range: List[float] = None,
|
|
418
|
+
norm: str = "median",
|
|
419
|
+
spectrogram_shape: Optional[Tuple[int, int]] = None,
|
|
420
|
+
):
|
|
421
|
+
"""
|
|
422
|
+
Compute the set of QTiles for each Q transform and determine which
|
|
423
|
+
has the highest energy value. Interpolate and return the
|
|
424
|
+
corresponding set of tiles.
|
|
425
|
+
|
|
426
|
+
Args:
|
|
427
|
+
X:
|
|
428
|
+
Time series of data. Should have the duration and sample rate
|
|
429
|
+
used to initialize this object. Expected input shape is
|
|
430
|
+
`(B, C, T)`, where T is the number of samples, C is the number
|
|
431
|
+
of channels, and B is the number of batches. If less than
|
|
432
|
+
three-dimensional, axes will be added during Q-tile
|
|
433
|
+
computation.
|
|
434
|
+
fsearch_range:
|
|
435
|
+
The lower and upper frequency values within which to search
|
|
436
|
+
for the maximum energy
|
|
437
|
+
norm:
|
|
438
|
+
The method of interpolation used by each QTile
|
|
439
|
+
spectrogram_shape:
|
|
440
|
+
The shape of the interpolated spectrogram, specified as
|
|
441
|
+
`(num_f_bins, num_t_bins)`. Because the
|
|
442
|
+
frequency spacing of the Q-tiles is in log-space, the frequency
|
|
443
|
+
interpolation is log-spaced as well. If not given, the shape
|
|
444
|
+
used to initialize the transform will be used.
|
|
445
|
+
|
|
446
|
+
Returns:
|
|
447
|
+
An interpolated Q-transform for the batch of data. Output will
|
|
448
|
+
have one more dimension than the input
|
|
449
|
+
"""
|
|
450
|
+
for transform in self.q_transforms:
|
|
451
|
+
transform.compute_qtiles(X, norm)
|
|
452
|
+
idx = torch.argmax(
|
|
453
|
+
torch.Tensor(
|
|
454
|
+
[
|
|
455
|
+
transform.get_max_energy(fsearch_range=fsearch_range)
|
|
456
|
+
for transform in self.q_transforms
|
|
457
|
+
]
|
|
458
|
+
)
|
|
459
|
+
)
|
|
460
|
+
if spectrogram_shape is None:
|
|
461
|
+
spectrogram_shape = self.spectrogram_shape
|
|
462
|
+
num_f_bins, num_t_bins = spectrogram_shape
|
|
463
|
+
return self.q_transforms[idx].interpolate(num_f_bins, num_t_bins)
|
|
@@ -905,15 +905,15 @@ def phenom_d_inspiral_phase(Mf, mass_1, mass_2, eta, eta2, xi, chi1, chi2):
|
|
|
905
905
|
sigma3 = sigma3Fit(eta, eta2, xi)
|
|
906
906
|
sigma4 = sigma4Fit(eta, eta2, xi)
|
|
907
907
|
|
|
908
|
-
ins_phasing += (Mf.
|
|
909
|
-
ins_phasing += (Mf.
|
|
910
|
-
ins_phasing += (Mf.
|
|
911
|
-
ins_phasing += (Mf.
|
|
908
|
+
ins_phasing += (Mf.mT * sigma1 / eta).mT
|
|
909
|
+
ins_phasing += (Mf.mT ** (4.0 / 3.0) * 0.75 * sigma2 / eta).mT
|
|
910
|
+
ins_phasing += (Mf.mT ** (5.0 / 3.0) * 0.6 * sigma3 / eta).mT
|
|
911
|
+
ins_phasing += (Mf.mT**2.0 * 0.5 * sigma4 / eta).mT
|
|
912
912
|
|
|
913
|
-
ins_Dphasing = (ins_Dphasing.T + sigma1 / eta).
|
|
914
|
-
ins_Dphasing += (Mf.
|
|
915
|
-
ins_Dphasing += (Mf.
|
|
916
|
-
ins_Dphasing += (Mf.
|
|
913
|
+
ins_Dphasing = (ins_Dphasing.T + sigma1 / eta).mT
|
|
914
|
+
ins_Dphasing += (Mf.mT ** (1.0 / 3.0) * sigma2 / eta).mT
|
|
915
|
+
ins_Dphasing += (Mf.mT ** (2.0 / 3.0) * sigma3 / eta).mT
|
|
916
|
+
ins_Dphasing += (Mf.mT * sigma4 / eta).mT
|
|
917
917
|
|
|
918
918
|
return ins_phasing, ins_Dphasing
|
|
919
919
|
|
|
@@ -925,16 +925,16 @@ def phenom_d_int_phase(Mf, eta, eta2, xi):
|
|
|
925
925
|
# Merger phase
|
|
926
926
|
# Leading beta0 is not added here
|
|
927
927
|
# overall 1/eta is not multiplied
|
|
928
|
-
int_phasing = (Mf.
|
|
929
|
-
int_phasing += (torch.log(Mf).
|
|
930
|
-
int_phasing -= (Mf.
|
|
928
|
+
int_phasing = (Mf.mT * beta1).mT
|
|
929
|
+
int_phasing += (torch.log(Mf).mT * beta2).mT
|
|
930
|
+
int_phasing -= (Mf.mT ** (-3.0) / 3.0 * beta3).mT
|
|
931
931
|
|
|
932
932
|
# overall 1/eta is multiple in derivative of
|
|
933
933
|
# intermediate phase
|
|
934
|
-
int_Dphasing = (Mf.
|
|
935
|
-
int_Dphasing += (Mf.
|
|
936
|
-
int_Dphasing = (int_Dphasing.T + beta1).
|
|
937
|
-
int_Dphasing = (int_Dphasing.T / eta).
|
|
934
|
+
int_Dphasing = (Mf.mT ** (-4.0) * beta3).mT
|
|
935
|
+
int_Dphasing += (Mf.mT ** (-1.0) * beta2).mT
|
|
936
|
+
int_Dphasing = (int_Dphasing.T + beta1).mT
|
|
937
|
+
int_Dphasing = (int_Dphasing.T / eta).mT
|
|
938
938
|
return int_phasing, int_Dphasing
|
|
939
939
|
|
|
940
940
|
|
|
@@ -947,19 +947,21 @@ def phenom_d_mrd_phase(Mf, eta, eta2, chi1, chi2, xi):
|
|
|
947
947
|
|
|
948
948
|
# merger ringdown
|
|
949
949
|
fRD, fDM = fring_fdamp(eta, eta2, chi1, chi2)
|
|
950
|
-
f_minus_alpha5_fRD = (Mf.
|
|
950
|
+
f_minus_alpha5_fRD = (Mf.t() - alpha5 * fRD).t()
|
|
951
951
|
|
|
952
952
|
# Leading 1/eta is not multiplied at this stage
|
|
953
|
-
mrd_phasing = (Mf.
|
|
954
|
-
mrd_phasing -= (1 / Mf.
|
|
955
|
-
mrd_phasing += (4.0 / 3.0) * (Mf.
|
|
956
|
-
mrd_phasing += (torch.atan(f_minus_alpha5_fRD.
|
|
957
|
-
|
|
958
|
-
mrd_Dphasing = (
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
mrd_Dphasing
|
|
962
|
-
mrd_Dphasing
|
|
953
|
+
mrd_phasing = (Mf.t() * alpha1).t()
|
|
954
|
+
mrd_phasing -= (1 / Mf.t() * alpha2).t()
|
|
955
|
+
mrd_phasing += (4.0 / 3.0) * (Mf.t() ** (3.0 / 4.0) * alpha3).t()
|
|
956
|
+
mrd_phasing += (torch.atan(f_minus_alpha5_fRD.t() / fDM) * alpha4).t()
|
|
957
|
+
|
|
958
|
+
mrd_Dphasing = (
|
|
959
|
+
alpha4 * fDM / (f_minus_alpha5_fRD.t() ** 2 + fDM**2)
|
|
960
|
+
).t()
|
|
961
|
+
mrd_Dphasing += (Mf.t() ** (-1.0 / 4.0) * alpha3).t()
|
|
962
|
+
mrd_Dphasing += (Mf.t() ** (-2.0) * alpha2).t()
|
|
963
|
+
mrd_Dphasing = (mrd_Dphasing.t() + alpha1).t()
|
|
964
|
+
mrd_Dphasing = (mrd_Dphasing.t() / eta).t()
|
|
963
965
|
|
|
964
966
|
return mrd_phasing, mrd_Dphasing
|
|
965
967
|
|
|
@@ -984,23 +986,23 @@ def phenom_d_phase(Mf, mass_1, mass_2, eta, eta2, chi1, chi2, xi):
|
|
|
984
986
|
PHI_fJoin_INS, eta, eta2, xi
|
|
985
987
|
)
|
|
986
988
|
C2Int = ins_Dphase_f1 - int_Dphase_f1
|
|
987
|
-
C1Int = ins_phase_f1 - (int_phase_f1.T / eta).
|
|
989
|
+
C1Int = ins_phase_f1 - (int_phase_f1.T / eta).mT - C2Int * PHI_fJoin_INS
|
|
988
990
|
# C1 continuity at ringdown
|
|
989
|
-
fRDJoin = (0.5 * torch.ones_like(Mf).
|
|
991
|
+
fRDJoin = (0.5 * torch.ones_like(Mf).mT * fRD).mT
|
|
990
992
|
int_phase_rd, int_Dphase_rd = phenom_d_int_phase(fRDJoin, eta, eta2, xi)
|
|
991
993
|
mrd_phase_rd, mrd_Dphase_rd = phenom_d_mrd_phase(
|
|
992
994
|
fRDJoin, eta, eta2, chi1, chi2, xi
|
|
993
995
|
)
|
|
994
|
-
PhiIntTempVal = (int_phase_rd.T / eta).
|
|
996
|
+
PhiIntTempVal = (int_phase_rd.T / eta).mT + C1Int + C2Int * fRDJoin
|
|
995
997
|
# C2MRD = int_Dphase_rd - mrd_Dphase_rd
|
|
996
998
|
C2MRD = C2Int + int_Dphase_rd - mrd_Dphase_rd
|
|
997
|
-
C1MRD = PhiIntTempVal - (mrd_phase_rd.T / eta).
|
|
999
|
+
C1MRD = PhiIntTempVal - (mrd_phase_rd.T / eta).mT - C2MRD * fRDJoin
|
|
998
1000
|
|
|
999
|
-
int_phase = (int_phase.T / eta).
|
|
1001
|
+
int_phase = (int_phase.T / eta).mT
|
|
1000
1002
|
int_phase += C1Int
|
|
1001
1003
|
int_phase += Mf * C2Int
|
|
1002
1004
|
|
|
1003
|
-
mrd_phase = (mrd_phase.T / eta).
|
|
1005
|
+
mrd_phase = (mrd_phase.T / eta).mT
|
|
1004
1006
|
mrd_phase += C1MRD
|
|
1005
1007
|
mrd_phase += Mf * C2MRD
|
|
1006
1008
|
|
|
@@ -1133,10 +1135,10 @@ def phenom_d_inspiral_amp(Mf, eta, eta2, Seta, xi, chi1, chi2, chi12, chi22):
|
|
|
1133
1135
|
+ Mf_five_third.T * prefactors_five_thirds
|
|
1134
1136
|
+ Mf_seven_third.T * prefactors_seven_thirds
|
|
1135
1137
|
+ MF_eight_third.T * prefactors_eight_thirds
|
|
1136
|
-
+ Mf.
|
|
1138
|
+
+ Mf.mT * prefactors_one
|
|
1137
1139
|
+ Mf_two.T * prefactors_two
|
|
1138
1140
|
+ Mf_three.T * prefactors_three
|
|
1139
|
-
).
|
|
1141
|
+
).mT
|
|
1140
1142
|
|
|
1141
1143
|
Damp = (
|
|
1142
1144
|
(2.0 / 3.0) / Mf_one_third.T * prefactors_two_thirds
|
|
@@ -1145,9 +1147,9 @@ def phenom_d_inspiral_amp(Mf, eta, eta2, Seta, xi, chi1, chi2, chi12, chi22):
|
|
|
1145
1147
|
+ (7.0 / 3.0) * Mf_four_third.T * prefactors_seven_thirds
|
|
1146
1148
|
+ (8.0 / 3.0) * Mf_five_third.T * prefactors_eight_thirds
|
|
1147
1149
|
+ prefactors_one
|
|
1148
|
-
+ 2.0 * Mf.
|
|
1150
|
+
+ 2.0 * Mf.mT * prefactors_two
|
|
1149
1151
|
+ 3.0 * Mf_two.T * prefactors_three
|
|
1150
|
-
).
|
|
1152
|
+
).mT
|
|
1151
1153
|
|
|
1152
1154
|
return amp, Damp
|
|
1153
1155
|
|
|
@@ -1160,15 +1162,15 @@ def phenom_d_mrd_amp(Mf, eta, eta2, chi1, chi2, xi):
|
|
|
1160
1162
|
gamma2 = gamma2_fun(eta, eta2, xi)
|
|
1161
1163
|
gamma3 = gamma3_fun(eta, eta2, xi)
|
|
1162
1164
|
fDMgamma3 = fDM * gamma3
|
|
1163
|
-
pow2_fDMgamma3 = (torch.ones_like(Mf).
|
|
1164
|
-
fminfRD = Mf - (torch.ones_like(Mf).
|
|
1165
|
-
exp_times_lorentzian = torch.exp(fminfRD.
|
|
1165
|
+
pow2_fDMgamma3 = (torch.ones_like(Mf).mT * fDMgamma3 * fDMgamma3).mT
|
|
1166
|
+
fminfRD = Mf - (torch.ones_like(Mf).mT * fRD).mT
|
|
1167
|
+
exp_times_lorentzian = torch.exp(fminfRD.mT * gamma2 / fDMgamma3).mT
|
|
1166
1168
|
exp_times_lorentzian *= fminfRD**2 + pow2_fDMgamma3
|
|
1167
1169
|
|
|
1168
|
-
amp = (1 / exp_times_lorentzian.T * gamma1 * gamma3 * fDM).
|
|
1169
|
-
Damp = (fminfRD.
|
|
1170
|
+
amp = (1 / exp_times_lorentzian.T * gamma1 * gamma3 * fDM).mT
|
|
1171
|
+
Damp = (fminfRD.mT * -2 * fDM * gamma1 * gamma3) / (
|
|
1170
1172
|
fminfRD * fminfRD + pow2_fDMgamma3
|
|
1171
|
-
).
|
|
1173
|
+
).mT - (gamma2 * gamma1)
|
|
1172
1174
|
Damp = Damp.T / exp_times_lorentzian
|
|
1173
1175
|
return amp, Damp
|
|
1174
1176
|
|
|
@@ -1184,7 +1186,7 @@ def phenom_d_int_amp(Mf, eta, eta2, Seta, chi1, chi2, chi12, chi22, xi):
|
|
|
1184
1186
|
gamma3 = gamma3_fun(eta, eta2, xi)
|
|
1185
1187
|
|
|
1186
1188
|
fpeak = fmaxCalc(fRD, fDM, gamma2, gamma3)
|
|
1187
|
-
Mf3 = (torch.ones_like(Mf).
|
|
1189
|
+
Mf3 = (torch.ones_like(Mf).mT * fpeak).mT
|
|
1188
1190
|
dfx = 0.5 * (Mf3 - Mf1)
|
|
1189
1191
|
Mf2 = Mf1 + dfx
|
|
1190
1192
|
|
|
@@ -1192,7 +1194,7 @@ def phenom_d_int_amp(Mf, eta, eta2, Seta, chi1, chi2, chi12, chi22, xi):
|
|
|
1192
1194
|
Mf1, eta, eta2, Seta, xi, chi1, chi2, chi12, chi22
|
|
1193
1195
|
)
|
|
1194
1196
|
v3, d2 = phenom_d_mrd_amp(Mf3, eta, eta2, chi1, chi2, xi)
|
|
1195
|
-
v2 = (torch.ones_like(Mf).
|
|
1197
|
+
v2 = (torch.ones_like(Mf).mT * AmpIntColFitCoeff(eta, eta2, xi)).mT
|
|
1196
1198
|
|
|
1197
1199
|
delta_0, delta_1, delta_2, delta_3, delta_4 = delta_values(
|
|
1198
1200
|
f1=Mf1, f2=Mf2, f3=Mf3, v1=v1, v2=v2, v3=v3, d1=d1, d2=d2
|
|
@@ -1225,7 +1227,7 @@ def phenom_d_amp(
|
|
|
1225
1227
|
fRD, fDM = fring_fdamp(eta, eta2, chi1, chi2)
|
|
1226
1228
|
Mf_peak = fmaxCalc(fRD, fDM, gamma2, gamma3)
|
|
1227
1229
|
# Geometric peak and joining frequencies
|
|
1228
|
-
Mf_peak = (torch.ones_like(Mf).
|
|
1230
|
+
Mf_peak = (torch.ones_like(Mf).mT * Mf_peak).mT
|
|
1229
1231
|
Mf_join_ins = 0.014 * torch.ones_like(Mf)
|
|
1230
1232
|
|
|
1231
1233
|
# construct full IMR Amp
|
|
@@ -1290,9 +1292,9 @@ def phenom_d_htilde(
|
|
|
1290
1292
|
Mf_ref, mass_1, mass_2, eta, eta2, chi1, chi2, xi
|
|
1291
1293
|
)
|
|
1292
1294
|
|
|
1293
|
-
Psi = (Psi.T - 2 * phic).
|
|
1295
|
+
Psi = (Psi.T - 2 * phic).mT
|
|
1294
1296
|
Psi -= Psi_ref
|
|
1295
|
-
Psi -= ((Mf - Mf_ref).
|
|
1297
|
+
Psi -= ((Mf - Mf_ref).mT * t0).mT
|
|
1296
1298
|
|
|
1297
1299
|
amp, _ = phenom_d_amp(
|
|
1298
1300
|
Mf,
|
|
@@ -1353,7 +1355,7 @@ def IMRPhenomD(
|
|
|
1353
1355
|
f, chirp_mass, mass_ratio, chi1, chi2, distance, phic, f_ref
|
|
1354
1356
|
)
|
|
1355
1357
|
|
|
1356
|
-
hp = (htilde.
|
|
1357
|
-
hc = -1j * (htilde.
|
|
1358
|
+
hp = (htilde.mT * pfac).mT
|
|
1359
|
+
hc = -1j * (htilde.mT * cfac).mT
|
|
1358
1360
|
|
|
1359
1361
|
return hp, hc
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "ml4gw"
|
|
3
|
-
version = "0.4.
|
|
3
|
+
version = "0.4.2"
|
|
4
4
|
description = "Tools for training torch models on gravitational wave data"
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
authors = [
|
|
@@ -11,14 +11,8 @@ authors = [
|
|
|
11
11
|
python = "^3.8,<3.12"
|
|
12
12
|
|
|
13
13
|
# torch deps
|
|
14
|
-
torch =
|
|
15
|
-
|
|
16
|
-
{version = "^2.0", python = ">=3.11"}
|
|
17
|
-
]
|
|
18
|
-
torchaudio = [
|
|
19
|
-
{version = "^0.13", python = ">=3.8,<3.11"},
|
|
20
|
-
{version = "^2.0", python = ">=3.11"}
|
|
21
|
-
]
|
|
14
|
+
torch = "^2.0"
|
|
15
|
+
torchaudio = "^2.0"
|
|
22
16
|
torchtyping = "^0.1"
|
|
23
17
|
|
|
24
18
|
[tool.poetry.group.dev.dependencies]
|
|
@@ -29,6 +23,7 @@ pytest = "^7.0"
|
|
|
29
23
|
lalsuite = "^7.0"
|
|
30
24
|
bilby = "^2.1"
|
|
31
25
|
jupyter = "^1.0.0"
|
|
26
|
+
gwpy = "^2.1"
|
|
32
27
|
|
|
33
28
|
Sphinx = ">5.0"
|
|
34
29
|
sphinx-rtd-theme = "^2.0.0"
|
|
@@ -1,122 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Module containing callables classes for generating samples
|
|
3
|
-
from specified distributions. Each callable should map from
|
|
4
|
-
an integer `N` to a 1D torch `Tensor` containing `N` samples
|
|
5
|
-
from the corresponding distribution.
|
|
6
|
-
"""
|
|
7
|
-
|
|
8
|
-
import math
|
|
9
|
-
from typing import Optional
|
|
10
|
-
|
|
11
|
-
import torch
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
class Uniform:
|
|
15
|
-
"""
|
|
16
|
-
Sample uniformly between `low` and `high`.
|
|
17
|
-
"""
|
|
18
|
-
|
|
19
|
-
def __init__(self, low: float = 0, high: float = 1) -> None:
|
|
20
|
-
self.low = low
|
|
21
|
-
self.high = high
|
|
22
|
-
|
|
23
|
-
def __call__(self, N: int) -> torch.Tensor:
|
|
24
|
-
return self.low + torch.rand(size=(N,)) * (self.high - self.low)
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
class Cosine:
|
|
28
|
-
"""
|
|
29
|
-
Sample from a raised Cosine distribution between
|
|
30
|
-
`low` and `high`. Based on the implementation from
|
|
31
|
-
bilby documented here:
|
|
32
|
-
https://lscsoft.docs.ligo.org/bilby/api/bilby.core.prior.analytical.Cosine.html # noqa
|
|
33
|
-
"""
|
|
34
|
-
|
|
35
|
-
def __init__(
|
|
36
|
-
self, low: float = -math.pi / 2, high: float = math.pi / 2
|
|
37
|
-
) -> None:
|
|
38
|
-
self.low = low
|
|
39
|
-
self.norm = 1 / (math.sin(high) - math.sin(low))
|
|
40
|
-
|
|
41
|
-
def __call__(self, N: int) -> torch.Tensor:
|
|
42
|
-
"""
|
|
43
|
-
Implementation lifted from
|
|
44
|
-
https://lscsoft.docs.ligo.org/bilby/_modules/bilby/core/prior/analytical.html#Cosine # noqa
|
|
45
|
-
"""
|
|
46
|
-
u = torch.rand(size=(N,))
|
|
47
|
-
return torch.arcsin(u / self.norm + math.sin(self.low))
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
class LogNormal:
|
|
51
|
-
"""
|
|
52
|
-
Sample from a log normal distribution with the
|
|
53
|
-
specified `mean` and standard deviation `std`.
|
|
54
|
-
If a `low` value is specified, values sampled
|
|
55
|
-
lower than this will be clipped to `low`.
|
|
56
|
-
"""
|
|
57
|
-
|
|
58
|
-
def __init__(
|
|
59
|
-
self, mean: float, std: float, low: Optional[float] = None
|
|
60
|
-
) -> None:
|
|
61
|
-
self.sigma = math.log((std / mean) ** 2 + 1) ** 0.5
|
|
62
|
-
self.mu = 2 * math.log(mean / (mean**2 + std**2) ** 0.25)
|
|
63
|
-
self.low = low
|
|
64
|
-
|
|
65
|
-
def __call__(self, N: int) -> torch.Tensor:
|
|
66
|
-
|
|
67
|
-
u = self.mu + torch.randn(N) * self.sigma
|
|
68
|
-
x = torch.exp(u)
|
|
69
|
-
|
|
70
|
-
if self.low is not None:
|
|
71
|
-
x = torch.clip(x, self.low)
|
|
72
|
-
return x
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
class LogUniform(Uniform):
|
|
76
|
-
"""
|
|
77
|
-
Sample from a log uniform distribution
|
|
78
|
-
"""
|
|
79
|
-
|
|
80
|
-
def __init__(self, low: float, high: float) -> None:
|
|
81
|
-
super().__init__(math.log(low), math.log(high))
|
|
82
|
-
|
|
83
|
-
def __call__(self, N: int) -> torch.Tensor:
|
|
84
|
-
u = super().__call__(N)
|
|
85
|
-
return torch.exp(u)
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
class PowerLaw:
|
|
89
|
-
"""
|
|
90
|
-
Sample from a power law distribution,
|
|
91
|
-
.. math::
|
|
92
|
-
p(x) \approx x^{-\alpha}.
|
|
93
|
-
|
|
94
|
-
Index alpha must be greater than 1.
|
|
95
|
-
This could be used, for example, as a universal distribution of
|
|
96
|
-
signal-to-noise ratios (SNRs) from uniformly volume distributed
|
|
97
|
-
sources
|
|
98
|
-
.. math::
|
|
99
|
-
|
|
100
|
-
p(\rho) = 3*\rho_0^3 / \rho^4
|
|
101
|
-
|
|
102
|
-
where :math:`\rho_0` is a representative minimum SNR
|
|
103
|
-
considered for detection. See, for example,
|
|
104
|
-
`Schutz (2011) <https://arxiv.org/abs/1102.5421>`_.
|
|
105
|
-
"""
|
|
106
|
-
|
|
107
|
-
def __init__(
|
|
108
|
-
self, x_min: float, x_max: float = float("inf"), alpha: float = 2
|
|
109
|
-
) -> None:
|
|
110
|
-
self.x_min = x_min
|
|
111
|
-
self.x_max = x_max
|
|
112
|
-
self.alpha = alpha
|
|
113
|
-
|
|
114
|
-
self.normalization = x_min ** (-self.alpha + 1)
|
|
115
|
-
self.normalization -= x_max ** (-self.alpha + 1)
|
|
116
|
-
|
|
117
|
-
def __call__(self, N: int) -> torch.Tensor:
|
|
118
|
-
u = torch.rand(N)
|
|
119
|
-
u *= self.normalization
|
|
120
|
-
samples = self.x_min ** (-self.alpha + 1) - u
|
|
121
|
-
samples = torch.pow(samples, -1.0 / (self.alpha - 1))
|
|
122
|
-
return samples
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|