ml4gw 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ml4gw might be problematic. Click here for more details.

Files changed (44) hide show
  1. ml4gw/augmentations.py +8 -2
  2. ml4gw/constants.py +10 -19
  3. ml4gw/dataloading/chunked_dataset.py +4 -2
  4. ml4gw/dataloading/hdf5_dataset.py +1 -1
  5. ml4gw/dataloading/in_memory_dataset.py +8 -4
  6. ml4gw/distributions.py +5 -3
  7. ml4gw/gw.py +21 -27
  8. ml4gw/nn/autoencoder/base.py +11 -6
  9. ml4gw/nn/autoencoder/convolutional.py +7 -4
  10. ml4gw/nn/autoencoder/skip_connection.py +7 -6
  11. ml4gw/nn/autoencoder/utils.py +2 -1
  12. ml4gw/nn/norm.py +5 -1
  13. ml4gw/nn/streaming/online_average.py +7 -5
  14. ml4gw/nn/streaming/snapshotter.py +7 -5
  15. ml4gw/spectral.py +41 -37
  16. ml4gw/transforms/__init__.py +1 -0
  17. ml4gw/transforms/pearson.py +7 -3
  18. ml4gw/transforms/qtransform.py +151 -53
  19. ml4gw/transforms/scaler.py +9 -3
  20. ml4gw/transforms/snr_rescaler.py +6 -5
  21. ml4gw/transforms/spectral.py +9 -2
  22. ml4gw/transforms/spectrogram.py +7 -1
  23. ml4gw/transforms/spline_interpolation.py +370 -0
  24. ml4gw/transforms/transform.py +4 -3
  25. ml4gw/transforms/waveforms.py +10 -7
  26. ml4gw/transforms/whitening.py +12 -4
  27. ml4gw/types.py +25 -10
  28. ml4gw/utils/interferometer.py +1 -1
  29. ml4gw/utils/slicing.py +24 -16
  30. ml4gw/waveforms/__init__.py +2 -5
  31. ml4gw/waveforms/adhoc/__init__.py +2 -0
  32. ml4gw/waveforms/{ringdown.py → adhoc/ringdown.py} +8 -9
  33. ml4gw/waveforms/{sine_gaussian.py → adhoc/sine_gaussian.py} +6 -6
  34. ml4gw/waveforms/cbc/__init__.py +3 -0
  35. ml4gw/waveforms/{phenom_d.py → cbc/phenom_d.py} +20 -18
  36. ml4gw/waveforms/{phenom_p.py → cbc/phenom_p.py} +106 -95
  37. ml4gw/waveforms/{taylorf2.py → cbc/taylorf2.py} +33 -27
  38. ml4gw/waveforms/conversion.py +187 -0
  39. ml4gw/waveforms/generator.py +9 -5
  40. {ml4gw-0.5.0.dist-info → ml4gw-0.6.0.dist-info}/METADATA +4 -3
  41. ml4gw-0.6.0.dist-info/RECORD +51 -0
  42. {ml4gw-0.5.0.dist-info → ml4gw-0.6.0.dist-info}/WHEEL +1 -1
  43. ml4gw-0.5.0.dist-info/RECORD +0 -47
  44. /ml4gw/waveforms/{phenom_d_data.py → cbc/phenom_d_data.py} +0 -0
ml4gw/spectral.py CHANGED
@@ -12,14 +12,18 @@ https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.csd.html
12
12
  from typing import Optional, Union
13
13
 
14
14
  import torch
15
- from torchtyping import TensorType
15
+ from jaxtyping import Float
16
+ from torch import Tensor
16
17
 
17
- from ml4gw import types
18
+ from ml4gw.types import (
19
+ FrequencySeries1to3d,
20
+ PSDTensor,
21
+ TimeSeries1to3d,
22
+ WaveformTensor,
23
+ )
18
24
 
19
- time = None
20
25
 
21
-
22
- def median(x, axis):
26
+ def median(x: Float[Tensor, "... size"], axis: int) -> Float[Tensor, "..."]:
23
27
  """
24
28
  Implements a median calculation that matches numpy's
25
29
  behavior for an even number of elements and includes
@@ -33,7 +37,7 @@ def median(x, axis):
33
37
 
34
38
 
35
39
  def _validate_shapes(
36
- x: torch.Tensor, nperseg: int, y: Optional[torch.Tensor] = None
40
+ x: Tensor, nperseg: int, y: Optional[Tensor] = None
37
41
  ) -> None:
38
42
  if x.shape[-1] < nperseg:
39
43
  raise ValueError(
@@ -83,14 +87,14 @@ def _validate_shapes(
83
87
 
84
88
 
85
89
  def fast_spectral_density(
86
- x: torch.Tensor,
90
+ x: TimeSeries1to3d,
87
91
  nperseg: int,
88
92
  nstride: int,
89
- window: torch.Tensor,
90
- scale: torch.Tensor,
93
+ window: Float[Tensor, " {nperseg//2+1}"],
94
+ scale: float,
91
95
  average: str = "median",
92
- y: Optional[torch.Tensor] = None,
93
- ) -> torch.Tensor:
96
+ y: Optional[TimeSeries1to3d] = None,
97
+ ) -> FrequencySeries1to3d:
94
98
  """
95
99
  Compute the power spectral density of a multichannel
96
100
  timeseries or a batch of multichannel timeseries, or
@@ -107,9 +111,9 @@ def fast_spectral_density(
107
111
  The timeseries tensor whose power spectral density
108
112
  to compute, or for cross spectral density the
109
113
  timeseries whose fft will be conjugated. Can have
110
- shape either
111
- `(batch_size, num_channels, length * sample_rate)`
112
- or `(num_channels, length * sample_rate)`.
114
+ shape `(batch_size, num_channels, length * sample_rate)`,
115
+ `(num_channels, length * sample_rate)`, or
116
+ `(length * sample_rate)`.
113
117
  nperseg:
114
118
  Number of samples included in each FFT window
115
119
  nstride:
@@ -150,7 +154,7 @@ def fast_spectral_density(
150
154
  channel in `x` across _all_ of `x`'s batch elements.
151
155
  Returns:
152
156
  Tensor of power spectral densities of `x` or its cross spectral
153
- density with the timeseries in `y`.
157
+ density with the timeseries in `y`.
154
158
  """
155
159
 
156
160
  _validate_shapes(x, nperseg, y)
@@ -240,17 +244,16 @@ def fast_spectral_density(
240
244
 
241
245
 
242
246
  def spectral_density(
243
- x: torch.Tensor,
247
+ x: TimeSeries1to3d,
244
248
  nperseg: int,
245
249
  nstride: int,
246
- window: torch.Tensor,
247
- scale: torch.Tensor,
250
+ window: Float[Tensor, " {nperseg//2+1}"],
251
+ scale: float,
248
252
  average: str = "median",
249
- ) -> torch.Tensor:
253
+ ) -> FrequencySeries1to3d:
250
254
  """
251
255
  Compute the power spectral density of a multichannel
252
- timeseries or a batch of multichannel timeseries, or
253
- the cross power spectral density of two such timeseries.
256
+ timeseries or a batch of multichannel timeseries.
254
257
  This implementation is exact for all frequency bins, but
255
258
  slower than the fast implementation.
256
259
 
@@ -259,9 +262,9 @@ def spectral_density(
259
262
  The timeseries tensor whose power spectral density
260
263
  to compute, or for cross spectral density the
261
264
  timeseries whose fft will be conjugated. Can have
262
- shape either
263
- `(batch_size, num_channels, length * sample_rate)`
264
- or `(num_channels, length * sample_rate)`.
265
+ shape `(batch_size, num_channels, length * sample_rate)`,
266
+ `(num_channels, length * sample_rate)`, or
267
+ `(length * sample_rate)`.
265
268
  nperseg:
266
269
  Number of samples included in each FFT window
267
270
  nstride:
@@ -336,11 +339,11 @@ def spectral_density(
336
339
 
337
340
 
338
341
  def truncate_inverse_power_spectrum(
339
- psd: types.PSDTensor,
340
- fduration: Union[TensorType["time"], float],
342
+ psd: PSDTensor,
343
+ fduration: Union[Float[Tensor, " time"], float],
341
344
  sample_rate: float,
342
345
  highpass: Optional[float] = None,
343
- ) -> types.PSDTensor:
346
+ ) -> PSDTensor:
344
347
  """
345
348
  Truncate the length of the time domain response
346
349
  of a whitening filter built using the specified
@@ -399,7 +402,7 @@ def truncate_inverse_power_spectrum(
399
402
  q = torch.fft.irfft(inv_asd, n=N, norm="forward", dim=-1)
400
403
 
401
404
  # taper the edges of the TD filter
402
- if isinstance(fduration, torch.Tensor):
405
+ if isinstance(fduration, Tensor):
403
406
  pad = fduration.size(-1) // 2
404
407
  window = fduration
405
408
  else:
@@ -422,8 +425,8 @@ def truncate_inverse_power_spectrum(
422
425
 
423
426
 
424
427
  def normalize_by_psd(
425
- X: types.WaveformTensor,
426
- psd: types.PSDTensor,
428
+ X: WaveformTensor,
429
+ psd: PSDTensor,
427
430
  sample_rate: float,
428
431
  pad: int,
429
432
  ):
@@ -438,7 +441,7 @@ def normalize_by_psd(
438
441
 
439
442
  # convert back to the time domain and normalize
440
443
  # TODO: what's this normalization factor?
441
- X = torch.fft.irfft(X_tilde, norm="forward", dim=-1)
444
+ X = torch.fft.irfft(X_tilde, n=X.shape[-1], norm="forward", dim=-1)
442
445
  X = X.float() / sample_rate**0.5
443
446
 
444
447
  # slice off corrupted data at edges of kernel
@@ -447,12 +450,12 @@ def normalize_by_psd(
447
450
 
448
451
 
449
452
  def whiten(
450
- X: types.WaveformTensor,
451
- psd: types.PSDTensor,
452
- fduration: Union[TensorType["time"], float],
453
+ X: WaveformTensor,
454
+ psd: PSDTensor,
455
+ fduration: Union[Float[Tensor, " time"], float],
453
456
  sample_rate: float,
454
457
  highpass: Optional[float] = None,
455
- ) -> types.WaveformTensor:
458
+ ) -> WaveformTensor:
456
459
  """
457
460
  Whiten a batch of timeseries using the specified
458
461
  background one-sided power spectral densities (PSDs),
@@ -460,7 +463,8 @@ def whiten(
460
463
  `fduration` and possibly to highpass filter.
461
464
 
462
465
  Args:
463
- X: batch of multichannel timeseries to whiten
466
+ X:
467
+ batch of multichannel timeseries to whiten
464
468
  psd:
465
469
  PSDs use to whiten the data. The frequency
466
470
  response of the whitening filter will be roughly
@@ -496,7 +500,7 @@ def whiten(
496
500
 
497
501
  # figure out how much data we'll need to slice
498
502
  # off after whitening
499
- if isinstance(fduration, torch.Tensor):
503
+ if isinstance(fduration, Tensor):
500
504
  pad = fduration.size(-1) // 2
501
505
  else:
502
506
  pad = int(fduration * sample_rate / 2)
@@ -4,5 +4,6 @@ from .scaler import ChannelWiseScaler
4
4
  from .snr_rescaler import SnrRescaler
5
5
  from .spectral import SpectralDensity
6
6
  from .spectrogram import MultiResolutionSpectrogram
7
+ from .spline_interpolation import SplineInterpolate
7
8
  from .waveforms import WaveformProjector, WaveformSampler
8
9
  from .whitening import FixedWhiten, Whiten
@@ -1,5 +1,8 @@
1
1
  import torch
2
+ from jaxtyping import Float
3
+ from torch import Tensor
2
4
 
5
+ from ml4gw.types import TimeSeries1to3d
3
6
  from ml4gw.utils.slicing import unfold_windows
4
7
 
5
8
 
@@ -40,7 +43,7 @@ class ShiftedPearsonCorrelation(torch.nn.Module):
40
43
  super().__init__()
41
44
  self.max_shift = max_shift
42
45
 
43
- def _shape_checks(self, x: torch.Tensor, y: torch.Tensor):
46
+ def _shape_checks(self, x: TimeSeries1to3d, y: TimeSeries1to3d):
44
47
  if x.ndim > 3:
45
48
  raise ValueError(
46
49
  "Tensor x can only have up to 3 dimensions "
@@ -61,8 +64,9 @@ class ShiftedPearsonCorrelation(torch.nn.Module):
61
64
  )
62
65
  )
63
66
 
64
- # TODO: torchtyping annotate
65
- def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
67
+ def forward(
68
+ self, x: TimeSeries1to3d, y: TimeSeries1to3d
69
+ ) -> Float[Tensor, "windows ..."]:
66
70
  self._shape_checks(x, y)
67
71
  dim = x.size(-1)
68
72
 
@@ -1,8 +1,14 @@
1
1
  import math
2
- from typing import List, Optional, Tuple
2
+ import warnings
3
+ from typing import List, Tuple
3
4
 
4
5
  import torch
5
6
  import torch.nn.functional as F
7
+ from jaxtyping import Float, Int
8
+ from torch import Tensor
9
+
10
+ from ml4gw.transforms.spline_interpolation import SplineInterpolate
11
+ from ml4gw.types import FrequencySeries1to3d, TimeSeries1to3d, TimeSeries3d
6
12
 
7
13
  """
8
14
  All based on https://github.com/gwpy/gwpy/blob/v3.0.8/gwpy/signal/qtransform.py
@@ -34,7 +40,6 @@ class QTile(torch.nn.Module):
34
40
  mismatch:
35
41
  The maximum fractional mismatch between neighboring tiles
36
42
 
37
-
38
43
  """
39
44
 
40
45
  def __init__(
@@ -44,7 +49,7 @@ class QTile(torch.nn.Module):
44
49
  duration: float,
45
50
  sample_rate: float,
46
51
  mismatch: float,
47
- ):
52
+ ) -> None:
48
53
  super().__init__()
49
54
  self.mismatch = mismatch
50
55
  self.q = q
@@ -63,18 +68,18 @@ class QTile(torch.nn.Module):
63
68
  self.register_buffer("indices", self.get_data_indices())
64
69
  self.register_buffer("window", self.get_window())
65
70
 
66
- def ntiles(self):
71
+ def ntiles(self) -> int:
67
72
  """
68
73
  Number of tiles in this frequency row
69
74
  """
70
75
  tcum_mismatch = self.duration * 2 * torch.pi * self.frequency / self.q
71
76
  return int(2 ** torch.ceil(torch.log2(tcum_mismatch / self.deltam)))
72
77
 
73
- def _get_indices(self):
78
+ def _get_indices(self) -> Int[Tensor, " windowsize"]:
74
79
  half = int((self.windowsize - 1) / 2)
75
80
  return torch.arange(-half, half + 1)
76
81
 
77
- def get_window(self):
82
+ def get_window(self) -> Float[Tensor, " windowsize"]:
78
83
  """
79
84
  Generate the bi-square window for this row
80
85
  """
@@ -87,7 +92,7 @@ class QTile(torch.nn.Module):
87
92
  )
88
93
  return torch.Tensor((1 - xfrequencies**2) ** 2 * norm)
89
94
 
90
- def get_data_indices(self):
95
+ def get_data_indices(self) -> Int[Tensor, " windowsize"]:
91
96
  """
92
97
  Get the index array of relevant frequencies for this row
93
98
  """
@@ -95,7 +100,11 @@ class QTile(torch.nn.Module):
95
100
  self._get_indices() + 1 + self.frequency * self.duration,
96
101
  ).type(torch.long)
97
102
 
98
- def forward(self, fseries: torch.Tensor, norm: str = "median"):
103
+ def forward(
104
+ self,
105
+ fseries: FrequencySeries1to3d,
106
+ norm: str = "median",
107
+ ) -> TimeSeries1to3d:
99
108
  """
100
109
  Compute the transform for this row
101
110
 
@@ -138,7 +147,7 @@ class QTile(torch.nn.Module):
138
147
  energy /= means
139
148
  else:
140
149
  raise ValueError("Invalid normalisation %r" % norm)
141
- return energy.type(torch.float32)
150
+ energy = energy.type(torch.float32)
142
151
  return energy
143
152
 
144
153
 
@@ -166,6 +175,19 @@ class SingleQTransform(torch.nn.Module):
166
175
  be chosen based on q, sample_rate, and duration
167
176
  mismatch:
168
177
  The maximum fractional mismatch between neighboring tiles
178
+ interpolation_method:
179
+ The method by which to interpolate each `QTile` to the specified
180
+ number of time and frequency bins. The acceptable values are
181
+ "bilinear", "bicubic", and "spline". The "bilinear" and "bicubic"
182
+ options will use PyTorch's built-in interpolation modes, while
183
+ "spline" will use the custom Torch-based implementation in
184
+ `ml4gw`, as PyTorch does not have spline-based intertpolation.
185
+ The "spline" mode is most similar to the results of GWpy's
186
+ Q-transform, which uses `scipy` to do spline interpolation.
187
+ However, it is also the slowest and most memory intensive due to
188
+ the matrix equation solving steps. Therefore, the default method
189
+ is "bicubic" as it produces the most similar results while
190
+ optimizing for computing performance.
169
191
  """
170
192
 
171
193
  def __init__(
@@ -176,7 +198,8 @@ class SingleQTransform(torch.nn.Module):
176
198
  q: float = 12,
177
199
  frange: List[float] = [0, torch.inf],
178
200
  mismatch: float = 0.2,
179
- ):
201
+ interpolation_method: str = "bicubic",
202
+ ) -> None:
180
203
  super().__init__()
181
204
  self.q = q
182
205
  self.spectrogram_shape = spectrogram_shape
@@ -184,21 +207,88 @@ class SingleQTransform(torch.nn.Module):
184
207
  self.duration = duration
185
208
  self.mismatch = mismatch
186
209
 
210
+ # If q is too large, the minimum of the frange computed
211
+ # below will be larger than the maximum
212
+ max_q = torch.pi * duration * sample_rate / 50 - 11 ** (0.5)
213
+ if q >= max_q:
214
+ raise ValueError(
215
+ "The given q value is too large for the given duration and "
216
+ f"sample rate. The maximum allowable value is {max_q}"
217
+ )
218
+
219
+ if interpolation_method not in ["bilinear", "bicubic", "spline"]:
220
+ raise ValueError(
221
+ "Interpolation method must be either 'bilinear', 'bicubic', "
222
+ f"or 'spline'; got {interpolation_method}"
223
+ )
224
+ self.interpolation_method = interpolation_method
225
+
187
226
  qprime = self.q / 11 ** (1 / 2.0)
188
227
  if self.frange[0] <= 0: # set non-zero lower frequency
189
228
  self.frange[0] = 50 * self.q / (2 * torch.pi * duration)
190
229
  if math.isinf(self.frange[1]): # set non-infinite upper frequency
191
230
  self.frange[1] = sample_rate / 2 / (1 + 1 / qprime)
231
+
192
232
  self.freqs = self.get_freqs()
193
233
  self.qtile_transforms = torch.nn.ModuleList(
194
234
  [
195
- QTile(self.q, freq, self.duration, sample_rate, self.mismatch)
235
+ QTile(
236
+ q=self.q,
237
+ frequency=freq,
238
+ duration=self.duration,
239
+ sample_rate=sample_rate,
240
+ mismatch=self.mismatch,
241
+ )
196
242
  for freq in self.freqs
197
243
  ]
198
244
  )
199
245
  self.qtiles = None
200
246
 
201
- def get_freqs(self):
247
+ if self.interpolation_method == "spline":
248
+ self._set_up_spline_interp()
249
+
250
+ def _set_up_spline_interp(self):
251
+ ntiles = [qtile.ntiles() for qtile in self.qtile_transforms]
252
+ # For efficiency, we'll stack all qtiles of the same length before
253
+ # interpolating, so we need to figure out which those are
254
+ unique_ntiles = sorted(list(set(ntiles)))
255
+ idx = torch.arange(len(ntiles))
256
+ self.stack_idx = [idx[Tensor(ntiles) == n] for n in unique_ntiles]
257
+
258
+ t_out = torch.arange(
259
+ 0, self.duration, self.duration / self.spectrogram_shape[1]
260
+ )
261
+ self.qtile_interpolators = torch.nn.ModuleList(
262
+ [
263
+ SplineInterpolate(
264
+ kx=3,
265
+ x_in=torch.arange(0, self.duration, self.duration / tiles),
266
+ y_in=torch.arange(len(idx)),
267
+ x_out=t_out,
268
+ y_out=torch.arange(len(idx)),
269
+ )
270
+ for tiles, idx in zip(unique_ntiles, self.stack_idx)
271
+ ]
272
+ )
273
+
274
+ t_in = t_out
275
+ f_in = self.freqs
276
+ f_out = torch.logspace(
277
+ math.log10(self.frange[0]),
278
+ math.log10(self.frange[-1]),
279
+ self.spectrogram_shape[0],
280
+ )
281
+
282
+ self.interpolator = SplineInterpolate(
283
+ kx=3,
284
+ ky=3,
285
+ x_in=t_in,
286
+ y_in=f_in,
287
+ x_out=t_out,
288
+ y_out=f_out,
289
+ )
290
+
291
+ def get_freqs(self) -> Float[Tensor, " nfreq"]:
202
292
  """
203
293
  Calculate the frequencies that will be used in this transform.
204
294
  For each frequency, a `QTile` is created.
@@ -214,7 +304,8 @@ class SingleQTransform(torch.nn.Module):
214
304
 
215
305
  freq_base = math.exp(2 / ((2 + self.q**2) ** (1 / 2.0)) * fstep)
216
306
  freqs = torch.Tensor([freq_base ** (i + 0.5) for i in range(nfreq)])
217
- freqs = (minf * freqs // fstepmin) * fstepmin
307
+ # Cast freqs to float64 to avoid off-by-ones from rounding
308
+ freqs = (minf * freqs.double() // fstepmin) * fstepmin
218
309
  return torch.unique(freqs)
219
310
 
220
311
  def get_max_energy(
@@ -262,7 +353,11 @@ class SingleQTransform(torch.nn.Module):
262
353
  if dimension == "batch":
263
354
  return torch.max(max_across_ft, dim=-1).values
264
355
 
265
- def compute_qtiles(self, X: torch.Tensor, norm: str = "median"):
356
+ def compute_qtiles(
357
+ self,
358
+ X: TimeSeries1to3d,
359
+ norm: str = "median",
360
+ ) -> None:
266
361
  """
267
362
  Take the FFT of the input timeseries and calculate the transform
268
363
  for each `QTile`
@@ -272,36 +367,47 @@ class SingleQTransform(torch.nn.Module):
272
367
  X[..., 1:] *= 2
273
368
  self.qtiles = [qtile(X, norm) for qtile in self.qtile_transforms]
274
369
 
275
- def interpolate(self, num_f_bins: int, num_t_bins: int):
276
- """
277
- Interpolate each `QTile` to the specified number of time and
278
- frequency bins. Note that PyTorch does not have the same
279
- interpolation methods that GWpy uses, and so the interpolated
280
- spectrograms will be different even though the uninterpolated
281
- values match. The `bicubic` interpolation method is used as
282
- it seems to match GWpy most closely.
283
- """
370
+ def interpolate(self) -> TimeSeries3d:
284
371
  if self.qtiles is None:
285
372
  raise RuntimeError(
286
373
  "Q-tiles must first be computed with .compute_qtiles()"
287
374
  )
375
+ if self.interpolation_method == "spline":
376
+ qtiles = [
377
+ torch.stack([self.qtiles[i] for i in idx], dim=-2)
378
+ for idx in self.stack_idx
379
+ ]
380
+ time_interped = torch.cat(
381
+ [
382
+ interpolator(qtile)
383
+ for qtile, interpolator in zip(
384
+ qtiles, self.qtile_interpolators
385
+ )
386
+ ],
387
+ dim=-2,
388
+ )
389
+ return self.interpolator(time_interped)
390
+ num_f_bins, num_t_bins = self.spectrogram_shape
288
391
  resampled = [
289
392
  F.interpolate(
290
- qtile[None], (qtile.shape[-2], num_t_bins), mode="bicubic"
393
+ qtile[None],
394
+ (qtile.shape[-2], num_t_bins),
395
+ mode=self.interpolation_method,
291
396
  )
292
397
  for qtile in self.qtiles
293
398
  ]
294
399
  resampled = torch.stack(resampled, dim=-2)
295
400
  resampled = F.interpolate(
296
- resampled[0], (num_f_bins, num_t_bins), mode="bicubic"
401
+ resampled[0],
402
+ (num_f_bins, num_t_bins),
403
+ mode=self.interpolation_method,
297
404
  )
298
405
  return torch.squeeze(resampled)
299
406
 
300
407
  def forward(
301
408
  self,
302
- X: torch.Tensor,
409
+ X: TimeSeries1to3d,
303
410
  norm: str = "median",
304
- spectrogram_shape: Optional[Tuple[int, int]] = None,
305
411
  ):
306
412
  """
307
413
  Compute the Q-tiles and interpolate
@@ -315,24 +421,15 @@ class SingleQTransform(torch.nn.Module):
315
421
  three-dimensional, axes will be added during Q-tile
316
422
  computation.
317
423
  norm:
318
- The method of interpolation used by each QTile
319
- spectrogram_shape:
320
- The shape of the interpolated spectrogram, specified as
321
- `(num_f_bins, num_t_bins)`. Because the
322
- frequency spacing of the Q-tiles is in log-space, the frequency
323
- interpolation is log-spaced as well. If not given, the shape
324
- used to initialize the transform will be used.
424
+ The method of normalization used by each QTile
325
425
 
326
426
  Returns:
327
427
  The interpolated Q-transform for the batch of data. Output will
328
428
  have one more dimension than the input
329
429
  """
330
430
 
331
- if spectrogram_shape is None:
332
- spectrogram_shape = self.spectrogram_shape
333
- num_f_bins, num_t_bins = spectrogram_shape
334
431
  self.compute_qtiles(X, norm)
335
- return self.interpolate(num_f_bins, num_t_bins)
432
+ return self.interpolate()
336
433
 
337
434
 
338
435
  class QScan(torch.nn.Module):
@@ -370,14 +467,22 @@ class QScan(torch.nn.Module):
370
467
  spectrogram_shape: Tuple[int, int],
371
468
  qrange: List[float] = [4, 64],
372
469
  frange: List[float] = [0, torch.inf],
470
+ interpolation_method="bicubic",
373
471
  mismatch: float = 0.2,
374
- ):
472
+ ) -> None:
375
473
  super().__init__()
376
474
  self.qrange = qrange
377
475
  self.mismatch = mismatch
378
- self.qs = self.get_qs()
379
476
  self.frange = frange
380
477
  self.spectrogram_shape = spectrogram_shape
478
+ max_q = torch.pi * duration * sample_rate / 50 - 11 ** (0.5)
479
+ self.qs = self.get_qs()
480
+ if self.qs[-1] >= max_q:
481
+ warnings.warn(
482
+ "Some Q values exceed the maximum allowable Q value of "
483
+ f"{max_q}. The list of Q values to be tested in this "
484
+ "scan will be truncated to avoid those values."
485
+ )
381
486
 
382
487
  # Deliberately doing something different from GWpy here.
383
488
  # Their final frange is the intersection of the frange
@@ -391,13 +496,15 @@ class QScan(torch.nn.Module):
391
496
  spectrogram_shape=spectrogram_shape,
392
497
  q=q,
393
498
  frange=self.frange.copy(),
499
+ interpolation_method=interpolation_method,
394
500
  mismatch=self.mismatch,
395
501
  )
396
502
  for q in self.qs
503
+ if q < max_q
397
504
  ]
398
505
  )
399
506
 
400
- def get_qs(self):
507
+ def get_qs(self) -> List[float]:
401
508
  """
402
509
  Determine the values of Q to try for the set of Q-transforms
403
510
  """
@@ -409,14 +516,14 @@ class QScan(torch.nn.Module):
409
516
  self.qrange[0] * math.exp(2 ** (1 / 2.0) * dq * (i + 0.5))
410
517
  for i in range(nplanes)
411
518
  ]
519
+
412
520
  return qs
413
521
 
414
522
  def forward(
415
523
  self,
416
- X: torch.Tensor,
524
+ X: TimeSeries1to3d,
417
525
  fsearch_range: List[float] = None,
418
526
  norm: str = "median",
419
- spectrogram_shape: Optional[Tuple[int, int]] = None,
420
527
  ):
421
528
  """
422
529
  Compute the set of QTiles for each Q transform and determine which
@@ -436,12 +543,6 @@ class QScan(torch.nn.Module):
436
543
  for the maximum energy
437
544
  norm:
438
545
  The method of interpolation used by each QTile
439
- spectrogram_shape:
440
- The shape of the interpolated spectrogram, specified as
441
- `(num_f_bins, num_t_bins)`. Because the
442
- frequency spacing of the Q-tiles is in log-space, the frequency
443
- interpolation is log-spaced as well. If not given, the shape
444
- used to initialize the transform will be used.
445
546
 
446
547
  Returns:
447
548
  An interpolated Q-transform for the batch of data. Output will
@@ -457,7 +558,4 @@ class QScan(torch.nn.Module):
457
558
  ]
458
559
  )
459
560
  )
460
- if spectrogram_shape is None:
461
- spectrogram_shape = self.spectrogram_shape
462
- num_f_bins, num_t_bins = spectrogram_shape
463
- return self.q_transforms[idx].interpolate(num_f_bins, num_t_bins)
561
+ return self.q_transforms[idx].interpolate()
@@ -1,6 +1,8 @@
1
1
  from typing import Optional
2
2
 
3
3
  import torch
4
+ from jaxtyping import Float
5
+ from torch import Tensor
4
6
 
5
7
  from ml4gw.transforms.transform import FittableTransform
6
8
 
@@ -34,7 +36,9 @@ class ChannelWiseScaler(FittableTransform):
34
36
  self.register_buffer("mean", mean)
35
37
  self.register_buffer("std", std)
36
38
 
37
- def fit(self, X: torch.Tensor) -> None:
39
+ def fit(
40
+ self, X: Float[Tensor, "... time"], std_reg: Optional[float] = 0.0
41
+ ) -> None:
38
42
  """Fit the scaling parameters to a timeseries
39
43
 
40
44
  Computes the channel-wise mean and standard deviation
@@ -57,10 +61,12 @@ class ChannelWiseScaler(FittableTransform):
57
61
  "Can't fit channel wise mean and standard deviation "
58
62
  "from tensor of shape {}".format(X.shape)
59
63
  )
60
-
64
+ std += std_reg * torch.ones_like(std)
61
65
  super().build(mean=mean, std=std)
62
66
 
63
- def forward(self, X: torch.Tensor, reverse: bool = False) -> torch.Tensor:
67
+ def forward(
68
+ self, X: Float[Tensor, "... time"], reverse: bool = False
69
+ ) -> Float[Tensor, "... time"]:
64
70
  if not reverse:
65
71
  return (X - self.mean) / self.std
66
72
  else:
@@ -2,8 +2,9 @@ from typing import Optional
2
2
 
3
3
  import torch
4
4
 
5
- from ml4gw import gw
5
+ from ml4gw.gw import compute_network_snr
6
6
  from ml4gw.transforms.transform import FittableSpectralTransform
7
+ from ml4gw.types import BatchTensor, TimeSeries2d, WaveformTensor
7
8
 
8
9
 
9
10
  class SnrRescaler(FittableSpectralTransform):
@@ -34,7 +35,7 @@ class SnrRescaler(FittableSpectralTransform):
34
35
 
35
36
  def fit(
36
37
  self,
37
- *background: torch.Tensor,
38
+ *background: TimeSeries2d,
38
39
  fftlength: Optional[float] = None,
39
40
  overlap: Optional[float] = None,
40
41
  ):
@@ -58,10 +59,10 @@ class SnrRescaler(FittableSpectralTransform):
58
59
 
59
60
  def forward(
60
61
  self,
61
- responses: gw.WaveformTensor,
62
- target_snrs: Optional[gw.ScalarTensor] = None,
62
+ responses: WaveformTensor,
63
+ target_snrs: Optional[BatchTensor] = None,
63
64
  ):
64
- snrs = gw.compute_network_snr(
65
+ snrs = compute_network_snr(
65
66
  responses, self.background, self.sample_rate, self.mask
66
67
  )
67
68
  if target_snrs is None: