ml4gw 0.7.4__py3-none-any.whl → 0.7.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ml4gw might be problematic. Click here for more details.
- ml4gw/augmentations.py +4 -4
- ml4gw/dataloading/chunked_dataset.py +3 -3
- ml4gw/dataloading/hdf5_dataset.py +7 -10
- ml4gw/dataloading/in_memory_dataset.py +21 -21
- ml4gw/distributions.py +216 -10
- ml4gw/gw.py +60 -53
- ml4gw/nn/autoencoder/base.py +9 -9
- ml4gw/nn/autoencoder/convolutional.py +4 -4
- ml4gw/nn/resnet/resnet_1d.py +13 -13
- ml4gw/nn/resnet/resnet_2d.py +12 -12
- ml4gw/nn/streaming/online_average.py +1 -1
- ml4gw/nn/streaming/snapshotter.py +14 -14
- ml4gw/spectral.py +48 -48
- ml4gw/transforms/iirfilter.py +3 -3
- ml4gw/transforms/pearson.py +7 -8
- ml4gw/transforms/qtransform.py +19 -19
- ml4gw/transforms/scaler.py +4 -4
- ml4gw/transforms/spectral.py +10 -10
- ml4gw/transforms/spectrogram.py +12 -11
- ml4gw/transforms/spline_interpolation.py +8 -15
- ml4gw/transforms/transform.py +1 -1
- ml4gw/transforms/whitening.py +36 -36
- ml4gw/utils/slicing.py +40 -40
- ml4gw/waveforms/cbc/phenom_d.py +22 -66
- ml4gw/waveforms/cbc/phenom_p.py +9 -5
- ml4gw/waveforms/cbc/taylorf2.py +8 -7
- ml4gw/waveforms/conversion.py +2 -1
- ml4gw/waveforms/generator.py +33 -32
- {ml4gw-0.7.4.dist-info → ml4gw-0.7.6.dist-info}/METADATA +7 -1
- ml4gw-0.7.6.dist-info/RECORD +55 -0
- ml4gw-0.7.4.dist-info/RECORD +0 -55
- {ml4gw-0.7.4.dist-info → ml4gw-0.7.6.dist-info}/WHEEL +0 -0
- {ml4gw-0.7.4.dist-info → ml4gw-0.7.6.dist-info}/licenses/LICENSE +0 -0
ml4gw/spectral.py
CHANGED
|
@@ -27,8 +27,8 @@ def median(x: Float[Tensor, "... size"], axis: int) -> Float[Tensor, "..."]:
|
|
|
27
27
|
"""
|
|
28
28
|
Implements a median calculation that matches numpy's
|
|
29
29
|
behavior for an even number of elements and includes
|
|
30
|
-
the same bias correction used by
|
|
31
|
-
|
|
30
|
+
the same bias correction used by
|
|
31
|
+
`scipy's implementation <https://github.com/scipy/scipy/blob/main/scipy/signal/_spectral_py.py#L2066>`_.
|
|
32
32
|
""" # noqa: E501
|
|
33
33
|
n = x.shape[axis]
|
|
34
34
|
ii_2 = 2 * torch.arange(1.0, (n - 1) // 2 + 1)
|
|
@@ -111,50 +111,50 @@ def fast_spectral_density(
|
|
|
111
111
|
The timeseries tensor whose power spectral density
|
|
112
112
|
to compute, or for cross spectral density the
|
|
113
113
|
timeseries whose fft will be conjugated. Can have
|
|
114
|
-
shape
|
|
115
|
-
|
|
116
|
-
|
|
114
|
+
shape ``(batch_size, num_channels, length * sample_rate)``,
|
|
115
|
+
``(num_channels, length * sample_rate)``, or
|
|
116
|
+
``(length * sample_rate)``.
|
|
117
117
|
nperseg:
|
|
118
118
|
Number of samples included in each FFT window
|
|
119
119
|
nstride:
|
|
120
120
|
Stride between FFT windows
|
|
121
121
|
window:
|
|
122
122
|
Window array to multiply by each FFT window before
|
|
123
|
-
FFT computation. Should have length
|
|
123
|
+
FFT computation. Should have length ``nperseg // 2 + 1``.
|
|
124
124
|
scale:
|
|
125
125
|
Scale factor to multiply the FFT'd data by, related to
|
|
126
126
|
desired units for output tensor (e.g. letting this equal
|
|
127
|
-
|
|
128
|
-
units of density,
|
|
127
|
+
``1 / (sample_rate * (window**2).sum())`` will give output
|
|
128
|
+
units of density, :math``\\text{Hz}^-1``.
|
|
129
129
|
average:
|
|
130
130
|
How to aggregate the contributions of each FFT window to
|
|
131
|
-
the spectral density. Allowed options are
|
|
132
|
-
|
|
131
|
+
the spectral density. Allowed options are ``'mean'`` and
|
|
132
|
+
``'median'``.
|
|
133
133
|
y:
|
|
134
134
|
Timeseries tensor to compute cross spectral density
|
|
135
|
-
with
|
|
136
|
-
density will be returned. Otherwise, if
|
|
137
|
-
|
|
135
|
+
with ``x``. If left as ``None``, ``x``'s power spectral
|
|
136
|
+
density will be returned. Otherwise, if ``x`` is 1D,
|
|
137
|
+
``y`` must also be 1D. If ``x`` is 2D, the assumption
|
|
138
138
|
is that this represents a single multi-channel timeseries,
|
|
139
|
-
and
|
|
139
|
+
and ``y`` must be either 2D or 1D. In the former case,
|
|
140
140
|
the cross-spectral densities of each channel will be
|
|
141
|
-
computed individually, so
|
|
142
|
-
Otherwise, this will compute the CSD of each of
|
|
143
|
-
with
|
|
144
|
-
of multi-channel timeseries. In this case,
|
|
141
|
+
computed individually, so ``y`` must have the same shape as ``x``.
|
|
142
|
+
Otherwise, this will compute the CSD of each of ``x``'s channels
|
|
143
|
+
with ``y``. If ``x`` is 3D, this will be assumed to be a batch
|
|
144
|
+
of multi-channel timeseries. In this case, ``y`` can either
|
|
145
145
|
be 3D, in which case each channel of each batch element will
|
|
146
146
|
have its CSD calculated or 2D, which has two different options.
|
|
147
|
-
If
|
|
148
|
-
be assumed that
|
|
147
|
+
If ``y``'s 0th dimension matches ``x``'s 0th dimension, it will
|
|
148
|
+
be assumed that ``y`` represents a batch of 1D timeseries, and
|
|
149
149
|
for each batch element this timeseries will have its CSD with
|
|
150
|
-
each channel of the corresponding batch element of
|
|
151
|
-
calculated. Otherwise, it sill be assumed that
|
|
150
|
+
each channel of the corresponding batch element of ``x``
|
|
151
|
+
calculated. Otherwise, it sill be assumed that ``y`` represents
|
|
152
152
|
a single multi-channel timeseries, in which case each channel
|
|
153
|
-
of
|
|
154
|
-
channel in
|
|
153
|
+
of ``y`` will have its CSD calculated with the corresponding
|
|
154
|
+
channel in ``x`` across _all_ of ``x``'s batch elements.
|
|
155
155
|
Returns:
|
|
156
|
-
Tensor of power spectral densities of
|
|
157
|
-
density with the timeseries in
|
|
156
|
+
Tensor of power spectral densities of ``x`` or its cross spectral
|
|
157
|
+
density with the timeseries in ``y``.
|
|
158
158
|
"""
|
|
159
159
|
|
|
160
160
|
_validate_shapes(x, nperseg, y)
|
|
@@ -262,25 +262,25 @@ def spectral_density(
|
|
|
262
262
|
The timeseries tensor whose power spectral density
|
|
263
263
|
to compute, or for cross spectral density the
|
|
264
264
|
timeseries whose fft will be conjugated. Can have
|
|
265
|
-
shape
|
|
266
|
-
|
|
267
|
-
|
|
265
|
+
shape ``(batch_size, num_channels, length * sample_rate)``,
|
|
266
|
+
``(num_channels, length * sample_rate)``, or
|
|
267
|
+
``(length * sample_rate)``.
|
|
268
268
|
nperseg:
|
|
269
269
|
Number of samples included in each FFT window
|
|
270
270
|
nstride:
|
|
271
271
|
Stride between FFT windows
|
|
272
272
|
window:
|
|
273
273
|
Window array to multiply by each FFT window before
|
|
274
|
-
FFT computation. Should have length
|
|
274
|
+
FFT computation. Should have length ``nperseg // 2 + 1``.
|
|
275
275
|
scale:
|
|
276
276
|
Scale factor to multiply the FFT'd data by, related to
|
|
277
277
|
desired units for output tensor (e.g. letting this equal
|
|
278
|
-
|
|
279
|
-
units of density,
|
|
278
|
+
``1 / (sample_rate * (window**2).sum())`` will give output
|
|
279
|
+
units of density, :math:`\\text{Hz}^-1`.
|
|
280
280
|
average:
|
|
281
281
|
How to aggregate the contributions of each FFT window to
|
|
282
|
-
the spectral density. Allowed options are
|
|
283
|
-
|
|
282
|
+
the spectral density. Allowed options are ``'mean'`` and
|
|
283
|
+
``'median'``.
|
|
284
284
|
"""
|
|
285
285
|
|
|
286
286
|
_validate_shapes(x, nperseg)
|
|
@@ -348,18 +348,18 @@ def truncate_inverse_power_spectrum(
|
|
|
348
348
|
"""
|
|
349
349
|
Truncate the length of the time domain response
|
|
350
350
|
of a whitening filter built using the specified
|
|
351
|
-
|
|
351
|
+
``psd`` so that it has maximum length ``fduration``
|
|
352
352
|
seconds. This is meant to mitigate the impact
|
|
353
353
|
of sharp features in the background PSD causing
|
|
354
354
|
time domain responses longer than the segments
|
|
355
355
|
to which the whitening filter will be applied.
|
|
356
356
|
|
|
357
357
|
Implementation details adapted from
|
|
358
|
-
https://github.com/vivinousi/gw-detection-deep-learning/blob/203966cc2ee47c32c292be000fb009a16824b7d9/modules/whiten.py#L8
|
|
358
|
+
`here <https://github.com/vivinousi/gw-detection-deep-learning/blob/203966cc2ee47c32c292be000fb009a16824b7d9/modules/whiten.py#L8>`_.
|
|
359
359
|
|
|
360
360
|
Args:
|
|
361
361
|
psd:
|
|
362
|
-
The one-sided power
|
|
362
|
+
The one-sided power spectral density used
|
|
363
363
|
to construct a whitening filter.
|
|
364
364
|
fduration:
|
|
365
365
|
Desired length in seconds of the time domain
|
|
@@ -375,14 +375,14 @@ def truncate_inverse_power_spectrum(
|
|
|
375
375
|
highpass:
|
|
376
376
|
If specified, will zero out the frequency response
|
|
377
377
|
of all frequencies below this value in Hz. If left
|
|
378
|
-
as
|
|
378
|
+
as ``None``, no highpass filtering will be applied.
|
|
379
379
|
lowpass:
|
|
380
380
|
If specified, will zero out the frequency response
|
|
381
381
|
of all frequencies above this value in Hz. If left
|
|
382
|
-
as
|
|
382
|
+
as ``None``, no lowpass filtering will be applied.
|
|
383
383
|
Returns:
|
|
384
384
|
The PSD with its time domain response truncated
|
|
385
|
-
to
|
|
385
|
+
to ``fduration`` and any filtered frequencies
|
|
386
386
|
tapered.
|
|
387
387
|
""" # noqa: E501
|
|
388
388
|
|
|
@@ -469,7 +469,7 @@ def whiten(
|
|
|
469
469
|
Whiten a batch of timeseries using the specified
|
|
470
470
|
background one-sided power spectral densities (PSDs),
|
|
471
471
|
modified to have the desired time domain response length
|
|
472
|
-
|
|
472
|
+
``fduration`` and possibly to highpass/lowpass filter.
|
|
473
473
|
|
|
474
474
|
Args:
|
|
475
475
|
X:
|
|
@@ -480,11 +480,11 @@ def whiten(
|
|
|
480
480
|
the inverse of the square root of this PSD, ensuring
|
|
481
481
|
that data from the same distribution will have
|
|
482
482
|
approximately uniform power after whitening.
|
|
483
|
-
If 2D, each batch element in
|
|
483
|
+
If 2D, each batch element in ``X`` will be whitened
|
|
484
484
|
using the same PSDs. If 3D, each batch element will
|
|
485
485
|
be whitened by the PSDs contained along the 0th
|
|
486
|
-
dimenion of
|
|
487
|
-
of
|
|
486
|
+
dimenion of ``psd``, and so the first two dimensions
|
|
487
|
+
of ``X`` and ``psd`` should match.
|
|
488
488
|
fduration:
|
|
489
489
|
Desired length in seconds of the time domain
|
|
490
490
|
response of a whitening filter built using
|
|
@@ -496,20 +496,20 @@ def whiten(
|
|
|
496
496
|
the whitened timeseries to account for filter
|
|
497
497
|
settle-in time.
|
|
498
498
|
sample_rate:
|
|
499
|
-
Rate at which the data in
|
|
499
|
+
Rate at which the data in ``X`` has been sampled
|
|
500
500
|
highpass:
|
|
501
501
|
The frequency in Hz at which to highpass filter
|
|
502
502
|
the data, setting the frequency response in the
|
|
503
|
-
whitening filter to 0. If left as
|
|
503
|
+
whitening filter to 0. If left as ``None``, no
|
|
504
504
|
highpass filtering will be applied.
|
|
505
505
|
lowpass:
|
|
506
506
|
The frequency in Hz at which to lowpass filter
|
|
507
507
|
the data, setting the frequency response in the
|
|
508
|
-
whitening filter to 0. If left as
|
|
508
|
+
whitening filter to 0. If left as ``None``, no
|
|
509
509
|
lowpass filtering will be applied.
|
|
510
510
|
Returns:
|
|
511
511
|
Batch of whitened multichannel timeseries with
|
|
512
|
-
|
|
512
|
+
``fduration / 2`` seconds trimmed from each side.
|
|
513
513
|
"""
|
|
514
514
|
|
|
515
515
|
# figure out how much data we'll need to slice
|
ml4gw/transforms/iirfilter.py
CHANGED
|
@@ -9,7 +9,7 @@ class IIRFilter(torch.nn.Module):
|
|
|
9
9
|
r"""
|
|
10
10
|
IIR digital and analog filter design given order and critical points.
|
|
11
11
|
Design an Nth-order digital or analog filter and apply it to a signal.
|
|
12
|
-
Uses SciPy's
|
|
12
|
+
Uses SciPy's ``iirfilter`` function to create the filter coefficients.
|
|
13
13
|
https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirfilter.html
|
|
14
14
|
|
|
15
15
|
The forward call of this module accepts a batch tensor of shape
|
|
@@ -24,8 +24,8 @@ class IIRFilter(torch.nn.Module):
|
|
|
24
24
|
default, fs is 2 half-cycles/sample, so these are normalized
|
|
25
25
|
from 0 to 1, where 1 is the Nyquist frequency. (Wn is thus in
|
|
26
26
|
half-cycles / sample). For analog filters, Wn is an angular
|
|
27
|
-
frequency (e.g., rad/s). When Wn is a length-2 sequence
|
|
28
|
-
must be less than
|
|
27
|
+
frequency (e.g., rad/s). When Wn is a length-2 sequence,``Wn[0]``
|
|
28
|
+
must be less than ``Wn[1]``.
|
|
29
29
|
rp:
|
|
30
30
|
For Chebyshev and elliptic filters, provides the maximum ripple in
|
|
31
31
|
the passband. (dB)
|
ml4gw/transforms/pearson.py
CHANGED
|
@@ -8,20 +8,19 @@ from ..utils.slicing import unfold_windows
|
|
|
8
8
|
|
|
9
9
|
class ShiftedPearsonCorrelation(torch.nn.Module):
|
|
10
10
|
"""
|
|
11
|
-
Compute the
|
|
12
|
-
(https://en.wikipedia.org/wiki/Pearson_correlation_coefficient)
|
|
11
|
+
Compute the `Pearson correlation <https://en.wikipedia.org/wiki/Pearson_correlation_coefficient>`_
|
|
13
12
|
for two equal-length timeseries over a pre-defined number of time
|
|
14
13
|
shifts in each direction. Useful for when you want a
|
|
15
14
|
correlation, but not over every possible shift (i.e.
|
|
16
15
|
a convolution).
|
|
17
16
|
|
|
18
|
-
The number of dimensions of the second timeseries
|
|
17
|
+
The number of dimensions of the second timeseries ``y``
|
|
19
18
|
passed at call time should always be less than or equal
|
|
20
|
-
to the number of dimensions of the first timeseries
|
|
19
|
+
to the number of dimensions of the first timeseries ``x``,
|
|
21
20
|
and each dimension should match the corresponding one of
|
|
22
|
-
|
|
23
|
-
then
|
|
24
|
-
|
|
21
|
+
``x`` in reverse order (i.e. if ``x`` has shape ``(B, C, T)``
|
|
22
|
+
then ``y`` should either have shape ``(T,)``, ``(C, T)``, or
|
|
23
|
+
``(B, C, T)``).
|
|
25
24
|
|
|
26
25
|
Note that no windowing to either timeseries is applied
|
|
27
26
|
at call time. Users should do any requisite windowing
|
|
@@ -36,7 +35,7 @@ class ShiftedPearsonCorrelation(torch.nn.Module):
|
|
|
36
35
|
The maximum number of 1-step time shifts in
|
|
37
36
|
each direction over which to compute the
|
|
38
37
|
Pearson coefficient. Output shape will then
|
|
39
|
-
be
|
|
38
|
+
be ``(2 * max_shifts + 1, B, C)``.
|
|
40
39
|
"""
|
|
41
40
|
|
|
42
41
|
def __init__(self, max_shift: int) -> None:
|
ml4gw/transforms/qtransform.py
CHANGED
|
@@ -22,7 +22,7 @@ class QTile(torch.nn.Module):
|
|
|
22
22
|
"""
|
|
23
23
|
Compute the row of Q-tiles for a single Q value and a single
|
|
24
24
|
frequency for a batch of multi-channel frequency series data.
|
|
25
|
-
Should really be called
|
|
25
|
+
Should really be called ``QRow``, but I want to match GWpy.
|
|
26
26
|
Input data should have three dimensions or fewer.
|
|
27
27
|
If fewer, dimensions will be added until the input is
|
|
28
28
|
three-dimensional.
|
|
@@ -112,17 +112,17 @@ class QTile(torch.nn.Module):
|
|
|
112
112
|
fseries:
|
|
113
113
|
Frequency series of data. Should correspond to data with
|
|
114
114
|
the duration and sample rate used to initialize this object.
|
|
115
|
-
Expected input shape is
|
|
115
|
+
Expected input shape is ``(B, C, F)``, where F is the number
|
|
116
116
|
of samples, C is the number of channels, and B is the number
|
|
117
117
|
of batches. If less than three-dimensional, axes will be
|
|
118
118
|
added.
|
|
119
119
|
norm:
|
|
120
120
|
The method of normalization. Options are "median", "mean", or
|
|
121
|
-
|
|
121
|
+
``None``.
|
|
122
122
|
|
|
123
123
|
Returns:
|
|
124
124
|
The row of Q-tiles for the given Q and frequency. Output is
|
|
125
|
-
three-dimensional:
|
|
125
|
+
three-dimensional: ``(B, C, T)``
|
|
126
126
|
"""
|
|
127
127
|
if len(fseries.shape) > 3:
|
|
128
128
|
raise ValueError("Input data has more than 3 dimensions")
|
|
@@ -164,7 +164,7 @@ class SingleQTransform(torch.nn.Module):
|
|
|
164
164
|
Sample rate of the data in Hz
|
|
165
165
|
spectrogram_shape:
|
|
166
166
|
The shape of the interpolated spectrogram, specified as
|
|
167
|
-
|
|
167
|
+
``(num_f_bins, num_t_bins)``. Because the
|
|
168
168
|
frequency spacing of the Q-tiles is in log-space, the frequency
|
|
169
169
|
interpolation is log-spaced as well.
|
|
170
170
|
q:
|
|
@@ -176,14 +176,14 @@ class SingleQTransform(torch.nn.Module):
|
|
|
176
176
|
mismatch:
|
|
177
177
|
The maximum fractional mismatch between neighboring tiles
|
|
178
178
|
interpolation_method:
|
|
179
|
-
The method by which to interpolate each
|
|
179
|
+
The method by which to interpolate each ``QTile`` to the specified
|
|
180
180
|
number of time and frequency bins. The acceptable values are
|
|
181
181
|
"bilinear", "bicubic", and "spline". The "bilinear" and "bicubic"
|
|
182
182
|
options will use PyTorch's built-in interpolation modes, while
|
|
183
183
|
"spline" will use the custom Torch-based implementation in
|
|
184
|
-
|
|
184
|
+
``ml4gw``, as PyTorch does not have spline-based intertpolation.
|
|
185
185
|
The "spline" mode is most similar to the results of GWpy's
|
|
186
|
-
Q-transform, which uses
|
|
186
|
+
Q-transform, which uses ``scipy`` to do spline interpolation.
|
|
187
187
|
However, it is also the slowest and most memory intensive due to
|
|
188
188
|
the matrix equation solving steps. Therefore, the default method
|
|
189
189
|
is "bicubic" as it produces the most similar results while
|
|
@@ -291,7 +291,7 @@ class SingleQTransform(torch.nn.Module):
|
|
|
291
291
|
def get_freqs(self) -> Float[Tensor, " nfreq"]:
|
|
292
292
|
"""
|
|
293
293
|
Calculate the frequencies that will be used in this transform.
|
|
294
|
-
For each frequency, a
|
|
294
|
+
For each frequency, a ``QTile`` is created.
|
|
295
295
|
"""
|
|
296
296
|
minf, maxf = self.frange
|
|
297
297
|
fcum_mismatch = (
|
|
@@ -320,7 +320,7 @@ class SingleQTransform(torch.nn.Module):
|
|
|
320
320
|
be slow, so this isn't used yet.
|
|
321
321
|
|
|
322
322
|
Optionally, a pair of frequency values can be specified for
|
|
323
|
-
|
|
323
|
+
``fsearch_range`` to restrict the frequencies in which the maximum
|
|
324
324
|
energy value is sought.
|
|
325
325
|
"""
|
|
326
326
|
allowed_dimensions = ["both", "neither", "channel", "batch"]
|
|
@@ -360,7 +360,7 @@ class SingleQTransform(torch.nn.Module):
|
|
|
360
360
|
) -> None:
|
|
361
361
|
"""
|
|
362
362
|
Take the FFT of the input timeseries and calculate the transform
|
|
363
|
-
for each
|
|
363
|
+
for each ``QTile``
|
|
364
364
|
"""
|
|
365
365
|
# Computing the FFT with the same normalization and scaling as GWpy
|
|
366
366
|
X = torch.fft.rfft(X, norm="forward")
|
|
@@ -416,9 +416,9 @@ class SingleQTransform(torch.nn.Module):
|
|
|
416
416
|
X:
|
|
417
417
|
Time series of data. Should have the duration and sample rate
|
|
418
418
|
used to initialize this object. Expected input shape is
|
|
419
|
-
|
|
420
|
-
of channels, and B is the number of batches. If less
|
|
421
|
-
three-dimensional, axes will be added during Q-tile
|
|
419
|
+
``(B, C, T)``, where T is the number of samples, C is the
|
|
420
|
+
number of channels, and B is the number of batches. If less
|
|
421
|
+
than three-dimensional, axes will be added during Q-tile
|
|
422
422
|
computation.
|
|
423
423
|
norm:
|
|
424
424
|
The method of normalization used by each QTile
|
|
@@ -445,13 +445,13 @@ class QScan(torch.nn.Module):
|
|
|
445
445
|
Sample rate of the data in Hz
|
|
446
446
|
spectrogram_shape:
|
|
447
447
|
The shape of the interpolated spectrogram, specified as
|
|
448
|
-
|
|
448
|
+
``(num_f_bins, num_t_bins)``. Because the
|
|
449
449
|
frequency spacing of the Q-tiles is in log-space, the frequency
|
|
450
450
|
interpolation is log-spaced as well.
|
|
451
451
|
qrange:
|
|
452
452
|
The lower and upper values of Q to consider. The
|
|
453
453
|
actual values of Q used for the transforms are
|
|
454
|
-
determined by the
|
|
454
|
+
determined by the ``get_qs`` method
|
|
455
455
|
frange:
|
|
456
456
|
The lower and upper frequency limit to consider for
|
|
457
457
|
the transform. If unspecified, default values will
|
|
@@ -535,9 +535,9 @@ class QScan(torch.nn.Module):
|
|
|
535
535
|
X:
|
|
536
536
|
Time series of data. Should have the duration and sample rate
|
|
537
537
|
used to initialize this object. Expected input shape is
|
|
538
|
-
|
|
539
|
-
of channels, and B is the number of batches. If less
|
|
540
|
-
three-dimensional, axes will be added during Q-tile
|
|
538
|
+
``(B, C, T)``, where T is the number of samples, C is the
|
|
539
|
+
number of channels, and B is the number of batches. If less
|
|
540
|
+
than three-dimensional, axes will be added during Q-tile
|
|
541
541
|
computation.
|
|
542
542
|
fsearch_range:
|
|
543
543
|
The lower and upper frequency values within which to search
|
ml4gw/transforms/scaler.py
CHANGED
|
@@ -13,14 +13,14 @@ class ChannelWiseScaler(FittableTransform):
|
|
|
13
13
|
Scales timeseries channels by the mean and standard
|
|
14
14
|
deviation of the channels of the timeseries used to
|
|
15
15
|
fit the module. To reverse the scaling, provide the
|
|
16
|
-
|
|
16
|
+
``reverse=True`` keyword argument at call time.
|
|
17
17
|
By default, the scaling parameters are set to zero mean
|
|
18
18
|
and unit variance, amounting to an identity transform.
|
|
19
19
|
|
|
20
20
|
Args:
|
|
21
21
|
num_channels:
|
|
22
22
|
The number of channels of the target timeseries.
|
|
23
|
-
If left as
|
|
23
|
+
If left as ``None``, the timeseries will be assumed
|
|
24
24
|
to be 1D (single channel).
|
|
25
25
|
"""
|
|
26
26
|
|
|
@@ -42,8 +42,8 @@ class ChannelWiseScaler(FittableTransform):
|
|
|
42
42
|
"""Fit the scaling parameters to a timeseries
|
|
43
43
|
|
|
44
44
|
Computes the channel-wise mean and standard deviation
|
|
45
|
-
of the timeseries
|
|
46
|
-
|
|
45
|
+
of the timeseries ``X`` and sets these values to the
|
|
46
|
+
``mean`` and ``std`` parameters of the scaler.
|
|
47
47
|
"""
|
|
48
48
|
|
|
49
49
|
if X.ndim == 1:
|
ml4gw/transforms/spectral.py
CHANGED
|
@@ -14,39 +14,39 @@ class SpectralDensity(torch.nn.Module):
|
|
|
14
14
|
of a batch of multichannel timeseries, or the cross spectral
|
|
15
15
|
density of two batches of multichannel timeseries.
|
|
16
16
|
|
|
17
|
-
On
|
|
17
|
+
On ``SpectralDensity.forward`` call, if only one tensor is provided,
|
|
18
18
|
this transform will compute its power spectral density. If a second
|
|
19
19
|
tensor is provided, the cross spectral density between the two
|
|
20
20
|
timeseries will be computed. For information about the allowed
|
|
21
21
|
relationships between these two tensors, see the documentation to
|
|
22
|
-
|
|
22
|
+
:meth:`~ml4gw.spectral.fast_spectral_density`.
|
|
23
23
|
|
|
24
24
|
Note that the cross spectral density computation is currently
|
|
25
|
-
only available for
|
|
26
|
-
|
|
27
|
-
a
|
|
25
|
+
only available for :meth:`~ml4gw.spectral.fast_spectral_density`. If
|
|
26
|
+
``fast=False`` and a second tensor is passed to ``SpectralDensity.forward``, # noqa E501
|
|
27
|
+
a ``NotImplementedError`` will be raised.
|
|
28
28
|
|
|
29
29
|
Args:
|
|
30
30
|
sample_rate:
|
|
31
|
-
Rate at which tensors passed to
|
|
31
|
+
Rate at which tensors passed to ``forward`` will be sampled
|
|
32
32
|
fftlength:
|
|
33
33
|
Length of the window, in seconds, to use for FFT estimates
|
|
34
34
|
overlap:
|
|
35
35
|
Overlap between windows used for FFT calculation. If left
|
|
36
|
-
as
|
|
36
|
+
as ``None``, this will be set to ``fftlength / 2``.
|
|
37
37
|
average:
|
|
38
38
|
Aggregation method to use for combining windowed FFTs.
|
|
39
|
-
Allowed values are
|
|
39
|
+
Allowed values are ``"mean"`` and ``"median"``.
|
|
40
40
|
window:
|
|
41
41
|
Window array to multiply by each FFT window before
|
|
42
|
-
FFT computation. Should have length
|
|
42
|
+
FFT computation. Should have length ``nperseg``.
|
|
43
43
|
Defaults to a hanning window.
|
|
44
44
|
fast:
|
|
45
45
|
Whether to use a faster spectral density computation that
|
|
46
46
|
support cross spectral density, or a slower one which does
|
|
47
47
|
not. The cost of the fast implementation is that it is not
|
|
48
48
|
exact for the two lowest frequency bins.
|
|
49
|
-
"""
|
|
49
|
+
""" # noqa E501
|
|
50
50
|
|
|
51
51
|
def __init__(
|
|
52
52
|
self,
|
ml4gw/transforms/spectrogram.py
CHANGED
|
@@ -14,18 +14,18 @@ class MultiResolutionSpectrogram(torch.nn.Module):
|
|
|
14
14
|
"""
|
|
15
15
|
Create a batch of multi-resolution spectrograms
|
|
16
16
|
from a batch of timeseries. Input is expected to
|
|
17
|
-
have the shape
|
|
18
|
-
of batches,
|
|
17
|
+
have the shape ``(B, C, T)``, where ``B`` is the number
|
|
18
|
+
of batches, ``C`` is the number of channels, and ``T``
|
|
19
19
|
is the number of time samples.
|
|
20
20
|
|
|
21
21
|
For each timeseries, calculate multiple normalized
|
|
22
|
-
spectrograms based on the
|
|
22
|
+
spectrograms based on the ``Spectrogram`` ``kwargs`` given.
|
|
23
23
|
Combine the spectrograms by taking the maximum value
|
|
24
24
|
from the nearest time-frequncy bin.
|
|
25
25
|
|
|
26
26
|
If the largest number of time bins among the spectrograms
|
|
27
|
-
is
|
|
28
|
-
the output will have dimensions
|
|
27
|
+
is ``N`` and the largest number of frequency bins is ``M``,
|
|
28
|
+
the output will have dimensions ``(B, C, M, N)``
|
|
29
29
|
|
|
30
30
|
Args:
|
|
31
31
|
kernel_length:
|
|
@@ -34,10 +34,11 @@ class MultiResolutionSpectrogram(torch.nn.Module):
|
|
|
34
34
|
spectrogram
|
|
35
35
|
sample_rate:
|
|
36
36
|
The sample rate of the timeseries in Hz
|
|
37
|
-
kwargs:
|
|
37
|
+
**kwargs:
|
|
38
38
|
Arguments passed in kwargs will used to create
|
|
39
|
-
|
|
40
|
-
|
|
39
|
+
``torchaudio.transforms.Spectrogram`` (see
|
|
40
|
+
`documentation <https://docs.pytorch.org/audio/main/generated/torchaudio.transforms.Spectrogram.html>`_).
|
|
41
|
+
Each argument should be a list of values. Any list
|
|
41
42
|
of length greater than 1 should be the same
|
|
42
43
|
length
|
|
43
44
|
"""
|
|
@@ -140,9 +141,9 @@ class MultiResolutionSpectrogram(torch.nn.Module):
|
|
|
140
141
|
Batch of multichannel timeseries which will
|
|
141
142
|
be used to calculate the multi-resolution
|
|
142
143
|
spectrogram. Should have the shape
|
|
143
|
-
|
|
144
|
-
batches,
|
|
145
|
-
and
|
|
144
|
+
``(B, C, T)``, where ``B`` is the number of
|
|
145
|
+
batches, ``C`` is the number of channels,
|
|
146
|
+
and ``T`` is the number of time samples.
|
|
146
147
|
"""
|
|
147
148
|
if X.shape[-1] != self.kernel_size:
|
|
148
149
|
raise ValueError(
|
|
@@ -13,16 +13,16 @@ class SplineInterpolate(torch.nn.Module):
|
|
|
13
13
|
"""
|
|
14
14
|
Perform 1D or 2D spline interpolation based on De Boor's method.
|
|
15
15
|
Supports batched, multi-channel inputs, so acceptable data
|
|
16
|
-
shapes are
|
|
17
|
-
|
|
18
|
-
|
|
16
|
+
shapes are ``(width)``, ``(height, width)``, ``(batch, width)``,
|
|
17
|
+
``(batch, height, width)``, ``(batch, channel, width)``, and
|
|
18
|
+
``(batch, channel, height, width)``.
|
|
19
19
|
|
|
20
20
|
During initialization of this Module, both the desired input
|
|
21
21
|
and output coordinate Tensors can be specified to allow
|
|
22
22
|
pre-computation of the B-spline basis matrices, though the only
|
|
23
23
|
mandatory argument is the coordinates of the data along the
|
|
24
|
-
|
|
25
|
-
the
|
|
24
|
+
``width`` dimension. If no argument is given for coordinates along
|
|
25
|
+
the ``height`` dimension, it is assumed that 1D interpolation is
|
|
26
26
|
desired.
|
|
27
27
|
|
|
28
28
|
Unlike scipy's implementation of spline interpolation, the data
|
|
@@ -55,7 +55,7 @@ class SplineInterpolate(torch.nn.Module):
|
|
|
55
55
|
sx:
|
|
56
56
|
Regularization factor to avoid singularities during matrix
|
|
57
57
|
inversion for interpolation along the width dimension. Not
|
|
58
|
-
to be confused with the
|
|
58
|
+
to be confused with the ``s`` parameter in scipy's spline
|
|
59
59
|
methods, which controls the number of knots.
|
|
60
60
|
sy:
|
|
61
61
|
Regularization factor to avoid singularities during matrix
|
|
@@ -256,11 +256,6 @@ class SplineInterpolate(torch.nn.Module):
|
|
|
256
256
|
return b[:, :, -1]
|
|
257
257
|
|
|
258
258
|
def bivariate_spline_fit_natural(self, Z):
|
|
259
|
-
if len(Z.shape) == 3:
|
|
260
|
-
Z_Bx = torch.matmul(Z, self.Bx)
|
|
261
|
-
# ((BxT @ Bx)^-1 @ (Z @ Bx)T)T = Z @ BxT^-1
|
|
262
|
-
return torch.linalg.solve(self.BxT_Bx, Z_Bx.mT).mT
|
|
263
|
-
|
|
264
259
|
# Adding batch/channel dimension handling
|
|
265
260
|
# ByT @ Z @ BxW
|
|
266
261
|
ByT_Z_Bx = torch.einsum("ij,bcik,kl->bcjl", self.By, Z, self.Bx)
|
|
@@ -280,8 +275,6 @@ class SplineInterpolate(torch.nn.Module):
|
|
|
280
275
|
Z_interp: Interpolated values at the grid points.
|
|
281
276
|
"""
|
|
282
277
|
# Perform matrix multiplication using einsum to get Z_interp
|
|
283
|
-
if len(C.shape) == 3:
|
|
284
|
-
return torch.matmul(C, self.Bx_out.mT)
|
|
285
278
|
return torch.einsum("ik,bckm,mj->bcij", self.By_out, C, self.Bx_out.mT)
|
|
286
279
|
|
|
287
280
|
def _validate_inputs(self, Z, x_out, y_out):
|
|
@@ -339,7 +332,7 @@ class SplineInterpolate(torch.nn.Module):
|
|
|
339
332
|
Z:
|
|
340
333
|
Tensor of data to be interpolated. Must be between 1 and 4
|
|
341
334
|
dimensions. The shape of the tensor must agree with the
|
|
342
|
-
input coordinates given on initialization. If
|
|
335
|
+
input coordinates given on initialization. If ``y_in`` was
|
|
343
336
|
not specified during initialization, it is assumed that
|
|
344
337
|
Z does not have a height dimension.
|
|
345
338
|
x_out:
|
|
@@ -352,7 +345,7 @@ class SplineInterpolate(torch.nn.Module):
|
|
|
352
345
|
initialization.
|
|
353
346
|
|
|
354
347
|
Returns:
|
|
355
|
-
A 4D tensor with shape
|
|
348
|
+
A 4D tensor with shape ``(batch, channel, height, width)``.
|
|
356
349
|
Depending on the input data shape, many of these dimensions
|
|
357
350
|
may have length 1.
|
|
358
351
|
"""
|
ml4gw/transforms/transform.py
CHANGED
|
@@ -70,7 +70,7 @@ class FittableSpectralTransform(FittableTransform):
|
|
|
70
70
|
)
|
|
71
71
|
|
|
72
72
|
# add two dummy dimensions in case we need to interpolate
|
|
73
|
-
# the frequency dimension, since
|
|
73
|
+
# the frequency dimension, since ``interpolate`` expects
|
|
74
74
|
# a (batch, channel, spatial) formatted tensor as input
|
|
75
75
|
x = x.view(1, 1, -1)
|
|
76
76
|
if x.size(-1) != num_freqs:
|