ml4gw 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ml4gw might be problematic. Click here for more details.
- ml4gw/augmentations.py +8 -2
- ml4gw/constants.py +10 -19
- ml4gw/dataloading/chunked_dataset.py +4 -2
- ml4gw/dataloading/hdf5_dataset.py +1 -1
- ml4gw/dataloading/in_memory_dataset.py +8 -4
- ml4gw/distributions.py +5 -3
- ml4gw/gw.py +21 -27
- ml4gw/nn/autoencoder/base.py +11 -6
- ml4gw/nn/autoencoder/convolutional.py +7 -4
- ml4gw/nn/autoencoder/skip_connection.py +7 -6
- ml4gw/nn/autoencoder/utils.py +2 -1
- ml4gw/nn/norm.py +5 -1
- ml4gw/nn/streaming/online_average.py +7 -5
- ml4gw/nn/streaming/snapshotter.py +7 -5
- ml4gw/spectral.py +41 -37
- ml4gw/transforms/__init__.py +1 -0
- ml4gw/transforms/pearson.py +7 -3
- ml4gw/transforms/qtransform.py +151 -53
- ml4gw/transforms/scaler.py +9 -3
- ml4gw/transforms/snr_rescaler.py +6 -5
- ml4gw/transforms/spectral.py +9 -2
- ml4gw/transforms/spectrogram.py +7 -1
- ml4gw/transforms/spline_interpolation.py +370 -0
- ml4gw/transforms/transform.py +4 -3
- ml4gw/transforms/waveforms.py +10 -7
- ml4gw/transforms/whitening.py +12 -4
- ml4gw/types.py +25 -10
- ml4gw/utils/interferometer.py +1 -1
- ml4gw/utils/slicing.py +24 -16
- ml4gw/waveforms/__init__.py +2 -5
- ml4gw/waveforms/adhoc/__init__.py +2 -0
- ml4gw/waveforms/{ringdown.py → adhoc/ringdown.py} +8 -9
- ml4gw/waveforms/{sine_gaussian.py → adhoc/sine_gaussian.py} +6 -6
- ml4gw/waveforms/cbc/__init__.py +3 -0
- ml4gw/waveforms/{phenom_d.py → cbc/phenom_d.py} +20 -18
- ml4gw/waveforms/{phenom_p.py → cbc/phenom_p.py} +106 -95
- ml4gw/waveforms/{taylorf2.py → cbc/taylorf2.py} +33 -27
- ml4gw/waveforms/conversion.py +187 -0
- ml4gw/waveforms/generator.py +9 -5
- {ml4gw-0.5.0.dist-info → ml4gw-0.6.0.dist-info}/METADATA +4 -3
- ml4gw-0.6.0.dist-info/RECORD +51 -0
- {ml4gw-0.5.0.dist-info → ml4gw-0.6.0.dist-info}/WHEEL +1 -1
- ml4gw-0.5.0.dist-info/RECORD +0 -47
- /ml4gw/waveforms/{phenom_d_data.py → cbc/phenom_d_data.py} +0 -0
ml4gw/transforms/spectral.py
CHANGED
|
@@ -1,8 +1,11 @@
|
|
|
1
1
|
from typing import Optional
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
+
from jaxtyping import Float
|
|
5
|
+
from torch import Tensor
|
|
4
6
|
|
|
5
7
|
from ml4gw.spectral import fast_spectral_density, spectral_density
|
|
8
|
+
from ml4gw.types import FrequencySeries1to3d, TimeSeries1to3d
|
|
6
9
|
|
|
7
10
|
|
|
8
11
|
class SpectralDensity(torch.nn.Module):
|
|
@@ -51,7 +54,9 @@ class SpectralDensity(torch.nn.Module):
|
|
|
51
54
|
fftlength: float,
|
|
52
55
|
overlap: Optional[float] = None,
|
|
53
56
|
average: str = "mean",
|
|
54
|
-
window: Optional[
|
|
57
|
+
window: Optional[
|
|
58
|
+
Float[Tensor, " {int(fftlength*sample_rate)}"]
|
|
59
|
+
] = None,
|
|
55
60
|
fast: bool = False,
|
|
56
61
|
) -> None:
|
|
57
62
|
if overlap is None:
|
|
@@ -93,7 +98,9 @@ class SpectralDensity(torch.nn.Module):
|
|
|
93
98
|
self.average = average
|
|
94
99
|
self.fast = fast
|
|
95
100
|
|
|
96
|
-
def forward(
|
|
101
|
+
def forward(
|
|
102
|
+
self, x: TimeSeries1to3d, y: Optional[TimeSeries1to3d] = None
|
|
103
|
+
) -> FrequencySeries1to3d:
|
|
97
104
|
if self.fast:
|
|
98
105
|
return fast_spectral_density(
|
|
99
106
|
x,
|
ml4gw/transforms/spectrogram.py
CHANGED
|
@@ -3,8 +3,12 @@ from typing import Dict, List
|
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
import torch.nn.functional as F
|
|
6
|
+
from jaxtyping import Float
|
|
7
|
+
from torch import Tensor
|
|
6
8
|
from torchaudio.transforms import Spectrogram
|
|
7
9
|
|
|
10
|
+
from ml4gw.types import TimeSeries3d
|
|
11
|
+
|
|
8
12
|
|
|
9
13
|
class MultiResolutionSpectrogram(torch.nn.Module):
|
|
10
14
|
"""
|
|
@@ -122,7 +126,9 @@ class MultiResolutionSpectrogram(torch.nn.Module):
|
|
|
122
126
|
|
|
123
127
|
return [dict(zip(kwargs, col)) for col in zip(*kwargs.values())]
|
|
124
128
|
|
|
125
|
-
def forward(
|
|
129
|
+
def forward(
|
|
130
|
+
self, X: TimeSeries3d
|
|
131
|
+
) -> Float[Tensor, "batch channel frequency time"]:
|
|
126
132
|
"""
|
|
127
133
|
Calculate spectrograms of the input tensor and
|
|
128
134
|
combine them into a single spectrogram
|
|
@@ -0,0 +1,370 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Adaptation of code from https://github.com/dottormale/Qtransform
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Optional, Tuple
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
from torch import Tensor
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SplineInterpolate(torch.nn.Module):
|
|
13
|
+
"""
|
|
14
|
+
Perform 1D or 2D spline interpolation based on De Boor's method.
|
|
15
|
+
Supports batched, multi-channel inputs, so acceptable data
|
|
16
|
+
shapes are `(width)`, `(height, width)`, `(batch, width)`,
|
|
17
|
+
`(batch, height, width)`, `(batch, channel, width)`, and
|
|
18
|
+
`(batch, channel, height, width)`.
|
|
19
|
+
|
|
20
|
+
During initialization of this Module, both the desired input
|
|
21
|
+
and output coordinate Tensors can be specified to allow
|
|
22
|
+
pre-computation of the B-spline basis matrices, though the only
|
|
23
|
+
mandatory argument is the coordinates of the data along the
|
|
24
|
+
`width` dimension. If no argument is given for coordinates along
|
|
25
|
+
the `height` dimension, it is assumed that 1D interpolation is
|
|
26
|
+
desired.
|
|
27
|
+
|
|
28
|
+
Unlike scipy's implementation of spline interpolation, the data
|
|
29
|
+
to be interpolated is not passed until actually calling the
|
|
30
|
+
object. This is useful for cases where the input and output
|
|
31
|
+
coordinates are known in advance, but the data is not, so that
|
|
32
|
+
the interpolator can be set up ahead of time.
|
|
33
|
+
|
|
34
|
+
WARNING: compared to scipy's spline interpolation, this method
|
|
35
|
+
produces edge artifacts when the output coordinates are near
|
|
36
|
+
the boundaries of the input coordinates. Therefore, it is
|
|
37
|
+
recommended to interpolate only to coordinates that are well
|
|
38
|
+
within the input coordinate range. Unfortunately, the specific
|
|
39
|
+
definition of "well within" changes based on the size of the
|
|
40
|
+
data, so some testing may be required to get good results.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
x_in:
|
|
44
|
+
Coordinates of the width dimension of the data
|
|
45
|
+
y_in:
|
|
46
|
+
Coordinates of the height dimension of the data. If not
|
|
47
|
+
specified, it is assumed the 1D interpolation is desired,
|
|
48
|
+
and so the default value is a Tensor of length 1
|
|
49
|
+
kx:
|
|
50
|
+
Degree of spline interpolation along the width dimension.
|
|
51
|
+
Default is cubic.
|
|
52
|
+
ky:
|
|
53
|
+
Degree of spline interpolation along the height dimension.
|
|
54
|
+
Default is cubic.
|
|
55
|
+
sx:
|
|
56
|
+
Regularization factor to avoid singularities during matrix
|
|
57
|
+
inversion for interpolation along the width dimension. Not
|
|
58
|
+
to be confused with the `s` parameter in scipy's spline
|
|
59
|
+
methods, which controls the number of knots.
|
|
60
|
+
sy:
|
|
61
|
+
Regularization factor to avoid singularities during matrix
|
|
62
|
+
inversion for interpolation along the height dimension.
|
|
63
|
+
x_out:
|
|
64
|
+
Coordinates for the data to be interpolated to along the
|
|
65
|
+
width dimension. If not specified during initialization,
|
|
66
|
+
this must be specified during the object call.
|
|
67
|
+
y_out:
|
|
68
|
+
Coordinates for the data to be interpolated to along the
|
|
69
|
+
height dimension. If not specified during initialization,
|
|
70
|
+
this must be specified during the object call.
|
|
71
|
+
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
def __init__(
|
|
75
|
+
self,
|
|
76
|
+
x_in: Tensor,
|
|
77
|
+
y_in: Tensor = Tensor([1]),
|
|
78
|
+
kx: int = 3,
|
|
79
|
+
ky: int = 3,
|
|
80
|
+
sx: float = 0.001,
|
|
81
|
+
sy: float = 0.001,
|
|
82
|
+
x_out: Optional[Tensor] = None,
|
|
83
|
+
y_out: Optional[Tensor] = None,
|
|
84
|
+
):
|
|
85
|
+
super().__init__()
|
|
86
|
+
self.kx = kx
|
|
87
|
+
self.ky = ky
|
|
88
|
+
self.sx = sx
|
|
89
|
+
self.sy = sy
|
|
90
|
+
self.register_buffer("x_in", x_in)
|
|
91
|
+
self.register_buffer("y_in", y_in)
|
|
92
|
+
self.register_buffer("x_out", x_out)
|
|
93
|
+
self.register_buffer("y_out", y_out)
|
|
94
|
+
|
|
95
|
+
tx, Bx, BxT_Bx = self._compute_knots_and_basis_matrices(x_in, kx, sx)
|
|
96
|
+
self.register_buffer("tx", tx)
|
|
97
|
+
self.register_buffer("Bx", Bx)
|
|
98
|
+
self.register_buffer("BxT_Bx", BxT_Bx)
|
|
99
|
+
|
|
100
|
+
ty, By, ByT_By = self._compute_knots_and_basis_matrices(y_in, ky, sy)
|
|
101
|
+
self.register_buffer("ty", ty)
|
|
102
|
+
self.register_buffer("By", By)
|
|
103
|
+
self.register_buffer("ByT_By", ByT_By)
|
|
104
|
+
|
|
105
|
+
if self.x_out is not None:
|
|
106
|
+
Bx_out = self.bspline_basis_natural(x_out, kx, self.tx)
|
|
107
|
+
self.register_buffer("Bx_out", Bx_out)
|
|
108
|
+
if self.y_out is not None:
|
|
109
|
+
By_out = self.bspline_basis_natural(y_out, ky, self.ty)
|
|
110
|
+
self.register_buffer("By_out", By_out)
|
|
111
|
+
|
|
112
|
+
def _compute_knots_and_basis_matrices(self, x, k, s):
|
|
113
|
+
knots = self.generate_natural_knots(x, k)
|
|
114
|
+
basis_matrix = self.bspline_basis_natural(x, k, knots)
|
|
115
|
+
identity = torch.eye(basis_matrix.shape[-1])
|
|
116
|
+
B_T_B = basis_matrix.T @ basis_matrix + s * identity
|
|
117
|
+
return knots, basis_matrix, B_T_B
|
|
118
|
+
|
|
119
|
+
def generate_natural_knots(self, x: Tensor, k: int) -> Tensor:
|
|
120
|
+
"""
|
|
121
|
+
Generates a natural knot sequence for B-spline interpolation.
|
|
122
|
+
Natural knot sequence means that 2*k knots are added to the beginning
|
|
123
|
+
and end of datapoints as replicas of first and last datapoint
|
|
124
|
+
respectively in order to enforce natural boundary conditions,
|
|
125
|
+
i.e. second derivative = 0.
|
|
126
|
+
The other n nodes are placed in correspondece of the data points.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
x: Tensor of data point positions.
|
|
130
|
+
k: Degree of the spline.
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
Tensor of knot positions.
|
|
134
|
+
"""
|
|
135
|
+
return F.pad(x[None], (k, k), mode="replicate")[0]
|
|
136
|
+
|
|
137
|
+
def compute_L_R(
|
|
138
|
+
self,
|
|
139
|
+
x: Tensor,
|
|
140
|
+
t: Tensor,
|
|
141
|
+
d: int,
|
|
142
|
+
m: int,
|
|
143
|
+
) -> Tuple[Tensor, Tensor]:
|
|
144
|
+
|
|
145
|
+
"""
|
|
146
|
+
Compute the L and R values for B-spline basis functions.
|
|
147
|
+
L and R are respectively the first and second coefficient multiplying
|
|
148
|
+
B_{i,p-1}(x) and B_{i+1,p-1}(x) in De Boor's recursive formula for
|
|
149
|
+
Bspline basis funciton computation
|
|
150
|
+
See https://en.wikipedia.org/wiki/De_Boor%27s_algorithm for details
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
x:
|
|
154
|
+
Tensor of data point positions.
|
|
155
|
+
t:
|
|
156
|
+
Tensor of knot positions.
|
|
157
|
+
d:
|
|
158
|
+
Current degree of the basis function.
|
|
159
|
+
m:
|
|
160
|
+
Number of intervals (n - k - 1, where n is the number of knots
|
|
161
|
+
and k is the degree).
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
L: Tensor containing left values for the B-spline basis functions.
|
|
165
|
+
R: Tensor containing right values for the B-spline basis functions.
|
|
166
|
+
"""
|
|
167
|
+
left_num = x.unsqueeze(1) - t[:m].unsqueeze(0)
|
|
168
|
+
left_den = t[d : m + d] - t[:m]
|
|
169
|
+
L = left_num / left_den.unsqueeze(0)
|
|
170
|
+
L = torch.nan_to_num_(L, nan=0.0, posinf=0.0, neginf=0.0)
|
|
171
|
+
|
|
172
|
+
right_num = t[d + 1 : m + d + 1] - x.unsqueeze(1)
|
|
173
|
+
right_den = t[d + 1 : m + d + 1] - t[1 : m + 1]
|
|
174
|
+
R = right_num / right_den.unsqueeze(0)
|
|
175
|
+
R = torch.nan_to_num_(R, nan=0.0, posinf=0.0, neginf=0.0)
|
|
176
|
+
|
|
177
|
+
return L, R
|
|
178
|
+
|
|
179
|
+
def zeroth_order(
|
|
180
|
+
self,
|
|
181
|
+
x: Tensor,
|
|
182
|
+
k: int,
|
|
183
|
+
t: Tensor,
|
|
184
|
+
n: int,
|
|
185
|
+
m: int,
|
|
186
|
+
) -> Tensor:
|
|
187
|
+
|
|
188
|
+
"""
|
|
189
|
+
Compute the zeroth-order B-spline basis functions
|
|
190
|
+
according to de Boors recursive formula.
|
|
191
|
+
See https://en.wikipedia.org/wiki/De_Boor%27s_algorithm for reference
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
x:
|
|
195
|
+
Tensor of data point positions.
|
|
196
|
+
k:
|
|
197
|
+
Degree of the spline.
|
|
198
|
+
t:
|
|
199
|
+
Tensor of knot positions.
|
|
200
|
+
n:
|
|
201
|
+
Number of data points.
|
|
202
|
+
m:
|
|
203
|
+
Number of intervals (n - k - 1, where n is the number of knots
|
|
204
|
+
and k is the degree).
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
b: Tensor containing the zeroth-order B-spline basis functions.
|
|
208
|
+
"""
|
|
209
|
+
b = torch.zeros((n, m, k + 1))
|
|
210
|
+
|
|
211
|
+
mask_lower = t[: m + 1].unsqueeze(0)[:, :-1] <= x.unsqueeze(1)
|
|
212
|
+
mask_upper = x.unsqueeze(1) < t[: m + 1].unsqueeze(0)[:, 1:]
|
|
213
|
+
|
|
214
|
+
b[:, :, 0] = mask_lower & mask_upper
|
|
215
|
+
b[:, 0, 0] = torch.where(x < t[1], torch.ones_like(x), b[:, 0, 0])
|
|
216
|
+
b[:, -1, 0] = torch.where(x >= t[-2], torch.ones_like(x), b[:, -1, 0])
|
|
217
|
+
return b
|
|
218
|
+
|
|
219
|
+
def bspline_basis_natural(
|
|
220
|
+
self,
|
|
221
|
+
x: Tensor,
|
|
222
|
+
k: int,
|
|
223
|
+
t: Tensor,
|
|
224
|
+
) -> Tensor:
|
|
225
|
+
"""
|
|
226
|
+
Compute bspline basis function using de Boor's recursive formula
|
|
227
|
+
(See https://en.wikipedia.org/wiki/De_Boor%27s_algorithm for reference)
|
|
228
|
+
Args:
|
|
229
|
+
x: Tensor of data point positions.
|
|
230
|
+
k: Degree of the spline.
|
|
231
|
+
t: Tensor of knot positions.
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
Tensor containing the kth-order B-spline basis functions
|
|
235
|
+
"""
|
|
236
|
+
|
|
237
|
+
if len(x) == 1:
|
|
238
|
+
return torch.eye(1)
|
|
239
|
+
n = x.shape[0]
|
|
240
|
+
m = t.shape[0] - k - 1
|
|
241
|
+
|
|
242
|
+
# calculate zeroth order basis funciton
|
|
243
|
+
b = self.zeroth_order(x, k, t, n, m)
|
|
244
|
+
|
|
245
|
+
zeros_tensor = torch.zeros(b.shape[0], 1)
|
|
246
|
+
# recursive de Boors formula for bspline basis functions
|
|
247
|
+
for d in range(1, k + 1):
|
|
248
|
+
L, R = self.compute_L_R(x, t, d, m)
|
|
249
|
+
left = L * b[:, :, d - 1]
|
|
250
|
+
|
|
251
|
+
temp_b = torch.cat([b[:, 1:, d - 1], zeros_tensor], dim=1)
|
|
252
|
+
|
|
253
|
+
right = R * temp_b
|
|
254
|
+
b[:, :, d] = left + right
|
|
255
|
+
|
|
256
|
+
return b[:, :, -1]
|
|
257
|
+
|
|
258
|
+
def bivariate_spline_fit_natural(self, Z):
|
|
259
|
+
|
|
260
|
+
if len(Z.shape) == 3:
|
|
261
|
+
Z_Bx = torch.matmul(Z, self.Bx)
|
|
262
|
+
# ((BxT @ Bx)^-1 @ (Z @ Bx)T)T = Z @ BxT^-1
|
|
263
|
+
return torch.linalg.solve(self.BxT_Bx, Z_Bx.mT).mT
|
|
264
|
+
|
|
265
|
+
# Adding batch/channel dimension handling
|
|
266
|
+
# ByT @ Z @ Bx
|
|
267
|
+
ByT_Z_Bx = torch.einsum("ij,bcik,kl->bcjl", self.By, Z, self.Bx)
|
|
268
|
+
# (ByT @ By)^-1 @ (ByT @ Z @ Bx) = By^-1 @ Z @ Bx
|
|
269
|
+
E = torch.linalg.solve(self.ByT_By, ByT_Z_Bx)
|
|
270
|
+
# ((BxT @ Bx)^-1 @ (By^-1 @ Z @ Bx)T)T = By^-1 @ Z @ BxT^-1
|
|
271
|
+
return torch.linalg.solve(self.BxT_Bx, E.mT).mT
|
|
272
|
+
|
|
273
|
+
def evaluate_bivariate_spline(self, C: Tensor):
|
|
274
|
+
"""
|
|
275
|
+
Evaluate a bivariate spline on a grid of x and y points.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
C: Coefficient tensor of shape (batch_size, mx, my).
|
|
279
|
+
|
|
280
|
+
Returns:
|
|
281
|
+
Z_interp: Interpolated values at the grid points.
|
|
282
|
+
"""
|
|
283
|
+
# Perform matrix multiplication using einsum to get Z_interp
|
|
284
|
+
if len(C.shape) == 3:
|
|
285
|
+
return torch.matmul(C, self.Bx_out.mT)
|
|
286
|
+
return torch.einsum("ik,bckm,mj->bcij", self.By_out, C, self.Bx_out.mT)
|
|
287
|
+
|
|
288
|
+
def _validate_inputs(self, Z, x_out, y_out):
|
|
289
|
+
if x_out is None and self.x_out is None:
|
|
290
|
+
raise ValueError(
|
|
291
|
+
"Output x-coordinates were not specified in either object "
|
|
292
|
+
"creation or in forward call"
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
if y_out is None and self.y_out is None:
|
|
296
|
+
y_out = self.y_in
|
|
297
|
+
|
|
298
|
+
dims = len(Z.shape)
|
|
299
|
+
if dims > 4:
|
|
300
|
+
raise ValueError("Input data has more than 4 dimensions")
|
|
301
|
+
|
|
302
|
+
if len(self.y_in) > 1 and dims == 1:
|
|
303
|
+
raise ValueError(
|
|
304
|
+
"An input y-coordinate array with length greater than 1 "
|
|
305
|
+
"was given, but the input data is 1-dimensional. Expected "
|
|
306
|
+
"input data to be at least 2-dimensional"
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
# Expand Z to have 4 dimensions
|
|
310
|
+
# There are 6 valid input shapes: (w), (b, w), (b, c, w),
|
|
311
|
+
# (h, w), (b, h, w), and (b, c, h, w).
|
|
312
|
+
|
|
313
|
+
# If the input y coordinate array has length 1,
|
|
314
|
+
# assume the first dimension(s) are batch dimensions
|
|
315
|
+
# and that no height dimension is included in Z
|
|
316
|
+
idx = -2 if len(self.y_in) == 1 else -3
|
|
317
|
+
while len(Z.shape) < 4:
|
|
318
|
+
Z = Z.unsqueeze(idx)
|
|
319
|
+
|
|
320
|
+
if Z.shape[-2:] != torch.Size([len(self.y_in), len(self.x_in)]):
|
|
321
|
+
raise ValueError(
|
|
322
|
+
"The spatial dimensions of the data tensor do not match "
|
|
323
|
+
"the given input dimensions. "
|
|
324
|
+
f"Expected [{len(self.y_in)}, {len(self.x_in)}], but got "
|
|
325
|
+
f"[{Z.shape[-2]}, {Z.shape[-1]}]"
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
return Z, y_out
|
|
329
|
+
|
|
330
|
+
def forward(
|
|
331
|
+
self,
|
|
332
|
+
Z: Tensor,
|
|
333
|
+
x_out: Optional[Tensor] = None,
|
|
334
|
+
y_out: Optional[Tensor] = None,
|
|
335
|
+
) -> Tensor:
|
|
336
|
+
"""
|
|
337
|
+
Compute the interpolated data
|
|
338
|
+
|
|
339
|
+
Args:
|
|
340
|
+
Z:
|
|
341
|
+
Tensor of data to be interpolated. Must be between 1 and 4
|
|
342
|
+
dimensions. The shape of the tensor must agree with the
|
|
343
|
+
input coordinates given on initialization. If `y_in` was
|
|
344
|
+
not specified during initialization, it is assumed that
|
|
345
|
+
Z does not have a height dimension.
|
|
346
|
+
x_out:
|
|
347
|
+
Coordinates to interpolate the data to along the width
|
|
348
|
+
dimension. Overrides any value that was set during
|
|
349
|
+
initialization.
|
|
350
|
+
y_out:
|
|
351
|
+
Coordinates to interpolate the data to along the height
|
|
352
|
+
dimension. Overrides any value that was set during
|
|
353
|
+
initialization.
|
|
354
|
+
|
|
355
|
+
Returns:
|
|
356
|
+
A 4D tensor with shape `(batch, channel, height, width)`.
|
|
357
|
+
Depending on the input data shape, many of these dimensions
|
|
358
|
+
may have length 1.
|
|
359
|
+
"""
|
|
360
|
+
|
|
361
|
+
Z, y_out = self._validate_inputs(Z, x_out, y_out)
|
|
362
|
+
|
|
363
|
+
if x_out is not None:
|
|
364
|
+
self.Bx_out = self.bspline_basis_natural(x_out, self.kx, self.tx)
|
|
365
|
+
if y_out is not None:
|
|
366
|
+
self.By_out = self.bspline_basis_natural(y_out, self.ky, self.ty)
|
|
367
|
+
|
|
368
|
+
coef = self.bivariate_spline_fit_natural(Z)
|
|
369
|
+
Z_interp = self.evaluate_bivariate_spline(coef)
|
|
370
|
+
return Z_interp
|
ml4gw/transforms/transform.py
CHANGED
|
@@ -3,6 +3,7 @@ from typing import Optional
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
5
|
from ml4gw.spectral import spectral_density
|
|
6
|
+
from ml4gw.types import FrequencySeries1to3d, TimeSeries1to3d
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
class FittableTransform(torch.nn.Module):
|
|
@@ -43,12 +44,12 @@ class FittableTransform(torch.nn.Module):
|
|
|
43
44
|
class FittableSpectralTransform(FittableTransform):
|
|
44
45
|
def normalize_psd(
|
|
45
46
|
self,
|
|
46
|
-
x,
|
|
47
|
+
x: TimeSeries1to3d,
|
|
47
48
|
sample_rate: float,
|
|
48
49
|
num_freqs: int,
|
|
49
50
|
fftlength: Optional[float] = None,
|
|
50
51
|
overlap: Optional[float] = None,
|
|
51
|
-
):
|
|
52
|
+
) -> FrequencySeries1to3d:
|
|
52
53
|
# if we specified an FFT length, convert
|
|
53
54
|
# the (assumed) time-domain data to the
|
|
54
55
|
# frequency domain
|
|
@@ -68,7 +69,7 @@ class FittableSpectralTransform(FittableTransform):
|
|
|
68
69
|
scale=scale,
|
|
69
70
|
)
|
|
70
71
|
|
|
71
|
-
# add two dummy dimensions in case we need to
|
|
72
|
+
# add two dummy dimensions in case we need to interpolate
|
|
72
73
|
# the frequency dimension, since `interpolate` expects
|
|
73
74
|
# a (batch, channel, spatial) formatted tensor as input
|
|
74
75
|
x = x.view(1, 1, -1)
|
ml4gw/transforms/waveforms.py
CHANGED
|
@@ -1,8 +1,11 @@
|
|
|
1
1
|
from typing import List, Optional
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
+
from jaxtyping import Float
|
|
5
|
+
from torch import Tensor
|
|
4
6
|
|
|
5
7
|
from ml4gw import gw
|
|
8
|
+
from ml4gw.types import BatchTensor
|
|
6
9
|
|
|
7
10
|
|
|
8
11
|
# TODO: should these live in ml4gw.waveforms submodule?
|
|
@@ -10,8 +13,8 @@ from ml4gw import gw
|
|
|
10
13
|
class WaveformSampler(torch.nn.Module):
|
|
11
14
|
def __init__(
|
|
12
15
|
self,
|
|
13
|
-
parameters: Optional[
|
|
14
|
-
**polarizations:
|
|
16
|
+
parameters: Optional[Float[Tensor, "batch num_params"]] = None,
|
|
17
|
+
**polarizations: Float[Tensor, "batch time"],
|
|
15
18
|
):
|
|
16
19
|
super().__init__()
|
|
17
20
|
# make sure we have the same number of waveforms
|
|
@@ -29,7 +32,7 @@ class WaveformSampler(torch.nn.Module):
|
|
|
29
32
|
elif num_waveforms is None:
|
|
30
33
|
num_waveforms = tensor.shape[0]
|
|
31
34
|
|
|
32
|
-
self.polarizations[polarization] =
|
|
35
|
+
self.polarizations[polarization] = Tensor(tensor)
|
|
33
36
|
|
|
34
37
|
if parameters is not None and len(parameters) != num_waveforms:
|
|
35
38
|
raise ValueError(
|
|
@@ -73,10 +76,10 @@ class WaveformProjector(torch.nn.Module):
|
|
|
73
76
|
|
|
74
77
|
def forward(
|
|
75
78
|
self,
|
|
76
|
-
dec:
|
|
77
|
-
psi:
|
|
78
|
-
phi:
|
|
79
|
-
**polarizations,
|
|
79
|
+
dec: BatchTensor,
|
|
80
|
+
psi: BatchTensor,
|
|
81
|
+
phi: BatchTensor,
|
|
82
|
+
**polarizations: Float[Tensor, "batch time"],
|
|
80
83
|
):
|
|
81
84
|
ifo_responses = gw.compute_observed_strain(
|
|
82
85
|
dec,
|
ml4gw/transforms/whitening.py
CHANGED
|
@@ -1,9 +1,15 @@
|
|
|
1
|
-
from typing import Optional
|
|
1
|
+
from typing import Optional, Union
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
5
|
from ml4gw import spectral
|
|
6
6
|
from ml4gw.transforms.transform import FittableSpectralTransform
|
|
7
|
+
from ml4gw.types import (
|
|
8
|
+
FrequencySeries1d,
|
|
9
|
+
FrequencySeries1to3d,
|
|
10
|
+
TimeSeries1d,
|
|
11
|
+
TimeSeries3d,
|
|
12
|
+
)
|
|
7
13
|
|
|
8
14
|
|
|
9
15
|
class Whiten(torch.nn.Module):
|
|
@@ -58,7 +64,9 @@ class Whiten(torch.nn.Module):
|
|
|
58
64
|
window = torch.hann_window(size, dtype=torch.float64)
|
|
59
65
|
self.register_buffer("window", window)
|
|
60
66
|
|
|
61
|
-
def forward(
|
|
67
|
+
def forward(
|
|
68
|
+
self, X: TimeSeries3d, psd: FrequencySeries1to3d
|
|
69
|
+
) -> TimeSeries3d:
|
|
62
70
|
"""
|
|
63
71
|
Whiten a batch of multichannel timeseries by a
|
|
64
72
|
background power spectral density.
|
|
@@ -142,7 +150,7 @@ class FixedWhiten(FittableSpectralTransform):
|
|
|
142
150
|
def fit(
|
|
143
151
|
self,
|
|
144
152
|
fduration: float,
|
|
145
|
-
*background:
|
|
153
|
+
*background: Union[TimeSeries1d, FrequencySeries1d],
|
|
146
154
|
fftlength: Optional[float] = None,
|
|
147
155
|
highpass: Optional[float] = None,
|
|
148
156
|
overlap: Optional[float] = None
|
|
@@ -224,7 +232,7 @@ class FixedWhiten(FittableSpectralTransform):
|
|
|
224
232
|
fduration = torch.Tensor([fduration])
|
|
225
233
|
self.build(psd=psd, fduration=fduration)
|
|
226
234
|
|
|
227
|
-
def forward(self, X:
|
|
235
|
+
def forward(self, X: TimeSeries3d) -> TimeSeries3d:
|
|
228
236
|
"""
|
|
229
237
|
Whiten the input timeseries tensor using the
|
|
230
238
|
PSD fit by the `.fit` method, which must be
|
ml4gw/types.py
CHANGED
|
@@ -1,10 +1,25 @@
|
|
|
1
|
-
from
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
1
|
+
from typing import Union
|
|
2
|
+
|
|
3
|
+
from jaxtyping import Float
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
WaveformTensor = Float[Tensor, "batch num_ifos time"]
|
|
7
|
+
PSDTensor = Float[Tensor, "num_ifos frequency"]
|
|
8
|
+
BatchTensor = Float[Tensor, "batch"]
|
|
9
|
+
VectorGeometry = Float[Tensor, "batch space"]
|
|
10
|
+
TensorGeometry = Float[Tensor, "batch space space"]
|
|
11
|
+
NetworkVertices = Float[Tensor, "num_ifos 3"]
|
|
12
|
+
NetworkDetectorTensors = Float[Tensor, "num_ifos 3 3"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
TimeSeries1d = Float[Tensor, "time"]
|
|
16
|
+
TimeSeries2d = Float[TimeSeries1d, "channel"]
|
|
17
|
+
TimeSeries3d = Float[TimeSeries2d, "batch"]
|
|
18
|
+
TimeSeries1to3d = Union[TimeSeries1d, TimeSeries2d, TimeSeries3d]
|
|
19
|
+
|
|
20
|
+
FrequencySeries1d = Float[Tensor, "frequency"]
|
|
21
|
+
FrequencySeries2d = Float[FrequencySeries1d, "channel"]
|
|
22
|
+
FrequencySeries3d = Float[FrequencySeries2d, "batch"]
|
|
23
|
+
FrequencySeries1to3d = Union[
|
|
24
|
+
FrequencySeries1d, FrequencySeries2d, FrequencySeries3d
|
|
25
|
+
]
|
ml4gw/utils/interferometer.py
CHANGED
|
@@ -4,7 +4,7 @@ import torch
|
|
|
4
4
|
# based on values from
|
|
5
5
|
# https://lscsoft.docs.ligo.org/lalsuite/lal/_l_a_l_detectors_8h_source.html
|
|
6
6
|
class InterferometerGeometry:
|
|
7
|
-
def __init__(self, name: str):
|
|
7
|
+
def __init__(self, name: str) -> None:
|
|
8
8
|
if name == "H1":
|
|
9
9
|
self.x_arm = torch.Tensor(
|
|
10
10
|
(-0.22389266154, +0.79983062746, +0.55690487831)
|
ml4gw/utils/slicing.py
CHANGED
|
@@ -1,25 +1,30 @@
|
|
|
1
1
|
from typing import Optional, Union
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
+
from jaxtyping import Float, Int64
|
|
5
|
+
from torch import Tensor
|
|
4
6
|
from torch.nn.functional import unfold
|
|
5
|
-
from torchtyping import TensorType
|
|
6
7
|
|
|
7
|
-
|
|
8
|
-
|
|
8
|
+
from ml4gw.types import (
|
|
9
|
+
TimeSeries1d,
|
|
10
|
+
TimeSeries1to3d,
|
|
11
|
+
TimeSeries2d,
|
|
12
|
+
TimeSeries3d,
|
|
13
|
+
)
|
|
9
14
|
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
BatchTimeSeriesTensor = Union[
|
|
13
|
-
TensorType["batch", "time"], TensorType["batch", "channel", "time"]
|
|
14
|
-
]
|
|
15
|
+
BatchTimeSeriesTensor = Union[Float[Tensor, "batch time"], TimeSeries3d]
|
|
15
16
|
|
|
16
17
|
|
|
17
18
|
def unfold_windows(
|
|
18
|
-
x:
|
|
19
|
+
x: TimeSeries1to3d,
|
|
19
20
|
window_size: int,
|
|
20
21
|
stride: int,
|
|
21
22
|
drop_last: bool = True,
|
|
22
|
-
)
|
|
23
|
+
) -> Union[
|
|
24
|
+
Float[TimeSeries1d, " window"],
|
|
25
|
+
Float[TimeSeries2d, " window"],
|
|
26
|
+
Float[TimeSeries3d, " window"],
|
|
27
|
+
]:
|
|
23
28
|
"""Unfold a timeseries into windows
|
|
24
29
|
|
|
25
30
|
Args:
|
|
@@ -83,8 +88,8 @@ def unfold_windows(
|
|
|
83
88
|
|
|
84
89
|
|
|
85
90
|
def slice_kernels(
|
|
86
|
-
x:
|
|
87
|
-
idx:
|
|
91
|
+
x: TimeSeries1to3d,
|
|
92
|
+
idx: Int64[Tensor, "..."],
|
|
88
93
|
kernel_size: int,
|
|
89
94
|
) -> BatchTimeSeriesTensor:
|
|
90
95
|
"""Slice kernels from single or multichannel timeseries
|
|
@@ -96,7 +101,8 @@ def slice_kernels(
|
|
|
96
101
|
one more dimension than `x`.
|
|
97
102
|
|
|
98
103
|
Args:
|
|
99
|
-
x:
|
|
104
|
+
x:
|
|
105
|
+
The timeseries tensor to slice kernels from
|
|
100
106
|
idx:
|
|
101
107
|
The indices in `x` of the first sample of each
|
|
102
108
|
kernel. If `x` is 1D, `idx` must be 1D as well.
|
|
@@ -114,6 +120,7 @@ def slice_kernels(
|
|
|
114
120
|
coincidentally among the channels.
|
|
115
121
|
kernel_size:
|
|
116
122
|
The length of the kernels to slice from the timeseries
|
|
123
|
+
|
|
117
124
|
Returns:
|
|
118
125
|
A tensor of shape `(batch_size, kernel_size)` if `x` is
|
|
119
126
|
1D and `(batch_size, num_channels, kernel_size)` if `x`
|
|
@@ -225,7 +232,7 @@ def slice_kernels(
|
|
|
225
232
|
|
|
226
233
|
|
|
227
234
|
def sample_kernels(
|
|
228
|
-
X:
|
|
235
|
+
X: TimeSeries1to3d,
|
|
229
236
|
kernel_size: int,
|
|
230
237
|
N: Optional[int] = None,
|
|
231
238
|
max_center_offset: Optional[int] = None,
|
|
@@ -245,8 +252,9 @@ def sample_kernels(
|
|
|
245
252
|
either be `None` or be equal to `len(X)`.
|
|
246
253
|
|
|
247
254
|
Args:
|
|
248
|
-
X:
|
|
249
|
-
|
|
255
|
+
X:
|
|
256
|
+
The timeseries tensor from which to sample kernels
|
|
257
|
+
kernel_size: The size of the kernels to sample
|
|
250
258
|
N:
|
|
251
259
|
The number of kernels to sample. Can be left as
|
|
252
260
|
`None` if `X` is 3D, otherwise must be specified
|