ml4gw 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ml4gw might be problematic. Click here for more details.
- ml4gw/constants.py +45 -0
- ml4gw/dataloading/in_memory_dataset.py +18 -32
- ml4gw/distributions.py +124 -76
- ml4gw/nn/norm.py +6 -0
- ml4gw/transforms/__init__.py +1 -0
- ml4gw/transforms/qtransform.py +463 -0
- ml4gw/transforms/spectral.py +17 -5
- ml4gw/utils/interferometer.py +6 -0
- ml4gw/waveforms/__init__.py +2 -0
- ml4gw/waveforms/generator.py +1 -1
- ml4gw/waveforms/phenom_d.py +1334 -1252
- ml4gw/waveforms/phenom_p.py +779 -0
- ml4gw/waveforms/ringdown.py +110 -0
- ml4gw/waveforms/sine_gaussian.py +4 -5
- ml4gw/waveforms/taylorf2.py +297 -278
- {ml4gw-0.4.1.dist-info → ml4gw-0.5.0.dist-info}/METADATA +2 -1
- {ml4gw-0.4.1.dist-info → ml4gw-0.5.0.dist-info}/RECORD +18 -14
- {ml4gw-0.4.1.dist-info → ml4gw-0.5.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,463 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import List, Optional, Tuple
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
All based on https://github.com/gwpy/gwpy/blob/v3.0.8/gwpy/signal/qtransform.py
|
|
9
|
+
The methods, names, and descriptions come almost entirely from GWpy.
|
|
10
|
+
This code allows the Q-transform to be performed on batches of multi-channel
|
|
11
|
+
input on GPU.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class QTile(torch.nn.Module):
|
|
16
|
+
"""
|
|
17
|
+
Compute the row of Q-tiles for a single Q value and a single
|
|
18
|
+
frequency for a batch of multi-channel frequency series data.
|
|
19
|
+
Should really be called `QRow`, but I want to match GWpy.
|
|
20
|
+
Input data should have three dimensions or fewer.
|
|
21
|
+
If fewer, dimensions will be added until the input is
|
|
22
|
+
three-dimensional.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
q:
|
|
26
|
+
The Q value to use in computing the Q tile
|
|
27
|
+
frequency:
|
|
28
|
+
The frequency for which to compute the Q tile in Hz
|
|
29
|
+
duration:
|
|
30
|
+
The length of time in seconds that the input frequency
|
|
31
|
+
series represents
|
|
32
|
+
sample_rate:
|
|
33
|
+
The sample rate of the original time series in Hz
|
|
34
|
+
mismatch:
|
|
35
|
+
The maximum fractional mismatch between neighboring tiles
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
q: float,
|
|
43
|
+
frequency: float,
|
|
44
|
+
duration: float,
|
|
45
|
+
sample_rate: float,
|
|
46
|
+
mismatch: float,
|
|
47
|
+
):
|
|
48
|
+
super().__init__()
|
|
49
|
+
self.mismatch = mismatch
|
|
50
|
+
self.q = q
|
|
51
|
+
self.deltam = torch.tensor(2 * (self.mismatch / 3.0) ** (1 / 2.0))
|
|
52
|
+
self.qprime = self.q / 11 ** (1 / 2.0)
|
|
53
|
+
self.frequency = frequency
|
|
54
|
+
self.duration = duration
|
|
55
|
+
self.sample_rate = sample_rate
|
|
56
|
+
|
|
57
|
+
self.windowsize = (
|
|
58
|
+
2 * int(self.frequency / self.qprime * self.duration) + 1
|
|
59
|
+
)
|
|
60
|
+
pad = self.ntiles() - self.windowsize
|
|
61
|
+
padding = torch.Tensor((int((pad - 1) / 2.0), int((pad + 1) / 2.0)))
|
|
62
|
+
self.register_buffer("padding", padding)
|
|
63
|
+
self.register_buffer("indices", self.get_data_indices())
|
|
64
|
+
self.register_buffer("window", self.get_window())
|
|
65
|
+
|
|
66
|
+
def ntiles(self):
|
|
67
|
+
"""
|
|
68
|
+
Number of tiles in this frequency row
|
|
69
|
+
"""
|
|
70
|
+
tcum_mismatch = self.duration * 2 * torch.pi * self.frequency / self.q
|
|
71
|
+
return int(2 ** torch.ceil(torch.log2(tcum_mismatch / self.deltam)))
|
|
72
|
+
|
|
73
|
+
def _get_indices(self):
|
|
74
|
+
half = int((self.windowsize - 1) / 2)
|
|
75
|
+
return torch.arange(-half, half + 1)
|
|
76
|
+
|
|
77
|
+
def get_window(self):
|
|
78
|
+
"""
|
|
79
|
+
Generate the bi-square window for this row
|
|
80
|
+
"""
|
|
81
|
+
wfrequencies = self._get_indices() / self.duration
|
|
82
|
+
xfrequencies = wfrequencies * self.qprime / self.frequency
|
|
83
|
+
norm = (
|
|
84
|
+
self.ntiles()
|
|
85
|
+
/ (self.duration * self.sample_rate)
|
|
86
|
+
* (315 * self.qprime / (128 * self.frequency)) ** (1 / 2.0)
|
|
87
|
+
)
|
|
88
|
+
return torch.Tensor((1 - xfrequencies**2) ** 2 * norm)
|
|
89
|
+
|
|
90
|
+
def get_data_indices(self):
|
|
91
|
+
"""
|
|
92
|
+
Get the index array of relevant frequencies for this row
|
|
93
|
+
"""
|
|
94
|
+
return torch.round(
|
|
95
|
+
self._get_indices() + 1 + self.frequency * self.duration,
|
|
96
|
+
).type(torch.long)
|
|
97
|
+
|
|
98
|
+
def forward(self, fseries: torch.Tensor, norm: str = "median"):
|
|
99
|
+
"""
|
|
100
|
+
Compute the transform for this row
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
fseries:
|
|
104
|
+
Frequency series of data. Should correspond to data with
|
|
105
|
+
the duration and sample rate used to initialize this object.
|
|
106
|
+
Expected input shape is `(B, C, F)`, where F is the number
|
|
107
|
+
of samples, C is the number of channels, and B is the number
|
|
108
|
+
of batches. If less than three-dimensional, axes will be
|
|
109
|
+
added.
|
|
110
|
+
norm:
|
|
111
|
+
The method of normalization. Options are "median", "mean", or
|
|
112
|
+
`None`.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
The row of Q-tiles for the given Q and frequency. Output is
|
|
116
|
+
three-dimensional: `(B, C, T)`
|
|
117
|
+
"""
|
|
118
|
+
if len(fseries.shape) > 3:
|
|
119
|
+
raise ValueError("Input data has more than 3 dimensions")
|
|
120
|
+
|
|
121
|
+
while len(fseries.shape) < 3:
|
|
122
|
+
fseries = fseries[None]
|
|
123
|
+
|
|
124
|
+
windowed = fseries[..., self.indices] * self.window
|
|
125
|
+
left, right = self.padding
|
|
126
|
+
padded = F.pad(windowed, (int(left), int(right)), mode="constant")
|
|
127
|
+
wenergy = torch.fft.ifftshift(padded, dim=-1)
|
|
128
|
+
|
|
129
|
+
tdenergy = torch.fft.ifft(wenergy)
|
|
130
|
+
energy = tdenergy.real**2.0 + tdenergy.imag**2.0
|
|
131
|
+
if norm:
|
|
132
|
+
norm = norm.lower() if isinstance(norm, str) else norm
|
|
133
|
+
if norm == "median":
|
|
134
|
+
medians = torch.quantile(energy, q=0.5, dim=-1, keepdim=True)
|
|
135
|
+
energy /= medians
|
|
136
|
+
elif norm == "mean":
|
|
137
|
+
means = torch.mean(energy, dim=-1, keepdim=True)
|
|
138
|
+
energy /= means
|
|
139
|
+
else:
|
|
140
|
+
raise ValueError("Invalid normalisation %r" % norm)
|
|
141
|
+
return energy.type(torch.float32)
|
|
142
|
+
return energy
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class SingleQTransform(torch.nn.Module):
|
|
146
|
+
"""
|
|
147
|
+
Compute the Q-transform for a single Q value for a batch of
|
|
148
|
+
multi-channel time series data. Input data should have
|
|
149
|
+
three dimensions or fewer.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
duration:
|
|
153
|
+
Length of the time series data in seconds
|
|
154
|
+
sample_rate:
|
|
155
|
+
Sample rate of the data in Hz
|
|
156
|
+
spectrogram_shape:
|
|
157
|
+
The shape of the interpolated spectrogram, specified as
|
|
158
|
+
`(num_f_bins, num_t_bins)`. Because the
|
|
159
|
+
frequency spacing of the Q-tiles is in log-space, the frequency
|
|
160
|
+
interpolation is log-spaced as well.
|
|
161
|
+
q:
|
|
162
|
+
The Q value to use for the Q transform
|
|
163
|
+
frange:
|
|
164
|
+
The lower and upper frequency limit to consider for
|
|
165
|
+
the transform. If unspecified, default values will
|
|
166
|
+
be chosen based on q, sample_rate, and duration
|
|
167
|
+
mismatch:
|
|
168
|
+
The maximum fractional mismatch between neighboring tiles
|
|
169
|
+
"""
|
|
170
|
+
|
|
171
|
+
def __init__(
|
|
172
|
+
self,
|
|
173
|
+
duration: float,
|
|
174
|
+
sample_rate: float,
|
|
175
|
+
spectrogram_shape: Tuple[int, int],
|
|
176
|
+
q: float = 12,
|
|
177
|
+
frange: List[float] = [0, torch.inf],
|
|
178
|
+
mismatch: float = 0.2,
|
|
179
|
+
):
|
|
180
|
+
super().__init__()
|
|
181
|
+
self.q = q
|
|
182
|
+
self.spectrogram_shape = spectrogram_shape
|
|
183
|
+
self.frange = frange
|
|
184
|
+
self.duration = duration
|
|
185
|
+
self.mismatch = mismatch
|
|
186
|
+
|
|
187
|
+
qprime = self.q / 11 ** (1 / 2.0)
|
|
188
|
+
if self.frange[0] <= 0: # set non-zero lower frequency
|
|
189
|
+
self.frange[0] = 50 * self.q / (2 * torch.pi * duration)
|
|
190
|
+
if math.isinf(self.frange[1]): # set non-infinite upper frequency
|
|
191
|
+
self.frange[1] = sample_rate / 2 / (1 + 1 / qprime)
|
|
192
|
+
self.freqs = self.get_freqs()
|
|
193
|
+
self.qtile_transforms = torch.nn.ModuleList(
|
|
194
|
+
[
|
|
195
|
+
QTile(self.q, freq, self.duration, sample_rate, self.mismatch)
|
|
196
|
+
for freq in self.freqs
|
|
197
|
+
]
|
|
198
|
+
)
|
|
199
|
+
self.qtiles = None
|
|
200
|
+
|
|
201
|
+
def get_freqs(self):
|
|
202
|
+
"""
|
|
203
|
+
Calculate the frequencies that will be used in this transform.
|
|
204
|
+
For each frequency, a `QTile` is created.
|
|
205
|
+
"""
|
|
206
|
+
minf, maxf = self.frange
|
|
207
|
+
fcum_mismatch = (
|
|
208
|
+
math.log(maxf / minf) * (2 + self.q**2) ** (1 / 2.0) / 2.0
|
|
209
|
+
)
|
|
210
|
+
deltam = 2 * (self.mismatch / 3.0) ** (1 / 2.0)
|
|
211
|
+
nfreq = int(max(1, math.ceil(fcum_mismatch / deltam)))
|
|
212
|
+
fstep = fcum_mismatch / nfreq
|
|
213
|
+
fstepmin = 1 / self.duration
|
|
214
|
+
|
|
215
|
+
freq_base = math.exp(2 / ((2 + self.q**2) ** (1 / 2.0)) * fstep)
|
|
216
|
+
freqs = torch.Tensor([freq_base ** (i + 0.5) for i in range(nfreq)])
|
|
217
|
+
freqs = (minf * freqs // fstepmin) * fstepmin
|
|
218
|
+
return torch.unique(freqs)
|
|
219
|
+
|
|
220
|
+
def get_max_energy(
|
|
221
|
+
self, fsearch_range: List[float] = None, dimension: str = "both"
|
|
222
|
+
):
|
|
223
|
+
"""
|
|
224
|
+
Gets the maximum energy value among the QTiles. The maximum can
|
|
225
|
+
be computed across all batches and channels, across all channels,
|
|
226
|
+
across all batches, or individually for each channel/batch
|
|
227
|
+
combination. This could be useful for allowing the use of different
|
|
228
|
+
Q values for different channels and batches, but the slicing would
|
|
229
|
+
be slow, so this isn't used yet.
|
|
230
|
+
|
|
231
|
+
Optionally, a pair of frequency values can be specified for
|
|
232
|
+
`fsearch_range` to restrict the frequencies in which the maximum
|
|
233
|
+
energy value is sought.
|
|
234
|
+
"""
|
|
235
|
+
allowed_dimensions = ["both", "neither", "channel", "batch"]
|
|
236
|
+
if dimension not in allowed_dimensions:
|
|
237
|
+
raise ValueError(f"Dimension must be one of {allowed_dimensions}")
|
|
238
|
+
|
|
239
|
+
if self.qtiles is None:
|
|
240
|
+
raise RuntimeError(
|
|
241
|
+
"Q-tiles must first be computed with .compute_qtiles()"
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
if fsearch_range is not None:
|
|
245
|
+
start = min(torch.argwhere(self.freqs > fsearch_range[0]))
|
|
246
|
+
stop = min(torch.argwhere(self.freqs > fsearch_range[1]))
|
|
247
|
+
qtiles = self.qtiles[start:stop]
|
|
248
|
+
else:
|
|
249
|
+
qtiles = self.qtiles
|
|
250
|
+
|
|
251
|
+
if dimension == "both":
|
|
252
|
+
return max([torch.max(qtile) for qtile in qtiles])
|
|
253
|
+
|
|
254
|
+
max_across_t = [torch.max(qtile, dim=-1).values for qtile in qtiles]
|
|
255
|
+
max_across_t = torch.stack(max_across_t, dim=-1)
|
|
256
|
+
max_across_ft = torch.max(max_across_t, dim=-1).values
|
|
257
|
+
|
|
258
|
+
if dimension == "neither":
|
|
259
|
+
return max_across_ft
|
|
260
|
+
if dimension == "channel":
|
|
261
|
+
return torch.max(max_across_ft, dim=-2).values
|
|
262
|
+
if dimension == "batch":
|
|
263
|
+
return torch.max(max_across_ft, dim=-1).values
|
|
264
|
+
|
|
265
|
+
def compute_qtiles(self, X: torch.Tensor, norm: str = "median"):
|
|
266
|
+
"""
|
|
267
|
+
Take the FFT of the input timeseries and calculate the transform
|
|
268
|
+
for each `QTile`
|
|
269
|
+
"""
|
|
270
|
+
# Computing the FFT with the same normalization and scaling as GWpy
|
|
271
|
+
X = torch.fft.rfft(X, norm="forward")
|
|
272
|
+
X[..., 1:] *= 2
|
|
273
|
+
self.qtiles = [qtile(X, norm) for qtile in self.qtile_transforms]
|
|
274
|
+
|
|
275
|
+
def interpolate(self, num_f_bins: int, num_t_bins: int):
|
|
276
|
+
"""
|
|
277
|
+
Interpolate each `QTile` to the specified number of time and
|
|
278
|
+
frequency bins. Note that PyTorch does not have the same
|
|
279
|
+
interpolation methods that GWpy uses, and so the interpolated
|
|
280
|
+
spectrograms will be different even though the uninterpolated
|
|
281
|
+
values match. The `bicubic` interpolation method is used as
|
|
282
|
+
it seems to match GWpy most closely.
|
|
283
|
+
"""
|
|
284
|
+
if self.qtiles is None:
|
|
285
|
+
raise RuntimeError(
|
|
286
|
+
"Q-tiles must first be computed with .compute_qtiles()"
|
|
287
|
+
)
|
|
288
|
+
resampled = [
|
|
289
|
+
F.interpolate(
|
|
290
|
+
qtile[None], (qtile.shape[-2], num_t_bins), mode="bicubic"
|
|
291
|
+
)
|
|
292
|
+
for qtile in self.qtiles
|
|
293
|
+
]
|
|
294
|
+
resampled = torch.stack(resampled, dim=-2)
|
|
295
|
+
resampled = F.interpolate(
|
|
296
|
+
resampled[0], (num_f_bins, num_t_bins), mode="bicubic"
|
|
297
|
+
)
|
|
298
|
+
return torch.squeeze(resampled)
|
|
299
|
+
|
|
300
|
+
def forward(
|
|
301
|
+
self,
|
|
302
|
+
X: torch.Tensor,
|
|
303
|
+
norm: str = "median",
|
|
304
|
+
spectrogram_shape: Optional[Tuple[int, int]] = None,
|
|
305
|
+
):
|
|
306
|
+
"""
|
|
307
|
+
Compute the Q-tiles and interpolate
|
|
308
|
+
|
|
309
|
+
Args:
|
|
310
|
+
X:
|
|
311
|
+
Time series of data. Should have the duration and sample rate
|
|
312
|
+
used to initialize this object. Expected input shape is
|
|
313
|
+
`(B, C, T)`, where T is the number of samples, C is the number
|
|
314
|
+
of channels, and B is the number of batches. If less than
|
|
315
|
+
three-dimensional, axes will be added during Q-tile
|
|
316
|
+
computation.
|
|
317
|
+
norm:
|
|
318
|
+
The method of interpolation used by each QTile
|
|
319
|
+
spectrogram_shape:
|
|
320
|
+
The shape of the interpolated spectrogram, specified as
|
|
321
|
+
`(num_f_bins, num_t_bins)`. Because the
|
|
322
|
+
frequency spacing of the Q-tiles is in log-space, the frequency
|
|
323
|
+
interpolation is log-spaced as well. If not given, the shape
|
|
324
|
+
used to initialize the transform will be used.
|
|
325
|
+
|
|
326
|
+
Returns:
|
|
327
|
+
The interpolated Q-transform for the batch of data. Output will
|
|
328
|
+
have one more dimension than the input
|
|
329
|
+
"""
|
|
330
|
+
|
|
331
|
+
if spectrogram_shape is None:
|
|
332
|
+
spectrogram_shape = self.spectrogram_shape
|
|
333
|
+
num_f_bins, num_t_bins = spectrogram_shape
|
|
334
|
+
self.compute_qtiles(X, norm)
|
|
335
|
+
return self.interpolate(num_f_bins, num_t_bins)
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
class QScan(torch.nn.Module):
|
|
339
|
+
"""
|
|
340
|
+
Calculate the Q-transform of a batch of multi-channel
|
|
341
|
+
time series data for a range of Q values and return
|
|
342
|
+
the interpolated Q-transform with the highest energy.
|
|
343
|
+
|
|
344
|
+
Args:
|
|
345
|
+
duration:
|
|
346
|
+
Length of the time series data in seconds
|
|
347
|
+
sample_rate:
|
|
348
|
+
Sample rate of the data in Hz
|
|
349
|
+
spectrogram_shape:
|
|
350
|
+
The shape of the interpolated spectrogram, specified as
|
|
351
|
+
`(num_f_bins, num_t_bins)`. Because the
|
|
352
|
+
frequency spacing of the Q-tiles is in log-space, the frequency
|
|
353
|
+
interpolation is log-spaced as well.
|
|
354
|
+
qrange:
|
|
355
|
+
The lower and upper values of Q to consider. The
|
|
356
|
+
actual values of Q used for the transforms are
|
|
357
|
+
determined by the `get_qs` method
|
|
358
|
+
frange:
|
|
359
|
+
The lower and upper frequency limit to consider for
|
|
360
|
+
the transform. If unspecified, default values will
|
|
361
|
+
be chosen based on q, sample_rate, and duration
|
|
362
|
+
mismatch:
|
|
363
|
+
The maximum fractional mismatch between neighboring tiles
|
|
364
|
+
"""
|
|
365
|
+
|
|
366
|
+
def __init__(
|
|
367
|
+
self,
|
|
368
|
+
duration: float,
|
|
369
|
+
sample_rate: float,
|
|
370
|
+
spectrogram_shape: Tuple[int, int],
|
|
371
|
+
qrange: List[float] = [4, 64],
|
|
372
|
+
frange: List[float] = [0, torch.inf],
|
|
373
|
+
mismatch: float = 0.2,
|
|
374
|
+
):
|
|
375
|
+
super().__init__()
|
|
376
|
+
self.qrange = qrange
|
|
377
|
+
self.mismatch = mismatch
|
|
378
|
+
self.qs = self.get_qs()
|
|
379
|
+
self.frange = frange
|
|
380
|
+
self.spectrogram_shape = spectrogram_shape
|
|
381
|
+
|
|
382
|
+
# Deliberately doing something different from GWpy here.
|
|
383
|
+
# Their final frange is the intersection of the frange
|
|
384
|
+
# from each q. This implementation uses the frange of
|
|
385
|
+
# the chosen q.
|
|
386
|
+
self.q_transforms = torch.nn.ModuleList(
|
|
387
|
+
[
|
|
388
|
+
SingleQTransform(
|
|
389
|
+
duration=duration,
|
|
390
|
+
sample_rate=sample_rate,
|
|
391
|
+
spectrogram_shape=spectrogram_shape,
|
|
392
|
+
q=q,
|
|
393
|
+
frange=self.frange.copy(),
|
|
394
|
+
mismatch=self.mismatch,
|
|
395
|
+
)
|
|
396
|
+
for q in self.qs
|
|
397
|
+
]
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
def get_qs(self):
|
|
401
|
+
"""
|
|
402
|
+
Determine the values of Q to try for the set of Q-transforms
|
|
403
|
+
"""
|
|
404
|
+
deltam = 2 * (self.mismatch / 3.0) ** (1 / 2.0)
|
|
405
|
+
cumum = math.log(self.qrange[1] / self.qrange[0]) / 2 ** (1 / 2.0)
|
|
406
|
+
nplanes = int(max(math.ceil(cumum / deltam), 1))
|
|
407
|
+
dq = cumum / nplanes
|
|
408
|
+
qs = [
|
|
409
|
+
self.qrange[0] * math.exp(2 ** (1 / 2.0) * dq * (i + 0.5))
|
|
410
|
+
for i in range(nplanes)
|
|
411
|
+
]
|
|
412
|
+
return qs
|
|
413
|
+
|
|
414
|
+
def forward(
|
|
415
|
+
self,
|
|
416
|
+
X: torch.Tensor,
|
|
417
|
+
fsearch_range: List[float] = None,
|
|
418
|
+
norm: str = "median",
|
|
419
|
+
spectrogram_shape: Optional[Tuple[int, int]] = None,
|
|
420
|
+
):
|
|
421
|
+
"""
|
|
422
|
+
Compute the set of QTiles for each Q transform and determine which
|
|
423
|
+
has the highest energy value. Interpolate and return the
|
|
424
|
+
corresponding set of tiles.
|
|
425
|
+
|
|
426
|
+
Args:
|
|
427
|
+
X:
|
|
428
|
+
Time series of data. Should have the duration and sample rate
|
|
429
|
+
used to initialize this object. Expected input shape is
|
|
430
|
+
`(B, C, T)`, where T is the number of samples, C is the number
|
|
431
|
+
of channels, and B is the number of batches. If less than
|
|
432
|
+
three-dimensional, axes will be added during Q-tile
|
|
433
|
+
computation.
|
|
434
|
+
fsearch_range:
|
|
435
|
+
The lower and upper frequency values within which to search
|
|
436
|
+
for the maximum energy
|
|
437
|
+
norm:
|
|
438
|
+
The method of interpolation used by each QTile
|
|
439
|
+
spectrogram_shape:
|
|
440
|
+
The shape of the interpolated spectrogram, specified as
|
|
441
|
+
`(num_f_bins, num_t_bins)`. Because the
|
|
442
|
+
frequency spacing of the Q-tiles is in log-space, the frequency
|
|
443
|
+
interpolation is log-spaced as well. If not given, the shape
|
|
444
|
+
used to initialize the transform will be used.
|
|
445
|
+
|
|
446
|
+
Returns:
|
|
447
|
+
An interpolated Q-transform for the batch of data. Output will
|
|
448
|
+
have one more dimension than the input
|
|
449
|
+
"""
|
|
450
|
+
for transform in self.q_transforms:
|
|
451
|
+
transform.compute_qtiles(X, norm)
|
|
452
|
+
idx = torch.argmax(
|
|
453
|
+
torch.Tensor(
|
|
454
|
+
[
|
|
455
|
+
transform.get_max_energy(fsearch_range=fsearch_range)
|
|
456
|
+
for transform in self.q_transforms
|
|
457
|
+
]
|
|
458
|
+
)
|
|
459
|
+
)
|
|
460
|
+
if spectrogram_shape is None:
|
|
461
|
+
spectrogram_shape = self.spectrogram_shape
|
|
462
|
+
num_f_bins, num_t_bins = spectrogram_shape
|
|
463
|
+
return self.q_transforms[idx].interpolate(num_f_bins, num_t_bins)
|
ml4gw/transforms/spectral.py
CHANGED
|
@@ -34,6 +34,10 @@ class SpectralDensity(torch.nn.Module):
|
|
|
34
34
|
average:
|
|
35
35
|
Aggregation method to use for combining windowed FFTs.
|
|
36
36
|
Allowed values are `"mean"` and `"median"`.
|
|
37
|
+
window:
|
|
38
|
+
Window array to multiply by each FFT window before
|
|
39
|
+
FFT computation. Should have length `nperseg`.
|
|
40
|
+
Defaults to a hanning window.
|
|
37
41
|
fast:
|
|
38
42
|
Whether to use a faster spectral density computation that
|
|
39
43
|
support cross spectral density, or a slower one which does
|
|
@@ -47,6 +51,7 @@ class SpectralDensity(torch.nn.Module):
|
|
|
47
51
|
fftlength: float,
|
|
48
52
|
overlap: Optional[float] = None,
|
|
49
53
|
average: str = "mean",
|
|
54
|
+
window: Optional[torch.Tensor] = None,
|
|
50
55
|
fast: bool = False,
|
|
51
56
|
) -> None:
|
|
52
57
|
if overlap is None:
|
|
@@ -63,11 +68,18 @@ class SpectralDensity(torch.nn.Module):
|
|
|
63
68
|
self.nperseg = int(fftlength * sample_rate)
|
|
64
69
|
self.nstride = self.nperseg - int(overlap * sample_rate)
|
|
65
70
|
|
|
66
|
-
#
|
|
67
|
-
#
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
+
# if no window is provided, default to a hanning window;
|
|
72
|
+
# validate that window is correct size
|
|
73
|
+
if window is None:
|
|
74
|
+
window = torch.hann_window(self.nperseg)
|
|
75
|
+
|
|
76
|
+
if window.size(0) != self.nperseg:
|
|
77
|
+
raise ValueError(
|
|
78
|
+
"Window must have length {} got {}".format(
|
|
79
|
+
self.nperseg, window.size(0)
|
|
80
|
+
)
|
|
81
|
+
)
|
|
82
|
+
self.register_buffer("window", window)
|
|
71
83
|
|
|
72
84
|
# scale corresponds to "density" normalization, worth
|
|
73
85
|
# considering adding this as a kwarg and changing this calc
|
ml4gw/utils/interferometer.py
CHANGED
|
@@ -35,6 +35,12 @@ class InterferometerGeometry:
|
|
|
35
35
|
self.vertex = torch.Tensor(
|
|
36
36
|
(4.54637409900e06, 8.42989697626e05, 4.37857696241e06)
|
|
37
37
|
)
|
|
38
|
+
elif name == "K1":
|
|
39
|
+
self.x_arm = torch.Tensor((-0.3759040, -0.8361583, 0.3994189))
|
|
40
|
+
self.y_arm = torch.Tensor((0.7164378, 0.01114076, 0.6975620))
|
|
41
|
+
self.vertex = torch.Tensor(
|
|
42
|
+
(-3777336.024, 3484898.411, 3765313.697)
|
|
43
|
+
)
|
|
38
44
|
else:
|
|
39
45
|
raise ValueError(
|
|
40
46
|
f"{name} is not recognized as an interferometer, "
|
ml4gw/waveforms/__init__.py
CHANGED
ml4gw/waveforms/generator.py
CHANGED