ml4gw 0.7.5__py3-none-any.whl → 0.7.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ml4gw might be problematic. Click here for more details.
- ml4gw/augmentations.py +4 -4
- ml4gw/dataloading/chunked_dataset.py +3 -3
- ml4gw/dataloading/hdf5_dataset.py +7 -10
- ml4gw/dataloading/in_memory_dataset.py +21 -21
- ml4gw/distributions.py +20 -18
- ml4gw/gw.py +60 -53
- ml4gw/nn/autoencoder/base.py +9 -9
- ml4gw/nn/autoencoder/convolutional.py +4 -4
- ml4gw/nn/resnet/resnet_1d.py +13 -13
- ml4gw/nn/resnet/resnet_2d.py +12 -12
- ml4gw/nn/streaming/online_average.py +1 -1
- ml4gw/nn/streaming/snapshotter.py +14 -14
- ml4gw/spectral.py +48 -48
- ml4gw/transforms/__init__.py +1 -1
- ml4gw/transforms/iirfilter.py +3 -3
- ml4gw/transforms/pearson.py +7 -8
- ml4gw/transforms/qtransform.py +29 -34
- ml4gw/transforms/scaler.py +4 -4
- ml4gw/transforms/spectral.py +10 -10
- ml4gw/transforms/spectrogram.py +12 -11
- ml4gw/transforms/spline_interpolation.py +310 -146
- ml4gw/transforms/transform.py +1 -1
- ml4gw/transforms/whitening.py +36 -36
- ml4gw/utils/slicing.py +40 -40
- ml4gw/waveforms/cbc/phenom_d.py +22 -66
- ml4gw/waveforms/cbc/phenom_p.py +9 -5
- ml4gw/waveforms/cbc/taylorf2.py +8 -7
- ml4gw/waveforms/conversion.py +2 -1
- ml4gw/waveforms/generator.py +33 -32
- {ml4gw-0.7.5.dist-info → ml4gw-0.7.7.dist-info}/METADATA +6 -5
- ml4gw-0.7.7.dist-info/RECORD +56 -0
- {ml4gw-0.7.5.dist-info → ml4gw-0.7.7.dist-info}/WHEEL +2 -1
- ml4gw-0.7.7.dist-info/top_level.txt +1 -0
- ml4gw-0.7.5.dist-info/RECORD +0 -55
- {ml4gw-0.7.5.dist-info → ml4gw-0.7.7.dist-info}/licenses/LICENSE +0 -0
ml4gw/transforms/spectral.py
CHANGED
|
@@ -14,39 +14,39 @@ class SpectralDensity(torch.nn.Module):
|
|
|
14
14
|
of a batch of multichannel timeseries, or the cross spectral
|
|
15
15
|
density of two batches of multichannel timeseries.
|
|
16
16
|
|
|
17
|
-
On
|
|
17
|
+
On ``SpectralDensity.forward`` call, if only one tensor is provided,
|
|
18
18
|
this transform will compute its power spectral density. If a second
|
|
19
19
|
tensor is provided, the cross spectral density between the two
|
|
20
20
|
timeseries will be computed. For information about the allowed
|
|
21
21
|
relationships between these two tensors, see the documentation to
|
|
22
|
-
|
|
22
|
+
:meth:`~ml4gw.spectral.fast_spectral_density`.
|
|
23
23
|
|
|
24
24
|
Note that the cross spectral density computation is currently
|
|
25
|
-
only available for
|
|
26
|
-
|
|
27
|
-
a
|
|
25
|
+
only available for :meth:`~ml4gw.spectral.fast_spectral_density`. If
|
|
26
|
+
``fast=False`` and a second tensor is passed to ``SpectralDensity.forward``, # noqa E501
|
|
27
|
+
a ``NotImplementedError`` will be raised.
|
|
28
28
|
|
|
29
29
|
Args:
|
|
30
30
|
sample_rate:
|
|
31
|
-
Rate at which tensors passed to
|
|
31
|
+
Rate at which tensors passed to ``forward`` will be sampled
|
|
32
32
|
fftlength:
|
|
33
33
|
Length of the window, in seconds, to use for FFT estimates
|
|
34
34
|
overlap:
|
|
35
35
|
Overlap between windows used for FFT calculation. If left
|
|
36
|
-
as
|
|
36
|
+
as ``None``, this will be set to ``fftlength / 2``.
|
|
37
37
|
average:
|
|
38
38
|
Aggregation method to use for combining windowed FFTs.
|
|
39
|
-
Allowed values are
|
|
39
|
+
Allowed values are ``"mean"`` and ``"median"``.
|
|
40
40
|
window:
|
|
41
41
|
Window array to multiply by each FFT window before
|
|
42
|
-
FFT computation. Should have length
|
|
42
|
+
FFT computation. Should have length ``nperseg``.
|
|
43
43
|
Defaults to a hanning window.
|
|
44
44
|
fast:
|
|
45
45
|
Whether to use a faster spectral density computation that
|
|
46
46
|
support cross spectral density, or a slower one which does
|
|
47
47
|
not. The cost of the fast implementation is that it is not
|
|
48
48
|
exact for the two lowest frequency bins.
|
|
49
|
-
"""
|
|
49
|
+
""" # noqa E501
|
|
50
50
|
|
|
51
51
|
def __init__(
|
|
52
52
|
self,
|
ml4gw/transforms/spectrogram.py
CHANGED
|
@@ -14,18 +14,18 @@ class MultiResolutionSpectrogram(torch.nn.Module):
|
|
|
14
14
|
"""
|
|
15
15
|
Create a batch of multi-resolution spectrograms
|
|
16
16
|
from a batch of timeseries. Input is expected to
|
|
17
|
-
have the shape
|
|
18
|
-
of batches,
|
|
17
|
+
have the shape ``(B, C, T)``, where ``B`` is the number
|
|
18
|
+
of batches, ``C`` is the number of channels, and ``T``
|
|
19
19
|
is the number of time samples.
|
|
20
20
|
|
|
21
21
|
For each timeseries, calculate multiple normalized
|
|
22
|
-
spectrograms based on the
|
|
22
|
+
spectrograms based on the ``Spectrogram`` ``kwargs`` given.
|
|
23
23
|
Combine the spectrograms by taking the maximum value
|
|
24
24
|
from the nearest time-frequncy bin.
|
|
25
25
|
|
|
26
26
|
If the largest number of time bins among the spectrograms
|
|
27
|
-
is
|
|
28
|
-
the output will have dimensions
|
|
27
|
+
is ``N`` and the largest number of frequency bins is ``M``,
|
|
28
|
+
the output will have dimensions ``(B, C, M, N)``
|
|
29
29
|
|
|
30
30
|
Args:
|
|
31
31
|
kernel_length:
|
|
@@ -34,10 +34,11 @@ class MultiResolutionSpectrogram(torch.nn.Module):
|
|
|
34
34
|
spectrogram
|
|
35
35
|
sample_rate:
|
|
36
36
|
The sample rate of the timeseries in Hz
|
|
37
|
-
kwargs:
|
|
37
|
+
**kwargs:
|
|
38
38
|
Arguments passed in kwargs will used to create
|
|
39
|
-
|
|
40
|
-
|
|
39
|
+
``torchaudio.transforms.Spectrogram`` (see
|
|
40
|
+
`documentation <https://docs.pytorch.org/audio/main/generated/torchaudio.transforms.Spectrogram.html>`_).
|
|
41
|
+
Each argument should be a list of values. Any list
|
|
41
42
|
of length greater than 1 should be the same
|
|
42
43
|
length
|
|
43
44
|
"""
|
|
@@ -140,9 +141,9 @@ class MultiResolutionSpectrogram(torch.nn.Module):
|
|
|
140
141
|
Batch of multichannel timeseries which will
|
|
141
142
|
be used to calculate the multi-resolution
|
|
142
143
|
spectrogram. Should have the shape
|
|
143
|
-
|
|
144
|
-
batches,
|
|
145
|
-
and
|
|
144
|
+
``(B, C, T)``, where ``B`` is the number of
|
|
145
|
+
batches, ``C`` is the number of channels,
|
|
146
|
+
and ``T`` is the number of time samples.
|
|
146
147
|
"""
|
|
147
148
|
if X.shape[-1] != self.kernel_size:
|
|
148
149
|
raise ValueError(
|
|
@@ -1,131 +1,29 @@
|
|
|
1
1
|
"""
|
|
2
|
-
Adaptation of code from https://github.com/dottormale/
|
|
2
|
+
Adaptation of code from https://github.com/dottormale/Qtransform_torch/
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from typing import Optional, Tuple
|
|
6
6
|
|
|
7
7
|
import torch
|
|
8
|
-
import torch.nn.functional as F
|
|
9
8
|
from torch import Tensor
|
|
10
9
|
|
|
11
10
|
|
|
12
|
-
class
|
|
11
|
+
class SplineInterpolateBase(torch.nn.Module):
|
|
13
12
|
"""
|
|
14
|
-
|
|
15
|
-
Supports batched, multi-channel inputs, so acceptable data
|
|
16
|
-
shapes are `(width)`, `(height, width)`, `(batch, width)`,
|
|
17
|
-
`(batch, height, width)`, `(batch, channel, width)`, and
|
|
18
|
-
`(batch, channel, height, width)`.
|
|
19
|
-
|
|
20
|
-
During initialization of this Module, both the desired input
|
|
21
|
-
and output coordinate Tensors can be specified to allow
|
|
22
|
-
pre-computation of the B-spline basis matrices, though the only
|
|
23
|
-
mandatory argument is the coordinates of the data along the
|
|
24
|
-
`width` dimension. If no argument is given for coordinates along
|
|
25
|
-
the `height` dimension, it is assumed that 1D interpolation is
|
|
26
|
-
desired.
|
|
27
|
-
|
|
28
|
-
Unlike scipy's implementation of spline interpolation, the data
|
|
29
|
-
to be interpolated is not passed until actually calling the
|
|
30
|
-
object. This is useful for cases where the input and output
|
|
31
|
-
coordinates are known in advance, but the data is not, so that
|
|
32
|
-
the interpolator can be set up ahead of time.
|
|
33
|
-
|
|
34
|
-
WARNING: compared to scipy's spline interpolation, this method
|
|
35
|
-
produces edge artifacts when the output coordinates are near
|
|
36
|
-
the boundaries of the input coordinates. Therefore, it is
|
|
37
|
-
recommended to interpolate only to coordinates that are well
|
|
38
|
-
within the input coordinate range. Unfortunately, the specific
|
|
39
|
-
definition of "well within" changes based on the size of the
|
|
40
|
-
data, so some testing may be required to get good results.
|
|
41
|
-
|
|
42
|
-
Args:
|
|
43
|
-
x_in:
|
|
44
|
-
Coordinates of the width dimension of the data
|
|
45
|
-
y_in:
|
|
46
|
-
Coordinates of the height dimension of the data. If not
|
|
47
|
-
specified, it is assumed the 1D interpolation is desired,
|
|
48
|
-
and so the default value is a Tensor of length 1
|
|
49
|
-
kx:
|
|
50
|
-
Degree of spline interpolation along the width dimension.
|
|
51
|
-
Default is cubic.
|
|
52
|
-
ky:
|
|
53
|
-
Degree of spline interpolation along the height dimension.
|
|
54
|
-
Default is cubic.
|
|
55
|
-
sx:
|
|
56
|
-
Regularization factor to avoid singularities during matrix
|
|
57
|
-
inversion for interpolation along the width dimension. Not
|
|
58
|
-
to be confused with the `s` parameter in scipy's spline
|
|
59
|
-
methods, which controls the number of knots.
|
|
60
|
-
sy:
|
|
61
|
-
Regularization factor to avoid singularities during matrix
|
|
62
|
-
inversion for interpolation along the height dimension.
|
|
63
|
-
x_out:
|
|
64
|
-
Coordinates for the data to be interpolated to along the
|
|
65
|
-
width dimension. If not specified during initialization,
|
|
66
|
-
this must be specified during the object call.
|
|
67
|
-
y_out:
|
|
68
|
-
Coordinates for the data to be interpolated to along the
|
|
69
|
-
height dimension. If not specified during initialization,
|
|
70
|
-
this must be specified during the object call.
|
|
71
|
-
|
|
13
|
+
Base class for spline interpolation.
|
|
72
14
|
"""
|
|
73
15
|
|
|
74
|
-
def __init__(
|
|
75
|
-
self,
|
|
76
|
-
x_in: Tensor,
|
|
77
|
-
y_in: Tensor = None,
|
|
78
|
-
kx: int = 3,
|
|
79
|
-
ky: int = 3,
|
|
80
|
-
sx: float = 0.001,
|
|
81
|
-
sy: float = 0.001,
|
|
82
|
-
x_out: Optional[Tensor] = None,
|
|
83
|
-
y_out: Optional[Tensor] = None,
|
|
84
|
-
):
|
|
85
|
-
super().__init__()
|
|
86
|
-
if y_in is None:
|
|
87
|
-
y_in = Tensor([1])
|
|
88
|
-
self.kx = kx
|
|
89
|
-
self.ky = ky
|
|
90
|
-
self.sx = sx
|
|
91
|
-
self.sy = sy
|
|
92
|
-
self.register_buffer("x_in", x_in)
|
|
93
|
-
self.register_buffer("y_in", y_in)
|
|
94
|
-
self.register_buffer("x_out", x_out)
|
|
95
|
-
self.register_buffer("y_out", y_out)
|
|
96
|
-
|
|
97
|
-
tx, Bx, BxT_Bx = self._compute_knots_and_basis_matrices(x_in, kx, sx)
|
|
98
|
-
self.register_buffer("tx", tx)
|
|
99
|
-
self.register_buffer("Bx", Bx)
|
|
100
|
-
self.register_buffer("BxT_Bx", BxT_Bx)
|
|
101
|
-
|
|
102
|
-
ty, By, ByT_By = self._compute_knots_and_basis_matrices(y_in, ky, sy)
|
|
103
|
-
self.register_buffer("ty", ty)
|
|
104
|
-
self.register_buffer("By", By)
|
|
105
|
-
self.register_buffer("ByT_By", ByT_By)
|
|
106
|
-
|
|
107
|
-
if self.x_out is not None:
|
|
108
|
-
Bx_out = self.bspline_basis_natural(x_out, kx, self.tx)
|
|
109
|
-
self.register_buffer("Bx_out", Bx_out)
|
|
110
|
-
if self.y_out is not None:
|
|
111
|
-
By_out = self.bspline_basis_natural(y_out, ky, self.ty)
|
|
112
|
-
self.register_buffer("By_out", By_out)
|
|
113
|
-
|
|
114
16
|
def _compute_knots_and_basis_matrices(self, x, k, s):
|
|
115
|
-
knots = self.
|
|
17
|
+
knots = self.generate_fitpack_knots(x, k)
|
|
116
18
|
basis_matrix = self.bspline_basis_natural(x, k, knots)
|
|
117
19
|
identity = torch.eye(basis_matrix.shape[-1])
|
|
118
20
|
B_T_B = basis_matrix.T @ basis_matrix + s * identity
|
|
119
21
|
return knots, basis_matrix, B_T_B
|
|
120
22
|
|
|
121
|
-
def
|
|
23
|
+
def generate_fitpack_knots(self, x: Tensor, k: int) -> Tensor:
|
|
122
24
|
"""
|
|
123
|
-
Generates a
|
|
124
|
-
|
|
125
|
-
and end of datapoints as replicas of first and last datapoint
|
|
126
|
-
respectively in order to enforce natural boundary conditions,
|
|
127
|
-
i.e. second derivative = 0.
|
|
128
|
-
The other n nodes are placed in correspondece of the data points.
|
|
25
|
+
Generates a knot sequence for B-spline interpolation
|
|
26
|
+
in the same way as the FITPACK algorithm used by SciPy.
|
|
129
27
|
|
|
130
28
|
Args:
|
|
131
29
|
x: Tensor of data point positions.
|
|
@@ -134,7 +32,17 @@ class SplineInterpolate(torch.nn.Module):
|
|
|
134
32
|
Returns:
|
|
135
33
|
Tensor of knot positions.
|
|
136
34
|
"""
|
|
137
|
-
|
|
35
|
+
num_knots = x.shape[-1] + k + 1
|
|
36
|
+
knots = torch.zeros(num_knots, dtype=x.dtype)
|
|
37
|
+
knots[: k + 1] = x[0]
|
|
38
|
+
knots[-(k + 1) :] = x[-1]
|
|
39
|
+
|
|
40
|
+
# Interior knots are the rolling average of the data points
|
|
41
|
+
# excluding the first and last points
|
|
42
|
+
windows = x[1:-1].unfold(dimension=-1, size=k, step=1)
|
|
43
|
+
knots[k + 1 : -k - 1] = windows.mean(dim=-1)
|
|
44
|
+
|
|
45
|
+
return knots
|
|
138
46
|
|
|
139
47
|
def compute_L_R(
|
|
140
48
|
self,
|
|
@@ -233,9 +141,6 @@ class SplineInterpolate(torch.nn.Module):
|
|
|
233
141
|
Returns:
|
|
234
142
|
Tensor containing the kth-order B-spline basis functions
|
|
235
143
|
"""
|
|
236
|
-
|
|
237
|
-
if len(x) == 1:
|
|
238
|
-
return torch.eye(1)
|
|
239
144
|
n = x.shape[0]
|
|
240
145
|
m = t.shape[0] - k - 1
|
|
241
146
|
|
|
@@ -255,12 +160,272 @@ class SplineInterpolate(torch.nn.Module):
|
|
|
255
160
|
|
|
256
161
|
return b[:, :, -1]
|
|
257
162
|
|
|
258
|
-
def bivariate_spline_fit_natural(self, Z):
|
|
259
|
-
if len(Z.shape) == 3:
|
|
260
|
-
Z_Bx = torch.matmul(Z, self.Bx)
|
|
261
|
-
# ((BxT @ Bx)^-1 @ (Z @ Bx)T)T = Z @ BxT^-1
|
|
262
|
-
return torch.linalg.solve(self.BxT_Bx, Z_Bx.mT).mT
|
|
263
163
|
|
|
164
|
+
class SplineInterpolate1D(SplineInterpolateBase):
|
|
165
|
+
"""
|
|
166
|
+
Perform 1D spline interpolation based on De Boor's method.
|
|
167
|
+
It is allowed to have two spatial dimensions, but the second
|
|
168
|
+
dimension cannot be interpolated along. To interpolate along both
|
|
169
|
+
dimensions, use :class:`SplineInterpolate2D`.
|
|
170
|
+
|
|
171
|
+
Supports batched, multi-channel inputs, so acceptable data
|
|
172
|
+
shapes are ``(width)``, ``(height, width)``, ``(batch, width)``,
|
|
173
|
+
``(batch, height, width)``, ``(batch, channel, width)``, and
|
|
174
|
+
``(batch, channel, height, width)``.
|
|
175
|
+
|
|
176
|
+
During initialization of this Module, both the desired input
|
|
177
|
+
and output coordinate Tensors can be specified to allow
|
|
178
|
+
pre-computation of the B-spline basis matrices, though the only
|
|
179
|
+
mandatory argument is the coordinates of the data along the
|
|
180
|
+
``width`` dimension.
|
|
181
|
+
|
|
182
|
+
Unlike scipy's implementation of spline interpolation, the data
|
|
183
|
+
to be interpolated is not passed until actually calling the
|
|
184
|
+
object. This is useful for cases where the input and output
|
|
185
|
+
coordinates are known in advance, but the data is not, so that
|
|
186
|
+
the interpolator can be set up ahead of time.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
x_in:
|
|
190
|
+
Coordinates of the width dimension of the data
|
|
191
|
+
kx:
|
|
192
|
+
Degree of spline interpolation along the width dimension.
|
|
193
|
+
Default is cubic.
|
|
194
|
+
sx:
|
|
195
|
+
Regularization factor to avoid singularities during matrix
|
|
196
|
+
inversion for interpolation along the width dimension. Not
|
|
197
|
+
to be confused with the ``s`` parameter in scipy's spline
|
|
198
|
+
methods, which controls the number of knots.
|
|
199
|
+
x_out:
|
|
200
|
+
Coordinates for the data to be interpolated to along the
|
|
201
|
+
width dimension. If not specified during initialization,
|
|
202
|
+
this must be specified during the object call.
|
|
203
|
+
|
|
204
|
+
"""
|
|
205
|
+
|
|
206
|
+
def __init__(
|
|
207
|
+
self,
|
|
208
|
+
x_in: Tensor,
|
|
209
|
+
kx: int = 3,
|
|
210
|
+
sx: float = 0.0,
|
|
211
|
+
x_out: Optional[Tensor] = None,
|
|
212
|
+
):
|
|
213
|
+
super().__init__()
|
|
214
|
+
|
|
215
|
+
if len(x_in) < kx + 2:
|
|
216
|
+
raise ValueError(
|
|
217
|
+
"Input x-coordinates must have at least kx + 2 points."
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
# Ensure that coordinates are floats
|
|
221
|
+
x_in = x_in.float()
|
|
222
|
+
x_out = x_out.float() if x_out is not None else None
|
|
223
|
+
|
|
224
|
+
self.kx = kx
|
|
225
|
+
self.sx = sx
|
|
226
|
+
self.register_buffer("x_in", x_in)
|
|
227
|
+
self.register_buffer("x_out", x_out)
|
|
228
|
+
|
|
229
|
+
tx, Bx, BxT_Bx = self._compute_knots_and_basis_matrices(x_in, kx, sx)
|
|
230
|
+
self.register_buffer("tx", tx)
|
|
231
|
+
self.register_buffer("Bx", Bx)
|
|
232
|
+
self.register_buffer("BxT_Bx", BxT_Bx)
|
|
233
|
+
|
|
234
|
+
if self.x_out is not None:
|
|
235
|
+
x_clamped = torch.clamp(x_out, tx[kx], tx[-kx - 1])
|
|
236
|
+
Bx_out = self.bspline_basis_natural(x_clamped, kx, self.tx)
|
|
237
|
+
self.register_buffer("Bx_out", Bx_out)
|
|
238
|
+
|
|
239
|
+
def spline_fit_natural(self, Z):
|
|
240
|
+
# Adding batch/channel dimension handling
|
|
241
|
+
# Bx @ Z
|
|
242
|
+
BxT_Z = torch.einsum("ij,bchj->bchi", self.Bx.T, Z)
|
|
243
|
+
# (BxT @ Bx)^-1 @ (BxT @ Z) = Bx^-1 @ Z
|
|
244
|
+
C = torch.linalg.solve(self.BxT_Bx, BxT_Z.unsqueeze(-1))
|
|
245
|
+
return C.squeeze(-1)
|
|
246
|
+
|
|
247
|
+
def evaluate_spline(self, C: Tensor):
|
|
248
|
+
"""
|
|
249
|
+
Evaluate a bivariate spline on a grid of x and y points.
|
|
250
|
+
|
|
251
|
+
Args:
|
|
252
|
+
C: Coefficient tensor of shape (batch_size, mx, my).
|
|
253
|
+
|
|
254
|
+
Returns:
|
|
255
|
+
Z_interp: Interpolated values at the grid points.
|
|
256
|
+
"""
|
|
257
|
+
# Perform matrix multiplication using einsum to get Z_interp
|
|
258
|
+
return torch.einsum("ij,bchj->bchi", self.Bx_out, C)
|
|
259
|
+
|
|
260
|
+
def _validate_inputs(self, Z, x_out):
|
|
261
|
+
if x_out is None and self.x_out is None:
|
|
262
|
+
raise ValueError(
|
|
263
|
+
"Output x-coordinates were not specified in either object "
|
|
264
|
+
"creation or in forward call"
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
dims = len(Z.shape)
|
|
268
|
+
if dims > 4:
|
|
269
|
+
raise ValueError("Input data has more than 4 dimensions")
|
|
270
|
+
|
|
271
|
+
if Z.shape[-1] != len(self.x_in):
|
|
272
|
+
raise ValueError(
|
|
273
|
+
"The spatial dimensions of the data tensor do not match "
|
|
274
|
+
"the given input dimensions. "
|
|
275
|
+
f"Expected {len(self.x_in)}, but got {Z.shape[-1]}"
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
# Expand Z to have a batch, channel, and height dimension if needed
|
|
279
|
+
while len(Z.shape) < 4:
|
|
280
|
+
Z = Z.unsqueeze(0)
|
|
281
|
+
|
|
282
|
+
return Z
|
|
283
|
+
|
|
284
|
+
def forward(
|
|
285
|
+
self,
|
|
286
|
+
Z: Tensor,
|
|
287
|
+
x_out: Optional[Tensor] = None,
|
|
288
|
+
) -> Tensor:
|
|
289
|
+
"""
|
|
290
|
+
Compute the interpolated data
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
Z:
|
|
294
|
+
Tensor of data to be interpolated. Must be between 2 and 4
|
|
295
|
+
dimensions. The shape of the tensor must agree with the
|
|
296
|
+
input coordinates given on initialization.
|
|
297
|
+
x_out:
|
|
298
|
+
Coordinates to interpolate the data to along the width
|
|
299
|
+
dimension. Overrides any value that was set during
|
|
300
|
+
initialization.
|
|
301
|
+
|
|
302
|
+
Returns:
|
|
303
|
+
A 4D tensor with shape ``(batch, channel, height, width)``.
|
|
304
|
+
Depending on the input data shape, many of these dimensions
|
|
305
|
+
may have length 1.
|
|
306
|
+
"""
|
|
307
|
+
|
|
308
|
+
Z = self._validate_inputs(Z, x_out)
|
|
309
|
+
|
|
310
|
+
if x_out is not None:
|
|
311
|
+
x_out = x_out.float()
|
|
312
|
+
x_clamped = torch.clamp(
|
|
313
|
+
x_out, self.tx[self.kx], self.tx[-self.kx - 1]
|
|
314
|
+
)
|
|
315
|
+
self.Bx_out = self.bspline_basis_natural(
|
|
316
|
+
x_clamped, self.kx, self.tx
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
coef = self.spline_fit_natural(Z)
|
|
320
|
+
Z_interp = self.evaluate_spline(coef)
|
|
321
|
+
return Z_interp
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
class SplineInterpolate2D(SplineInterpolateBase):
|
|
325
|
+
"""
|
|
326
|
+
Perform 2D spline interpolation based on De Boor's method.
|
|
327
|
+
Supports batched, multi-channel inputs, so acceptable data
|
|
328
|
+
shapes are ``(height, width)``, ``(batch, height, width)``,
|
|
329
|
+
and ``(batch, channel, height, width)``.
|
|
330
|
+
|
|
331
|
+
During initialization of this Module, both the desired input
|
|
332
|
+
and output coordinate Tensors can be specified to allow
|
|
333
|
+
pre-computation of the B-spline basis matrices, though the only
|
|
334
|
+
mandatory arguments are the input coordinates.
|
|
335
|
+
|
|
336
|
+
Unlike scipy's implementation of spline interpolation, the data
|
|
337
|
+
to be interpolated is not passed until actually calling the
|
|
338
|
+
object. This is useful for cases where the input and output
|
|
339
|
+
coordinates are known in advance, but the data is not, so that
|
|
340
|
+
the interpolator can be set up ahead of time.
|
|
341
|
+
|
|
342
|
+
Args:
|
|
343
|
+
x_in:
|
|
344
|
+
Coordinates of the width dimension of the data
|
|
345
|
+
y_in:
|
|
346
|
+
Coordinates of the height dimension of the data.
|
|
347
|
+
kx:
|
|
348
|
+
Degree of spline interpolation along the width dimension.
|
|
349
|
+
Default is cubic.
|
|
350
|
+
ky:
|
|
351
|
+
Degree of spline interpolation along the height dimension.
|
|
352
|
+
Default is cubic.
|
|
353
|
+
sx:
|
|
354
|
+
Regularization factor to avoid singularities during matrix
|
|
355
|
+
inversion for interpolation along the width dimension. Not
|
|
356
|
+
to be confused with the ``s`` parameter in scipy's spline
|
|
357
|
+
methods, which controls the number of knots.
|
|
358
|
+
sy:
|
|
359
|
+
Regularization factor to avoid singularities during matrix
|
|
360
|
+
inversion for interpolation along the height dimension.
|
|
361
|
+
x_out:
|
|
362
|
+
Coordinates for the data to be interpolated to along the
|
|
363
|
+
width dimension. If not specified during initialization,
|
|
364
|
+
this must be specified during the object call.
|
|
365
|
+
y_out:
|
|
366
|
+
Coordinates for the data to be interpolated to along the
|
|
367
|
+
height dimension. If not specified during initialization,
|
|
368
|
+
this must be specified during the object call.
|
|
369
|
+
|
|
370
|
+
"""
|
|
371
|
+
|
|
372
|
+
def __init__(
|
|
373
|
+
self,
|
|
374
|
+
x_in: Tensor,
|
|
375
|
+
y_in: Tensor,
|
|
376
|
+
kx: int = 3,
|
|
377
|
+
ky: int = 3,
|
|
378
|
+
sx: float = 0.0,
|
|
379
|
+
sy: float = 0.0,
|
|
380
|
+
x_out: Optional[Tensor] = None,
|
|
381
|
+
y_out: Optional[Tensor] = None,
|
|
382
|
+
):
|
|
383
|
+
super().__init__()
|
|
384
|
+
|
|
385
|
+
if len(x_in) < kx + 2:
|
|
386
|
+
raise ValueError(
|
|
387
|
+
"Input x-coordinates must have at least kx + 2 points."
|
|
388
|
+
)
|
|
389
|
+
if len(y_in) < ky + 2:
|
|
390
|
+
raise ValueError(
|
|
391
|
+
"Input y-coordinates must have at least ky + 2 points."
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
# Ensure that coordinates are floats
|
|
395
|
+
x_in = x_in.float()
|
|
396
|
+
y_in = y_in.float()
|
|
397
|
+
x_out = x_out.float() if x_out is not None else None
|
|
398
|
+
y_out = y_out.float() if y_out is not None else None
|
|
399
|
+
|
|
400
|
+
self.kx = kx
|
|
401
|
+
self.ky = ky
|
|
402
|
+
self.sx = sx
|
|
403
|
+
self.sy = sy
|
|
404
|
+
self.register_buffer("x_in", x_in)
|
|
405
|
+
self.register_buffer("y_in", y_in)
|
|
406
|
+
self.register_buffer("x_out", x_out)
|
|
407
|
+
self.register_buffer("y_out", y_out)
|
|
408
|
+
|
|
409
|
+
tx, Bx, BxT_Bx = self._compute_knots_and_basis_matrices(x_in, kx, sx)
|
|
410
|
+
self.register_buffer("tx", tx)
|
|
411
|
+
self.register_buffer("Bx", Bx)
|
|
412
|
+
self.register_buffer("BxT_Bx", BxT_Bx)
|
|
413
|
+
|
|
414
|
+
ty, By, ByT_By = self._compute_knots_and_basis_matrices(y_in, ky, sy)
|
|
415
|
+
self.register_buffer("ty", ty)
|
|
416
|
+
self.register_buffer("By", By)
|
|
417
|
+
self.register_buffer("ByT_By", ByT_By)
|
|
418
|
+
|
|
419
|
+
if self.x_out is not None:
|
|
420
|
+
x_clamped = torch.clamp(x_out, tx[kx], tx[-kx - 1])
|
|
421
|
+
Bx_out = self.bspline_basis_natural(x_clamped, kx, self.tx)
|
|
422
|
+
self.register_buffer("Bx_out", Bx_out)
|
|
423
|
+
if self.y_out is not None:
|
|
424
|
+
y_clamped = torch.clamp(y_out, ty[ky], ty[-ky - 1])
|
|
425
|
+
By_out = self.bspline_basis_natural(y_clamped, ky, self.ty)
|
|
426
|
+
self.register_buffer("By_out", By_out)
|
|
427
|
+
|
|
428
|
+
def bivariate_spline_fit_natural(self, Z):
|
|
264
429
|
# Adding batch/channel dimension handling
|
|
265
430
|
# ByT @ Z @ BxW
|
|
266
431
|
ByT_Z_Bx = torch.einsum("ij,bcik,kl->bcjl", self.By, Z, self.Bx)
|
|
@@ -280,8 +445,6 @@ class SplineInterpolate(torch.nn.Module):
|
|
|
280
445
|
Z_interp: Interpolated values at the grid points.
|
|
281
446
|
"""
|
|
282
447
|
# Perform matrix multiplication using einsum to get Z_interp
|
|
283
|
-
if len(C.shape) == 3:
|
|
284
|
-
return torch.matmul(C, self.Bx_out.mT)
|
|
285
448
|
return torch.einsum("ik,bckm,mj->bcij", self.By_out, C, self.Bx_out.mT)
|
|
286
449
|
|
|
287
450
|
def _validate_inputs(self, Z, x_out, y_out):
|
|
@@ -292,29 +455,16 @@ class SplineInterpolate(torch.nn.Module):
|
|
|
292
455
|
)
|
|
293
456
|
|
|
294
457
|
if y_out is None and self.y_out is None:
|
|
295
|
-
|
|
458
|
+
raise ValueError(
|
|
459
|
+
"Output y-coordinates were not specified in either object "
|
|
460
|
+
"creation or in forward call"
|
|
461
|
+
)
|
|
296
462
|
|
|
297
463
|
dims = len(Z.shape)
|
|
298
464
|
if dims > 4:
|
|
299
465
|
raise ValueError("Input data has more than 4 dimensions")
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
raise ValueError(
|
|
303
|
-
"An input y-coordinate array with length greater than 1 "
|
|
304
|
-
"was given, but the input data is 1-dimensional. Expected "
|
|
305
|
-
"input data to be at least 2-dimensional"
|
|
306
|
-
)
|
|
307
|
-
|
|
308
|
-
# Expand Z to have 4 dimensions
|
|
309
|
-
# There are 6 valid input shapes: (w), (b, w), (b, c, w),
|
|
310
|
-
# (h, w), (b, h, w), and (b, c, h, w).
|
|
311
|
-
|
|
312
|
-
# If the input y coordinate array has length 1,
|
|
313
|
-
# assume the first dimension(s) are batch dimensions
|
|
314
|
-
# and that no height dimension is included in Z
|
|
315
|
-
idx = -2 if len(self.y_in) == 1 else -3
|
|
316
|
-
while len(Z.shape) < 4:
|
|
317
|
-
Z = Z.unsqueeze(idx)
|
|
466
|
+
if dims < 2:
|
|
467
|
+
raise ValueError("Input data has fewer than 2 dimensions")
|
|
318
468
|
|
|
319
469
|
if Z.shape[-2:] != torch.Size([len(self.y_in), len(self.x_in)]):
|
|
320
470
|
raise ValueError(
|
|
@@ -324,6 +474,10 @@ class SplineInterpolate(torch.nn.Module):
|
|
|
324
474
|
f"[{Z.shape[-2]}, {Z.shape[-1]}]"
|
|
325
475
|
)
|
|
326
476
|
|
|
477
|
+
# Expand Z to have a batch and channel dimension if needed
|
|
478
|
+
while len(Z.shape) < 4:
|
|
479
|
+
Z = Z.unsqueeze(0)
|
|
480
|
+
|
|
327
481
|
return Z, y_out
|
|
328
482
|
|
|
329
483
|
def forward(
|
|
@@ -337,11 +491,9 @@ class SplineInterpolate(torch.nn.Module):
|
|
|
337
491
|
|
|
338
492
|
Args:
|
|
339
493
|
Z:
|
|
340
|
-
Tensor of data to be interpolated. Must be between
|
|
494
|
+
Tensor of data to be interpolated. Must be between 2 and 4
|
|
341
495
|
dimensions. The shape of the tensor must agree with the
|
|
342
|
-
input coordinates given on initialization.
|
|
343
|
-
not specified during initialization, it is assumed that
|
|
344
|
-
Z does not have a height dimension.
|
|
496
|
+
input coordinates given on initialization.
|
|
345
497
|
x_out:
|
|
346
498
|
Coordinates to interpolate the data to along the width
|
|
347
499
|
dimension. Overrides any value that was set during
|
|
@@ -352,7 +504,7 @@ class SplineInterpolate(torch.nn.Module):
|
|
|
352
504
|
initialization.
|
|
353
505
|
|
|
354
506
|
Returns:
|
|
355
|
-
A 4D tensor with shape
|
|
507
|
+
A 4D tensor with shape ``(batch, channel, height, width)``.
|
|
356
508
|
Depending on the input data shape, many of these dimensions
|
|
357
509
|
may have length 1.
|
|
358
510
|
"""
|
|
@@ -360,9 +512,21 @@ class SplineInterpolate(torch.nn.Module):
|
|
|
360
512
|
Z, y_out = self._validate_inputs(Z, x_out, y_out)
|
|
361
513
|
|
|
362
514
|
if x_out is not None:
|
|
363
|
-
|
|
515
|
+
x_out = x_out.float()
|
|
516
|
+
x_clamped = torch.clamp(
|
|
517
|
+
x_out, self.tx[self.kx], self.tx[-self.kx - 1]
|
|
518
|
+
)
|
|
519
|
+
self.Bx_out = self.bspline_basis_natural(
|
|
520
|
+
x_clamped, self.kx, self.tx
|
|
521
|
+
)
|
|
364
522
|
if y_out is not None:
|
|
365
|
-
|
|
523
|
+
y_out = y_out.float()
|
|
524
|
+
y_clamped = torch.clamp(
|
|
525
|
+
y_out, self.ty[self.ky], self.ty[-self.ky - 1]
|
|
526
|
+
)
|
|
527
|
+
self.By_out = self.bspline_basis_natural(
|
|
528
|
+
y_clamped, self.ky, self.ty
|
|
529
|
+
)
|
|
366
530
|
|
|
367
531
|
coef = self.bivariate_spline_fit_natural(Z)
|
|
368
532
|
Z_interp = self.evaluate_bivariate_spline(coef)
|
ml4gw/transforms/transform.py
CHANGED
|
@@ -70,7 +70,7 @@ class FittableSpectralTransform(FittableTransform):
|
|
|
70
70
|
)
|
|
71
71
|
|
|
72
72
|
# add two dummy dimensions in case we need to interpolate
|
|
73
|
-
# the frequency dimension, since
|
|
73
|
+
# the frequency dimension, since ``interpolate`` expects
|
|
74
74
|
# a (batch, channel, spatial) formatted tensor as input
|
|
75
75
|
x = x.view(1, 1, -1)
|
|
76
76
|
if x.size(-1) != num_freqs:
|