ml4gw 0.4.2__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ml4gw might be problematic. Click here for more details.

ml4gw/spectral.py CHANGED
@@ -12,14 +12,18 @@ https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.csd.html
12
12
  from typing import Optional, Union
13
13
 
14
14
  import torch
15
- from torchtyping import TensorType
15
+ from jaxtyping import Float
16
+ from torch import Tensor
16
17
 
17
- from ml4gw import types
18
+ from ml4gw.types import (
19
+ FrequencySeries1to3d,
20
+ PSDTensor,
21
+ TimeSeries1to3d,
22
+ WaveformTensor,
23
+ )
18
24
 
19
- time = None
20
25
 
21
-
22
- def median(x, axis):
26
+ def median(x: Float[Tensor, "... size"], axis: int) -> Float[Tensor, "..."]:
23
27
  """
24
28
  Implements a median calculation that matches numpy's
25
29
  behavior for an even number of elements and includes
@@ -33,7 +37,7 @@ def median(x, axis):
33
37
 
34
38
 
35
39
  def _validate_shapes(
36
- x: torch.Tensor, nperseg: int, y: Optional[torch.Tensor] = None
40
+ x: Tensor, nperseg: int, y: Optional[Tensor] = None
37
41
  ) -> None:
38
42
  if x.shape[-1] < nperseg:
39
43
  raise ValueError(
@@ -83,14 +87,14 @@ def _validate_shapes(
83
87
 
84
88
 
85
89
  def fast_spectral_density(
86
- x: torch.Tensor,
90
+ x: TimeSeries1to3d,
87
91
  nperseg: int,
88
92
  nstride: int,
89
- window: torch.Tensor,
90
- scale: torch.Tensor,
93
+ window: Float[Tensor, " {nperseg//2+1}"],
94
+ scale: float,
91
95
  average: str = "median",
92
- y: Optional[torch.Tensor] = None,
93
- ) -> torch.Tensor:
96
+ y: Optional[TimeSeries1to3d] = None,
97
+ ) -> FrequencySeries1to3d:
94
98
  """
95
99
  Compute the power spectral density of a multichannel
96
100
  timeseries or a batch of multichannel timeseries, or
@@ -107,9 +111,9 @@ def fast_spectral_density(
107
111
  The timeseries tensor whose power spectral density
108
112
  to compute, or for cross spectral density the
109
113
  timeseries whose fft will be conjugated. Can have
110
- shape either
111
- `(batch_size, num_channels, length * sample_rate)`
112
- or `(num_channels, length * sample_rate)`.
114
+ shape `(batch_size, num_channels, length * sample_rate)`,
115
+ `(num_channels, length * sample_rate)`, or
116
+ `(length * sample_rate)`.
113
117
  nperseg:
114
118
  Number of samples included in each FFT window
115
119
  nstride:
@@ -150,7 +154,7 @@ def fast_spectral_density(
150
154
  channel in `x` across _all_ of `x`'s batch elements.
151
155
  Returns:
152
156
  Tensor of power spectral densities of `x` or its cross spectral
153
- density with the timeseries in `y`.
157
+ density with the timeseries in `y`.
154
158
  """
155
159
 
156
160
  _validate_shapes(x, nperseg, y)
@@ -240,17 +244,16 @@ def fast_spectral_density(
240
244
 
241
245
 
242
246
  def spectral_density(
243
- x: torch.Tensor,
247
+ x: TimeSeries1to3d,
244
248
  nperseg: int,
245
249
  nstride: int,
246
- window: torch.Tensor,
247
- scale: torch.Tensor,
250
+ window: Float[Tensor, " {nperseg//2+1}"],
251
+ scale: float,
248
252
  average: str = "median",
249
- ) -> torch.Tensor:
253
+ ) -> FrequencySeries1to3d:
250
254
  """
251
255
  Compute the power spectral density of a multichannel
252
- timeseries or a batch of multichannel timeseries, or
253
- the cross power spectral density of two such timeseries.
256
+ timeseries or a batch of multichannel timeseries.
254
257
  This implementation is exact for all frequency bins, but
255
258
  slower than the fast implementation.
256
259
 
@@ -259,9 +262,9 @@ def spectral_density(
259
262
  The timeseries tensor whose power spectral density
260
263
  to compute, or for cross spectral density the
261
264
  timeseries whose fft will be conjugated. Can have
262
- shape either
263
- `(batch_size, num_channels, length * sample_rate)`
264
- or `(num_channels, length * sample_rate)`.
265
+ shape `(batch_size, num_channels, length * sample_rate)`,
266
+ `(num_channels, length * sample_rate)`, or
267
+ `(length * sample_rate)`.
265
268
  nperseg:
266
269
  Number of samples included in each FFT window
267
270
  nstride:
@@ -336,11 +339,11 @@ def spectral_density(
336
339
 
337
340
 
338
341
  def truncate_inverse_power_spectrum(
339
- psd: types.PSDTensor,
340
- fduration: Union[TensorType["time"], float],
342
+ psd: PSDTensor,
343
+ fduration: Union[Float[Tensor, " time"], float],
341
344
  sample_rate: float,
342
345
  highpass: Optional[float] = None,
343
- ) -> types.PSDTensor:
346
+ ) -> PSDTensor:
344
347
  """
345
348
  Truncate the length of the time domain response
346
349
  of a whitening filter built using the specified
@@ -399,7 +402,7 @@ def truncate_inverse_power_spectrum(
399
402
  q = torch.fft.irfft(inv_asd, n=N, norm="forward", dim=-1)
400
403
 
401
404
  # taper the edges of the TD filter
402
- if isinstance(fduration, torch.Tensor):
405
+ if isinstance(fduration, Tensor):
403
406
  pad = fduration.size(-1) // 2
404
407
  window = fduration
405
408
  else:
@@ -422,8 +425,8 @@ def truncate_inverse_power_spectrum(
422
425
 
423
426
 
424
427
  def normalize_by_psd(
425
- X: types.WaveformTensor,
426
- psd: types.PSDTensor,
428
+ X: WaveformTensor,
429
+ psd: PSDTensor,
427
430
  sample_rate: float,
428
431
  pad: int,
429
432
  ):
@@ -447,12 +450,12 @@ def normalize_by_psd(
447
450
 
448
451
 
449
452
  def whiten(
450
- X: types.WaveformTensor,
451
- psd: types.PSDTensor,
452
- fduration: Union[TensorType["time"], float],
453
+ X: WaveformTensor,
454
+ psd: PSDTensor,
455
+ fduration: Union[Float[Tensor, " time"], float],
453
456
  sample_rate: float,
454
457
  highpass: Optional[float] = None,
455
- ) -> types.WaveformTensor:
458
+ ) -> WaveformTensor:
456
459
  """
457
460
  Whiten a batch of timeseries using the specified
458
461
  background one-sided power spectral densities (PSDs),
@@ -460,7 +463,8 @@ def whiten(
460
463
  `fduration` and possibly to highpass filter.
461
464
 
462
465
  Args:
463
- X: batch of multichannel timeseries to whiten
466
+ X:
467
+ batch of multichannel timeseries to whiten
464
468
  psd:
465
469
  PSDs use to whiten the data. The frequency
466
470
  response of the whitening filter will be roughly
@@ -496,7 +500,7 @@ def whiten(
496
500
 
497
501
  # figure out how much data we'll need to slice
498
502
  # off after whitening
499
- if isinstance(fduration, torch.Tensor):
503
+ if isinstance(fduration, Tensor):
500
504
  pad = fduration.size(-1) // 2
501
505
  else:
502
506
  pad = int(fduration * sample_rate / 2)
@@ -1,5 +1,8 @@
1
1
  import torch
2
+ from jaxtyping import Float
3
+ from torch import Tensor
2
4
 
5
+ from ml4gw.types import TimeSeries1to3d
3
6
  from ml4gw.utils.slicing import unfold_windows
4
7
 
5
8
 
@@ -40,7 +43,7 @@ class ShiftedPearsonCorrelation(torch.nn.Module):
40
43
  super().__init__()
41
44
  self.max_shift = max_shift
42
45
 
43
- def _shape_checks(self, x: torch.Tensor, y: torch.Tensor):
46
+ def _shape_checks(self, x: TimeSeries1to3d, y: TimeSeries1to3d):
44
47
  if x.ndim > 3:
45
48
  raise ValueError(
46
49
  "Tensor x can only have up to 3 dimensions "
@@ -61,8 +64,9 @@ class ShiftedPearsonCorrelation(torch.nn.Module):
61
64
  )
62
65
  )
63
66
 
64
- # TODO: torchtyping annotate
65
- def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
67
+ def forward(
68
+ self, x: TimeSeries1to3d, y: TimeSeries1to3d
69
+ ) -> Float[Tensor, "windows ..."]:
66
70
  self._shape_checks(x, y)
67
71
  dim = x.size(-1)
68
72
 
@@ -3,6 +3,10 @@ from typing import List, Optional, Tuple
3
3
 
4
4
  import torch
5
5
  import torch.nn.functional as F
6
+ from jaxtyping import Float, Int
7
+ from torch import Tensor
8
+
9
+ from ml4gw.types import FrequencySeries1to3d, TimeSeries1to3d, TimeSeries3d
6
10
 
7
11
  """
8
12
  All based on https://github.com/gwpy/gwpy/blob/v3.0.8/gwpy/signal/qtransform.py
@@ -44,7 +48,7 @@ class QTile(torch.nn.Module):
44
48
  duration: float,
45
49
  sample_rate: float,
46
50
  mismatch: float,
47
- ):
51
+ ) -> None:
48
52
  super().__init__()
49
53
  self.mismatch = mismatch
50
54
  self.q = q
@@ -63,18 +67,18 @@ class QTile(torch.nn.Module):
63
67
  self.register_buffer("indices", self.get_data_indices())
64
68
  self.register_buffer("window", self.get_window())
65
69
 
66
- def ntiles(self):
70
+ def ntiles(self) -> int:
67
71
  """
68
72
  Number of tiles in this frequency row
69
73
  """
70
74
  tcum_mismatch = self.duration * 2 * torch.pi * self.frequency / self.q
71
75
  return int(2 ** torch.ceil(torch.log2(tcum_mismatch / self.deltam)))
72
76
 
73
- def _get_indices(self):
77
+ def _get_indices(self) -> Int[Tensor, " windowsize"]:
74
78
  half = int((self.windowsize - 1) / 2)
75
79
  return torch.arange(-half, half + 1)
76
80
 
77
- def get_window(self):
81
+ def get_window(self) -> Float[Tensor, " windowsize"]:
78
82
  """
79
83
  Generate the bi-square window for this row
80
84
  """
@@ -87,7 +91,7 @@ class QTile(torch.nn.Module):
87
91
  )
88
92
  return torch.Tensor((1 - xfrequencies**2) ** 2 * norm)
89
93
 
90
- def get_data_indices(self):
94
+ def get_data_indices(self) -> Int[Tensor, " windowsize"]:
91
95
  """
92
96
  Get the index array of relevant frequencies for this row
93
97
  """
@@ -95,7 +99,9 @@ class QTile(torch.nn.Module):
95
99
  self._get_indices() + 1 + self.frequency * self.duration,
96
100
  ).type(torch.long)
97
101
 
98
- def forward(self, fseries: torch.Tensor, norm: str = "median"):
102
+ def forward(
103
+ self, fseries: FrequencySeries1to3d, norm: str = "median"
104
+ ) -> TimeSeries1to3d:
99
105
  """
100
106
  Compute the transform for this row
101
107
 
@@ -176,7 +182,7 @@ class SingleQTransform(torch.nn.Module):
176
182
  q: float = 12,
177
183
  frange: List[float] = [0, torch.inf],
178
184
  mismatch: float = 0.2,
179
- ):
185
+ ) -> None:
180
186
  super().__init__()
181
187
  self.q = q
182
188
  self.spectrogram_shape = spectrogram_shape
@@ -198,7 +204,7 @@ class SingleQTransform(torch.nn.Module):
198
204
  )
199
205
  self.qtiles = None
200
206
 
201
- def get_freqs(self):
207
+ def get_freqs(self) -> Float[Tensor, " nfreq"]:
202
208
  """
203
209
  Calculate the frequencies that will be used in this transform.
204
210
  For each frequency, a `QTile` is created.
@@ -262,7 +268,7 @@ class SingleQTransform(torch.nn.Module):
262
268
  if dimension == "batch":
263
269
  return torch.max(max_across_ft, dim=-1).values
264
270
 
265
- def compute_qtiles(self, X: torch.Tensor, norm: str = "median"):
271
+ def compute_qtiles(self, X: TimeSeries1to3d, norm: str = "median") -> None:
266
272
  """
267
273
  Take the FFT of the input timeseries and calculate the transform
268
274
  for each `QTile`
@@ -272,7 +278,7 @@ class SingleQTransform(torch.nn.Module):
272
278
  X[..., 1:] *= 2
273
279
  self.qtiles = [qtile(X, norm) for qtile in self.qtile_transforms]
274
280
 
275
- def interpolate(self, num_f_bins: int, num_t_bins: int):
281
+ def interpolate(self, num_f_bins: int, num_t_bins: int) -> TimeSeries3d:
276
282
  """
277
283
  Interpolate each `QTile` to the specified number of time and
278
284
  frequency bins. Note that PyTorch does not have the same
@@ -299,7 +305,7 @@ class SingleQTransform(torch.nn.Module):
299
305
 
300
306
  def forward(
301
307
  self,
302
- X: torch.Tensor,
308
+ X: TimeSeries1to3d,
303
309
  norm: str = "median",
304
310
  spectrogram_shape: Optional[Tuple[int, int]] = None,
305
311
  ):
@@ -371,7 +377,7 @@ class QScan(torch.nn.Module):
371
377
  qrange: List[float] = [4, 64],
372
378
  frange: List[float] = [0, torch.inf],
373
379
  mismatch: float = 0.2,
374
- ):
380
+ ) -> None:
375
381
  super().__init__()
376
382
  self.qrange = qrange
377
383
  self.mismatch = mismatch
@@ -397,7 +403,7 @@ class QScan(torch.nn.Module):
397
403
  ]
398
404
  )
399
405
 
400
- def get_qs(self):
406
+ def get_qs(self) -> List[float]:
401
407
  """
402
408
  Determine the values of Q to try for the set of Q-transforms
403
409
  """
@@ -413,7 +419,7 @@ class QScan(torch.nn.Module):
413
419
 
414
420
  def forward(
415
421
  self,
416
- X: torch.Tensor,
422
+ X: TimeSeries1to3d,
417
423
  fsearch_range: List[float] = None,
418
424
  norm: str = "median",
419
425
  spectrogram_shape: Optional[Tuple[int, int]] = None,
@@ -1,6 +1,8 @@
1
1
  from typing import Optional
2
2
 
3
3
  import torch
4
+ from jaxtyping import Float
5
+ from torch import Tensor
4
6
 
5
7
  from ml4gw.transforms.transform import FittableTransform
6
8
 
@@ -34,7 +36,7 @@ class ChannelWiseScaler(FittableTransform):
34
36
  self.register_buffer("mean", mean)
35
37
  self.register_buffer("std", std)
36
38
 
37
- def fit(self, X: torch.Tensor) -> None:
39
+ def fit(self, X: Float[Tensor, "... time"]) -> None:
38
40
  """Fit the scaling parameters to a timeseries
39
41
 
40
42
  Computes the channel-wise mean and standard deviation
@@ -60,7 +62,9 @@ class ChannelWiseScaler(FittableTransform):
60
62
 
61
63
  super().build(mean=mean, std=std)
62
64
 
63
- def forward(self, X: torch.Tensor, reverse: bool = False) -> torch.Tensor:
65
+ def forward(
66
+ self, X: Float[Tensor, "... time"], reverse: bool = False
67
+ ) -> Float[Tensor, "... time"]:
64
68
  if not reverse:
65
69
  return (X - self.mean) / self.std
66
70
  else:
@@ -2,8 +2,9 @@ from typing import Optional
2
2
 
3
3
  import torch
4
4
 
5
- from ml4gw import gw
5
+ from ml4gw.gw import compute_network_snr
6
6
  from ml4gw.transforms.transform import FittableSpectralTransform
7
+ from ml4gw.types import BatchTensor, TimeSeries2d, WaveformTensor
7
8
 
8
9
 
9
10
  class SnrRescaler(FittableSpectralTransform):
@@ -34,7 +35,7 @@ class SnrRescaler(FittableSpectralTransform):
34
35
 
35
36
  def fit(
36
37
  self,
37
- *background: torch.Tensor,
38
+ *background: TimeSeries2d,
38
39
  fftlength: Optional[float] = None,
39
40
  overlap: Optional[float] = None,
40
41
  ):
@@ -58,10 +59,10 @@ class SnrRescaler(FittableSpectralTransform):
58
59
 
59
60
  def forward(
60
61
  self,
61
- responses: gw.WaveformTensor,
62
- target_snrs: Optional[gw.ScalarTensor] = None,
62
+ responses: WaveformTensor,
63
+ target_snrs: Optional[BatchTensor] = None,
63
64
  ):
64
- snrs = gw.compute_network_snr(
65
+ snrs = compute_network_snr(
65
66
  responses, self.background, self.sample_rate, self.mask
66
67
  )
67
68
  if target_snrs is None:
@@ -1,8 +1,11 @@
1
1
  from typing import Optional
2
2
 
3
3
  import torch
4
+ from jaxtyping import Float
5
+ from torch import Tensor
4
6
 
5
7
  from ml4gw.spectral import fast_spectral_density, spectral_density
8
+ from ml4gw.types import FrequencySeries1to3d, TimeSeries1to3d
6
9
 
7
10
 
8
11
  class SpectralDensity(torch.nn.Module):
@@ -34,6 +37,10 @@ class SpectralDensity(torch.nn.Module):
34
37
  average:
35
38
  Aggregation method to use for combining windowed FFTs.
36
39
  Allowed values are `"mean"` and `"median"`.
40
+ window:
41
+ Window array to multiply by each FFT window before
42
+ FFT computation. Should have length `nperseg`.
43
+ Defaults to a hanning window.
37
44
  fast:
38
45
  Whether to use a faster spectral density computation that
39
46
  support cross spectral density, or a slower one which does
@@ -47,6 +54,9 @@ class SpectralDensity(torch.nn.Module):
47
54
  fftlength: float,
48
55
  overlap: Optional[float] = None,
49
56
  average: str = "mean",
57
+ window: Optional[
58
+ Float[Tensor, " {int(fftlength*sample_rate)}"]
59
+ ] = None,
50
60
  fast: bool = False,
51
61
  ) -> None:
52
62
  if overlap is None:
@@ -63,11 +73,18 @@ class SpectralDensity(torch.nn.Module):
63
73
  self.nperseg = int(fftlength * sample_rate)
64
74
  self.nstride = self.nperseg - int(overlap * sample_rate)
65
75
 
66
- # TODOs: Do we allow for arbitrary windows?
67
- # Making this buffer persistent in case we want
68
- # to implement this down the line, so that custom
69
- # windows can be loaded in.
70
- self.register_buffer("window", torch.hann_window(self.nperseg))
76
+ # if no window is provided, default to a hanning window;
77
+ # validate that window is correct size
78
+ if window is None:
79
+ window = torch.hann_window(self.nperseg)
80
+
81
+ if window.size(0) != self.nperseg:
82
+ raise ValueError(
83
+ "Window must have length {} got {}".format(
84
+ self.nperseg, window.size(0)
85
+ )
86
+ )
87
+ self.register_buffer("window", window)
71
88
 
72
89
  # scale corresponds to "density" normalization, worth
73
90
  # considering adding this as a kwarg and changing this calc
@@ -81,7 +98,9 @@ class SpectralDensity(torch.nn.Module):
81
98
  self.average = average
82
99
  self.fast = fast
83
100
 
84
- def forward(self, x: torch.Tensor, y: Optional[torch.Tensor] = None):
101
+ def forward(
102
+ self, x: TimeSeries1to3d, y: Optional[TimeSeries1to3d] = None
103
+ ) -> FrequencySeries1to3d:
85
104
  if self.fast:
86
105
  return fast_spectral_density(
87
106
  x,
@@ -3,8 +3,12 @@ from typing import Dict, List
3
3
 
4
4
  import torch
5
5
  import torch.nn.functional as F
6
+ from jaxtyping import Float
7
+ from torch import Tensor
6
8
  from torchaudio.transforms import Spectrogram
7
9
 
10
+ from ml4gw.types import TimeSeries3d
11
+
8
12
 
9
13
  class MultiResolutionSpectrogram(torch.nn.Module):
10
14
  """
@@ -122,7 +126,9 @@ class MultiResolutionSpectrogram(torch.nn.Module):
122
126
 
123
127
  return [dict(zip(kwargs, col)) for col in zip(*kwargs.values())]
124
128
 
125
- def forward(self, X: torch.Tensor) -> torch.Tensor:
129
+ def forward(
130
+ self, X: TimeSeries3d
131
+ ) -> Float[Tensor, "batch channel frequency time"]:
126
132
  """
127
133
  Calculate spectrograms of the input tensor and
128
134
  combine them into a single spectrogram
@@ -3,6 +3,7 @@ from typing import Optional
3
3
  import torch
4
4
 
5
5
  from ml4gw.spectral import spectral_density
6
+ from ml4gw.types import FrequencySeries1to3d, TimeSeries1to3d
6
7
 
7
8
 
8
9
  class FittableTransform(torch.nn.Module):
@@ -43,12 +44,12 @@ class FittableTransform(torch.nn.Module):
43
44
  class FittableSpectralTransform(FittableTransform):
44
45
  def normalize_psd(
45
46
  self,
46
- x,
47
+ x: TimeSeries1to3d,
47
48
  sample_rate: float,
48
49
  num_freqs: int,
49
50
  fftlength: Optional[float] = None,
50
51
  overlap: Optional[float] = None,
51
- ):
52
+ ) -> FrequencySeries1to3d:
52
53
  # if we specified an FFT length, convert
53
54
  # the (assumed) time-domain data to the
54
55
  # frequency domain
@@ -68,7 +69,7 @@ class FittableSpectralTransform(FittableTransform):
68
69
  scale=scale,
69
70
  )
70
71
 
71
- # add two dummy dimensions in case we need to inerpolate
72
+ # add two dummy dimensions in case we need to interpolate
72
73
  # the frequency dimension, since `interpolate` expects
73
74
  # a (batch, channel, spatial) formatted tensor as input
74
75
  x = x.view(1, 1, -1)
@@ -1,8 +1,11 @@
1
1
  from typing import List, Optional
2
2
 
3
3
  import torch
4
+ from jaxtyping import Float
5
+ from torch import Tensor
4
6
 
5
7
  from ml4gw import gw
8
+ from ml4gw.types import BatchTensor
6
9
 
7
10
 
8
11
  # TODO: should these live in ml4gw.waveforms submodule?
@@ -10,8 +13,8 @@ from ml4gw import gw
10
13
  class WaveformSampler(torch.nn.Module):
11
14
  def __init__(
12
15
  self,
13
- parameters: Optional[torch.Tensor] = None,
14
- **polarizations: torch.Tensor,
16
+ parameters: Optional[Float[Tensor, "batch num_params"]] = None,
17
+ **polarizations: Float[Tensor, "batch time"],
15
18
  ):
16
19
  super().__init__()
17
20
  # make sure we have the same number of waveforms
@@ -29,7 +32,7 @@ class WaveformSampler(torch.nn.Module):
29
32
  elif num_waveforms is None:
30
33
  num_waveforms = tensor.shape[0]
31
34
 
32
- self.polarizations[polarization] = torch.Tensor(tensor)
35
+ self.polarizations[polarization] = Tensor(tensor)
33
36
 
34
37
  if parameters is not None and len(parameters) != num_waveforms:
35
38
  raise ValueError(
@@ -73,10 +76,10 @@ class WaveformProjector(torch.nn.Module):
73
76
 
74
77
  def forward(
75
78
  self,
76
- dec: gw.ScalarTensor,
77
- psi: gw.ScalarTensor,
78
- phi: gw.ScalarTensor,
79
- **polarizations,
79
+ dec: BatchTensor,
80
+ psi: BatchTensor,
81
+ phi: BatchTensor,
82
+ **polarizations: Float[Tensor, "batch time"],
80
83
  ):
81
84
  ifo_responses = gw.compute_observed_strain(
82
85
  dec,
@@ -1,9 +1,15 @@
1
- from typing import Optional
1
+ from typing import Optional, Union
2
2
 
3
3
  import torch
4
4
 
5
5
  from ml4gw import spectral
6
6
  from ml4gw.transforms.transform import FittableSpectralTransform
7
+ from ml4gw.types import (
8
+ FrequencySeries1d,
9
+ FrequencySeries1to3d,
10
+ TimeSeries1d,
11
+ TimeSeries3d,
12
+ )
7
13
 
8
14
 
9
15
  class Whiten(torch.nn.Module):
@@ -58,7 +64,9 @@ class Whiten(torch.nn.Module):
58
64
  window = torch.hann_window(size, dtype=torch.float64)
59
65
  self.register_buffer("window", window)
60
66
 
61
- def forward(self, X: torch.Tensor, psd: torch.Tensor) -> torch.Tensor:
67
+ def forward(
68
+ self, X: TimeSeries3d, psd: FrequencySeries1to3d
69
+ ) -> TimeSeries3d:
62
70
  """
63
71
  Whiten a batch of multichannel timeseries by a
64
72
  background power spectral density.
@@ -142,7 +150,7 @@ class FixedWhiten(FittableSpectralTransform):
142
150
  def fit(
143
151
  self,
144
152
  fduration: float,
145
- *background: torch.Tensor,
153
+ *background: Union[TimeSeries1d, FrequencySeries1d],
146
154
  fftlength: Optional[float] = None,
147
155
  highpass: Optional[float] = None,
148
156
  overlap: Optional[float] = None
@@ -224,7 +232,7 @@ class FixedWhiten(FittableSpectralTransform):
224
232
  fduration = torch.Tensor([fduration])
225
233
  self.build(psd=psd, fduration=fduration)
226
234
 
227
- def forward(self, X: torch.Tensor) -> torch.Tensor:
235
+ def forward(self, X: TimeSeries3d) -> TimeSeries3d:
228
236
  """
229
237
  Whiten the input timeseries tensor using the
230
238
  PSD fit by the `.fit` method, which must be
ml4gw/types.py CHANGED
@@ -1,10 +1,25 @@
1
- from torchtyping import TensorType
2
-
3
- WaveformTensor = TensorType["batch", "num_ifos", "time"]
4
- PSDTensor = TensorType["num_ifos", "frequency"]
5
- ScalarTensor = TensorType["batch"]
6
- VectorGeometry = TensorType["batch", "space"]
7
- TensorGeometry = TensorType["batch", "space", "space"]
8
- NetworkVertices = TensorType["num_ifos", 3]
9
- NetworkDetectorTensors = TensorType["num_ifos", 3, 3]
10
- TimeSeriesTensor = TensorType["num_channels", "time"]
1
+ from typing import Union
2
+
3
+ from jaxtyping import Float
4
+ from torch import Tensor
5
+
6
+ WaveformTensor = Float[Tensor, "batch num_ifos time"]
7
+ PSDTensor = Float[Tensor, "num_ifos frequency"]
8
+ BatchTensor = Float[Tensor, "batch"]
9
+ VectorGeometry = Float[Tensor, "batch space"]
10
+ TensorGeometry = Float[Tensor, "batch space space"]
11
+ NetworkVertices = Float[Tensor, "num_ifos 3"]
12
+ NetworkDetectorTensors = Float[Tensor, "num_ifos 3 3"]
13
+
14
+
15
+ TimeSeries1d = Float[Tensor, "time"]
16
+ TimeSeries2d = Float[TimeSeries1d, "channel"]
17
+ TimeSeries3d = Float[TimeSeries2d, "batch"]
18
+ TimeSeries1to3d = Union[TimeSeries1d, TimeSeries2d, TimeSeries3d]
19
+
20
+ FrequencySeries1d = Float[Tensor, "frequency"]
21
+ FrequencySeries2d = Float[FrequencySeries1d, "channel"]
22
+ FrequencySeries3d = Float[FrequencySeries2d, "batch"]
23
+ FrequencySeries1to3d = Union[
24
+ FrequencySeries1d, FrequencySeries2d, FrequencySeries3d
25
+ ]
@@ -4,7 +4,7 @@ import torch
4
4
  # based on values from
5
5
  # https://lscsoft.docs.ligo.org/lalsuite/lal/_l_a_l_detectors_8h_source.html
6
6
  class InterferometerGeometry:
7
- def __init__(self, name: str):
7
+ def __init__(self, name: str) -> None:
8
8
  if name == "H1":
9
9
  self.x_arm = torch.Tensor(
10
10
  (-0.22389266154, +0.79983062746, +0.55690487831)
@@ -35,6 +35,12 @@ class InterferometerGeometry:
35
35
  self.vertex = torch.Tensor(
36
36
  (4.54637409900e06, 8.42989697626e05, 4.37857696241e06)
37
37
  )
38
+ elif name == "K1":
39
+ self.x_arm = torch.Tensor((-0.3759040, -0.8361583, 0.3994189))
40
+ self.y_arm = torch.Tensor((0.7164378, 0.01114076, 0.6975620))
41
+ self.vertex = torch.Tensor(
42
+ (-3777336.024, 3484898.411, 3765313.697)
43
+ )
38
44
  else:
39
45
  raise ValueError(
40
46
  f"{name} is not recognized as an interferometer, "