ml4gw 0.7.6__py3-none-any.whl → 0.7.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4gw/augmentations.py +5 -0
- ml4gw/dataloading/__init__.py +5 -0
- ml4gw/dataloading/chunked_dataset.py +2 -4
- ml4gw/dataloading/hdf5_dataset.py +12 -10
- ml4gw/dataloading/in_memory_dataset.py +12 -12
- ml4gw/distributions.py +3 -3
- ml4gw/gw.py +18 -21
- ml4gw/nn/__init__.py +6 -0
- ml4gw/nn/autoencoder/base.py +5 -9
- ml4gw/nn/autoencoder/convolutional.py +7 -10
- ml4gw/nn/autoencoder/skip_connection.py +3 -5
- ml4gw/nn/norm.py +4 -4
- ml4gw/nn/resnet/resnet_1d.py +12 -13
- ml4gw/nn/resnet/resnet_2d.py +13 -14
- ml4gw/nn/streaming/online_average.py +3 -5
- ml4gw/nn/streaming/snapshotter.py +10 -14
- ml4gw/spectral.py +20 -23
- ml4gw/transforms/__init__.py +7 -1
- ml4gw/transforms/decimator.py +183 -0
- ml4gw/transforms/iirfilter.py +3 -5
- ml4gw/transforms/pearson.py +3 -4
- ml4gw/transforms/qtransform.py +20 -26
- ml4gw/transforms/scaler.py +3 -5
- ml4gw/transforms/snr_rescaler.py +7 -11
- ml4gw/transforms/spectral.py +6 -13
- ml4gw/transforms/spectrogram.py +6 -3
- ml4gw/transforms/spline_interpolation.py +312 -143
- ml4gw/transforms/transform.py +4 -6
- ml4gw/transforms/waveforms.py +8 -15
- ml4gw/transforms/whitening.py +11 -16
- ml4gw/types.py +8 -5
- ml4gw/utils/interferometer.py +20 -3
- ml4gw/utils/slicing.py +26 -30
- ml4gw/waveforms/__init__.py +6 -0
- ml4gw/waveforms/cbc/phenom_p.py +7 -9
- ml4gw/waveforms/conversion.py +2 -4
- ml4gw/waveforms/generator.py +3 -3
- {ml4gw-0.7.6.dist-info → ml4gw-0.7.8.dist-info}/METADATA +33 -12
- ml4gw-0.7.8.dist-info/RECORD +57 -0
- {ml4gw-0.7.6.dist-info → ml4gw-0.7.8.dist-info}/WHEEL +2 -1
- ml4gw-0.7.8.dist-info/top_level.txt +1 -0
- ml4gw-0.7.6.dist-info/RECORD +0 -55
- {ml4gw-0.7.6.dist-info → ml4gw-0.7.8.dist-info}/licenses/LICENSE +0 -0
ml4gw/transforms/spectrogram.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import warnings
|
|
2
|
-
from typing import Dict, List
|
|
3
2
|
|
|
4
3
|
import torch
|
|
5
4
|
import torch.nn.functional as F
|
|
@@ -104,7 +103,7 @@ class MultiResolutionSpectrogram(torch.nn.Module):
|
|
|
104
103
|
self.register_buffer("freq_idxs", freq_idxs)
|
|
105
104
|
self.register_buffer("time_idxs", time_idxs)
|
|
106
105
|
|
|
107
|
-
def _check_and_format_kwargs(self, kwargs:
|
|
106
|
+
def _check_and_format_kwargs(self, kwargs: dict[str, list]) -> list:
|
|
108
107
|
lengths = sorted(len(v) for v in kwargs.values())
|
|
109
108
|
lengths = list(set(lengths))
|
|
110
109
|
|
|
@@ -127,7 +126,10 @@ class MultiResolutionSpectrogram(torch.nn.Module):
|
|
|
127
126
|
size = lengths[1]
|
|
128
127
|
kwargs = {k: v * int(size / len(v)) for k, v in kwargs.items()}
|
|
129
128
|
|
|
130
|
-
return [
|
|
129
|
+
return [
|
|
130
|
+
dict(zip(kwargs, col, strict=True))
|
|
131
|
+
for col in zip(*kwargs.values(), strict=True)
|
|
132
|
+
]
|
|
131
133
|
|
|
132
134
|
def forward(
|
|
133
135
|
self, X: TimeSeries3d
|
|
@@ -161,6 +163,7 @@ class MultiResolutionSpectrogram(torch.nn.Module):
|
|
|
161
163
|
self.right_pad,
|
|
162
164
|
self.top_pad,
|
|
163
165
|
self.bottom_pad,
|
|
166
|
+
strict=True,
|
|
164
167
|
):
|
|
165
168
|
padded_specs.append(F.pad(spec, (left, right, top, bottom)))
|
|
166
169
|
|
|
@@ -1,131 +1,27 @@
|
|
|
1
1
|
"""
|
|
2
|
-
Adaptation of code from https://github.com/dottormale/
|
|
2
|
+
Adaptation of code from https://github.com/dottormale/Qtransform_torch/
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from typing import Optional, Tuple
|
|
6
|
-
|
|
7
5
|
import torch
|
|
8
|
-
import torch.nn.functional as F
|
|
9
6
|
from torch import Tensor
|
|
10
7
|
|
|
11
8
|
|
|
12
|
-
class
|
|
9
|
+
class SplineInterpolateBase(torch.nn.Module):
|
|
13
10
|
"""
|
|
14
|
-
|
|
15
|
-
Supports batched, multi-channel inputs, so acceptable data
|
|
16
|
-
shapes are ``(width)``, ``(height, width)``, ``(batch, width)``,
|
|
17
|
-
``(batch, height, width)``, ``(batch, channel, width)``, and
|
|
18
|
-
``(batch, channel, height, width)``.
|
|
19
|
-
|
|
20
|
-
During initialization of this Module, both the desired input
|
|
21
|
-
and output coordinate Tensors can be specified to allow
|
|
22
|
-
pre-computation of the B-spline basis matrices, though the only
|
|
23
|
-
mandatory argument is the coordinates of the data along the
|
|
24
|
-
``width`` dimension. If no argument is given for coordinates along
|
|
25
|
-
the ``height`` dimension, it is assumed that 1D interpolation is
|
|
26
|
-
desired.
|
|
27
|
-
|
|
28
|
-
Unlike scipy's implementation of spline interpolation, the data
|
|
29
|
-
to be interpolated is not passed until actually calling the
|
|
30
|
-
object. This is useful for cases where the input and output
|
|
31
|
-
coordinates are known in advance, but the data is not, so that
|
|
32
|
-
the interpolator can be set up ahead of time.
|
|
33
|
-
|
|
34
|
-
WARNING: compared to scipy's spline interpolation, this method
|
|
35
|
-
produces edge artifacts when the output coordinates are near
|
|
36
|
-
the boundaries of the input coordinates. Therefore, it is
|
|
37
|
-
recommended to interpolate only to coordinates that are well
|
|
38
|
-
within the input coordinate range. Unfortunately, the specific
|
|
39
|
-
definition of "well within" changes based on the size of the
|
|
40
|
-
data, so some testing may be required to get good results.
|
|
41
|
-
|
|
42
|
-
Args:
|
|
43
|
-
x_in:
|
|
44
|
-
Coordinates of the width dimension of the data
|
|
45
|
-
y_in:
|
|
46
|
-
Coordinates of the height dimension of the data. If not
|
|
47
|
-
specified, it is assumed the 1D interpolation is desired,
|
|
48
|
-
and so the default value is a Tensor of length 1
|
|
49
|
-
kx:
|
|
50
|
-
Degree of spline interpolation along the width dimension.
|
|
51
|
-
Default is cubic.
|
|
52
|
-
ky:
|
|
53
|
-
Degree of spline interpolation along the height dimension.
|
|
54
|
-
Default is cubic.
|
|
55
|
-
sx:
|
|
56
|
-
Regularization factor to avoid singularities during matrix
|
|
57
|
-
inversion for interpolation along the width dimension. Not
|
|
58
|
-
to be confused with the ``s`` parameter in scipy's spline
|
|
59
|
-
methods, which controls the number of knots.
|
|
60
|
-
sy:
|
|
61
|
-
Regularization factor to avoid singularities during matrix
|
|
62
|
-
inversion for interpolation along the height dimension.
|
|
63
|
-
x_out:
|
|
64
|
-
Coordinates for the data to be interpolated to along the
|
|
65
|
-
width dimension. If not specified during initialization,
|
|
66
|
-
this must be specified during the object call.
|
|
67
|
-
y_out:
|
|
68
|
-
Coordinates for the data to be interpolated to along the
|
|
69
|
-
height dimension. If not specified during initialization,
|
|
70
|
-
this must be specified during the object call.
|
|
71
|
-
|
|
11
|
+
Base class for spline interpolation.
|
|
72
12
|
"""
|
|
73
13
|
|
|
74
|
-
def __init__(
|
|
75
|
-
self,
|
|
76
|
-
x_in: Tensor,
|
|
77
|
-
y_in: Tensor = None,
|
|
78
|
-
kx: int = 3,
|
|
79
|
-
ky: int = 3,
|
|
80
|
-
sx: float = 0.001,
|
|
81
|
-
sy: float = 0.001,
|
|
82
|
-
x_out: Optional[Tensor] = None,
|
|
83
|
-
y_out: Optional[Tensor] = None,
|
|
84
|
-
):
|
|
85
|
-
super().__init__()
|
|
86
|
-
if y_in is None:
|
|
87
|
-
y_in = Tensor([1])
|
|
88
|
-
self.kx = kx
|
|
89
|
-
self.ky = ky
|
|
90
|
-
self.sx = sx
|
|
91
|
-
self.sy = sy
|
|
92
|
-
self.register_buffer("x_in", x_in)
|
|
93
|
-
self.register_buffer("y_in", y_in)
|
|
94
|
-
self.register_buffer("x_out", x_out)
|
|
95
|
-
self.register_buffer("y_out", y_out)
|
|
96
|
-
|
|
97
|
-
tx, Bx, BxT_Bx = self._compute_knots_and_basis_matrices(x_in, kx, sx)
|
|
98
|
-
self.register_buffer("tx", tx)
|
|
99
|
-
self.register_buffer("Bx", Bx)
|
|
100
|
-
self.register_buffer("BxT_Bx", BxT_Bx)
|
|
101
|
-
|
|
102
|
-
ty, By, ByT_By = self._compute_knots_and_basis_matrices(y_in, ky, sy)
|
|
103
|
-
self.register_buffer("ty", ty)
|
|
104
|
-
self.register_buffer("By", By)
|
|
105
|
-
self.register_buffer("ByT_By", ByT_By)
|
|
106
|
-
|
|
107
|
-
if self.x_out is not None:
|
|
108
|
-
Bx_out = self.bspline_basis_natural(x_out, kx, self.tx)
|
|
109
|
-
self.register_buffer("Bx_out", Bx_out)
|
|
110
|
-
if self.y_out is not None:
|
|
111
|
-
By_out = self.bspline_basis_natural(y_out, ky, self.ty)
|
|
112
|
-
self.register_buffer("By_out", By_out)
|
|
113
|
-
|
|
114
14
|
def _compute_knots_and_basis_matrices(self, x, k, s):
|
|
115
|
-
knots = self.
|
|
15
|
+
knots = self.generate_fitpack_knots(x, k)
|
|
116
16
|
basis_matrix = self.bspline_basis_natural(x, k, knots)
|
|
117
17
|
identity = torch.eye(basis_matrix.shape[-1])
|
|
118
18
|
B_T_B = basis_matrix.T @ basis_matrix + s * identity
|
|
119
19
|
return knots, basis_matrix, B_T_B
|
|
120
20
|
|
|
121
|
-
def
|
|
21
|
+
def generate_fitpack_knots(self, x: Tensor, k: int) -> Tensor:
|
|
122
22
|
"""
|
|
123
|
-
Generates a
|
|
124
|
-
|
|
125
|
-
and end of datapoints as replicas of first and last datapoint
|
|
126
|
-
respectively in order to enforce natural boundary conditions,
|
|
127
|
-
i.e. second derivative = 0.
|
|
128
|
-
The other n nodes are placed in correspondece of the data points.
|
|
23
|
+
Generates a knot sequence for B-spline interpolation
|
|
24
|
+
in the same way as the FITPACK algorithm used by SciPy.
|
|
129
25
|
|
|
130
26
|
Args:
|
|
131
27
|
x: Tensor of data point positions.
|
|
@@ -134,7 +30,17 @@ class SplineInterpolate(torch.nn.Module):
|
|
|
134
30
|
Returns:
|
|
135
31
|
Tensor of knot positions.
|
|
136
32
|
"""
|
|
137
|
-
|
|
33
|
+
num_knots = x.shape[-1] + k + 1
|
|
34
|
+
knots = torch.zeros(num_knots, dtype=x.dtype)
|
|
35
|
+
knots[: k + 1] = x[0]
|
|
36
|
+
knots[-(k + 1) :] = x[-1]
|
|
37
|
+
|
|
38
|
+
# Interior knots are the rolling average of the data points
|
|
39
|
+
# excluding the first and last points
|
|
40
|
+
windows = x[1:-1].unfold(dimension=-1, size=k, step=1)
|
|
41
|
+
knots[k + 1 : -k - 1] = windows.mean(dim=-1)
|
|
42
|
+
|
|
43
|
+
return knots
|
|
138
44
|
|
|
139
45
|
def compute_L_R(
|
|
140
46
|
self,
|
|
@@ -142,7 +48,7 @@ class SplineInterpolate(torch.nn.Module):
|
|
|
142
48
|
t: Tensor,
|
|
143
49
|
d: int,
|
|
144
50
|
m: int,
|
|
145
|
-
) ->
|
|
51
|
+
) -> tuple[Tensor, Tensor]:
|
|
146
52
|
"""
|
|
147
53
|
Compute the L and R values for B-spline basis functions.
|
|
148
54
|
L and R are respectively the first and second coefficient multiplying
|
|
@@ -233,9 +139,6 @@ class SplineInterpolate(torch.nn.Module):
|
|
|
233
139
|
Returns:
|
|
234
140
|
Tensor containing the kth-order B-spline basis functions
|
|
235
141
|
"""
|
|
236
|
-
|
|
237
|
-
if len(x) == 1:
|
|
238
|
-
return torch.eye(1)
|
|
239
142
|
n = x.shape[0]
|
|
240
143
|
m = t.shape[0] - k - 1
|
|
241
144
|
|
|
@@ -255,6 +158,271 @@ class SplineInterpolate(torch.nn.Module):
|
|
|
255
158
|
|
|
256
159
|
return b[:, :, -1]
|
|
257
160
|
|
|
161
|
+
|
|
162
|
+
class SplineInterpolate1D(SplineInterpolateBase):
|
|
163
|
+
"""
|
|
164
|
+
Perform 1D spline interpolation based on De Boor's method.
|
|
165
|
+
It is allowed to have two spatial dimensions, but the second
|
|
166
|
+
dimension cannot be interpolated along. To interpolate along both
|
|
167
|
+
dimensions, use :class:`SplineInterpolate2D`.
|
|
168
|
+
|
|
169
|
+
Supports batched, multi-channel inputs, so acceptable data
|
|
170
|
+
shapes are ``(width)``, ``(height, width)``, ``(batch, width)``,
|
|
171
|
+
``(batch, height, width)``, ``(batch, channel, width)``, and
|
|
172
|
+
``(batch, channel, height, width)``.
|
|
173
|
+
|
|
174
|
+
During initialization of this Module, both the desired input
|
|
175
|
+
and output coordinate Tensors can be specified to allow
|
|
176
|
+
pre-computation of the B-spline basis matrices, though the only
|
|
177
|
+
mandatory argument is the coordinates of the data along the
|
|
178
|
+
``width`` dimension.
|
|
179
|
+
|
|
180
|
+
Unlike scipy's implementation of spline interpolation, the data
|
|
181
|
+
to be interpolated is not passed until actually calling the
|
|
182
|
+
object. This is useful for cases where the input and output
|
|
183
|
+
coordinates are known in advance, but the data is not, so that
|
|
184
|
+
the interpolator can be set up ahead of time.
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
x_in:
|
|
188
|
+
Coordinates of the width dimension of the data
|
|
189
|
+
kx:
|
|
190
|
+
Degree of spline interpolation along the width dimension.
|
|
191
|
+
Default is cubic.
|
|
192
|
+
sx:
|
|
193
|
+
Regularization factor to avoid singularities during matrix
|
|
194
|
+
inversion for interpolation along the width dimension. Not
|
|
195
|
+
to be confused with the ``s`` parameter in scipy's spline
|
|
196
|
+
methods, which controls the number of knots.
|
|
197
|
+
x_out:
|
|
198
|
+
Coordinates for the data to be interpolated to along the
|
|
199
|
+
width dimension. If not specified during initialization,
|
|
200
|
+
this must be specified during the object call.
|
|
201
|
+
|
|
202
|
+
"""
|
|
203
|
+
|
|
204
|
+
def __init__(
|
|
205
|
+
self,
|
|
206
|
+
x_in: Tensor,
|
|
207
|
+
kx: int = 3,
|
|
208
|
+
sx: float = 0.0,
|
|
209
|
+
x_out: Tensor | None = None,
|
|
210
|
+
):
|
|
211
|
+
super().__init__()
|
|
212
|
+
|
|
213
|
+
if len(x_in) < kx + 2:
|
|
214
|
+
raise ValueError(
|
|
215
|
+
"Input x-coordinates must have at least kx + 2 points."
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
# Ensure that coordinates are floats
|
|
219
|
+
x_in = x_in.float()
|
|
220
|
+
x_out = x_out.float() if x_out is not None else None
|
|
221
|
+
|
|
222
|
+
self.kx = kx
|
|
223
|
+
self.sx = sx
|
|
224
|
+
self.register_buffer("x_in", x_in)
|
|
225
|
+
self.register_buffer("x_out", x_out)
|
|
226
|
+
|
|
227
|
+
tx, Bx, BxT_Bx = self._compute_knots_and_basis_matrices(x_in, kx, sx)
|
|
228
|
+
self.register_buffer("tx", tx)
|
|
229
|
+
self.register_buffer("Bx", Bx)
|
|
230
|
+
self.register_buffer("BxT_Bx", BxT_Bx)
|
|
231
|
+
|
|
232
|
+
if self.x_out is not None:
|
|
233
|
+
x_clamped = torch.clamp(x_out, tx[kx], tx[-kx - 1])
|
|
234
|
+
Bx_out = self.bspline_basis_natural(x_clamped, kx, self.tx)
|
|
235
|
+
self.register_buffer("Bx_out", Bx_out)
|
|
236
|
+
|
|
237
|
+
def spline_fit_natural(self, Z):
|
|
238
|
+
# Adding batch/channel dimension handling
|
|
239
|
+
# Bx @ Z
|
|
240
|
+
BxT_Z = torch.einsum("ij,bchj->bchi", self.Bx.T, Z)
|
|
241
|
+
# (BxT @ Bx)^-1 @ (BxT @ Z) = Bx^-1 @ Z
|
|
242
|
+
C = torch.linalg.solve(self.BxT_Bx, BxT_Z.unsqueeze(-1))
|
|
243
|
+
return C.squeeze(-1)
|
|
244
|
+
|
|
245
|
+
def evaluate_spline(self, C: Tensor):
|
|
246
|
+
"""
|
|
247
|
+
Evaluate a bivariate spline on a grid of x and y points.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
C: Coefficient tensor of shape (batch_size, mx, my).
|
|
251
|
+
|
|
252
|
+
Returns:
|
|
253
|
+
Z_interp: Interpolated values at the grid points.
|
|
254
|
+
"""
|
|
255
|
+
# Perform matrix multiplication using einsum to get Z_interp
|
|
256
|
+
return torch.einsum("ij,bchj->bchi", self.Bx_out, C)
|
|
257
|
+
|
|
258
|
+
def _validate_inputs(self, Z, x_out):
|
|
259
|
+
if x_out is None and self.x_out is None:
|
|
260
|
+
raise ValueError(
|
|
261
|
+
"Output x-coordinates were not specified in either object "
|
|
262
|
+
"creation or in forward call"
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
dims = len(Z.shape)
|
|
266
|
+
if dims > 4:
|
|
267
|
+
raise ValueError("Input data has more than 4 dimensions")
|
|
268
|
+
|
|
269
|
+
if Z.shape[-1] != len(self.x_in):
|
|
270
|
+
raise ValueError(
|
|
271
|
+
"The spatial dimensions of the data tensor do not match "
|
|
272
|
+
"the given input dimensions. "
|
|
273
|
+
f"Expected {len(self.x_in)}, but got {Z.shape[-1]}"
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
# Expand Z to have a batch, channel, and height dimension if needed
|
|
277
|
+
while len(Z.shape) < 4:
|
|
278
|
+
Z = Z.unsqueeze(0)
|
|
279
|
+
|
|
280
|
+
return Z
|
|
281
|
+
|
|
282
|
+
def forward(
|
|
283
|
+
self,
|
|
284
|
+
Z: Tensor,
|
|
285
|
+
x_out: Tensor | None = None,
|
|
286
|
+
) -> Tensor:
|
|
287
|
+
"""
|
|
288
|
+
Compute the interpolated data
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
Z:
|
|
292
|
+
Tensor of data to be interpolated. Must be between 2 and 4
|
|
293
|
+
dimensions. The shape of the tensor must agree with the
|
|
294
|
+
input coordinates given on initialization.
|
|
295
|
+
x_out:
|
|
296
|
+
Coordinates to interpolate the data to along the width
|
|
297
|
+
dimension. Overrides any value that was set during
|
|
298
|
+
initialization.
|
|
299
|
+
|
|
300
|
+
Returns:
|
|
301
|
+
A 4D tensor with shape ``(batch, channel, height, width)``.
|
|
302
|
+
Depending on the input data shape, many of these dimensions
|
|
303
|
+
may have length 1.
|
|
304
|
+
"""
|
|
305
|
+
|
|
306
|
+
Z = self._validate_inputs(Z, x_out)
|
|
307
|
+
|
|
308
|
+
if x_out is not None:
|
|
309
|
+
x_out = x_out.float()
|
|
310
|
+
x_clamped = torch.clamp(
|
|
311
|
+
x_out, self.tx[self.kx], self.tx[-self.kx - 1]
|
|
312
|
+
)
|
|
313
|
+
self.Bx_out = self.bspline_basis_natural(
|
|
314
|
+
x_clamped, self.kx, self.tx
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
coef = self.spline_fit_natural(Z)
|
|
318
|
+
Z_interp = self.evaluate_spline(coef)
|
|
319
|
+
return Z_interp
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
class SplineInterpolate2D(SplineInterpolateBase):
|
|
323
|
+
"""
|
|
324
|
+
Perform 2D spline interpolation based on De Boor's method.
|
|
325
|
+
Supports batched, multi-channel inputs, so acceptable data
|
|
326
|
+
shapes are ``(height, width)``, ``(batch, height, width)``,
|
|
327
|
+
and ``(batch, channel, height, width)``.
|
|
328
|
+
|
|
329
|
+
During initialization of this Module, both the desired input
|
|
330
|
+
and output coordinate Tensors can be specified to allow
|
|
331
|
+
pre-computation of the B-spline basis matrices, though the only
|
|
332
|
+
mandatory arguments are the input coordinates.
|
|
333
|
+
|
|
334
|
+
Unlike scipy's implementation of spline interpolation, the data
|
|
335
|
+
to be interpolated is not passed until actually calling the
|
|
336
|
+
object. This is useful for cases where the input and output
|
|
337
|
+
coordinates are known in advance, but the data is not, so that
|
|
338
|
+
the interpolator can be set up ahead of time.
|
|
339
|
+
|
|
340
|
+
Args:
|
|
341
|
+
x_in:
|
|
342
|
+
Coordinates of the width dimension of the data
|
|
343
|
+
y_in:
|
|
344
|
+
Coordinates of the height dimension of the data.
|
|
345
|
+
kx:
|
|
346
|
+
Degree of spline interpolation along the width dimension.
|
|
347
|
+
Default is cubic.
|
|
348
|
+
ky:
|
|
349
|
+
Degree of spline interpolation along the height dimension.
|
|
350
|
+
Default is cubic.
|
|
351
|
+
sx:
|
|
352
|
+
Regularization factor to avoid singularities during matrix
|
|
353
|
+
inversion for interpolation along the width dimension. Not
|
|
354
|
+
to be confused with the ``s`` parameter in scipy's spline
|
|
355
|
+
methods, which controls the number of knots.
|
|
356
|
+
sy:
|
|
357
|
+
Regularization factor to avoid singularities during matrix
|
|
358
|
+
inversion for interpolation along the height dimension.
|
|
359
|
+
x_out:
|
|
360
|
+
Coordinates for the data to be interpolated to along the
|
|
361
|
+
width dimension. If not specified during initialization,
|
|
362
|
+
this must be specified during the object call.
|
|
363
|
+
y_out:
|
|
364
|
+
Coordinates for the data to be interpolated to along the
|
|
365
|
+
height dimension. If not specified during initialization,
|
|
366
|
+
this must be specified during the object call.
|
|
367
|
+
|
|
368
|
+
"""
|
|
369
|
+
|
|
370
|
+
def __init__(
|
|
371
|
+
self,
|
|
372
|
+
x_in: Tensor,
|
|
373
|
+
y_in: Tensor,
|
|
374
|
+
kx: int = 3,
|
|
375
|
+
ky: int = 3,
|
|
376
|
+
sx: float = 0.0,
|
|
377
|
+
sy: float = 0.0,
|
|
378
|
+
x_out: Tensor | None = None,
|
|
379
|
+
y_out: Tensor | None = None,
|
|
380
|
+
):
|
|
381
|
+
super().__init__()
|
|
382
|
+
|
|
383
|
+
if len(x_in) < kx + 2:
|
|
384
|
+
raise ValueError(
|
|
385
|
+
"Input x-coordinates must have at least kx + 2 points."
|
|
386
|
+
)
|
|
387
|
+
if len(y_in) < ky + 2:
|
|
388
|
+
raise ValueError(
|
|
389
|
+
"Input y-coordinates must have at least ky + 2 points."
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
# Ensure that coordinates are floats
|
|
393
|
+
x_in = x_in.float()
|
|
394
|
+
y_in = y_in.float()
|
|
395
|
+
x_out = x_out.float() if x_out is not None else None
|
|
396
|
+
y_out = y_out.float() if y_out is not None else None
|
|
397
|
+
|
|
398
|
+
self.kx = kx
|
|
399
|
+
self.ky = ky
|
|
400
|
+
self.sx = sx
|
|
401
|
+
self.sy = sy
|
|
402
|
+
self.register_buffer("x_in", x_in)
|
|
403
|
+
self.register_buffer("y_in", y_in)
|
|
404
|
+
self.register_buffer("x_out", x_out)
|
|
405
|
+
self.register_buffer("y_out", y_out)
|
|
406
|
+
|
|
407
|
+
tx, Bx, BxT_Bx = self._compute_knots_and_basis_matrices(x_in, kx, sx)
|
|
408
|
+
self.register_buffer("tx", tx)
|
|
409
|
+
self.register_buffer("Bx", Bx)
|
|
410
|
+
self.register_buffer("BxT_Bx", BxT_Bx)
|
|
411
|
+
|
|
412
|
+
ty, By, ByT_By = self._compute_knots_and_basis_matrices(y_in, ky, sy)
|
|
413
|
+
self.register_buffer("ty", ty)
|
|
414
|
+
self.register_buffer("By", By)
|
|
415
|
+
self.register_buffer("ByT_By", ByT_By)
|
|
416
|
+
|
|
417
|
+
if self.x_out is not None:
|
|
418
|
+
x_clamped = torch.clamp(x_out, tx[kx], tx[-kx - 1])
|
|
419
|
+
Bx_out = self.bspline_basis_natural(x_clamped, kx, self.tx)
|
|
420
|
+
self.register_buffer("Bx_out", Bx_out)
|
|
421
|
+
if self.y_out is not None:
|
|
422
|
+
y_clamped = torch.clamp(y_out, ty[ky], ty[-ky - 1])
|
|
423
|
+
By_out = self.bspline_basis_natural(y_clamped, ky, self.ty)
|
|
424
|
+
self.register_buffer("By_out", By_out)
|
|
425
|
+
|
|
258
426
|
def bivariate_spline_fit_natural(self, Z):
|
|
259
427
|
# Adding batch/channel dimension handling
|
|
260
428
|
# ByT @ Z @ BxW
|
|
@@ -285,29 +453,16 @@ class SplineInterpolate(torch.nn.Module):
|
|
|
285
453
|
)
|
|
286
454
|
|
|
287
455
|
if y_out is None and self.y_out is None:
|
|
288
|
-
|
|
456
|
+
raise ValueError(
|
|
457
|
+
"Output y-coordinates were not specified in either object "
|
|
458
|
+
"creation or in forward call"
|
|
459
|
+
)
|
|
289
460
|
|
|
290
461
|
dims = len(Z.shape)
|
|
291
462
|
if dims > 4:
|
|
292
463
|
raise ValueError("Input data has more than 4 dimensions")
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
raise ValueError(
|
|
296
|
-
"An input y-coordinate array with length greater than 1 "
|
|
297
|
-
"was given, but the input data is 1-dimensional. Expected "
|
|
298
|
-
"input data to be at least 2-dimensional"
|
|
299
|
-
)
|
|
300
|
-
|
|
301
|
-
# Expand Z to have 4 dimensions
|
|
302
|
-
# There are 6 valid input shapes: (w), (b, w), (b, c, w),
|
|
303
|
-
# (h, w), (b, h, w), and (b, c, h, w).
|
|
304
|
-
|
|
305
|
-
# If the input y coordinate array has length 1,
|
|
306
|
-
# assume the first dimension(s) are batch dimensions
|
|
307
|
-
# and that no height dimension is included in Z
|
|
308
|
-
idx = -2 if len(self.y_in) == 1 else -3
|
|
309
|
-
while len(Z.shape) < 4:
|
|
310
|
-
Z = Z.unsqueeze(idx)
|
|
464
|
+
if dims < 2:
|
|
465
|
+
raise ValueError("Input data has fewer than 2 dimensions")
|
|
311
466
|
|
|
312
467
|
if Z.shape[-2:] != torch.Size([len(self.y_in), len(self.x_in)]):
|
|
313
468
|
raise ValueError(
|
|
@@ -317,24 +472,26 @@ class SplineInterpolate(torch.nn.Module):
|
|
|
317
472
|
f"[{Z.shape[-2]}, {Z.shape[-1]}]"
|
|
318
473
|
)
|
|
319
474
|
|
|
475
|
+
# Expand Z to have a batch and channel dimension if needed
|
|
476
|
+
while len(Z.shape) < 4:
|
|
477
|
+
Z = Z.unsqueeze(0)
|
|
478
|
+
|
|
320
479
|
return Z, y_out
|
|
321
480
|
|
|
322
481
|
def forward(
|
|
323
482
|
self,
|
|
324
483
|
Z: Tensor,
|
|
325
|
-
x_out:
|
|
326
|
-
y_out:
|
|
484
|
+
x_out: Tensor | None = None,
|
|
485
|
+
y_out: Tensor | None = None,
|
|
327
486
|
) -> Tensor:
|
|
328
487
|
"""
|
|
329
488
|
Compute the interpolated data
|
|
330
489
|
|
|
331
490
|
Args:
|
|
332
491
|
Z:
|
|
333
|
-
Tensor of data to be interpolated. Must be between
|
|
492
|
+
Tensor of data to be interpolated. Must be between 2 and 4
|
|
334
493
|
dimensions. The shape of the tensor must agree with the
|
|
335
|
-
input coordinates given on initialization.
|
|
336
|
-
not specified during initialization, it is assumed that
|
|
337
|
-
Z does not have a height dimension.
|
|
494
|
+
input coordinates given on initialization.
|
|
338
495
|
x_out:
|
|
339
496
|
Coordinates to interpolate the data to along the width
|
|
340
497
|
dimension. Overrides any value that was set during
|
|
@@ -353,9 +510,21 @@ class SplineInterpolate(torch.nn.Module):
|
|
|
353
510
|
Z, y_out = self._validate_inputs(Z, x_out, y_out)
|
|
354
511
|
|
|
355
512
|
if x_out is not None:
|
|
356
|
-
|
|
513
|
+
x_out = x_out.float()
|
|
514
|
+
x_clamped = torch.clamp(
|
|
515
|
+
x_out, self.tx[self.kx], self.tx[-self.kx - 1]
|
|
516
|
+
)
|
|
517
|
+
self.Bx_out = self.bspline_basis_natural(
|
|
518
|
+
x_clamped, self.kx, self.tx
|
|
519
|
+
)
|
|
357
520
|
if y_out is not None:
|
|
358
|
-
|
|
521
|
+
y_out = y_out.float()
|
|
522
|
+
y_clamped = torch.clamp(
|
|
523
|
+
y_out, self.ty[self.ky], self.ty[-self.ky - 1]
|
|
524
|
+
)
|
|
525
|
+
self.By_out = self.bspline_basis_natural(
|
|
526
|
+
y_clamped, self.ky, self.ty
|
|
527
|
+
)
|
|
359
528
|
|
|
360
529
|
coef = self.bivariate_spline_fit_natural(Z)
|
|
361
530
|
Z_interp = self.evaluate_bivariate_spline(coef)
|
ml4gw/transforms/transform.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from typing import Optional
|
|
2
|
-
|
|
3
1
|
import torch
|
|
4
2
|
|
|
5
3
|
from ..spectral import spectral_density
|
|
@@ -20,8 +18,8 @@ class FittableTransform(torch.nn.Module):
|
|
|
20
18
|
def _check_built(self):
|
|
21
19
|
if not self.built:
|
|
22
20
|
raise ValueError(
|
|
23
|
-
"Must fit parameters of {} transform
|
|
24
|
-
"before calling forward step"
|
|
21
|
+
f"Must fit parameters of {self.__class__.__name__} transform "
|
|
22
|
+
"to data before calling forward step"
|
|
25
23
|
)
|
|
26
24
|
|
|
27
25
|
def __call__(self, *args, **kwargs):
|
|
@@ -47,8 +45,8 @@ class FittableSpectralTransform(FittableTransform):
|
|
|
47
45
|
x: TimeSeries1to3d,
|
|
48
46
|
sample_rate: float,
|
|
49
47
|
num_freqs: int,
|
|
50
|
-
fftlength:
|
|
51
|
-
overlap:
|
|
48
|
+
fftlength: float | None = None,
|
|
49
|
+
overlap: float | None = None,
|
|
52
50
|
) -> FrequencySeries1to3d:
|
|
53
51
|
# if we specified an FFT length, convert
|
|
54
52
|
# the (assumed) time-domain data to the
|
ml4gw/transforms/waveforms.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from typing import List, Optional
|
|
2
|
-
|
|
3
1
|
import torch
|
|
4
2
|
from jaxtyping import Float
|
|
5
3
|
from torch import Tensor
|
|
@@ -13,7 +11,7 @@ from ..types import BatchTensor
|
|
|
13
11
|
class WaveformSampler(torch.nn.Module):
|
|
14
12
|
def __init__(
|
|
15
13
|
self,
|
|
16
|
-
parameters:
|
|
14
|
+
parameters: Float[Tensor, "batch num_params"] | None = None,
|
|
17
15
|
**polarizations: Float[Tensor, "batch time"],
|
|
18
16
|
):
|
|
19
17
|
super().__init__()
|
|
@@ -24,10 +22,8 @@ class WaveformSampler(torch.nn.Module):
|
|
|
24
22
|
for polarization, tensor in polarizations.items():
|
|
25
23
|
if num_waveforms is not None and len(tensor) != num_waveforms:
|
|
26
24
|
raise ValueError(
|
|
27
|
-
"Polarization {} has {} waveforms "
|
|
28
|
-
"associated with it, expected {}"
|
|
29
|
-
polarization, len(tensor), num_waveforms
|
|
30
|
-
)
|
|
25
|
+
f"Polarization {polarization} has {len(tensor)} waveforms "
|
|
26
|
+
f"associated with it, expected {num_waveforms}"
|
|
31
27
|
)
|
|
32
28
|
elif num_waveforms is None:
|
|
33
29
|
num_waveforms = tensor.shape[0]
|
|
@@ -36,10 +32,8 @@ class WaveformSampler(torch.nn.Module):
|
|
|
36
32
|
|
|
37
33
|
if parameters is not None and len(parameters) != num_waveforms:
|
|
38
34
|
raise ValueError(
|
|
39
|
-
"Waveform parameters has {} waveforms "
|
|
40
|
-
"associated with it, expected {}"
|
|
41
|
-
len(parameters), num_waveforms
|
|
42
|
-
)
|
|
35
|
+
f"Waveform parameters has {len(parameters)} waveforms "
|
|
36
|
+
f"associated with it, expected {num_waveforms}"
|
|
43
37
|
)
|
|
44
38
|
self.num_waveforms = num_waveforms
|
|
45
39
|
self.parameters = parameters
|
|
@@ -48,9 +42,8 @@ class WaveformSampler(torch.nn.Module):
|
|
|
48
42
|
# TODO: should we allow sampling with replacement?
|
|
49
43
|
if N > self.num_waveforms:
|
|
50
44
|
raise ValueError(
|
|
51
|
-
"Requested {} waveforms, but only {} are
|
|
52
|
-
|
|
53
|
-
)
|
|
45
|
+
f"Requested {N} waveforms, but only {self.num_waveforms} are "
|
|
46
|
+
"available"
|
|
54
47
|
)
|
|
55
48
|
# TODO: do we still really want this behavior here when a
|
|
56
49
|
# user can do this without instantiating a WaveformSampler?
|
|
@@ -67,7 +60,7 @@ class WaveformSampler(torch.nn.Module):
|
|
|
67
60
|
|
|
68
61
|
|
|
69
62
|
class WaveformProjector(torch.nn.Module):
|
|
70
|
-
def __init__(self, ifos:
|
|
63
|
+
def __init__(self, ifos: list[str], sample_rate: float):
|
|
71
64
|
super().__init__()
|
|
72
65
|
tensors, vertices = gw.get_ifo_geometry(*ifos)
|
|
73
66
|
self.sample_rate = sample_rate
|