ml4gw 0.7.6__py3-none-any.whl → 0.7.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ml4gw might be problematic. Click here for more details.
- ml4gw/distributions.py +1 -1
- ml4gw/transforms/__init__.py +1 -1
- ml4gw/transforms/qtransform.py +10 -15
- ml4gw/transforms/spline_interpolation.py +309 -138
- {ml4gw-0.7.6.dist-info → ml4gw-0.7.7.dist-info}/METADATA +6 -5
- {ml4gw-0.7.6.dist-info → ml4gw-0.7.7.dist-info}/RECORD +9 -8
- {ml4gw-0.7.6.dist-info → ml4gw-0.7.7.dist-info}/WHEEL +2 -1
- ml4gw-0.7.7.dist-info/top_level.txt +1 -0
- {ml4gw-0.7.6.dist-info → ml4gw-0.7.7.dist-info}/licenses/LICENSE +0 -0
ml4gw/distributions.py
CHANGED
|
@@ -137,7 +137,7 @@ class PowerLaw(dist.TransformedDistribution):
|
|
|
137
137
|
support = dist.constraints.nonnegative
|
|
138
138
|
|
|
139
139
|
def __init__(
|
|
140
|
-
self, minimum: float, maximum: float, index:
|
|
140
|
+
self, minimum: float, maximum: float, index: float, validate_args=None
|
|
141
141
|
):
|
|
142
142
|
if index == 0:
|
|
143
143
|
raise ValueError("Index of 0 is the same as Uniform")
|
ml4gw/transforms/__init__.py
CHANGED
|
@@ -5,6 +5,6 @@ from .scaler import ChannelWiseScaler
|
|
|
5
5
|
from .snr_rescaler import SnrRescaler
|
|
6
6
|
from .spectral import SpectralDensity
|
|
7
7
|
from .spectrogram import MultiResolutionSpectrogram
|
|
8
|
-
from .spline_interpolation import
|
|
8
|
+
from .spline_interpolation import SplineInterpolate1D, SplineInterpolate2D
|
|
9
9
|
from .waveforms import WaveformProjector, WaveformSampler
|
|
10
10
|
from .whitening import FixedWhiten, Whiten
|
ml4gw/transforms/qtransform.py
CHANGED
|
@@ -8,7 +8,7 @@ from jaxtyping import Float, Int
|
|
|
8
8
|
from torch import Tensor
|
|
9
9
|
|
|
10
10
|
from ..types import FrequencySeries1to3d, TimeSeries1to3d, TimeSeries3d
|
|
11
|
-
from .spline_interpolation import
|
|
11
|
+
from .spline_interpolation import SplineInterpolate1D
|
|
12
12
|
|
|
13
13
|
"""
|
|
14
14
|
All based on https://github.com/gwpy/gwpy/blob/v3.0.8/gwpy/signal/qtransform.py
|
|
@@ -260,18 +260,15 @@ class SingleQTransform(torch.nn.Module):
|
|
|
260
260
|
)
|
|
261
261
|
self.qtile_interpolators = torch.nn.ModuleList(
|
|
262
262
|
[
|
|
263
|
-
|
|
263
|
+
SplineInterpolate1D(
|
|
264
264
|
kx=3,
|
|
265
265
|
x_in=torch.arange(0, self.duration, self.duration / tiles),
|
|
266
|
-
y_in=torch.arange(len(idx)),
|
|
267
266
|
x_out=t_out,
|
|
268
|
-
y_out=torch.arange(len(idx)),
|
|
269
267
|
)
|
|
270
|
-
for tiles
|
|
268
|
+
for tiles in unique_ntiles
|
|
271
269
|
]
|
|
272
270
|
)
|
|
273
271
|
|
|
274
|
-
t_in = t_out
|
|
275
272
|
f_in = self.freqs
|
|
276
273
|
f_out = torch.logspace(
|
|
277
274
|
math.log10(self.frange[0]),
|
|
@@ -279,13 +276,10 @@ class SingleQTransform(torch.nn.Module):
|
|
|
279
276
|
self.spectrogram_shape[0],
|
|
280
277
|
)
|
|
281
278
|
|
|
282
|
-
self.interpolator =
|
|
279
|
+
self.interpolator = SplineInterpolate1D(
|
|
283
280
|
kx=3,
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
y_in=f_in,
|
|
287
|
-
x_out=t_out,
|
|
288
|
-
y_out=f_out,
|
|
281
|
+
x_in=f_in,
|
|
282
|
+
x_out=f_out,
|
|
289
283
|
)
|
|
290
284
|
|
|
291
285
|
def get_freqs(self) -> Float[Tensor, " nfreq"]:
|
|
@@ -379,14 +373,15 @@ class SingleQTransform(torch.nn.Module):
|
|
|
379
373
|
]
|
|
380
374
|
time_interped = torch.cat(
|
|
381
375
|
[
|
|
382
|
-
|
|
383
|
-
for qtile,
|
|
376
|
+
qtile_interpolator(qtile)
|
|
377
|
+
for qtile, qtile_interpolator in zip(
|
|
384
378
|
qtiles, self.qtile_interpolators
|
|
385
379
|
)
|
|
386
380
|
],
|
|
387
381
|
dim=-2,
|
|
388
382
|
)
|
|
389
|
-
|
|
383
|
+
# Transpose because the final dimension gets interpolated
|
|
384
|
+
return self.interpolator(time_interped.mT).mT
|
|
390
385
|
num_f_bins, num_t_bins = self.spectrogram_shape
|
|
391
386
|
resampled = [
|
|
392
387
|
F.interpolate(
|
|
@@ -1,131 +1,29 @@
|
|
|
1
1
|
"""
|
|
2
|
-
Adaptation of code from https://github.com/dottormale/
|
|
2
|
+
Adaptation of code from https://github.com/dottormale/Qtransform_torch/
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from typing import Optional, Tuple
|
|
6
6
|
|
|
7
7
|
import torch
|
|
8
|
-
import torch.nn.functional as F
|
|
9
8
|
from torch import Tensor
|
|
10
9
|
|
|
11
10
|
|
|
12
|
-
class
|
|
11
|
+
class SplineInterpolateBase(torch.nn.Module):
|
|
13
12
|
"""
|
|
14
|
-
|
|
15
|
-
Supports batched, multi-channel inputs, so acceptable data
|
|
16
|
-
shapes are ``(width)``, ``(height, width)``, ``(batch, width)``,
|
|
17
|
-
``(batch, height, width)``, ``(batch, channel, width)``, and
|
|
18
|
-
``(batch, channel, height, width)``.
|
|
19
|
-
|
|
20
|
-
During initialization of this Module, both the desired input
|
|
21
|
-
and output coordinate Tensors can be specified to allow
|
|
22
|
-
pre-computation of the B-spline basis matrices, though the only
|
|
23
|
-
mandatory argument is the coordinates of the data along the
|
|
24
|
-
``width`` dimension. If no argument is given for coordinates along
|
|
25
|
-
the ``height`` dimension, it is assumed that 1D interpolation is
|
|
26
|
-
desired.
|
|
27
|
-
|
|
28
|
-
Unlike scipy's implementation of spline interpolation, the data
|
|
29
|
-
to be interpolated is not passed until actually calling the
|
|
30
|
-
object. This is useful for cases where the input and output
|
|
31
|
-
coordinates are known in advance, but the data is not, so that
|
|
32
|
-
the interpolator can be set up ahead of time.
|
|
33
|
-
|
|
34
|
-
WARNING: compared to scipy's spline interpolation, this method
|
|
35
|
-
produces edge artifacts when the output coordinates are near
|
|
36
|
-
the boundaries of the input coordinates. Therefore, it is
|
|
37
|
-
recommended to interpolate only to coordinates that are well
|
|
38
|
-
within the input coordinate range. Unfortunately, the specific
|
|
39
|
-
definition of "well within" changes based on the size of the
|
|
40
|
-
data, so some testing may be required to get good results.
|
|
41
|
-
|
|
42
|
-
Args:
|
|
43
|
-
x_in:
|
|
44
|
-
Coordinates of the width dimension of the data
|
|
45
|
-
y_in:
|
|
46
|
-
Coordinates of the height dimension of the data. If not
|
|
47
|
-
specified, it is assumed the 1D interpolation is desired,
|
|
48
|
-
and so the default value is a Tensor of length 1
|
|
49
|
-
kx:
|
|
50
|
-
Degree of spline interpolation along the width dimension.
|
|
51
|
-
Default is cubic.
|
|
52
|
-
ky:
|
|
53
|
-
Degree of spline interpolation along the height dimension.
|
|
54
|
-
Default is cubic.
|
|
55
|
-
sx:
|
|
56
|
-
Regularization factor to avoid singularities during matrix
|
|
57
|
-
inversion for interpolation along the width dimension. Not
|
|
58
|
-
to be confused with the ``s`` parameter in scipy's spline
|
|
59
|
-
methods, which controls the number of knots.
|
|
60
|
-
sy:
|
|
61
|
-
Regularization factor to avoid singularities during matrix
|
|
62
|
-
inversion for interpolation along the height dimension.
|
|
63
|
-
x_out:
|
|
64
|
-
Coordinates for the data to be interpolated to along the
|
|
65
|
-
width dimension. If not specified during initialization,
|
|
66
|
-
this must be specified during the object call.
|
|
67
|
-
y_out:
|
|
68
|
-
Coordinates for the data to be interpolated to along the
|
|
69
|
-
height dimension. If not specified during initialization,
|
|
70
|
-
this must be specified during the object call.
|
|
71
|
-
|
|
13
|
+
Base class for spline interpolation.
|
|
72
14
|
"""
|
|
73
15
|
|
|
74
|
-
def __init__(
|
|
75
|
-
self,
|
|
76
|
-
x_in: Tensor,
|
|
77
|
-
y_in: Tensor = None,
|
|
78
|
-
kx: int = 3,
|
|
79
|
-
ky: int = 3,
|
|
80
|
-
sx: float = 0.001,
|
|
81
|
-
sy: float = 0.001,
|
|
82
|
-
x_out: Optional[Tensor] = None,
|
|
83
|
-
y_out: Optional[Tensor] = None,
|
|
84
|
-
):
|
|
85
|
-
super().__init__()
|
|
86
|
-
if y_in is None:
|
|
87
|
-
y_in = Tensor([1])
|
|
88
|
-
self.kx = kx
|
|
89
|
-
self.ky = ky
|
|
90
|
-
self.sx = sx
|
|
91
|
-
self.sy = sy
|
|
92
|
-
self.register_buffer("x_in", x_in)
|
|
93
|
-
self.register_buffer("y_in", y_in)
|
|
94
|
-
self.register_buffer("x_out", x_out)
|
|
95
|
-
self.register_buffer("y_out", y_out)
|
|
96
|
-
|
|
97
|
-
tx, Bx, BxT_Bx = self._compute_knots_and_basis_matrices(x_in, kx, sx)
|
|
98
|
-
self.register_buffer("tx", tx)
|
|
99
|
-
self.register_buffer("Bx", Bx)
|
|
100
|
-
self.register_buffer("BxT_Bx", BxT_Bx)
|
|
101
|
-
|
|
102
|
-
ty, By, ByT_By = self._compute_knots_and_basis_matrices(y_in, ky, sy)
|
|
103
|
-
self.register_buffer("ty", ty)
|
|
104
|
-
self.register_buffer("By", By)
|
|
105
|
-
self.register_buffer("ByT_By", ByT_By)
|
|
106
|
-
|
|
107
|
-
if self.x_out is not None:
|
|
108
|
-
Bx_out = self.bspline_basis_natural(x_out, kx, self.tx)
|
|
109
|
-
self.register_buffer("Bx_out", Bx_out)
|
|
110
|
-
if self.y_out is not None:
|
|
111
|
-
By_out = self.bspline_basis_natural(y_out, ky, self.ty)
|
|
112
|
-
self.register_buffer("By_out", By_out)
|
|
113
|
-
|
|
114
16
|
def _compute_knots_and_basis_matrices(self, x, k, s):
|
|
115
|
-
knots = self.
|
|
17
|
+
knots = self.generate_fitpack_knots(x, k)
|
|
116
18
|
basis_matrix = self.bspline_basis_natural(x, k, knots)
|
|
117
19
|
identity = torch.eye(basis_matrix.shape[-1])
|
|
118
20
|
B_T_B = basis_matrix.T @ basis_matrix + s * identity
|
|
119
21
|
return knots, basis_matrix, B_T_B
|
|
120
22
|
|
|
121
|
-
def
|
|
23
|
+
def generate_fitpack_knots(self, x: Tensor, k: int) -> Tensor:
|
|
122
24
|
"""
|
|
123
|
-
Generates a
|
|
124
|
-
|
|
125
|
-
and end of datapoints as replicas of first and last datapoint
|
|
126
|
-
respectively in order to enforce natural boundary conditions,
|
|
127
|
-
i.e. second derivative = 0.
|
|
128
|
-
The other n nodes are placed in correspondece of the data points.
|
|
25
|
+
Generates a knot sequence for B-spline interpolation
|
|
26
|
+
in the same way as the FITPACK algorithm used by SciPy.
|
|
129
27
|
|
|
130
28
|
Args:
|
|
131
29
|
x: Tensor of data point positions.
|
|
@@ -134,7 +32,17 @@ class SplineInterpolate(torch.nn.Module):
|
|
|
134
32
|
Returns:
|
|
135
33
|
Tensor of knot positions.
|
|
136
34
|
"""
|
|
137
|
-
|
|
35
|
+
num_knots = x.shape[-1] + k + 1
|
|
36
|
+
knots = torch.zeros(num_knots, dtype=x.dtype)
|
|
37
|
+
knots[: k + 1] = x[0]
|
|
38
|
+
knots[-(k + 1) :] = x[-1]
|
|
39
|
+
|
|
40
|
+
# Interior knots are the rolling average of the data points
|
|
41
|
+
# excluding the first and last points
|
|
42
|
+
windows = x[1:-1].unfold(dimension=-1, size=k, step=1)
|
|
43
|
+
knots[k + 1 : -k - 1] = windows.mean(dim=-1)
|
|
44
|
+
|
|
45
|
+
return knots
|
|
138
46
|
|
|
139
47
|
def compute_L_R(
|
|
140
48
|
self,
|
|
@@ -233,9 +141,6 @@ class SplineInterpolate(torch.nn.Module):
|
|
|
233
141
|
Returns:
|
|
234
142
|
Tensor containing the kth-order B-spline basis functions
|
|
235
143
|
"""
|
|
236
|
-
|
|
237
|
-
if len(x) == 1:
|
|
238
|
-
return torch.eye(1)
|
|
239
144
|
n = x.shape[0]
|
|
240
145
|
m = t.shape[0] - k - 1
|
|
241
146
|
|
|
@@ -255,6 +160,271 @@ class SplineInterpolate(torch.nn.Module):
|
|
|
255
160
|
|
|
256
161
|
return b[:, :, -1]
|
|
257
162
|
|
|
163
|
+
|
|
164
|
+
class SplineInterpolate1D(SplineInterpolateBase):
|
|
165
|
+
"""
|
|
166
|
+
Perform 1D spline interpolation based on De Boor's method.
|
|
167
|
+
It is allowed to have two spatial dimensions, but the second
|
|
168
|
+
dimension cannot be interpolated along. To interpolate along both
|
|
169
|
+
dimensions, use :class:`SplineInterpolate2D`.
|
|
170
|
+
|
|
171
|
+
Supports batched, multi-channel inputs, so acceptable data
|
|
172
|
+
shapes are ``(width)``, ``(height, width)``, ``(batch, width)``,
|
|
173
|
+
``(batch, height, width)``, ``(batch, channel, width)``, and
|
|
174
|
+
``(batch, channel, height, width)``.
|
|
175
|
+
|
|
176
|
+
During initialization of this Module, both the desired input
|
|
177
|
+
and output coordinate Tensors can be specified to allow
|
|
178
|
+
pre-computation of the B-spline basis matrices, though the only
|
|
179
|
+
mandatory argument is the coordinates of the data along the
|
|
180
|
+
``width`` dimension.
|
|
181
|
+
|
|
182
|
+
Unlike scipy's implementation of spline interpolation, the data
|
|
183
|
+
to be interpolated is not passed until actually calling the
|
|
184
|
+
object. This is useful for cases where the input and output
|
|
185
|
+
coordinates are known in advance, but the data is not, so that
|
|
186
|
+
the interpolator can be set up ahead of time.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
x_in:
|
|
190
|
+
Coordinates of the width dimension of the data
|
|
191
|
+
kx:
|
|
192
|
+
Degree of spline interpolation along the width dimension.
|
|
193
|
+
Default is cubic.
|
|
194
|
+
sx:
|
|
195
|
+
Regularization factor to avoid singularities during matrix
|
|
196
|
+
inversion for interpolation along the width dimension. Not
|
|
197
|
+
to be confused with the ``s`` parameter in scipy's spline
|
|
198
|
+
methods, which controls the number of knots.
|
|
199
|
+
x_out:
|
|
200
|
+
Coordinates for the data to be interpolated to along the
|
|
201
|
+
width dimension. If not specified during initialization,
|
|
202
|
+
this must be specified during the object call.
|
|
203
|
+
|
|
204
|
+
"""
|
|
205
|
+
|
|
206
|
+
def __init__(
|
|
207
|
+
self,
|
|
208
|
+
x_in: Tensor,
|
|
209
|
+
kx: int = 3,
|
|
210
|
+
sx: float = 0.0,
|
|
211
|
+
x_out: Optional[Tensor] = None,
|
|
212
|
+
):
|
|
213
|
+
super().__init__()
|
|
214
|
+
|
|
215
|
+
if len(x_in) < kx + 2:
|
|
216
|
+
raise ValueError(
|
|
217
|
+
"Input x-coordinates must have at least kx + 2 points."
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
# Ensure that coordinates are floats
|
|
221
|
+
x_in = x_in.float()
|
|
222
|
+
x_out = x_out.float() if x_out is not None else None
|
|
223
|
+
|
|
224
|
+
self.kx = kx
|
|
225
|
+
self.sx = sx
|
|
226
|
+
self.register_buffer("x_in", x_in)
|
|
227
|
+
self.register_buffer("x_out", x_out)
|
|
228
|
+
|
|
229
|
+
tx, Bx, BxT_Bx = self._compute_knots_and_basis_matrices(x_in, kx, sx)
|
|
230
|
+
self.register_buffer("tx", tx)
|
|
231
|
+
self.register_buffer("Bx", Bx)
|
|
232
|
+
self.register_buffer("BxT_Bx", BxT_Bx)
|
|
233
|
+
|
|
234
|
+
if self.x_out is not None:
|
|
235
|
+
x_clamped = torch.clamp(x_out, tx[kx], tx[-kx - 1])
|
|
236
|
+
Bx_out = self.bspline_basis_natural(x_clamped, kx, self.tx)
|
|
237
|
+
self.register_buffer("Bx_out", Bx_out)
|
|
238
|
+
|
|
239
|
+
def spline_fit_natural(self, Z):
|
|
240
|
+
# Adding batch/channel dimension handling
|
|
241
|
+
# Bx @ Z
|
|
242
|
+
BxT_Z = torch.einsum("ij,bchj->bchi", self.Bx.T, Z)
|
|
243
|
+
# (BxT @ Bx)^-1 @ (BxT @ Z) = Bx^-1 @ Z
|
|
244
|
+
C = torch.linalg.solve(self.BxT_Bx, BxT_Z.unsqueeze(-1))
|
|
245
|
+
return C.squeeze(-1)
|
|
246
|
+
|
|
247
|
+
def evaluate_spline(self, C: Tensor):
|
|
248
|
+
"""
|
|
249
|
+
Evaluate a bivariate spline on a grid of x and y points.
|
|
250
|
+
|
|
251
|
+
Args:
|
|
252
|
+
C: Coefficient tensor of shape (batch_size, mx, my).
|
|
253
|
+
|
|
254
|
+
Returns:
|
|
255
|
+
Z_interp: Interpolated values at the grid points.
|
|
256
|
+
"""
|
|
257
|
+
# Perform matrix multiplication using einsum to get Z_interp
|
|
258
|
+
return torch.einsum("ij,bchj->bchi", self.Bx_out, C)
|
|
259
|
+
|
|
260
|
+
def _validate_inputs(self, Z, x_out):
|
|
261
|
+
if x_out is None and self.x_out is None:
|
|
262
|
+
raise ValueError(
|
|
263
|
+
"Output x-coordinates were not specified in either object "
|
|
264
|
+
"creation or in forward call"
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
dims = len(Z.shape)
|
|
268
|
+
if dims > 4:
|
|
269
|
+
raise ValueError("Input data has more than 4 dimensions")
|
|
270
|
+
|
|
271
|
+
if Z.shape[-1] != len(self.x_in):
|
|
272
|
+
raise ValueError(
|
|
273
|
+
"The spatial dimensions of the data tensor do not match "
|
|
274
|
+
"the given input dimensions. "
|
|
275
|
+
f"Expected {len(self.x_in)}, but got {Z.shape[-1]}"
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
# Expand Z to have a batch, channel, and height dimension if needed
|
|
279
|
+
while len(Z.shape) < 4:
|
|
280
|
+
Z = Z.unsqueeze(0)
|
|
281
|
+
|
|
282
|
+
return Z
|
|
283
|
+
|
|
284
|
+
def forward(
|
|
285
|
+
self,
|
|
286
|
+
Z: Tensor,
|
|
287
|
+
x_out: Optional[Tensor] = None,
|
|
288
|
+
) -> Tensor:
|
|
289
|
+
"""
|
|
290
|
+
Compute the interpolated data
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
Z:
|
|
294
|
+
Tensor of data to be interpolated. Must be between 2 and 4
|
|
295
|
+
dimensions. The shape of the tensor must agree with the
|
|
296
|
+
input coordinates given on initialization.
|
|
297
|
+
x_out:
|
|
298
|
+
Coordinates to interpolate the data to along the width
|
|
299
|
+
dimension. Overrides any value that was set during
|
|
300
|
+
initialization.
|
|
301
|
+
|
|
302
|
+
Returns:
|
|
303
|
+
A 4D tensor with shape ``(batch, channel, height, width)``.
|
|
304
|
+
Depending on the input data shape, many of these dimensions
|
|
305
|
+
may have length 1.
|
|
306
|
+
"""
|
|
307
|
+
|
|
308
|
+
Z = self._validate_inputs(Z, x_out)
|
|
309
|
+
|
|
310
|
+
if x_out is not None:
|
|
311
|
+
x_out = x_out.float()
|
|
312
|
+
x_clamped = torch.clamp(
|
|
313
|
+
x_out, self.tx[self.kx], self.tx[-self.kx - 1]
|
|
314
|
+
)
|
|
315
|
+
self.Bx_out = self.bspline_basis_natural(
|
|
316
|
+
x_clamped, self.kx, self.tx
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
coef = self.spline_fit_natural(Z)
|
|
320
|
+
Z_interp = self.evaluate_spline(coef)
|
|
321
|
+
return Z_interp
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
class SplineInterpolate2D(SplineInterpolateBase):
|
|
325
|
+
"""
|
|
326
|
+
Perform 2D spline interpolation based on De Boor's method.
|
|
327
|
+
Supports batched, multi-channel inputs, so acceptable data
|
|
328
|
+
shapes are ``(height, width)``, ``(batch, height, width)``,
|
|
329
|
+
and ``(batch, channel, height, width)``.
|
|
330
|
+
|
|
331
|
+
During initialization of this Module, both the desired input
|
|
332
|
+
and output coordinate Tensors can be specified to allow
|
|
333
|
+
pre-computation of the B-spline basis matrices, though the only
|
|
334
|
+
mandatory arguments are the input coordinates.
|
|
335
|
+
|
|
336
|
+
Unlike scipy's implementation of spline interpolation, the data
|
|
337
|
+
to be interpolated is not passed until actually calling the
|
|
338
|
+
object. This is useful for cases where the input and output
|
|
339
|
+
coordinates are known in advance, but the data is not, so that
|
|
340
|
+
the interpolator can be set up ahead of time.
|
|
341
|
+
|
|
342
|
+
Args:
|
|
343
|
+
x_in:
|
|
344
|
+
Coordinates of the width dimension of the data
|
|
345
|
+
y_in:
|
|
346
|
+
Coordinates of the height dimension of the data.
|
|
347
|
+
kx:
|
|
348
|
+
Degree of spline interpolation along the width dimension.
|
|
349
|
+
Default is cubic.
|
|
350
|
+
ky:
|
|
351
|
+
Degree of spline interpolation along the height dimension.
|
|
352
|
+
Default is cubic.
|
|
353
|
+
sx:
|
|
354
|
+
Regularization factor to avoid singularities during matrix
|
|
355
|
+
inversion for interpolation along the width dimension. Not
|
|
356
|
+
to be confused with the ``s`` parameter in scipy's spline
|
|
357
|
+
methods, which controls the number of knots.
|
|
358
|
+
sy:
|
|
359
|
+
Regularization factor to avoid singularities during matrix
|
|
360
|
+
inversion for interpolation along the height dimension.
|
|
361
|
+
x_out:
|
|
362
|
+
Coordinates for the data to be interpolated to along the
|
|
363
|
+
width dimension. If not specified during initialization,
|
|
364
|
+
this must be specified during the object call.
|
|
365
|
+
y_out:
|
|
366
|
+
Coordinates for the data to be interpolated to along the
|
|
367
|
+
height dimension. If not specified during initialization,
|
|
368
|
+
this must be specified during the object call.
|
|
369
|
+
|
|
370
|
+
"""
|
|
371
|
+
|
|
372
|
+
def __init__(
|
|
373
|
+
self,
|
|
374
|
+
x_in: Tensor,
|
|
375
|
+
y_in: Tensor,
|
|
376
|
+
kx: int = 3,
|
|
377
|
+
ky: int = 3,
|
|
378
|
+
sx: float = 0.0,
|
|
379
|
+
sy: float = 0.0,
|
|
380
|
+
x_out: Optional[Tensor] = None,
|
|
381
|
+
y_out: Optional[Tensor] = None,
|
|
382
|
+
):
|
|
383
|
+
super().__init__()
|
|
384
|
+
|
|
385
|
+
if len(x_in) < kx + 2:
|
|
386
|
+
raise ValueError(
|
|
387
|
+
"Input x-coordinates must have at least kx + 2 points."
|
|
388
|
+
)
|
|
389
|
+
if len(y_in) < ky + 2:
|
|
390
|
+
raise ValueError(
|
|
391
|
+
"Input y-coordinates must have at least ky + 2 points."
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
# Ensure that coordinates are floats
|
|
395
|
+
x_in = x_in.float()
|
|
396
|
+
y_in = y_in.float()
|
|
397
|
+
x_out = x_out.float() if x_out is not None else None
|
|
398
|
+
y_out = y_out.float() if y_out is not None else None
|
|
399
|
+
|
|
400
|
+
self.kx = kx
|
|
401
|
+
self.ky = ky
|
|
402
|
+
self.sx = sx
|
|
403
|
+
self.sy = sy
|
|
404
|
+
self.register_buffer("x_in", x_in)
|
|
405
|
+
self.register_buffer("y_in", y_in)
|
|
406
|
+
self.register_buffer("x_out", x_out)
|
|
407
|
+
self.register_buffer("y_out", y_out)
|
|
408
|
+
|
|
409
|
+
tx, Bx, BxT_Bx = self._compute_knots_and_basis_matrices(x_in, kx, sx)
|
|
410
|
+
self.register_buffer("tx", tx)
|
|
411
|
+
self.register_buffer("Bx", Bx)
|
|
412
|
+
self.register_buffer("BxT_Bx", BxT_Bx)
|
|
413
|
+
|
|
414
|
+
ty, By, ByT_By = self._compute_knots_and_basis_matrices(y_in, ky, sy)
|
|
415
|
+
self.register_buffer("ty", ty)
|
|
416
|
+
self.register_buffer("By", By)
|
|
417
|
+
self.register_buffer("ByT_By", ByT_By)
|
|
418
|
+
|
|
419
|
+
if self.x_out is not None:
|
|
420
|
+
x_clamped = torch.clamp(x_out, tx[kx], tx[-kx - 1])
|
|
421
|
+
Bx_out = self.bspline_basis_natural(x_clamped, kx, self.tx)
|
|
422
|
+
self.register_buffer("Bx_out", Bx_out)
|
|
423
|
+
if self.y_out is not None:
|
|
424
|
+
y_clamped = torch.clamp(y_out, ty[ky], ty[-ky - 1])
|
|
425
|
+
By_out = self.bspline_basis_natural(y_clamped, ky, self.ty)
|
|
426
|
+
self.register_buffer("By_out", By_out)
|
|
427
|
+
|
|
258
428
|
def bivariate_spline_fit_natural(self, Z):
|
|
259
429
|
# Adding batch/channel dimension handling
|
|
260
430
|
# ByT @ Z @ BxW
|
|
@@ -285,29 +455,16 @@ class SplineInterpolate(torch.nn.Module):
|
|
|
285
455
|
)
|
|
286
456
|
|
|
287
457
|
if y_out is None and self.y_out is None:
|
|
288
|
-
|
|
458
|
+
raise ValueError(
|
|
459
|
+
"Output y-coordinates were not specified in either object "
|
|
460
|
+
"creation or in forward call"
|
|
461
|
+
)
|
|
289
462
|
|
|
290
463
|
dims = len(Z.shape)
|
|
291
464
|
if dims > 4:
|
|
292
465
|
raise ValueError("Input data has more than 4 dimensions")
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
raise ValueError(
|
|
296
|
-
"An input y-coordinate array with length greater than 1 "
|
|
297
|
-
"was given, but the input data is 1-dimensional. Expected "
|
|
298
|
-
"input data to be at least 2-dimensional"
|
|
299
|
-
)
|
|
300
|
-
|
|
301
|
-
# Expand Z to have 4 dimensions
|
|
302
|
-
# There are 6 valid input shapes: (w), (b, w), (b, c, w),
|
|
303
|
-
# (h, w), (b, h, w), and (b, c, h, w).
|
|
304
|
-
|
|
305
|
-
# If the input y coordinate array has length 1,
|
|
306
|
-
# assume the first dimension(s) are batch dimensions
|
|
307
|
-
# and that no height dimension is included in Z
|
|
308
|
-
idx = -2 if len(self.y_in) == 1 else -3
|
|
309
|
-
while len(Z.shape) < 4:
|
|
310
|
-
Z = Z.unsqueeze(idx)
|
|
466
|
+
if dims < 2:
|
|
467
|
+
raise ValueError("Input data has fewer than 2 dimensions")
|
|
311
468
|
|
|
312
469
|
if Z.shape[-2:] != torch.Size([len(self.y_in), len(self.x_in)]):
|
|
313
470
|
raise ValueError(
|
|
@@ -317,6 +474,10 @@ class SplineInterpolate(torch.nn.Module):
|
|
|
317
474
|
f"[{Z.shape[-2]}, {Z.shape[-1]}]"
|
|
318
475
|
)
|
|
319
476
|
|
|
477
|
+
# Expand Z to have a batch and channel dimension if needed
|
|
478
|
+
while len(Z.shape) < 4:
|
|
479
|
+
Z = Z.unsqueeze(0)
|
|
480
|
+
|
|
320
481
|
return Z, y_out
|
|
321
482
|
|
|
322
483
|
def forward(
|
|
@@ -330,11 +491,9 @@ class SplineInterpolate(torch.nn.Module):
|
|
|
330
491
|
|
|
331
492
|
Args:
|
|
332
493
|
Z:
|
|
333
|
-
Tensor of data to be interpolated. Must be between
|
|
494
|
+
Tensor of data to be interpolated. Must be between 2 and 4
|
|
334
495
|
dimensions. The shape of the tensor must agree with the
|
|
335
|
-
input coordinates given on initialization.
|
|
336
|
-
not specified during initialization, it is assumed that
|
|
337
|
-
Z does not have a height dimension.
|
|
496
|
+
input coordinates given on initialization.
|
|
338
497
|
x_out:
|
|
339
498
|
Coordinates to interpolate the data to along the width
|
|
340
499
|
dimension. Overrides any value that was set during
|
|
@@ -353,9 +512,21 @@ class SplineInterpolate(torch.nn.Module):
|
|
|
353
512
|
Z, y_out = self._validate_inputs(Z, x_out, y_out)
|
|
354
513
|
|
|
355
514
|
if x_out is not None:
|
|
356
|
-
|
|
515
|
+
x_out = x_out.float()
|
|
516
|
+
x_clamped = torch.clamp(
|
|
517
|
+
x_out, self.tx[self.kx], self.tx[-self.kx - 1]
|
|
518
|
+
)
|
|
519
|
+
self.Bx_out = self.bspline_basis_natural(
|
|
520
|
+
x_clamped, self.kx, self.tx
|
|
521
|
+
)
|
|
357
522
|
if y_out is not None:
|
|
358
|
-
|
|
523
|
+
y_out = y_out.float()
|
|
524
|
+
y_clamped = torch.clamp(
|
|
525
|
+
y_out, self.ty[self.ky], self.ty[-self.ky - 1]
|
|
526
|
+
)
|
|
527
|
+
self.By_out = self.bspline_basis_natural(
|
|
528
|
+
y_clamped, self.ky, self.ty
|
|
529
|
+
)
|
|
359
530
|
|
|
360
531
|
coef = self.bivariate_spline_fit_natural(Z)
|
|
361
532
|
Z_interp = self.evaluate_bivariate_spline(coef)
|
|
@@ -1,9 +1,8 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ml4gw
|
|
3
|
-
Version: 0.7.
|
|
3
|
+
Version: 0.7.7
|
|
4
4
|
Summary: Tools for training torch models on gravitational wave data
|
|
5
5
|
Author-email: Ethan Marx <emarx@mit.edu>, Will Benoit <benoi090@umn.edu>, Deep Chatterjee <deep1018@mit.edu>, Alec Gunny <alec.gunny@ligo.org>
|
|
6
|
-
License-File: LICENSE
|
|
7
6
|
Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
|
|
8
7
|
Classifier: Programming Language :: Python :: 3.9
|
|
9
8
|
Classifier: Programming Language :: Python :: 3.10
|
|
@@ -11,12 +10,14 @@ Classifier: Programming Language :: Python :: 3.11
|
|
|
11
10
|
Classifier: Programming Language :: Python :: 3.12
|
|
12
11
|
Classifier: Programming Language :: Python :: 3.13
|
|
13
12
|
Requires-Python: <3.13,>=3.9
|
|
13
|
+
Description-Content-Type: text/markdown
|
|
14
|
+
License-File: LICENSE
|
|
14
15
|
Requires-Dist: jaxtyping<0.3,>=0.2
|
|
16
|
+
Requires-Dist: torch~=2.0
|
|
17
|
+
Requires-Dist: torchaudio~=2.0
|
|
15
18
|
Requires-Dist: numpy<2.0.0
|
|
16
19
|
Requires-Dist: scipy<1.15,>=1.9.0
|
|
17
|
-
|
|
18
|
-
Requires-Dist: torch~=2.0
|
|
19
|
-
Description-Content-Type: text/markdown
|
|
20
|
+
Dynamic: license-file
|
|
20
21
|
|
|
21
22
|
# ML4GW
|
|
22
23
|

|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
ml4gw/__init__.py,sha256=81quoggCuIypZjZs3bbf1Ty70KHdva5RGEJxi0oC57E,25
|
|
2
2
|
ml4gw/augmentations.py,sha256=4tSWO-I4Eg2QJWzdcLFg9QcOLlvRjNHvnjLCZS8K-Wc,1270
|
|
3
3
|
ml4gw/constants.py,sha256=RQPXwavlw_cWu3ByltvTejPsi6EWXHDJQ1HaV9iE3Lg,850
|
|
4
|
-
ml4gw/distributions.py,sha256=
|
|
4
|
+
ml4gw/distributions.py,sha256=B3heOf_x-71UJBExkmaiqBfjTHiGDds_bYo62wydddU,12832
|
|
5
5
|
ml4gw/gw.py,sha256=bJ-GCZxanqrhbm373h9muOSZpam7wM-dJBZroy_pVNQ,20291
|
|
6
6
|
ml4gw/spectral.py,sha256=Mx_zRjZ9tD7N-wknv35oA3fk2X0rDJxQdQzRyuCFryw,19982
|
|
7
7
|
ml4gw/types.py,sha256=CcctqDcNajR7khGT6BD-WYsfRKpiP0udoSAB0k1qcFw,863
|
|
@@ -22,15 +22,15 @@ ml4gw/nn/resnet/resnet_2d.py,sha256=MAbXtkSrP4aWGtY-QC8ox3-y5jDHJrzRPL5ryQ4RBvM,
|
|
|
22
22
|
ml4gw/nn/streaming/__init__.py,sha256=zgjGR2L8t0txXLnil9ceZT0tM8Y2FC8yPxqIKYH0o1A,80
|
|
23
23
|
ml4gw/nn/streaming/online_average.py,sha256=YSFUHhwNfQjUJbzQCqaCVApSueswzYB4yel981Omiqw,4718
|
|
24
24
|
ml4gw/nn/streaming/snapshotter.py,sha256=vEQLFi-fEH-o7TO9SmYXy5whxFxXQBDeOQOFhSnofSg,4503
|
|
25
|
-
ml4gw/transforms/__init__.py,sha256=
|
|
25
|
+
ml4gw/transforms/__init__.py,sha256=sREcTsJYPOHBk_QEvhlwvLeoR61Th8LVGaBicWngmv4,470
|
|
26
26
|
ml4gw/transforms/iirfilter.py,sha256=HcdsjcSaSi2xe65ojxnaqeSdbYvSQVFIkHKon3nW238,3194
|
|
27
27
|
ml4gw/transforms/pearson.py,sha256=sFyHD6IdskbRS8V1fY0Kt9N8R2_EhnuL6UjFa6fnmTU,3244
|
|
28
|
-
ml4gw/transforms/qtransform.py,sha256=
|
|
28
|
+
ml4gw/transforms/qtransform.py,sha256=2GtPyY5DPhoUMDs68juLIzmxbelfsS4CpI0HyTPe3Oo,20636
|
|
29
29
|
ml4gw/transforms/scaler.py,sha256=BKn4RQ_TNArdwPI_j5nAe7H2jOH_-MrZPsNByE-8Pl8,2518
|
|
30
30
|
ml4gw/transforms/snr_rescaler.py,sha256=lfuwdwMY117gB-emmn0_22gsK_A9xnkHJv2-76HFWc4,2728
|
|
31
31
|
ml4gw/transforms/spectral.py,sha256=ebAuPSdQqha6J3MMzxqJqR31XPKUDrSz3iJaHM3orpk,4449
|
|
32
32
|
ml4gw/transforms/spectrogram.py,sha256=NIyTD8kZRe8rjMUTy1_-wpFyvAswzTfYwD4TJJcPqgs,6369
|
|
33
|
-
ml4gw/transforms/spline_interpolation.py,sha256=
|
|
33
|
+
ml4gw/transforms/spline_interpolation.py,sha256=_D_P2jNIOL8-XiSSjsGClxuVwthO-CxcoqNwvrBWQpk,18668
|
|
34
34
|
ml4gw/transforms/transform.py,sha256=lpHQbM4PhdijvNBsZigPX-mS04aiVVq5q3HMfxvpFg0,2506
|
|
35
35
|
ml4gw/transforms/waveforms.py,sha256=koWOuHuUpQWmTT1yawSWa_MOuLfDBuugy91KIyuklOo,3189
|
|
36
36
|
ml4gw/transforms/whitening.py,sha256=UyFustRhu3zv0ynJBvvxekWA-YOMwEIOYDNpoD5r_qQ,10400
|
|
@@ -49,7 +49,8 @@ ml4gw/waveforms/cbc/phenom_d_data.py,sha256=WA1FBxUp9fo1IQaV_OLJ_5g5gI166mY1FtG9
|
|
|
49
49
|
ml4gw/waveforms/cbc/phenom_p.py,sha256=tOUBoYfr0ub6OGRjDQbquGoW8AnThiGjJvbHhyGnAnk,27680
|
|
50
50
|
ml4gw/waveforms/cbc/taylorf2.py,sha256=emWbl3vjsCzBOooHOVO7pPlPcj05r4up6InlMkO5m_E,10422
|
|
51
51
|
ml4gw/waveforms/cbc/utils.py,sha256=LT1ky10_6ZrbwTcxIrWP1O75GUEuU5q2ZE2yYDhadQE,3037
|
|
52
|
-
ml4gw-0.7.
|
|
53
|
-
ml4gw-0.7.
|
|
54
|
-
ml4gw-0.7.
|
|
55
|
-
ml4gw-0.7.
|
|
52
|
+
ml4gw-0.7.7.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
|
53
|
+
ml4gw-0.7.7.dist-info/METADATA,sha256=C3IsV7-AGBWfid4FvE1xuGL0HFX_wzRii0Wzb4aSOiM,3402
|
|
54
|
+
ml4gw-0.7.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
55
|
+
ml4gw-0.7.7.dist-info/top_level.txt,sha256=JnWLyPXJ3_WUcjr6fRV0ZTXj8FR0x4vBzjkg-1bl2tw,6
|
|
56
|
+
ml4gw-0.7.7.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
ml4gw
|
|
File without changes
|