ml4gw 0.5.1__tar.gz → 0.6.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ml4gw might be problematic. Click here for more details.
- {ml4gw-0.5.1 → ml4gw-0.6.0}/PKG-INFO +1 -1
- {ml4gw-0.5.1 → ml4gw-0.6.0}/ml4gw/constants.py +10 -19
- {ml4gw-0.5.1 → ml4gw-0.6.0}/ml4gw/spectral.py +1 -1
- {ml4gw-0.5.1 → ml4gw-0.6.0}/ml4gw/transforms/__init__.py +1 -0
- {ml4gw-0.5.1 → ml4gw-0.6.0}/ml4gw/transforms/qtransform.py +134 -42
- {ml4gw-0.5.1 → ml4gw-0.6.0}/ml4gw/transforms/scaler.py +4 -2
- ml4gw-0.6.0/ml4gw/transforms/spline_interpolation.py +370 -0
- ml4gw-0.6.0/ml4gw/waveforms/__init__.py +2 -0
- ml4gw-0.6.0/ml4gw/waveforms/adhoc/__init__.py +2 -0
- {ml4gw-0.5.1/ml4gw/waveforms → ml4gw-0.6.0/ml4gw/waveforms/cbc}/__init__.py +0 -2
- {ml4gw-0.5.1/ml4gw/waveforms → ml4gw-0.6.0/ml4gw/waveforms/cbc}/phenom_p.py +36 -42
- ml4gw-0.6.0/ml4gw/waveforms/conversion.py +187 -0
- {ml4gw-0.5.1 → ml4gw-0.6.0}/pyproject.toml +1 -1
- {ml4gw-0.5.1 → ml4gw-0.6.0}/README.md +0 -0
- {ml4gw-0.5.1 → ml4gw-0.6.0}/ml4gw/__init__.py +0 -0
- {ml4gw-0.5.1 → ml4gw-0.6.0}/ml4gw/augmentations.py +0 -0
- {ml4gw-0.5.1 → ml4gw-0.6.0}/ml4gw/dataloading/__init__.py +0 -0
- {ml4gw-0.5.1 → ml4gw-0.6.0}/ml4gw/dataloading/chunked_dataset.py +0 -0
- {ml4gw-0.5.1 → ml4gw-0.6.0}/ml4gw/dataloading/hdf5_dataset.py +0 -0
- {ml4gw-0.5.1 → ml4gw-0.6.0}/ml4gw/dataloading/in_memory_dataset.py +0 -0
- {ml4gw-0.5.1 → ml4gw-0.6.0}/ml4gw/distributions.py +0 -0
- {ml4gw-0.5.1 → ml4gw-0.6.0}/ml4gw/gw.py +0 -0
- {ml4gw-0.5.1 → ml4gw-0.6.0}/ml4gw/nn/__init__.py +0 -0
- {ml4gw-0.5.1 → ml4gw-0.6.0}/ml4gw/nn/autoencoder/__init__.py +0 -0
- {ml4gw-0.5.1 → ml4gw-0.6.0}/ml4gw/nn/autoencoder/base.py +0 -0
- {ml4gw-0.5.1 → ml4gw-0.6.0}/ml4gw/nn/autoencoder/convolutional.py +0 -0
- {ml4gw-0.5.1 → ml4gw-0.6.0}/ml4gw/nn/autoencoder/skip_connection.py +0 -0
- {ml4gw-0.5.1 → ml4gw-0.6.0}/ml4gw/nn/autoencoder/utils.py +0 -0
- {ml4gw-0.5.1 → ml4gw-0.6.0}/ml4gw/nn/norm.py +0 -0
- {ml4gw-0.5.1 → ml4gw-0.6.0}/ml4gw/nn/resnet/__init__.py +0 -0
- {ml4gw-0.5.1 → ml4gw-0.6.0}/ml4gw/nn/resnet/resnet_1d.py +0 -0
- {ml4gw-0.5.1 → ml4gw-0.6.0}/ml4gw/nn/resnet/resnet_2d.py +0 -0
- {ml4gw-0.5.1 → ml4gw-0.6.0}/ml4gw/nn/streaming/__init__.py +0 -0
- {ml4gw-0.5.1 → ml4gw-0.6.0}/ml4gw/nn/streaming/online_average.py +0 -0
- {ml4gw-0.5.1 → ml4gw-0.6.0}/ml4gw/nn/streaming/snapshotter.py +0 -0
- {ml4gw-0.5.1 → ml4gw-0.6.0}/ml4gw/transforms/pearson.py +0 -0
- {ml4gw-0.5.1 → ml4gw-0.6.0}/ml4gw/transforms/snr_rescaler.py +0 -0
- {ml4gw-0.5.1 → ml4gw-0.6.0}/ml4gw/transforms/spectral.py +0 -0
- {ml4gw-0.5.1 → ml4gw-0.6.0}/ml4gw/transforms/spectrogram.py +0 -0
- {ml4gw-0.5.1 → ml4gw-0.6.0}/ml4gw/transforms/transform.py +0 -0
- {ml4gw-0.5.1 → ml4gw-0.6.0}/ml4gw/transforms/waveforms.py +0 -0
- {ml4gw-0.5.1 → ml4gw-0.6.0}/ml4gw/transforms/whitening.py +0 -0
- {ml4gw-0.5.1 → ml4gw-0.6.0}/ml4gw/types.py +0 -0
- {ml4gw-0.5.1 → ml4gw-0.6.0}/ml4gw/utils/interferometer.py +0 -0
- {ml4gw-0.5.1 → ml4gw-0.6.0}/ml4gw/utils/slicing.py +0 -0
- {ml4gw-0.5.1/ml4gw/waveforms → ml4gw-0.6.0/ml4gw/waveforms/adhoc}/ringdown.py +0 -0
- {ml4gw-0.5.1/ml4gw/waveforms → ml4gw-0.6.0/ml4gw/waveforms/adhoc}/sine_gaussian.py +0 -0
- {ml4gw-0.5.1/ml4gw/waveforms → ml4gw-0.6.0/ml4gw/waveforms/cbc}/phenom_d.py +0 -0
- {ml4gw-0.5.1/ml4gw/waveforms → ml4gw-0.6.0/ml4gw/waveforms/cbc}/phenom_d_data.py +0 -0
- {ml4gw-0.5.1/ml4gw/waveforms → ml4gw-0.6.0/ml4gw/waveforms/cbc}/taylorf2.py +0 -0
- {ml4gw-0.5.1 → ml4gw-0.6.0}/ml4gw/waveforms/generator.py +0 -0
|
@@ -4,42 +4,33 @@ Various constants, all in SI units.
|
|
|
4
4
|
|
|
5
5
|
EulerGamma = 0.577215664901532860606512090082402431
|
|
6
6
|
|
|
7
|
+
# solar mass
|
|
7
8
|
MSUN = 1.988409902147041637325262574352366540e30 # kg
|
|
8
|
-
"""Solar mass"""
|
|
9
9
|
|
|
10
|
+
# Geometrized nominal solar mass, m
|
|
10
11
|
MRSUN = 1.476625038050124729627979840144936351e3
|
|
11
|
-
"""Geometrized nominal solar mass, m"""
|
|
12
12
|
|
|
13
|
+
# Newton's gravitational constant
|
|
13
14
|
G = 6.67430e-11 # m^3 / kg / s^2
|
|
14
|
-
"""Newton's gravitational constant"""
|
|
15
15
|
|
|
16
|
+
# Speed of light
|
|
16
17
|
C = 299792458.0 # m / s
|
|
17
|
-
"""Speed of light"""
|
|
18
18
|
|
|
19
|
-
|
|
19
|
+
# pi and 2pi
|
|
20
20
|
PI = 3.141592653589793238462643383279502884
|
|
21
|
-
|
|
22
21
|
TWO_PI = 6.283185307179586476925286766559005768
|
|
23
22
|
|
|
23
|
+
# G MSUN / C^3 in seconds
|
|
24
24
|
gt = G * MSUN / (C**3.0)
|
|
25
|
-
"""
|
|
26
|
-
G MSUN / C^3 in seconds
|
|
27
|
-
"""
|
|
28
25
|
|
|
26
|
+
# 1 solar mass in seconds. Same value as lal.MTSUN_SI
|
|
29
27
|
MTSUN_SI = 4.925490947641266978197229498498379006e-6
|
|
30
|
-
"""1 solar mass in seconds. Same value as lal.MTSUN_SI"""
|
|
31
28
|
|
|
29
|
+
# Meters per Mpc.
|
|
32
30
|
m_per_Mpc = 3.085677581491367278913937957796471611e22
|
|
33
|
-
"""
|
|
34
|
-
Meters per Mpc.
|
|
35
|
-
"""
|
|
36
31
|
|
|
32
|
+
# 1 Mpc in seconds.
|
|
37
33
|
MPC_SEC = m_per_Mpc / C
|
|
38
|
-
"""
|
|
39
|
-
1 Mpc in seconds.
|
|
40
|
-
"""
|
|
41
34
|
|
|
35
|
+
# Speed of light in vacuum (:math:`c`), in gigaparsecs per second
|
|
42
36
|
clightGpc = C / 3.0856778570831e22
|
|
43
|
-
"""
|
|
44
|
-
Speed of light in vacuum (:math:`c`), in gigaparsecs per second
|
|
45
|
-
"""
|
|
@@ -441,7 +441,7 @@ def normalize_by_psd(
|
|
|
441
441
|
|
|
442
442
|
# convert back to the time domain and normalize
|
|
443
443
|
# TODO: what's this normalization factor?
|
|
444
|
-
X = torch.fft.irfft(X_tilde, norm="forward", dim=-1)
|
|
444
|
+
X = torch.fft.irfft(X_tilde, n=X.shape[-1], norm="forward", dim=-1)
|
|
445
445
|
X = X.float() / sample_rate**0.5
|
|
446
446
|
|
|
447
447
|
# slice off corrupted data at edges of kernel
|
|
@@ -4,5 +4,6 @@ from .scaler import ChannelWiseScaler
|
|
|
4
4
|
from .snr_rescaler import SnrRescaler
|
|
5
5
|
from .spectral import SpectralDensity
|
|
6
6
|
from .spectrogram import MultiResolutionSpectrogram
|
|
7
|
+
from .spline_interpolation import SplineInterpolate
|
|
7
8
|
from .waveforms import WaveformProjector, WaveformSampler
|
|
8
9
|
from .whitening import FixedWhiten, Whiten
|
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
import math
|
|
2
|
-
|
|
2
|
+
import warnings
|
|
3
|
+
from typing import List, Tuple
|
|
3
4
|
|
|
4
5
|
import torch
|
|
5
6
|
import torch.nn.functional as F
|
|
6
7
|
from jaxtyping import Float, Int
|
|
7
8
|
from torch import Tensor
|
|
8
9
|
|
|
10
|
+
from ml4gw.transforms.spline_interpolation import SplineInterpolate
|
|
9
11
|
from ml4gw.types import FrequencySeries1to3d, TimeSeries1to3d, TimeSeries3d
|
|
10
12
|
|
|
11
13
|
"""
|
|
@@ -38,7 +40,6 @@ class QTile(torch.nn.Module):
|
|
|
38
40
|
mismatch:
|
|
39
41
|
The maximum fractional mismatch between neighboring tiles
|
|
40
42
|
|
|
41
|
-
|
|
42
43
|
"""
|
|
43
44
|
|
|
44
45
|
def __init__(
|
|
@@ -100,7 +101,9 @@ class QTile(torch.nn.Module):
|
|
|
100
101
|
).type(torch.long)
|
|
101
102
|
|
|
102
103
|
def forward(
|
|
103
|
-
self,
|
|
104
|
+
self,
|
|
105
|
+
fseries: FrequencySeries1to3d,
|
|
106
|
+
norm: str = "median",
|
|
104
107
|
) -> TimeSeries1to3d:
|
|
105
108
|
"""
|
|
106
109
|
Compute the transform for this row
|
|
@@ -144,7 +147,7 @@ class QTile(torch.nn.Module):
|
|
|
144
147
|
energy /= means
|
|
145
148
|
else:
|
|
146
149
|
raise ValueError("Invalid normalisation %r" % norm)
|
|
147
|
-
|
|
150
|
+
energy = energy.type(torch.float32)
|
|
148
151
|
return energy
|
|
149
152
|
|
|
150
153
|
|
|
@@ -172,6 +175,19 @@ class SingleQTransform(torch.nn.Module):
|
|
|
172
175
|
be chosen based on q, sample_rate, and duration
|
|
173
176
|
mismatch:
|
|
174
177
|
The maximum fractional mismatch between neighboring tiles
|
|
178
|
+
interpolation_method:
|
|
179
|
+
The method by which to interpolate each `QTile` to the specified
|
|
180
|
+
number of time and frequency bins. The acceptable values are
|
|
181
|
+
"bilinear", "bicubic", and "spline". The "bilinear" and "bicubic"
|
|
182
|
+
options will use PyTorch's built-in interpolation modes, while
|
|
183
|
+
"spline" will use the custom Torch-based implementation in
|
|
184
|
+
`ml4gw`, as PyTorch does not have spline-based intertpolation.
|
|
185
|
+
The "spline" mode is most similar to the results of GWpy's
|
|
186
|
+
Q-transform, which uses `scipy` to do spline interpolation.
|
|
187
|
+
However, it is also the slowest and most memory intensive due to
|
|
188
|
+
the matrix equation solving steps. Therefore, the default method
|
|
189
|
+
is "bicubic" as it produces the most similar results while
|
|
190
|
+
optimizing for computing performance.
|
|
175
191
|
"""
|
|
176
192
|
|
|
177
193
|
def __init__(
|
|
@@ -182,6 +198,7 @@ class SingleQTransform(torch.nn.Module):
|
|
|
182
198
|
q: float = 12,
|
|
183
199
|
frange: List[float] = [0, torch.inf],
|
|
184
200
|
mismatch: float = 0.2,
|
|
201
|
+
interpolation_method: str = "bicubic",
|
|
185
202
|
) -> None:
|
|
186
203
|
super().__init__()
|
|
187
204
|
self.q = q
|
|
@@ -190,20 +207,87 @@ class SingleQTransform(torch.nn.Module):
|
|
|
190
207
|
self.duration = duration
|
|
191
208
|
self.mismatch = mismatch
|
|
192
209
|
|
|
210
|
+
# If q is too large, the minimum of the frange computed
|
|
211
|
+
# below will be larger than the maximum
|
|
212
|
+
max_q = torch.pi * duration * sample_rate / 50 - 11 ** (0.5)
|
|
213
|
+
if q >= max_q:
|
|
214
|
+
raise ValueError(
|
|
215
|
+
"The given q value is too large for the given duration and "
|
|
216
|
+
f"sample rate. The maximum allowable value is {max_q}"
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
if interpolation_method not in ["bilinear", "bicubic", "spline"]:
|
|
220
|
+
raise ValueError(
|
|
221
|
+
"Interpolation method must be either 'bilinear', 'bicubic', "
|
|
222
|
+
f"or 'spline'; got {interpolation_method}"
|
|
223
|
+
)
|
|
224
|
+
self.interpolation_method = interpolation_method
|
|
225
|
+
|
|
193
226
|
qprime = self.q / 11 ** (1 / 2.0)
|
|
194
227
|
if self.frange[0] <= 0: # set non-zero lower frequency
|
|
195
228
|
self.frange[0] = 50 * self.q / (2 * torch.pi * duration)
|
|
196
229
|
if math.isinf(self.frange[1]): # set non-infinite upper frequency
|
|
197
230
|
self.frange[1] = sample_rate / 2 / (1 + 1 / qprime)
|
|
231
|
+
|
|
198
232
|
self.freqs = self.get_freqs()
|
|
199
233
|
self.qtile_transforms = torch.nn.ModuleList(
|
|
200
234
|
[
|
|
201
|
-
QTile(
|
|
235
|
+
QTile(
|
|
236
|
+
q=self.q,
|
|
237
|
+
frequency=freq,
|
|
238
|
+
duration=self.duration,
|
|
239
|
+
sample_rate=sample_rate,
|
|
240
|
+
mismatch=self.mismatch,
|
|
241
|
+
)
|
|
202
242
|
for freq in self.freqs
|
|
203
243
|
]
|
|
204
244
|
)
|
|
205
245
|
self.qtiles = None
|
|
206
246
|
|
|
247
|
+
if self.interpolation_method == "spline":
|
|
248
|
+
self._set_up_spline_interp()
|
|
249
|
+
|
|
250
|
+
def _set_up_spline_interp(self):
|
|
251
|
+
ntiles = [qtile.ntiles() for qtile in self.qtile_transforms]
|
|
252
|
+
# For efficiency, we'll stack all qtiles of the same length before
|
|
253
|
+
# interpolating, so we need to figure out which those are
|
|
254
|
+
unique_ntiles = sorted(list(set(ntiles)))
|
|
255
|
+
idx = torch.arange(len(ntiles))
|
|
256
|
+
self.stack_idx = [idx[Tensor(ntiles) == n] for n in unique_ntiles]
|
|
257
|
+
|
|
258
|
+
t_out = torch.arange(
|
|
259
|
+
0, self.duration, self.duration / self.spectrogram_shape[1]
|
|
260
|
+
)
|
|
261
|
+
self.qtile_interpolators = torch.nn.ModuleList(
|
|
262
|
+
[
|
|
263
|
+
SplineInterpolate(
|
|
264
|
+
kx=3,
|
|
265
|
+
x_in=torch.arange(0, self.duration, self.duration / tiles),
|
|
266
|
+
y_in=torch.arange(len(idx)),
|
|
267
|
+
x_out=t_out,
|
|
268
|
+
y_out=torch.arange(len(idx)),
|
|
269
|
+
)
|
|
270
|
+
for tiles, idx in zip(unique_ntiles, self.stack_idx)
|
|
271
|
+
]
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
t_in = t_out
|
|
275
|
+
f_in = self.freqs
|
|
276
|
+
f_out = torch.logspace(
|
|
277
|
+
math.log10(self.frange[0]),
|
|
278
|
+
math.log10(self.frange[-1]),
|
|
279
|
+
self.spectrogram_shape[0],
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
self.interpolator = SplineInterpolate(
|
|
283
|
+
kx=3,
|
|
284
|
+
ky=3,
|
|
285
|
+
x_in=t_in,
|
|
286
|
+
y_in=f_in,
|
|
287
|
+
x_out=t_out,
|
|
288
|
+
y_out=f_out,
|
|
289
|
+
)
|
|
290
|
+
|
|
207
291
|
def get_freqs(self) -> Float[Tensor, " nfreq"]:
|
|
208
292
|
"""
|
|
209
293
|
Calculate the frequencies that will be used in this transform.
|
|
@@ -220,7 +304,8 @@ class SingleQTransform(torch.nn.Module):
|
|
|
220
304
|
|
|
221
305
|
freq_base = math.exp(2 / ((2 + self.q**2) ** (1 / 2.0)) * fstep)
|
|
222
306
|
freqs = torch.Tensor([freq_base ** (i + 0.5) for i in range(nfreq)])
|
|
223
|
-
freqs
|
|
307
|
+
# Cast freqs to float64 to avoid off-by-ones from rounding
|
|
308
|
+
freqs = (minf * freqs.double() // fstepmin) * fstepmin
|
|
224
309
|
return torch.unique(freqs)
|
|
225
310
|
|
|
226
311
|
def get_max_energy(
|
|
@@ -268,7 +353,11 @@ class SingleQTransform(torch.nn.Module):
|
|
|
268
353
|
if dimension == "batch":
|
|
269
354
|
return torch.max(max_across_ft, dim=-1).values
|
|
270
355
|
|
|
271
|
-
def compute_qtiles(
|
|
356
|
+
def compute_qtiles(
|
|
357
|
+
self,
|
|
358
|
+
X: TimeSeries1to3d,
|
|
359
|
+
norm: str = "median",
|
|
360
|
+
) -> None:
|
|
272
361
|
"""
|
|
273
362
|
Take the FFT of the input timeseries and calculate the transform
|
|
274
363
|
for each `QTile`
|
|
@@ -278,28 +367,40 @@ class SingleQTransform(torch.nn.Module):
|
|
|
278
367
|
X[..., 1:] *= 2
|
|
279
368
|
self.qtiles = [qtile(X, norm) for qtile in self.qtile_transforms]
|
|
280
369
|
|
|
281
|
-
def interpolate(self
|
|
282
|
-
"""
|
|
283
|
-
Interpolate each `QTile` to the specified number of time and
|
|
284
|
-
frequency bins. Note that PyTorch does not have the same
|
|
285
|
-
interpolation methods that GWpy uses, and so the interpolated
|
|
286
|
-
spectrograms will be different even though the uninterpolated
|
|
287
|
-
values match. The `bicubic` interpolation method is used as
|
|
288
|
-
it seems to match GWpy most closely.
|
|
289
|
-
"""
|
|
370
|
+
def interpolate(self) -> TimeSeries3d:
|
|
290
371
|
if self.qtiles is None:
|
|
291
372
|
raise RuntimeError(
|
|
292
373
|
"Q-tiles must first be computed with .compute_qtiles()"
|
|
293
374
|
)
|
|
375
|
+
if self.interpolation_method == "spline":
|
|
376
|
+
qtiles = [
|
|
377
|
+
torch.stack([self.qtiles[i] for i in idx], dim=-2)
|
|
378
|
+
for idx in self.stack_idx
|
|
379
|
+
]
|
|
380
|
+
time_interped = torch.cat(
|
|
381
|
+
[
|
|
382
|
+
interpolator(qtile)
|
|
383
|
+
for qtile, interpolator in zip(
|
|
384
|
+
qtiles, self.qtile_interpolators
|
|
385
|
+
)
|
|
386
|
+
],
|
|
387
|
+
dim=-2,
|
|
388
|
+
)
|
|
389
|
+
return self.interpolator(time_interped)
|
|
390
|
+
num_f_bins, num_t_bins = self.spectrogram_shape
|
|
294
391
|
resampled = [
|
|
295
392
|
F.interpolate(
|
|
296
|
-
qtile[None],
|
|
393
|
+
qtile[None],
|
|
394
|
+
(qtile.shape[-2], num_t_bins),
|
|
395
|
+
mode=self.interpolation_method,
|
|
297
396
|
)
|
|
298
397
|
for qtile in self.qtiles
|
|
299
398
|
]
|
|
300
399
|
resampled = torch.stack(resampled, dim=-2)
|
|
301
400
|
resampled = F.interpolate(
|
|
302
|
-
resampled[0],
|
|
401
|
+
resampled[0],
|
|
402
|
+
(num_f_bins, num_t_bins),
|
|
403
|
+
mode=self.interpolation_method,
|
|
303
404
|
)
|
|
304
405
|
return torch.squeeze(resampled)
|
|
305
406
|
|
|
@@ -307,7 +408,6 @@ class SingleQTransform(torch.nn.Module):
|
|
|
307
408
|
self,
|
|
308
409
|
X: TimeSeries1to3d,
|
|
309
410
|
norm: str = "median",
|
|
310
|
-
spectrogram_shape: Optional[Tuple[int, int]] = None,
|
|
311
411
|
):
|
|
312
412
|
"""
|
|
313
413
|
Compute the Q-tiles and interpolate
|
|
@@ -321,24 +421,15 @@ class SingleQTransform(torch.nn.Module):
|
|
|
321
421
|
three-dimensional, axes will be added during Q-tile
|
|
322
422
|
computation.
|
|
323
423
|
norm:
|
|
324
|
-
The method of
|
|
325
|
-
spectrogram_shape:
|
|
326
|
-
The shape of the interpolated spectrogram, specified as
|
|
327
|
-
`(num_f_bins, num_t_bins)`. Because the
|
|
328
|
-
frequency spacing of the Q-tiles is in log-space, the frequency
|
|
329
|
-
interpolation is log-spaced as well. If not given, the shape
|
|
330
|
-
used to initialize the transform will be used.
|
|
424
|
+
The method of normalization used by each QTile
|
|
331
425
|
|
|
332
426
|
Returns:
|
|
333
427
|
The interpolated Q-transform for the batch of data. Output will
|
|
334
428
|
have one more dimension than the input
|
|
335
429
|
"""
|
|
336
430
|
|
|
337
|
-
if spectrogram_shape is None:
|
|
338
|
-
spectrogram_shape = self.spectrogram_shape
|
|
339
|
-
num_f_bins, num_t_bins = spectrogram_shape
|
|
340
431
|
self.compute_qtiles(X, norm)
|
|
341
|
-
return self.interpolate(
|
|
432
|
+
return self.interpolate()
|
|
342
433
|
|
|
343
434
|
|
|
344
435
|
class QScan(torch.nn.Module):
|
|
@@ -376,14 +467,22 @@ class QScan(torch.nn.Module):
|
|
|
376
467
|
spectrogram_shape: Tuple[int, int],
|
|
377
468
|
qrange: List[float] = [4, 64],
|
|
378
469
|
frange: List[float] = [0, torch.inf],
|
|
470
|
+
interpolation_method="bicubic",
|
|
379
471
|
mismatch: float = 0.2,
|
|
380
472
|
) -> None:
|
|
381
473
|
super().__init__()
|
|
382
474
|
self.qrange = qrange
|
|
383
475
|
self.mismatch = mismatch
|
|
384
|
-
self.qs = self.get_qs()
|
|
385
476
|
self.frange = frange
|
|
386
477
|
self.spectrogram_shape = spectrogram_shape
|
|
478
|
+
max_q = torch.pi * duration * sample_rate / 50 - 11 ** (0.5)
|
|
479
|
+
self.qs = self.get_qs()
|
|
480
|
+
if self.qs[-1] >= max_q:
|
|
481
|
+
warnings.warn(
|
|
482
|
+
"Some Q values exceed the maximum allowable Q value of "
|
|
483
|
+
f"{max_q}. The list of Q values to be tested in this "
|
|
484
|
+
"scan will be truncated to avoid those values."
|
|
485
|
+
)
|
|
387
486
|
|
|
388
487
|
# Deliberately doing something different from GWpy here.
|
|
389
488
|
# Their final frange is the intersection of the frange
|
|
@@ -397,9 +496,11 @@ class QScan(torch.nn.Module):
|
|
|
397
496
|
spectrogram_shape=spectrogram_shape,
|
|
398
497
|
q=q,
|
|
399
498
|
frange=self.frange.copy(),
|
|
499
|
+
interpolation_method=interpolation_method,
|
|
400
500
|
mismatch=self.mismatch,
|
|
401
501
|
)
|
|
402
502
|
for q in self.qs
|
|
503
|
+
if q < max_q
|
|
403
504
|
]
|
|
404
505
|
)
|
|
405
506
|
|
|
@@ -415,6 +516,7 @@ class QScan(torch.nn.Module):
|
|
|
415
516
|
self.qrange[0] * math.exp(2 ** (1 / 2.0) * dq * (i + 0.5))
|
|
416
517
|
for i in range(nplanes)
|
|
417
518
|
]
|
|
519
|
+
|
|
418
520
|
return qs
|
|
419
521
|
|
|
420
522
|
def forward(
|
|
@@ -422,7 +524,6 @@ class QScan(torch.nn.Module):
|
|
|
422
524
|
X: TimeSeries1to3d,
|
|
423
525
|
fsearch_range: List[float] = None,
|
|
424
526
|
norm: str = "median",
|
|
425
|
-
spectrogram_shape: Optional[Tuple[int, int]] = None,
|
|
426
527
|
):
|
|
427
528
|
"""
|
|
428
529
|
Compute the set of QTiles for each Q transform and determine which
|
|
@@ -442,12 +543,6 @@ class QScan(torch.nn.Module):
|
|
|
442
543
|
for the maximum energy
|
|
443
544
|
norm:
|
|
444
545
|
The method of interpolation used by each QTile
|
|
445
|
-
spectrogram_shape:
|
|
446
|
-
The shape of the interpolated spectrogram, specified as
|
|
447
|
-
`(num_f_bins, num_t_bins)`. Because the
|
|
448
|
-
frequency spacing of the Q-tiles is in log-space, the frequency
|
|
449
|
-
interpolation is log-spaced as well. If not given, the shape
|
|
450
|
-
used to initialize the transform will be used.
|
|
451
546
|
|
|
452
547
|
Returns:
|
|
453
548
|
An interpolated Q-transform for the batch of data. Output will
|
|
@@ -463,7 +558,4 @@ class QScan(torch.nn.Module):
|
|
|
463
558
|
]
|
|
464
559
|
)
|
|
465
560
|
)
|
|
466
|
-
|
|
467
|
-
spectrogram_shape = self.spectrogram_shape
|
|
468
|
-
num_f_bins, num_t_bins = spectrogram_shape
|
|
469
|
-
return self.q_transforms[idx].interpolate(num_f_bins, num_t_bins)
|
|
561
|
+
return self.q_transforms[idx].interpolate()
|
|
@@ -36,7 +36,9 @@ class ChannelWiseScaler(FittableTransform):
|
|
|
36
36
|
self.register_buffer("mean", mean)
|
|
37
37
|
self.register_buffer("std", std)
|
|
38
38
|
|
|
39
|
-
def fit(
|
|
39
|
+
def fit(
|
|
40
|
+
self, X: Float[Tensor, "... time"], std_reg: Optional[float] = 0.0
|
|
41
|
+
) -> None:
|
|
40
42
|
"""Fit the scaling parameters to a timeseries
|
|
41
43
|
|
|
42
44
|
Computes the channel-wise mean and standard deviation
|
|
@@ -59,7 +61,7 @@ class ChannelWiseScaler(FittableTransform):
|
|
|
59
61
|
"Can't fit channel wise mean and standard deviation "
|
|
60
62
|
"from tensor of shape {}".format(X.shape)
|
|
61
63
|
)
|
|
62
|
-
|
|
64
|
+
std += std_reg * torch.ones_like(std)
|
|
63
65
|
super().build(mean=mean, std=std)
|
|
64
66
|
|
|
65
67
|
def forward(
|