ml4gw 0.7.5__py3-none-any.whl → 0.7.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ml4gw might be problematic. Click here for more details.
- ml4gw/augmentations.py +4 -4
- ml4gw/dataloading/chunked_dataset.py +3 -3
- ml4gw/dataloading/hdf5_dataset.py +7 -10
- ml4gw/dataloading/in_memory_dataset.py +21 -21
- ml4gw/distributions.py +20 -18
- ml4gw/gw.py +60 -53
- ml4gw/nn/autoencoder/base.py +9 -9
- ml4gw/nn/autoencoder/convolutional.py +4 -4
- ml4gw/nn/resnet/resnet_1d.py +13 -13
- ml4gw/nn/resnet/resnet_2d.py +12 -12
- ml4gw/nn/streaming/online_average.py +1 -1
- ml4gw/nn/streaming/snapshotter.py +14 -14
- ml4gw/spectral.py +48 -48
- ml4gw/transforms/__init__.py +1 -1
- ml4gw/transforms/iirfilter.py +3 -3
- ml4gw/transforms/pearson.py +7 -8
- ml4gw/transforms/qtransform.py +29 -34
- ml4gw/transforms/scaler.py +4 -4
- ml4gw/transforms/spectral.py +10 -10
- ml4gw/transforms/spectrogram.py +12 -11
- ml4gw/transforms/spline_interpolation.py +310 -146
- ml4gw/transforms/transform.py +1 -1
- ml4gw/transforms/whitening.py +36 -36
- ml4gw/utils/slicing.py +40 -40
- ml4gw/waveforms/cbc/phenom_d.py +22 -66
- ml4gw/waveforms/cbc/phenom_p.py +9 -5
- ml4gw/waveforms/cbc/taylorf2.py +8 -7
- ml4gw/waveforms/conversion.py +2 -1
- ml4gw/waveforms/generator.py +33 -32
- {ml4gw-0.7.5.dist-info → ml4gw-0.7.7.dist-info}/METADATA +6 -5
- ml4gw-0.7.7.dist-info/RECORD +56 -0
- {ml4gw-0.7.5.dist-info → ml4gw-0.7.7.dist-info}/WHEEL +2 -1
- ml4gw-0.7.7.dist-info/top_level.txt +1 -0
- ml4gw-0.7.5.dist-info/RECORD +0 -55
- {ml4gw-0.7.5.dist-info → ml4gw-0.7.7.dist-info}/licenses/LICENSE +0 -0
ml4gw/nn/resnet/resnet_2d.py
CHANGED
|
@@ -105,10 +105,10 @@ class BasicBlock(nn.Module):
|
|
|
105
105
|
class Bottleneck(nn.Module):
|
|
106
106
|
"""
|
|
107
107
|
Bottleneck blocks implement one extra convolution
|
|
108
|
-
compared to basic blocks. In this layers, the
|
|
108
|
+
compared to basic blocks. In this layers, the ``planes``
|
|
109
109
|
parameter is generally meant to _downsize_ the number
|
|
110
110
|
of feature maps first, which then get expanded out to
|
|
111
|
-
|
|
111
|
+
``planes * Bottleneck.expansion`` feature maps at the
|
|
112
112
|
output of the layer.
|
|
113
113
|
"""
|
|
114
114
|
|
|
@@ -188,9 +188,9 @@ class ResNet2D(nn.Module):
|
|
|
188
188
|
A list representing the number of residual
|
|
189
189
|
blocks to include in each "layer" of the
|
|
190
190
|
network. Total layers (e.g. 50 in ResNet50)
|
|
191
|
-
is
|
|
192
|
-
is
|
|
193
|
-
|
|
191
|
+
is ``2 + sum(layers) * factor``, where factor
|
|
192
|
+
is ``2`` for vanilla ``ResNet`` and ``3`` for
|
|
193
|
+
``BottleneckResNet``.
|
|
194
194
|
kernel_size:
|
|
195
195
|
The size of the convolutional kernel to
|
|
196
196
|
use in all residual layers. _NOT_ the size
|
|
@@ -207,22 +207,22 @@ class ResNet2D(nn.Module):
|
|
|
207
207
|
connections between feature maps at subsequent
|
|
208
208
|
layers rather than global. Generally won't
|
|
209
209
|
need this to be >1, and wil raise an error if
|
|
210
|
-
>1 when using vanilla
|
|
210
|
+
>1 when using vanilla ``ResNet``.
|
|
211
211
|
width_per_group:
|
|
212
212
|
Base width of each of the feature map groups,
|
|
213
213
|
which is scaled up by the typical expansion
|
|
214
214
|
factor at each layer of the network. Meaningless
|
|
215
|
-
for vanilla
|
|
215
|
+
for vanilla ``ResNet``.
|
|
216
216
|
stride_type:
|
|
217
217
|
Whether to achieve downsampling on the time axis
|
|
218
218
|
by strided or dilated convolutions for each layer.
|
|
219
|
-
If left as
|
|
220
|
-
used at each layer. Otherwise,
|
|
221
|
-
be one element shorter than
|
|
222
|
-
|
|
219
|
+
If left as ``None``, strided convolutions will be
|
|
220
|
+
used at each layer. Otherwise, ``stride_type`` should
|
|
221
|
+
be one element shorter than ``layers`` and indicate either
|
|
222
|
+
``stride`` or ``dilation`` for each layer after the first.
|
|
223
223
|
norm_groups:
|
|
224
224
|
The number of groups to use in GroupNorm layers
|
|
225
|
-
throughout the model. If left as
|
|
225
|
+
throughout the model. If left as ``-1``, the number
|
|
226
226
|
of groups will be equal to the number of channels,
|
|
227
227
|
making this equilavent to LayerNorm
|
|
228
228
|
"""
|
|
@@ -11,7 +11,7 @@ class OnlineAverager(torch.nn.Module):
|
|
|
11
11
|
"""
|
|
12
12
|
Module for performing stateful online averaging of
|
|
13
13
|
batches of overlapping timeseries. At present, the
|
|
14
|
-
first
|
|
14
|
+
first ``num_updates`` predictions produced by this
|
|
15
15
|
model will underestimate the true average.
|
|
16
16
|
|
|
17
17
|
Args:
|
|
@@ -12,24 +12,24 @@ class Snapshotter(torch.nn.Module):
|
|
|
12
12
|
Model for converting streaming state updates into
|
|
13
13
|
a batch of overlapping snaphots of a multichannel
|
|
14
14
|
timeseries. Can support multiple timeseries in a
|
|
15
|
-
single state update via the
|
|
15
|
+
single state update via the ``channels_per_snapshot``
|
|
16
16
|
kwarg.
|
|
17
17
|
|
|
18
18
|
Specifically, maps tensors of shape
|
|
19
|
-
|
|
20
|
-
of shape
|
|
21
|
-
If
|
|
22
|
-
|
|
19
|
+
``(num_channels, batch_size * stride_size)`` to a tensor
|
|
20
|
+
of shape ``(batch_size, num_channels, snapshot_size)``.
|
|
21
|
+
If ``channels_per_snapshot`` is specified, it will return
|
|
22
|
+
``len(channels_per_snapshot)`` tensors of this shape,
|
|
23
23
|
with the channel dimension replaced by the corresponding
|
|
24
|
-
value of
|
|
24
|
+
value of ``channels_per_snapshot``. The last tensor returned
|
|
25
25
|
at call time will be the current state that can be passed
|
|
26
|
-
to the next
|
|
26
|
+
to the next ``forward`` call.
|
|
27
27
|
|
|
28
28
|
Args:
|
|
29
29
|
num_channels:
|
|
30
30
|
Number of channels in the timeseries. If
|
|
31
|
-
|
|
32
|
-
this should be equal to
|
|
31
|
+
``channels_per_snapshot`` is not ``None``,
|
|
32
|
+
this should be equal to ``sum(channels_per_snapshot)``.
|
|
33
33
|
snapshot_size:
|
|
34
34
|
The size of the output snapshot windows in
|
|
35
35
|
number of samples
|
|
@@ -39,17 +39,17 @@ class Snapshotter(torch.nn.Module):
|
|
|
39
39
|
batch_size:
|
|
40
40
|
The number of snapshots to produce at each
|
|
41
41
|
update. The last dimension of the input
|
|
42
|
-
tensor should have size
|
|
42
|
+
tensor should have size ``batch_size * stride_size``.
|
|
43
43
|
channels_per_snapshot:
|
|
44
44
|
How to split up the channels in the timeseries
|
|
45
|
-
for different tensors. If left as
|
|
45
|
+
for different tensors. If left as ``None``, all
|
|
46
46
|
the channels will be returned in a single tensor.
|
|
47
47
|
Otherwise, the channels will be split up into
|
|
48
|
-
|
|
48
|
+
``len(channels_per_snapshot)`` tensors, with each
|
|
49
49
|
tensor's channel dimension being equal to the
|
|
50
|
-
corresponding value in
|
|
50
|
+
corresponding value in ``channels_per_snapshot``.
|
|
51
51
|
Therefore, if specified, these values should
|
|
52
|
-
add up to
|
|
52
|
+
add up to ``num_channels``.
|
|
53
53
|
"""
|
|
54
54
|
|
|
55
55
|
def __init__(
|
ml4gw/spectral.py
CHANGED
|
@@ -27,8 +27,8 @@ def median(x: Float[Tensor, "... size"], axis: int) -> Float[Tensor, "..."]:
|
|
|
27
27
|
"""
|
|
28
28
|
Implements a median calculation that matches numpy's
|
|
29
29
|
behavior for an even number of elements and includes
|
|
30
|
-
the same bias correction used by
|
|
31
|
-
|
|
30
|
+
the same bias correction used by
|
|
31
|
+
`scipy's implementation <https://github.com/scipy/scipy/blob/main/scipy/signal/_spectral_py.py#L2066>`_.
|
|
32
32
|
""" # noqa: E501
|
|
33
33
|
n = x.shape[axis]
|
|
34
34
|
ii_2 = 2 * torch.arange(1.0, (n - 1) // 2 + 1)
|
|
@@ -111,50 +111,50 @@ def fast_spectral_density(
|
|
|
111
111
|
The timeseries tensor whose power spectral density
|
|
112
112
|
to compute, or for cross spectral density the
|
|
113
113
|
timeseries whose fft will be conjugated. Can have
|
|
114
|
-
shape
|
|
115
|
-
|
|
116
|
-
|
|
114
|
+
shape ``(batch_size, num_channels, length * sample_rate)``,
|
|
115
|
+
``(num_channels, length * sample_rate)``, or
|
|
116
|
+
``(length * sample_rate)``.
|
|
117
117
|
nperseg:
|
|
118
118
|
Number of samples included in each FFT window
|
|
119
119
|
nstride:
|
|
120
120
|
Stride between FFT windows
|
|
121
121
|
window:
|
|
122
122
|
Window array to multiply by each FFT window before
|
|
123
|
-
FFT computation. Should have length
|
|
123
|
+
FFT computation. Should have length ``nperseg // 2 + 1``.
|
|
124
124
|
scale:
|
|
125
125
|
Scale factor to multiply the FFT'd data by, related to
|
|
126
126
|
desired units for output tensor (e.g. letting this equal
|
|
127
|
-
|
|
128
|
-
units of density,
|
|
127
|
+
``1 / (sample_rate * (window**2).sum())`` will give output
|
|
128
|
+
units of density, :math``\\text{Hz}^-1``.
|
|
129
129
|
average:
|
|
130
130
|
How to aggregate the contributions of each FFT window to
|
|
131
|
-
the spectral density. Allowed options are
|
|
132
|
-
|
|
131
|
+
the spectral density. Allowed options are ``'mean'`` and
|
|
132
|
+
``'median'``.
|
|
133
133
|
y:
|
|
134
134
|
Timeseries tensor to compute cross spectral density
|
|
135
|
-
with
|
|
136
|
-
density will be returned. Otherwise, if
|
|
137
|
-
|
|
135
|
+
with ``x``. If left as ``None``, ``x``'s power spectral
|
|
136
|
+
density will be returned. Otherwise, if ``x`` is 1D,
|
|
137
|
+
``y`` must also be 1D. If ``x`` is 2D, the assumption
|
|
138
138
|
is that this represents a single multi-channel timeseries,
|
|
139
|
-
and
|
|
139
|
+
and ``y`` must be either 2D or 1D. In the former case,
|
|
140
140
|
the cross-spectral densities of each channel will be
|
|
141
|
-
computed individually, so
|
|
142
|
-
Otherwise, this will compute the CSD of each of
|
|
143
|
-
with
|
|
144
|
-
of multi-channel timeseries. In this case,
|
|
141
|
+
computed individually, so ``y`` must have the same shape as ``x``.
|
|
142
|
+
Otherwise, this will compute the CSD of each of ``x``'s channels
|
|
143
|
+
with ``y``. If ``x`` is 3D, this will be assumed to be a batch
|
|
144
|
+
of multi-channel timeseries. In this case, ``y`` can either
|
|
145
145
|
be 3D, in which case each channel of each batch element will
|
|
146
146
|
have its CSD calculated or 2D, which has two different options.
|
|
147
|
-
If
|
|
148
|
-
be assumed that
|
|
147
|
+
If ``y``'s 0th dimension matches ``x``'s 0th dimension, it will
|
|
148
|
+
be assumed that ``y`` represents a batch of 1D timeseries, and
|
|
149
149
|
for each batch element this timeseries will have its CSD with
|
|
150
|
-
each channel of the corresponding batch element of
|
|
151
|
-
calculated. Otherwise, it sill be assumed that
|
|
150
|
+
each channel of the corresponding batch element of ``x``
|
|
151
|
+
calculated. Otherwise, it sill be assumed that ``y`` represents
|
|
152
152
|
a single multi-channel timeseries, in which case each channel
|
|
153
|
-
of
|
|
154
|
-
channel in
|
|
153
|
+
of ``y`` will have its CSD calculated with the corresponding
|
|
154
|
+
channel in ``x`` across _all_ of ``x``'s batch elements.
|
|
155
155
|
Returns:
|
|
156
|
-
Tensor of power spectral densities of
|
|
157
|
-
density with the timeseries in
|
|
156
|
+
Tensor of power spectral densities of ``x`` or its cross spectral
|
|
157
|
+
density with the timeseries in ``y``.
|
|
158
158
|
"""
|
|
159
159
|
|
|
160
160
|
_validate_shapes(x, nperseg, y)
|
|
@@ -262,25 +262,25 @@ def spectral_density(
|
|
|
262
262
|
The timeseries tensor whose power spectral density
|
|
263
263
|
to compute, or for cross spectral density the
|
|
264
264
|
timeseries whose fft will be conjugated. Can have
|
|
265
|
-
shape
|
|
266
|
-
|
|
267
|
-
|
|
265
|
+
shape ``(batch_size, num_channels, length * sample_rate)``,
|
|
266
|
+
``(num_channels, length * sample_rate)``, or
|
|
267
|
+
``(length * sample_rate)``.
|
|
268
268
|
nperseg:
|
|
269
269
|
Number of samples included in each FFT window
|
|
270
270
|
nstride:
|
|
271
271
|
Stride between FFT windows
|
|
272
272
|
window:
|
|
273
273
|
Window array to multiply by each FFT window before
|
|
274
|
-
FFT computation. Should have length
|
|
274
|
+
FFT computation. Should have length ``nperseg // 2 + 1``.
|
|
275
275
|
scale:
|
|
276
276
|
Scale factor to multiply the FFT'd data by, related to
|
|
277
277
|
desired units for output tensor (e.g. letting this equal
|
|
278
|
-
|
|
279
|
-
units of density,
|
|
278
|
+
``1 / (sample_rate * (window**2).sum())`` will give output
|
|
279
|
+
units of density, :math:`\\text{Hz}^-1`.
|
|
280
280
|
average:
|
|
281
281
|
How to aggregate the contributions of each FFT window to
|
|
282
|
-
the spectral density. Allowed options are
|
|
283
|
-
|
|
282
|
+
the spectral density. Allowed options are ``'mean'`` and
|
|
283
|
+
``'median'``.
|
|
284
284
|
"""
|
|
285
285
|
|
|
286
286
|
_validate_shapes(x, nperseg)
|
|
@@ -348,18 +348,18 @@ def truncate_inverse_power_spectrum(
|
|
|
348
348
|
"""
|
|
349
349
|
Truncate the length of the time domain response
|
|
350
350
|
of a whitening filter built using the specified
|
|
351
|
-
|
|
351
|
+
``psd`` so that it has maximum length ``fduration``
|
|
352
352
|
seconds. This is meant to mitigate the impact
|
|
353
353
|
of sharp features in the background PSD causing
|
|
354
354
|
time domain responses longer than the segments
|
|
355
355
|
to which the whitening filter will be applied.
|
|
356
356
|
|
|
357
357
|
Implementation details adapted from
|
|
358
|
-
https://github.com/vivinousi/gw-detection-deep-learning/blob/203966cc2ee47c32c292be000fb009a16824b7d9/modules/whiten.py#L8
|
|
358
|
+
`here <https://github.com/vivinousi/gw-detection-deep-learning/blob/203966cc2ee47c32c292be000fb009a16824b7d9/modules/whiten.py#L8>`_.
|
|
359
359
|
|
|
360
360
|
Args:
|
|
361
361
|
psd:
|
|
362
|
-
The one-sided power
|
|
362
|
+
The one-sided power spectral density used
|
|
363
363
|
to construct a whitening filter.
|
|
364
364
|
fduration:
|
|
365
365
|
Desired length in seconds of the time domain
|
|
@@ -375,14 +375,14 @@ def truncate_inverse_power_spectrum(
|
|
|
375
375
|
highpass:
|
|
376
376
|
If specified, will zero out the frequency response
|
|
377
377
|
of all frequencies below this value in Hz. If left
|
|
378
|
-
as
|
|
378
|
+
as ``None``, no highpass filtering will be applied.
|
|
379
379
|
lowpass:
|
|
380
380
|
If specified, will zero out the frequency response
|
|
381
381
|
of all frequencies above this value in Hz. If left
|
|
382
|
-
as
|
|
382
|
+
as ``None``, no lowpass filtering will be applied.
|
|
383
383
|
Returns:
|
|
384
384
|
The PSD with its time domain response truncated
|
|
385
|
-
to
|
|
385
|
+
to ``fduration`` and any filtered frequencies
|
|
386
386
|
tapered.
|
|
387
387
|
""" # noqa: E501
|
|
388
388
|
|
|
@@ -469,7 +469,7 @@ def whiten(
|
|
|
469
469
|
Whiten a batch of timeseries using the specified
|
|
470
470
|
background one-sided power spectral densities (PSDs),
|
|
471
471
|
modified to have the desired time domain response length
|
|
472
|
-
|
|
472
|
+
``fduration`` and possibly to highpass/lowpass filter.
|
|
473
473
|
|
|
474
474
|
Args:
|
|
475
475
|
X:
|
|
@@ -480,11 +480,11 @@ def whiten(
|
|
|
480
480
|
the inverse of the square root of this PSD, ensuring
|
|
481
481
|
that data from the same distribution will have
|
|
482
482
|
approximately uniform power after whitening.
|
|
483
|
-
If 2D, each batch element in
|
|
483
|
+
If 2D, each batch element in ``X`` will be whitened
|
|
484
484
|
using the same PSDs. If 3D, each batch element will
|
|
485
485
|
be whitened by the PSDs contained along the 0th
|
|
486
|
-
dimenion of
|
|
487
|
-
of
|
|
486
|
+
dimenion of ``psd``, and so the first two dimensions
|
|
487
|
+
of ``X`` and ``psd`` should match.
|
|
488
488
|
fduration:
|
|
489
489
|
Desired length in seconds of the time domain
|
|
490
490
|
response of a whitening filter built using
|
|
@@ -496,20 +496,20 @@ def whiten(
|
|
|
496
496
|
the whitened timeseries to account for filter
|
|
497
497
|
settle-in time.
|
|
498
498
|
sample_rate:
|
|
499
|
-
Rate at which the data in
|
|
499
|
+
Rate at which the data in ``X`` has been sampled
|
|
500
500
|
highpass:
|
|
501
501
|
The frequency in Hz at which to highpass filter
|
|
502
502
|
the data, setting the frequency response in the
|
|
503
|
-
whitening filter to 0. If left as
|
|
503
|
+
whitening filter to 0. If left as ``None``, no
|
|
504
504
|
highpass filtering will be applied.
|
|
505
505
|
lowpass:
|
|
506
506
|
The frequency in Hz at which to lowpass filter
|
|
507
507
|
the data, setting the frequency response in the
|
|
508
|
-
whitening filter to 0. If left as
|
|
508
|
+
whitening filter to 0. If left as ``None``, no
|
|
509
509
|
lowpass filtering will be applied.
|
|
510
510
|
Returns:
|
|
511
511
|
Batch of whitened multichannel timeseries with
|
|
512
|
-
|
|
512
|
+
``fduration / 2`` seconds trimmed from each side.
|
|
513
513
|
"""
|
|
514
514
|
|
|
515
515
|
# figure out how much data we'll need to slice
|
ml4gw/transforms/__init__.py
CHANGED
|
@@ -5,6 +5,6 @@ from .scaler import ChannelWiseScaler
|
|
|
5
5
|
from .snr_rescaler import SnrRescaler
|
|
6
6
|
from .spectral import SpectralDensity
|
|
7
7
|
from .spectrogram import MultiResolutionSpectrogram
|
|
8
|
-
from .spline_interpolation import
|
|
8
|
+
from .spline_interpolation import SplineInterpolate1D, SplineInterpolate2D
|
|
9
9
|
from .waveforms import WaveformProjector, WaveformSampler
|
|
10
10
|
from .whitening import FixedWhiten, Whiten
|
ml4gw/transforms/iirfilter.py
CHANGED
|
@@ -9,7 +9,7 @@ class IIRFilter(torch.nn.Module):
|
|
|
9
9
|
r"""
|
|
10
10
|
IIR digital and analog filter design given order and critical points.
|
|
11
11
|
Design an Nth-order digital or analog filter and apply it to a signal.
|
|
12
|
-
Uses SciPy's
|
|
12
|
+
Uses SciPy's ``iirfilter`` function to create the filter coefficients.
|
|
13
13
|
https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirfilter.html
|
|
14
14
|
|
|
15
15
|
The forward call of this module accepts a batch tensor of shape
|
|
@@ -24,8 +24,8 @@ class IIRFilter(torch.nn.Module):
|
|
|
24
24
|
default, fs is 2 half-cycles/sample, so these are normalized
|
|
25
25
|
from 0 to 1, where 1 is the Nyquist frequency. (Wn is thus in
|
|
26
26
|
half-cycles / sample). For analog filters, Wn is an angular
|
|
27
|
-
frequency (e.g., rad/s). When Wn is a length-2 sequence
|
|
28
|
-
must be less than
|
|
27
|
+
frequency (e.g., rad/s). When Wn is a length-2 sequence,``Wn[0]``
|
|
28
|
+
must be less than ``Wn[1]``.
|
|
29
29
|
rp:
|
|
30
30
|
For Chebyshev and elliptic filters, provides the maximum ripple in
|
|
31
31
|
the passband. (dB)
|
ml4gw/transforms/pearson.py
CHANGED
|
@@ -8,20 +8,19 @@ from ..utils.slicing import unfold_windows
|
|
|
8
8
|
|
|
9
9
|
class ShiftedPearsonCorrelation(torch.nn.Module):
|
|
10
10
|
"""
|
|
11
|
-
Compute the
|
|
12
|
-
(https://en.wikipedia.org/wiki/Pearson_correlation_coefficient)
|
|
11
|
+
Compute the `Pearson correlation <https://en.wikipedia.org/wiki/Pearson_correlation_coefficient>`_
|
|
13
12
|
for two equal-length timeseries over a pre-defined number of time
|
|
14
13
|
shifts in each direction. Useful for when you want a
|
|
15
14
|
correlation, but not over every possible shift (i.e.
|
|
16
15
|
a convolution).
|
|
17
16
|
|
|
18
|
-
The number of dimensions of the second timeseries
|
|
17
|
+
The number of dimensions of the second timeseries ``y``
|
|
19
18
|
passed at call time should always be less than or equal
|
|
20
|
-
to the number of dimensions of the first timeseries
|
|
19
|
+
to the number of dimensions of the first timeseries ``x``,
|
|
21
20
|
and each dimension should match the corresponding one of
|
|
22
|
-
|
|
23
|
-
then
|
|
24
|
-
|
|
21
|
+
``x`` in reverse order (i.e. if ``x`` has shape ``(B, C, T)``
|
|
22
|
+
then ``y`` should either have shape ``(T,)``, ``(C, T)``, or
|
|
23
|
+
``(B, C, T)``).
|
|
25
24
|
|
|
26
25
|
Note that no windowing to either timeseries is applied
|
|
27
26
|
at call time. Users should do any requisite windowing
|
|
@@ -36,7 +35,7 @@ class ShiftedPearsonCorrelation(torch.nn.Module):
|
|
|
36
35
|
The maximum number of 1-step time shifts in
|
|
37
36
|
each direction over which to compute the
|
|
38
37
|
Pearson coefficient. Output shape will then
|
|
39
|
-
be
|
|
38
|
+
be ``(2 * max_shifts + 1, B, C)``.
|
|
40
39
|
"""
|
|
41
40
|
|
|
42
41
|
def __init__(self, max_shift: int) -> None:
|
ml4gw/transforms/qtransform.py
CHANGED
|
@@ -8,7 +8,7 @@ from jaxtyping import Float, Int
|
|
|
8
8
|
from torch import Tensor
|
|
9
9
|
|
|
10
10
|
from ..types import FrequencySeries1to3d, TimeSeries1to3d, TimeSeries3d
|
|
11
|
-
from .spline_interpolation import
|
|
11
|
+
from .spline_interpolation import SplineInterpolate1D
|
|
12
12
|
|
|
13
13
|
"""
|
|
14
14
|
All based on https://github.com/gwpy/gwpy/blob/v3.0.8/gwpy/signal/qtransform.py
|
|
@@ -22,7 +22,7 @@ class QTile(torch.nn.Module):
|
|
|
22
22
|
"""
|
|
23
23
|
Compute the row of Q-tiles for a single Q value and a single
|
|
24
24
|
frequency for a batch of multi-channel frequency series data.
|
|
25
|
-
Should really be called
|
|
25
|
+
Should really be called ``QRow``, but I want to match GWpy.
|
|
26
26
|
Input data should have three dimensions or fewer.
|
|
27
27
|
If fewer, dimensions will be added until the input is
|
|
28
28
|
three-dimensional.
|
|
@@ -112,17 +112,17 @@ class QTile(torch.nn.Module):
|
|
|
112
112
|
fseries:
|
|
113
113
|
Frequency series of data. Should correspond to data with
|
|
114
114
|
the duration and sample rate used to initialize this object.
|
|
115
|
-
Expected input shape is
|
|
115
|
+
Expected input shape is ``(B, C, F)``, where F is the number
|
|
116
116
|
of samples, C is the number of channels, and B is the number
|
|
117
117
|
of batches. If less than three-dimensional, axes will be
|
|
118
118
|
added.
|
|
119
119
|
norm:
|
|
120
120
|
The method of normalization. Options are "median", "mean", or
|
|
121
|
-
|
|
121
|
+
``None``.
|
|
122
122
|
|
|
123
123
|
Returns:
|
|
124
124
|
The row of Q-tiles for the given Q and frequency. Output is
|
|
125
|
-
three-dimensional:
|
|
125
|
+
three-dimensional: ``(B, C, T)``
|
|
126
126
|
"""
|
|
127
127
|
if len(fseries.shape) > 3:
|
|
128
128
|
raise ValueError("Input data has more than 3 dimensions")
|
|
@@ -164,7 +164,7 @@ class SingleQTransform(torch.nn.Module):
|
|
|
164
164
|
Sample rate of the data in Hz
|
|
165
165
|
spectrogram_shape:
|
|
166
166
|
The shape of the interpolated spectrogram, specified as
|
|
167
|
-
|
|
167
|
+
``(num_f_bins, num_t_bins)``. Because the
|
|
168
168
|
frequency spacing of the Q-tiles is in log-space, the frequency
|
|
169
169
|
interpolation is log-spaced as well.
|
|
170
170
|
q:
|
|
@@ -176,14 +176,14 @@ class SingleQTransform(torch.nn.Module):
|
|
|
176
176
|
mismatch:
|
|
177
177
|
The maximum fractional mismatch between neighboring tiles
|
|
178
178
|
interpolation_method:
|
|
179
|
-
The method by which to interpolate each
|
|
179
|
+
The method by which to interpolate each ``QTile`` to the specified
|
|
180
180
|
number of time and frequency bins. The acceptable values are
|
|
181
181
|
"bilinear", "bicubic", and "spline". The "bilinear" and "bicubic"
|
|
182
182
|
options will use PyTorch's built-in interpolation modes, while
|
|
183
183
|
"spline" will use the custom Torch-based implementation in
|
|
184
|
-
|
|
184
|
+
``ml4gw``, as PyTorch does not have spline-based intertpolation.
|
|
185
185
|
The "spline" mode is most similar to the results of GWpy's
|
|
186
|
-
Q-transform, which uses
|
|
186
|
+
Q-transform, which uses ``scipy`` to do spline interpolation.
|
|
187
187
|
However, it is also the slowest and most memory intensive due to
|
|
188
188
|
the matrix equation solving steps. Therefore, the default method
|
|
189
189
|
is "bicubic" as it produces the most similar results while
|
|
@@ -260,18 +260,15 @@ class SingleQTransform(torch.nn.Module):
|
|
|
260
260
|
)
|
|
261
261
|
self.qtile_interpolators = torch.nn.ModuleList(
|
|
262
262
|
[
|
|
263
|
-
|
|
263
|
+
SplineInterpolate1D(
|
|
264
264
|
kx=3,
|
|
265
265
|
x_in=torch.arange(0, self.duration, self.duration / tiles),
|
|
266
|
-
y_in=torch.arange(len(idx)),
|
|
267
266
|
x_out=t_out,
|
|
268
|
-
y_out=torch.arange(len(idx)),
|
|
269
267
|
)
|
|
270
|
-
for tiles
|
|
268
|
+
for tiles in unique_ntiles
|
|
271
269
|
]
|
|
272
270
|
)
|
|
273
271
|
|
|
274
|
-
t_in = t_out
|
|
275
272
|
f_in = self.freqs
|
|
276
273
|
f_out = torch.logspace(
|
|
277
274
|
math.log10(self.frange[0]),
|
|
@@ -279,19 +276,16 @@ class SingleQTransform(torch.nn.Module):
|
|
|
279
276
|
self.spectrogram_shape[0],
|
|
280
277
|
)
|
|
281
278
|
|
|
282
|
-
self.interpolator =
|
|
279
|
+
self.interpolator = SplineInterpolate1D(
|
|
283
280
|
kx=3,
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
y_in=f_in,
|
|
287
|
-
x_out=t_out,
|
|
288
|
-
y_out=f_out,
|
|
281
|
+
x_in=f_in,
|
|
282
|
+
x_out=f_out,
|
|
289
283
|
)
|
|
290
284
|
|
|
291
285
|
def get_freqs(self) -> Float[Tensor, " nfreq"]:
|
|
292
286
|
"""
|
|
293
287
|
Calculate the frequencies that will be used in this transform.
|
|
294
|
-
For each frequency, a
|
|
288
|
+
For each frequency, a ``QTile`` is created.
|
|
295
289
|
"""
|
|
296
290
|
minf, maxf = self.frange
|
|
297
291
|
fcum_mismatch = (
|
|
@@ -320,7 +314,7 @@ class SingleQTransform(torch.nn.Module):
|
|
|
320
314
|
be slow, so this isn't used yet.
|
|
321
315
|
|
|
322
316
|
Optionally, a pair of frequency values can be specified for
|
|
323
|
-
|
|
317
|
+
``fsearch_range`` to restrict the frequencies in which the maximum
|
|
324
318
|
energy value is sought.
|
|
325
319
|
"""
|
|
326
320
|
allowed_dimensions = ["both", "neither", "channel", "batch"]
|
|
@@ -360,7 +354,7 @@ class SingleQTransform(torch.nn.Module):
|
|
|
360
354
|
) -> None:
|
|
361
355
|
"""
|
|
362
356
|
Take the FFT of the input timeseries and calculate the transform
|
|
363
|
-
for each
|
|
357
|
+
for each ``QTile``
|
|
364
358
|
"""
|
|
365
359
|
# Computing the FFT with the same normalization and scaling as GWpy
|
|
366
360
|
X = torch.fft.rfft(X, norm="forward")
|
|
@@ -379,14 +373,15 @@ class SingleQTransform(torch.nn.Module):
|
|
|
379
373
|
]
|
|
380
374
|
time_interped = torch.cat(
|
|
381
375
|
[
|
|
382
|
-
|
|
383
|
-
for qtile,
|
|
376
|
+
qtile_interpolator(qtile)
|
|
377
|
+
for qtile, qtile_interpolator in zip(
|
|
384
378
|
qtiles, self.qtile_interpolators
|
|
385
379
|
)
|
|
386
380
|
],
|
|
387
381
|
dim=-2,
|
|
388
382
|
)
|
|
389
|
-
|
|
383
|
+
# Transpose because the final dimension gets interpolated
|
|
384
|
+
return self.interpolator(time_interped.mT).mT
|
|
390
385
|
num_f_bins, num_t_bins = self.spectrogram_shape
|
|
391
386
|
resampled = [
|
|
392
387
|
F.interpolate(
|
|
@@ -416,9 +411,9 @@ class SingleQTransform(torch.nn.Module):
|
|
|
416
411
|
X:
|
|
417
412
|
Time series of data. Should have the duration and sample rate
|
|
418
413
|
used to initialize this object. Expected input shape is
|
|
419
|
-
|
|
420
|
-
of channels, and B is the number of batches. If less
|
|
421
|
-
three-dimensional, axes will be added during Q-tile
|
|
414
|
+
``(B, C, T)``, where T is the number of samples, C is the
|
|
415
|
+
number of channels, and B is the number of batches. If less
|
|
416
|
+
than three-dimensional, axes will be added during Q-tile
|
|
422
417
|
computation.
|
|
423
418
|
norm:
|
|
424
419
|
The method of normalization used by each QTile
|
|
@@ -445,13 +440,13 @@ class QScan(torch.nn.Module):
|
|
|
445
440
|
Sample rate of the data in Hz
|
|
446
441
|
spectrogram_shape:
|
|
447
442
|
The shape of the interpolated spectrogram, specified as
|
|
448
|
-
|
|
443
|
+
``(num_f_bins, num_t_bins)``. Because the
|
|
449
444
|
frequency spacing of the Q-tiles is in log-space, the frequency
|
|
450
445
|
interpolation is log-spaced as well.
|
|
451
446
|
qrange:
|
|
452
447
|
The lower and upper values of Q to consider. The
|
|
453
448
|
actual values of Q used for the transforms are
|
|
454
|
-
determined by the
|
|
449
|
+
determined by the ``get_qs`` method
|
|
455
450
|
frange:
|
|
456
451
|
The lower and upper frequency limit to consider for
|
|
457
452
|
the transform. If unspecified, default values will
|
|
@@ -535,9 +530,9 @@ class QScan(torch.nn.Module):
|
|
|
535
530
|
X:
|
|
536
531
|
Time series of data. Should have the duration and sample rate
|
|
537
532
|
used to initialize this object. Expected input shape is
|
|
538
|
-
|
|
539
|
-
of channels, and B is the number of batches. If less
|
|
540
|
-
three-dimensional, axes will be added during Q-tile
|
|
533
|
+
``(B, C, T)``, where T is the number of samples, C is the
|
|
534
|
+
number of channels, and B is the number of batches. If less
|
|
535
|
+
than three-dimensional, axes will be added during Q-tile
|
|
541
536
|
computation.
|
|
542
537
|
fsearch_range:
|
|
543
538
|
The lower and upper frequency values within which to search
|
ml4gw/transforms/scaler.py
CHANGED
|
@@ -13,14 +13,14 @@ class ChannelWiseScaler(FittableTransform):
|
|
|
13
13
|
Scales timeseries channels by the mean and standard
|
|
14
14
|
deviation of the channels of the timeseries used to
|
|
15
15
|
fit the module. To reverse the scaling, provide the
|
|
16
|
-
|
|
16
|
+
``reverse=True`` keyword argument at call time.
|
|
17
17
|
By default, the scaling parameters are set to zero mean
|
|
18
18
|
and unit variance, amounting to an identity transform.
|
|
19
19
|
|
|
20
20
|
Args:
|
|
21
21
|
num_channels:
|
|
22
22
|
The number of channels of the target timeseries.
|
|
23
|
-
If left as
|
|
23
|
+
If left as ``None``, the timeseries will be assumed
|
|
24
24
|
to be 1D (single channel).
|
|
25
25
|
"""
|
|
26
26
|
|
|
@@ -42,8 +42,8 @@ class ChannelWiseScaler(FittableTransform):
|
|
|
42
42
|
"""Fit the scaling parameters to a timeseries
|
|
43
43
|
|
|
44
44
|
Computes the channel-wise mean and standard deviation
|
|
45
|
-
of the timeseries
|
|
46
|
-
|
|
45
|
+
of the timeseries ``X`` and sets these values to the
|
|
46
|
+
``mean`` and ``std`` parameters of the scaler.
|
|
47
47
|
"""
|
|
48
48
|
|
|
49
49
|
if X.ndim == 1:
|