ml4gw 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ml4gw might be problematic. Click here for more details.

@@ -0,0 +1,463 @@
1
+ import math
2
+ from typing import List, Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ """
8
+ All based on https://github.com/gwpy/gwpy/blob/v3.0.8/gwpy/signal/qtransform.py
9
+ The methods, names, and descriptions come almost entirely from GWpy.
10
+ This code allows the Q-transform to be performed on batches of multi-channel
11
+ input on GPU.
12
+ """
13
+
14
+
15
+ class QTile(torch.nn.Module):
16
+ """
17
+ Compute the row of Q-tiles for a single Q value and a single
18
+ frequency for a batch of multi-channel frequency series data.
19
+ Should really be called `QRow`, but I want to match GWpy.
20
+ Input data should have three dimensions or fewer.
21
+ If fewer, dimensions will be added until the input is
22
+ three-dimensional.
23
+
24
+ Args:
25
+ q:
26
+ The Q value to use in computing the Q tile
27
+ frequency:
28
+ The frequency for which to compute the Q tile in Hz
29
+ duration:
30
+ The length of time in seconds that the input frequency
31
+ series represents
32
+ sample_rate:
33
+ The sample rate of the original time series in Hz
34
+ mismatch:
35
+ The maximum fractional mismatch between neighboring tiles
36
+
37
+
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ q: float,
43
+ frequency: float,
44
+ duration: float,
45
+ sample_rate: float,
46
+ mismatch: float,
47
+ ):
48
+ super().__init__()
49
+ self.mismatch = mismatch
50
+ self.q = q
51
+ self.deltam = torch.tensor(2 * (self.mismatch / 3.0) ** (1 / 2.0))
52
+ self.qprime = self.q / 11 ** (1 / 2.0)
53
+ self.frequency = frequency
54
+ self.duration = duration
55
+ self.sample_rate = sample_rate
56
+
57
+ self.windowsize = (
58
+ 2 * int(self.frequency / self.qprime * self.duration) + 1
59
+ )
60
+ pad = self.ntiles() - self.windowsize
61
+ padding = torch.Tensor((int((pad - 1) / 2.0), int((pad + 1) / 2.0)))
62
+ self.register_buffer("padding", padding)
63
+ self.register_buffer("indices", self.get_data_indices())
64
+ self.register_buffer("window", self.get_window())
65
+
66
+ def ntiles(self):
67
+ """
68
+ Number of tiles in this frequency row
69
+ """
70
+ tcum_mismatch = self.duration * 2 * torch.pi * self.frequency / self.q
71
+ return int(2 ** torch.ceil(torch.log2(tcum_mismatch / self.deltam)))
72
+
73
+ def _get_indices(self):
74
+ half = int((self.windowsize - 1) / 2)
75
+ return torch.arange(-half, half + 1)
76
+
77
+ def get_window(self):
78
+ """
79
+ Generate the bi-square window for this row
80
+ """
81
+ wfrequencies = self._get_indices() / self.duration
82
+ xfrequencies = wfrequencies * self.qprime / self.frequency
83
+ norm = (
84
+ self.ntiles()
85
+ / (self.duration * self.sample_rate)
86
+ * (315 * self.qprime / (128 * self.frequency)) ** (1 / 2.0)
87
+ )
88
+ return torch.Tensor((1 - xfrequencies**2) ** 2 * norm)
89
+
90
+ def get_data_indices(self):
91
+ """
92
+ Get the index array of relevant frequencies for this row
93
+ """
94
+ return torch.round(
95
+ self._get_indices() + 1 + self.frequency * self.duration,
96
+ ).type(torch.long)
97
+
98
+ def forward(self, fseries: torch.Tensor, norm: str = "median"):
99
+ """
100
+ Compute the transform for this row
101
+
102
+ Args:
103
+ fseries:
104
+ Frequency series of data. Should correspond to data with
105
+ the duration and sample rate used to initialize this object.
106
+ Expected input shape is `(B, C, F)`, where F is the number
107
+ of samples, C is the number of channels, and B is the number
108
+ of batches. If less than three-dimensional, axes will be
109
+ added.
110
+ norm:
111
+ The method of normalization. Options are "median", "mean", or
112
+ `None`.
113
+
114
+ Returns:
115
+ The row of Q-tiles for the given Q and frequency. Output is
116
+ three-dimensional: `(B, C, T)`
117
+ """
118
+ if len(fseries.shape) > 3:
119
+ raise ValueError("Input data has more than 3 dimensions")
120
+
121
+ while len(fseries.shape) < 3:
122
+ fseries = fseries[None]
123
+
124
+ windowed = fseries[..., self.indices] * self.window
125
+ left, right = self.padding
126
+ padded = F.pad(windowed, (int(left), int(right)), mode="constant")
127
+ wenergy = torch.fft.ifftshift(padded, dim=-1)
128
+
129
+ tdenergy = torch.fft.ifft(wenergy)
130
+ energy = tdenergy.real**2.0 + tdenergy.imag**2.0
131
+ if norm:
132
+ norm = norm.lower() if isinstance(norm, str) else norm
133
+ if norm == "median":
134
+ medians = torch.quantile(energy, q=0.5, dim=-1, keepdim=True)
135
+ energy /= medians
136
+ elif norm == "mean":
137
+ means = torch.mean(energy, dim=-1, keepdim=True)
138
+ energy /= means
139
+ else:
140
+ raise ValueError("Invalid normalisation %r" % norm)
141
+ return energy.type(torch.float32)
142
+ return energy
143
+
144
+
145
+ class SingleQTransform(torch.nn.Module):
146
+ """
147
+ Compute the Q-transform for a single Q value for a batch of
148
+ multi-channel time series data. Input data should have
149
+ three dimensions or fewer.
150
+
151
+ Args:
152
+ duration:
153
+ Length of the time series data in seconds
154
+ sample_rate:
155
+ Sample rate of the data in Hz
156
+ spectrogram_shape:
157
+ The shape of the interpolated spectrogram, specified as
158
+ `(num_f_bins, num_t_bins)`. Because the
159
+ frequency spacing of the Q-tiles is in log-space, the frequency
160
+ interpolation is log-spaced as well.
161
+ q:
162
+ The Q value to use for the Q transform
163
+ frange:
164
+ The lower and upper frequency limit to consider for
165
+ the transform. If unspecified, default values will
166
+ be chosen based on q, sample_rate, and duration
167
+ mismatch:
168
+ The maximum fractional mismatch between neighboring tiles
169
+ """
170
+
171
+ def __init__(
172
+ self,
173
+ duration: float,
174
+ sample_rate: float,
175
+ spectrogram_shape: Tuple[int, int],
176
+ q: float = 12,
177
+ frange: List[float] = [0, torch.inf],
178
+ mismatch: float = 0.2,
179
+ ):
180
+ super().__init__()
181
+ self.q = q
182
+ self.spectrogram_shape = spectrogram_shape
183
+ self.frange = frange
184
+ self.duration = duration
185
+ self.mismatch = mismatch
186
+
187
+ qprime = self.q / 11 ** (1 / 2.0)
188
+ if self.frange[0] <= 0: # set non-zero lower frequency
189
+ self.frange[0] = 50 * self.q / (2 * torch.pi * duration)
190
+ if math.isinf(self.frange[1]): # set non-infinite upper frequency
191
+ self.frange[1] = sample_rate / 2 / (1 + 1 / qprime)
192
+ self.freqs = self.get_freqs()
193
+ self.qtile_transforms = torch.nn.ModuleList(
194
+ [
195
+ QTile(self.q, freq, self.duration, sample_rate, self.mismatch)
196
+ for freq in self.freqs
197
+ ]
198
+ )
199
+ self.qtiles = None
200
+
201
+ def get_freqs(self):
202
+ """
203
+ Calculate the frequencies that will be used in this transform.
204
+ For each frequency, a `QTile` is created.
205
+ """
206
+ minf, maxf = self.frange
207
+ fcum_mismatch = (
208
+ math.log(maxf / minf) * (2 + self.q**2) ** (1 / 2.0) / 2.0
209
+ )
210
+ deltam = 2 * (self.mismatch / 3.0) ** (1 / 2.0)
211
+ nfreq = int(max(1, math.ceil(fcum_mismatch / deltam)))
212
+ fstep = fcum_mismatch / nfreq
213
+ fstepmin = 1 / self.duration
214
+
215
+ freq_base = math.exp(2 / ((2 + self.q**2) ** (1 / 2.0)) * fstep)
216
+ freqs = torch.Tensor([freq_base ** (i + 0.5) for i in range(nfreq)])
217
+ freqs = (minf * freqs // fstepmin) * fstepmin
218
+ return torch.unique(freqs)
219
+
220
+ def get_max_energy(
221
+ self, fsearch_range: List[float] = None, dimension: str = "both"
222
+ ):
223
+ """
224
+ Gets the maximum energy value among the QTiles. The maximum can
225
+ be computed across all batches and channels, across all channels,
226
+ across all batches, or individually for each channel/batch
227
+ combination. This could be useful for allowing the use of different
228
+ Q values for different channels and batches, but the slicing would
229
+ be slow, so this isn't used yet.
230
+
231
+ Optionally, a pair of frequency values can be specified for
232
+ `fsearch_range` to restrict the frequencies in which the maximum
233
+ energy value is sought.
234
+ """
235
+ allowed_dimensions = ["both", "neither", "channel", "batch"]
236
+ if dimension not in allowed_dimensions:
237
+ raise ValueError(f"Dimension must be one of {allowed_dimensions}")
238
+
239
+ if self.qtiles is None:
240
+ raise RuntimeError(
241
+ "Q-tiles must first be computed with .compute_qtiles()"
242
+ )
243
+
244
+ if fsearch_range is not None:
245
+ start = min(torch.argwhere(self.freqs > fsearch_range[0]))
246
+ stop = min(torch.argwhere(self.freqs > fsearch_range[1]))
247
+ qtiles = self.qtiles[start:stop]
248
+ else:
249
+ qtiles = self.qtiles
250
+
251
+ if dimension == "both":
252
+ return max([torch.max(qtile) for qtile in qtiles])
253
+
254
+ max_across_t = [torch.max(qtile, dim=-1).values for qtile in qtiles]
255
+ max_across_t = torch.stack(max_across_t, dim=-1)
256
+ max_across_ft = torch.max(max_across_t, dim=-1).values
257
+
258
+ if dimension == "neither":
259
+ return max_across_ft
260
+ if dimension == "channel":
261
+ return torch.max(max_across_ft, dim=-2).values
262
+ if dimension == "batch":
263
+ return torch.max(max_across_ft, dim=-1).values
264
+
265
+ def compute_qtiles(self, X: torch.Tensor, norm: str = "median"):
266
+ """
267
+ Take the FFT of the input timeseries and calculate the transform
268
+ for each `QTile`
269
+ """
270
+ # Computing the FFT with the same normalization and scaling as GWpy
271
+ X = torch.fft.rfft(X, norm="forward")
272
+ X[..., 1:] *= 2
273
+ self.qtiles = [qtile(X, norm) for qtile in self.qtile_transforms]
274
+
275
+ def interpolate(self, num_f_bins: int, num_t_bins: int):
276
+ """
277
+ Interpolate each `QTile` to the specified number of time and
278
+ frequency bins. Note that PyTorch does not have the same
279
+ interpolation methods that GWpy uses, and so the interpolated
280
+ spectrograms will be different even though the uninterpolated
281
+ values match. The `bicubic` interpolation method is used as
282
+ it seems to match GWpy most closely.
283
+ """
284
+ if self.qtiles is None:
285
+ raise RuntimeError(
286
+ "Q-tiles must first be computed with .compute_qtiles()"
287
+ )
288
+ resampled = [
289
+ F.interpolate(
290
+ qtile[None], (qtile.shape[-2], num_t_bins), mode="bicubic"
291
+ )
292
+ for qtile in self.qtiles
293
+ ]
294
+ resampled = torch.stack(resampled, dim=-2)
295
+ resampled = F.interpolate(
296
+ resampled[0], (num_f_bins, num_t_bins), mode="bicubic"
297
+ )
298
+ return torch.squeeze(resampled)
299
+
300
+ def forward(
301
+ self,
302
+ X: torch.Tensor,
303
+ norm: str = "median",
304
+ spectrogram_shape: Optional[Tuple[int, int]] = None,
305
+ ):
306
+ """
307
+ Compute the Q-tiles and interpolate
308
+
309
+ Args:
310
+ X:
311
+ Time series of data. Should have the duration and sample rate
312
+ used to initialize this object. Expected input shape is
313
+ `(B, C, T)`, where T is the number of samples, C is the number
314
+ of channels, and B is the number of batches. If less than
315
+ three-dimensional, axes will be added during Q-tile
316
+ computation.
317
+ norm:
318
+ The method of interpolation used by each QTile
319
+ spectrogram_shape:
320
+ The shape of the interpolated spectrogram, specified as
321
+ `(num_f_bins, num_t_bins)`. Because the
322
+ frequency spacing of the Q-tiles is in log-space, the frequency
323
+ interpolation is log-spaced as well. If not given, the shape
324
+ used to initialize the transform will be used.
325
+
326
+ Returns:
327
+ The interpolated Q-transform for the batch of data. Output will
328
+ have one more dimension than the input
329
+ """
330
+
331
+ if spectrogram_shape is None:
332
+ spectrogram_shape = self.spectrogram_shape
333
+ num_f_bins, num_t_bins = spectrogram_shape
334
+ self.compute_qtiles(X, norm)
335
+ return self.interpolate(num_f_bins, num_t_bins)
336
+
337
+
338
+ class QScan(torch.nn.Module):
339
+ """
340
+ Calculate the Q-transform of a batch of multi-channel
341
+ time series data for a range of Q values and return
342
+ the interpolated Q-transform with the highest energy.
343
+
344
+ Args:
345
+ duration:
346
+ Length of the time series data in seconds
347
+ sample_rate:
348
+ Sample rate of the data in Hz
349
+ spectrogram_shape:
350
+ The shape of the interpolated spectrogram, specified as
351
+ `(num_f_bins, num_t_bins)`. Because the
352
+ frequency spacing of the Q-tiles is in log-space, the frequency
353
+ interpolation is log-spaced as well.
354
+ qrange:
355
+ The lower and upper values of Q to consider. The
356
+ actual values of Q used for the transforms are
357
+ determined by the `get_qs` method
358
+ frange:
359
+ The lower and upper frequency limit to consider for
360
+ the transform. If unspecified, default values will
361
+ be chosen based on q, sample_rate, and duration
362
+ mismatch:
363
+ The maximum fractional mismatch between neighboring tiles
364
+ """
365
+
366
+ def __init__(
367
+ self,
368
+ duration: float,
369
+ sample_rate: float,
370
+ spectrogram_shape: Tuple[int, int],
371
+ qrange: List[float] = [4, 64],
372
+ frange: List[float] = [0, torch.inf],
373
+ mismatch: float = 0.2,
374
+ ):
375
+ super().__init__()
376
+ self.qrange = qrange
377
+ self.mismatch = mismatch
378
+ self.qs = self.get_qs()
379
+ self.frange = frange
380
+ self.spectrogram_shape = spectrogram_shape
381
+
382
+ # Deliberately doing something different from GWpy here.
383
+ # Their final frange is the intersection of the frange
384
+ # from each q. This implementation uses the frange of
385
+ # the chosen q.
386
+ self.q_transforms = torch.nn.ModuleList(
387
+ [
388
+ SingleQTransform(
389
+ duration=duration,
390
+ sample_rate=sample_rate,
391
+ spectrogram_shape=spectrogram_shape,
392
+ q=q,
393
+ frange=self.frange.copy(),
394
+ mismatch=self.mismatch,
395
+ )
396
+ for q in self.qs
397
+ ]
398
+ )
399
+
400
+ def get_qs(self):
401
+ """
402
+ Determine the values of Q to try for the set of Q-transforms
403
+ """
404
+ deltam = 2 * (self.mismatch / 3.0) ** (1 / 2.0)
405
+ cumum = math.log(self.qrange[1] / self.qrange[0]) / 2 ** (1 / 2.0)
406
+ nplanes = int(max(math.ceil(cumum / deltam), 1))
407
+ dq = cumum / nplanes
408
+ qs = [
409
+ self.qrange[0] * math.exp(2 ** (1 / 2.0) * dq * (i + 0.5))
410
+ for i in range(nplanes)
411
+ ]
412
+ return qs
413
+
414
+ def forward(
415
+ self,
416
+ X: torch.Tensor,
417
+ fsearch_range: List[float] = None,
418
+ norm: str = "median",
419
+ spectrogram_shape: Optional[Tuple[int, int]] = None,
420
+ ):
421
+ """
422
+ Compute the set of QTiles for each Q transform and determine which
423
+ has the highest energy value. Interpolate and return the
424
+ corresponding set of tiles.
425
+
426
+ Args:
427
+ X:
428
+ Time series of data. Should have the duration and sample rate
429
+ used to initialize this object. Expected input shape is
430
+ `(B, C, T)`, where T is the number of samples, C is the number
431
+ of channels, and B is the number of batches. If less than
432
+ three-dimensional, axes will be added during Q-tile
433
+ computation.
434
+ fsearch_range:
435
+ The lower and upper frequency values within which to search
436
+ for the maximum energy
437
+ norm:
438
+ The method of interpolation used by each QTile
439
+ spectrogram_shape:
440
+ The shape of the interpolated spectrogram, specified as
441
+ `(num_f_bins, num_t_bins)`. Because the
442
+ frequency spacing of the Q-tiles is in log-space, the frequency
443
+ interpolation is log-spaced as well. If not given, the shape
444
+ used to initialize the transform will be used.
445
+
446
+ Returns:
447
+ An interpolated Q-transform for the batch of data. Output will
448
+ have one more dimension than the input
449
+ """
450
+ for transform in self.q_transforms:
451
+ transform.compute_qtiles(X, norm)
452
+ idx = torch.argmax(
453
+ torch.Tensor(
454
+ [
455
+ transform.get_max_energy(fsearch_range=fsearch_range)
456
+ for transform in self.q_transforms
457
+ ]
458
+ )
459
+ )
460
+ if spectrogram_shape is None:
461
+ spectrogram_shape = self.spectrogram_shape
462
+ num_f_bins, num_t_bins = spectrogram_shape
463
+ return self.q_transforms[idx].interpolate(num_f_bins, num_t_bins)
@@ -34,6 +34,10 @@ class SpectralDensity(torch.nn.Module):
34
34
  average:
35
35
  Aggregation method to use for combining windowed FFTs.
36
36
  Allowed values are `"mean"` and `"median"`.
37
+ window:
38
+ Window array to multiply by each FFT window before
39
+ FFT computation. Should have length `nperseg`.
40
+ Defaults to a hanning window.
37
41
  fast:
38
42
  Whether to use a faster spectral density computation that
39
43
  support cross spectral density, or a slower one which does
@@ -47,6 +51,7 @@ class SpectralDensity(torch.nn.Module):
47
51
  fftlength: float,
48
52
  overlap: Optional[float] = None,
49
53
  average: str = "mean",
54
+ window: Optional[torch.Tensor] = None,
50
55
  fast: bool = False,
51
56
  ) -> None:
52
57
  if overlap is None:
@@ -63,11 +68,18 @@ class SpectralDensity(torch.nn.Module):
63
68
  self.nperseg = int(fftlength * sample_rate)
64
69
  self.nstride = self.nperseg - int(overlap * sample_rate)
65
70
 
66
- # TODOs: Do we allow for arbitrary windows?
67
- # Making this buffer persistent in case we want
68
- # to implement this down the line, so that custom
69
- # windows can be loaded in.
70
- self.register_buffer("window", torch.hann_window(self.nperseg))
71
+ # if no window is provided, default to a hanning window;
72
+ # validate that window is correct size
73
+ if window is None:
74
+ window = torch.hann_window(self.nperseg)
75
+
76
+ if window.size(0) != self.nperseg:
77
+ raise ValueError(
78
+ "Window must have length {} got {}".format(
79
+ self.nperseg, window.size(0)
80
+ )
81
+ )
82
+ self.register_buffer("window", window)
71
83
 
72
84
  # scale corresponds to "density" normalization, worth
73
85
  # considering adding this as a kwarg and changing this calc
@@ -35,6 +35,12 @@ class InterferometerGeometry:
35
35
  self.vertex = torch.Tensor(
36
36
  (4.54637409900e06, 8.42989697626e05, 4.37857696241e06)
37
37
  )
38
+ elif name == "K1":
39
+ self.x_arm = torch.Tensor((-0.3759040, -0.8361583, 0.3994189))
40
+ self.y_arm = torch.Tensor((0.7164378, 0.01114076, 0.6975620))
41
+ self.vertex = torch.Tensor(
42
+ (-3777336.024, 3484898.411, 3765313.697)
43
+ )
38
44
  else:
39
45
  raise ValueError(
40
46
  f"{name} is not recognized as an interferometer, "
@@ -1,3 +1,5 @@
1
1
  from .phenom_d import IMRPhenomD
2
+ from .phenom_p import IMRPhenomPv2
3
+ from .ringdown import Ringdown
2
4
  from .sine_gaussian import SineGaussian
3
5
  from .taylorf2 import TaylorF2
@@ -12,7 +12,7 @@ class ParameterSampler(torch.nn.Module):
12
12
  self,
13
13
  N: int,
14
14
  ):
15
- return {k: v(N) for k, v in self.parameters.items()}
15
+ return {k: v.sample((N,)) for k, v in self.parameters.items()}
16
16
 
17
17
 
18
18
  class WaveformGenerator(torch.nn.Module):