ml4gw 0.5.0__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ml4gw might be problematic. Click here for more details.

@@ -1,5 +1,8 @@
1
1
  import torch
2
+ from jaxtyping import Float
3
+ from torch import Tensor
2
4
 
5
+ from ml4gw.types import TimeSeries1to3d
3
6
  from ml4gw.utils.slicing import unfold_windows
4
7
 
5
8
 
@@ -40,7 +43,7 @@ class ShiftedPearsonCorrelation(torch.nn.Module):
40
43
  super().__init__()
41
44
  self.max_shift = max_shift
42
45
 
43
- def _shape_checks(self, x: torch.Tensor, y: torch.Tensor):
46
+ def _shape_checks(self, x: TimeSeries1to3d, y: TimeSeries1to3d):
44
47
  if x.ndim > 3:
45
48
  raise ValueError(
46
49
  "Tensor x can only have up to 3 dimensions "
@@ -61,8 +64,9 @@ class ShiftedPearsonCorrelation(torch.nn.Module):
61
64
  )
62
65
  )
63
66
 
64
- # TODO: torchtyping annotate
65
- def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
67
+ def forward(
68
+ self, x: TimeSeries1to3d, y: TimeSeries1to3d
69
+ ) -> Float[Tensor, "windows ..."]:
66
70
  self._shape_checks(x, y)
67
71
  dim = x.size(-1)
68
72
 
@@ -3,6 +3,10 @@ from typing import List, Optional, Tuple
3
3
 
4
4
  import torch
5
5
  import torch.nn.functional as F
6
+ from jaxtyping import Float, Int
7
+ from torch import Tensor
8
+
9
+ from ml4gw.types import FrequencySeries1to3d, TimeSeries1to3d, TimeSeries3d
6
10
 
7
11
  """
8
12
  All based on https://github.com/gwpy/gwpy/blob/v3.0.8/gwpy/signal/qtransform.py
@@ -44,7 +48,7 @@ class QTile(torch.nn.Module):
44
48
  duration: float,
45
49
  sample_rate: float,
46
50
  mismatch: float,
47
- ):
51
+ ) -> None:
48
52
  super().__init__()
49
53
  self.mismatch = mismatch
50
54
  self.q = q
@@ -63,18 +67,18 @@ class QTile(torch.nn.Module):
63
67
  self.register_buffer("indices", self.get_data_indices())
64
68
  self.register_buffer("window", self.get_window())
65
69
 
66
- def ntiles(self):
70
+ def ntiles(self) -> int:
67
71
  """
68
72
  Number of tiles in this frequency row
69
73
  """
70
74
  tcum_mismatch = self.duration * 2 * torch.pi * self.frequency / self.q
71
75
  return int(2 ** torch.ceil(torch.log2(tcum_mismatch / self.deltam)))
72
76
 
73
- def _get_indices(self):
77
+ def _get_indices(self) -> Int[Tensor, " windowsize"]:
74
78
  half = int((self.windowsize - 1) / 2)
75
79
  return torch.arange(-half, half + 1)
76
80
 
77
- def get_window(self):
81
+ def get_window(self) -> Float[Tensor, " windowsize"]:
78
82
  """
79
83
  Generate the bi-square window for this row
80
84
  """
@@ -87,7 +91,7 @@ class QTile(torch.nn.Module):
87
91
  )
88
92
  return torch.Tensor((1 - xfrequencies**2) ** 2 * norm)
89
93
 
90
- def get_data_indices(self):
94
+ def get_data_indices(self) -> Int[Tensor, " windowsize"]:
91
95
  """
92
96
  Get the index array of relevant frequencies for this row
93
97
  """
@@ -95,7 +99,9 @@ class QTile(torch.nn.Module):
95
99
  self._get_indices() + 1 + self.frequency * self.duration,
96
100
  ).type(torch.long)
97
101
 
98
- def forward(self, fseries: torch.Tensor, norm: str = "median"):
102
+ def forward(
103
+ self, fseries: FrequencySeries1to3d, norm: str = "median"
104
+ ) -> TimeSeries1to3d:
99
105
  """
100
106
  Compute the transform for this row
101
107
 
@@ -176,7 +182,7 @@ class SingleQTransform(torch.nn.Module):
176
182
  q: float = 12,
177
183
  frange: List[float] = [0, torch.inf],
178
184
  mismatch: float = 0.2,
179
- ):
185
+ ) -> None:
180
186
  super().__init__()
181
187
  self.q = q
182
188
  self.spectrogram_shape = spectrogram_shape
@@ -198,7 +204,7 @@ class SingleQTransform(torch.nn.Module):
198
204
  )
199
205
  self.qtiles = None
200
206
 
201
- def get_freqs(self):
207
+ def get_freqs(self) -> Float[Tensor, " nfreq"]:
202
208
  """
203
209
  Calculate the frequencies that will be used in this transform.
204
210
  For each frequency, a `QTile` is created.
@@ -262,7 +268,7 @@ class SingleQTransform(torch.nn.Module):
262
268
  if dimension == "batch":
263
269
  return torch.max(max_across_ft, dim=-1).values
264
270
 
265
- def compute_qtiles(self, X: torch.Tensor, norm: str = "median"):
271
+ def compute_qtiles(self, X: TimeSeries1to3d, norm: str = "median") -> None:
266
272
  """
267
273
  Take the FFT of the input timeseries and calculate the transform
268
274
  for each `QTile`
@@ -272,7 +278,7 @@ class SingleQTransform(torch.nn.Module):
272
278
  X[..., 1:] *= 2
273
279
  self.qtiles = [qtile(X, norm) for qtile in self.qtile_transforms]
274
280
 
275
- def interpolate(self, num_f_bins: int, num_t_bins: int):
281
+ def interpolate(self, num_f_bins: int, num_t_bins: int) -> TimeSeries3d:
276
282
  """
277
283
  Interpolate each `QTile` to the specified number of time and
278
284
  frequency bins. Note that PyTorch does not have the same
@@ -299,7 +305,7 @@ class SingleQTransform(torch.nn.Module):
299
305
 
300
306
  def forward(
301
307
  self,
302
- X: torch.Tensor,
308
+ X: TimeSeries1to3d,
303
309
  norm: str = "median",
304
310
  spectrogram_shape: Optional[Tuple[int, int]] = None,
305
311
  ):
@@ -371,7 +377,7 @@ class QScan(torch.nn.Module):
371
377
  qrange: List[float] = [4, 64],
372
378
  frange: List[float] = [0, torch.inf],
373
379
  mismatch: float = 0.2,
374
- ):
380
+ ) -> None:
375
381
  super().__init__()
376
382
  self.qrange = qrange
377
383
  self.mismatch = mismatch
@@ -397,7 +403,7 @@ class QScan(torch.nn.Module):
397
403
  ]
398
404
  )
399
405
 
400
- def get_qs(self):
406
+ def get_qs(self) -> List[float]:
401
407
  """
402
408
  Determine the values of Q to try for the set of Q-transforms
403
409
  """
@@ -413,7 +419,7 @@ class QScan(torch.nn.Module):
413
419
 
414
420
  def forward(
415
421
  self,
416
- X: torch.Tensor,
422
+ X: TimeSeries1to3d,
417
423
  fsearch_range: List[float] = None,
418
424
  norm: str = "median",
419
425
  spectrogram_shape: Optional[Tuple[int, int]] = None,
@@ -1,6 +1,8 @@
1
1
  from typing import Optional
2
2
 
3
3
  import torch
4
+ from jaxtyping import Float
5
+ from torch import Tensor
4
6
 
5
7
  from ml4gw.transforms.transform import FittableTransform
6
8
 
@@ -34,7 +36,7 @@ class ChannelWiseScaler(FittableTransform):
34
36
  self.register_buffer("mean", mean)
35
37
  self.register_buffer("std", std)
36
38
 
37
- def fit(self, X: torch.Tensor) -> None:
39
+ def fit(self, X: Float[Tensor, "... time"]) -> None:
38
40
  """Fit the scaling parameters to a timeseries
39
41
 
40
42
  Computes the channel-wise mean and standard deviation
@@ -60,7 +62,9 @@ class ChannelWiseScaler(FittableTransform):
60
62
 
61
63
  super().build(mean=mean, std=std)
62
64
 
63
- def forward(self, X: torch.Tensor, reverse: bool = False) -> torch.Tensor:
65
+ def forward(
66
+ self, X: Float[Tensor, "... time"], reverse: bool = False
67
+ ) -> Float[Tensor, "... time"]:
64
68
  if not reverse:
65
69
  return (X - self.mean) / self.std
66
70
  else:
@@ -2,8 +2,9 @@ from typing import Optional
2
2
 
3
3
  import torch
4
4
 
5
- from ml4gw import gw
5
+ from ml4gw.gw import compute_network_snr
6
6
  from ml4gw.transforms.transform import FittableSpectralTransform
7
+ from ml4gw.types import BatchTensor, TimeSeries2d, WaveformTensor
7
8
 
8
9
 
9
10
  class SnrRescaler(FittableSpectralTransform):
@@ -34,7 +35,7 @@ class SnrRescaler(FittableSpectralTransform):
34
35
 
35
36
  def fit(
36
37
  self,
37
- *background: torch.Tensor,
38
+ *background: TimeSeries2d,
38
39
  fftlength: Optional[float] = None,
39
40
  overlap: Optional[float] = None,
40
41
  ):
@@ -58,10 +59,10 @@ class SnrRescaler(FittableSpectralTransform):
58
59
 
59
60
  def forward(
60
61
  self,
61
- responses: gw.WaveformTensor,
62
- target_snrs: Optional[gw.ScalarTensor] = None,
62
+ responses: WaveformTensor,
63
+ target_snrs: Optional[BatchTensor] = None,
63
64
  ):
64
- snrs = gw.compute_network_snr(
65
+ snrs = compute_network_snr(
65
66
  responses, self.background, self.sample_rate, self.mask
66
67
  )
67
68
  if target_snrs is None:
@@ -1,8 +1,11 @@
1
1
  from typing import Optional
2
2
 
3
3
  import torch
4
+ from jaxtyping import Float
5
+ from torch import Tensor
4
6
 
5
7
  from ml4gw.spectral import fast_spectral_density, spectral_density
8
+ from ml4gw.types import FrequencySeries1to3d, TimeSeries1to3d
6
9
 
7
10
 
8
11
  class SpectralDensity(torch.nn.Module):
@@ -51,7 +54,9 @@ class SpectralDensity(torch.nn.Module):
51
54
  fftlength: float,
52
55
  overlap: Optional[float] = None,
53
56
  average: str = "mean",
54
- window: Optional[torch.Tensor] = None,
57
+ window: Optional[
58
+ Float[Tensor, " {int(fftlength*sample_rate)}"]
59
+ ] = None,
55
60
  fast: bool = False,
56
61
  ) -> None:
57
62
  if overlap is None:
@@ -93,7 +98,9 @@ class SpectralDensity(torch.nn.Module):
93
98
  self.average = average
94
99
  self.fast = fast
95
100
 
96
- def forward(self, x: torch.Tensor, y: Optional[torch.Tensor] = None):
101
+ def forward(
102
+ self, x: TimeSeries1to3d, y: Optional[TimeSeries1to3d] = None
103
+ ) -> FrequencySeries1to3d:
97
104
  if self.fast:
98
105
  return fast_spectral_density(
99
106
  x,
@@ -3,8 +3,12 @@ from typing import Dict, List
3
3
 
4
4
  import torch
5
5
  import torch.nn.functional as F
6
+ from jaxtyping import Float
7
+ from torch import Tensor
6
8
  from torchaudio.transforms import Spectrogram
7
9
 
10
+ from ml4gw.types import TimeSeries3d
11
+
8
12
 
9
13
  class MultiResolutionSpectrogram(torch.nn.Module):
10
14
  """
@@ -122,7 +126,9 @@ class MultiResolutionSpectrogram(torch.nn.Module):
122
126
 
123
127
  return [dict(zip(kwargs, col)) for col in zip(*kwargs.values())]
124
128
 
125
- def forward(self, X: torch.Tensor) -> torch.Tensor:
129
+ def forward(
130
+ self, X: TimeSeries3d
131
+ ) -> Float[Tensor, "batch channel frequency time"]:
126
132
  """
127
133
  Calculate spectrograms of the input tensor and
128
134
  combine them into a single spectrogram
@@ -3,6 +3,7 @@ from typing import Optional
3
3
  import torch
4
4
 
5
5
  from ml4gw.spectral import spectral_density
6
+ from ml4gw.types import FrequencySeries1to3d, TimeSeries1to3d
6
7
 
7
8
 
8
9
  class FittableTransform(torch.nn.Module):
@@ -43,12 +44,12 @@ class FittableTransform(torch.nn.Module):
43
44
  class FittableSpectralTransform(FittableTransform):
44
45
  def normalize_psd(
45
46
  self,
46
- x,
47
+ x: TimeSeries1to3d,
47
48
  sample_rate: float,
48
49
  num_freqs: int,
49
50
  fftlength: Optional[float] = None,
50
51
  overlap: Optional[float] = None,
51
- ):
52
+ ) -> FrequencySeries1to3d:
52
53
  # if we specified an FFT length, convert
53
54
  # the (assumed) time-domain data to the
54
55
  # frequency domain
@@ -68,7 +69,7 @@ class FittableSpectralTransform(FittableTransform):
68
69
  scale=scale,
69
70
  )
70
71
 
71
- # add two dummy dimensions in case we need to inerpolate
72
+ # add two dummy dimensions in case we need to interpolate
72
73
  # the frequency dimension, since `interpolate` expects
73
74
  # a (batch, channel, spatial) formatted tensor as input
74
75
  x = x.view(1, 1, -1)
@@ -1,8 +1,11 @@
1
1
  from typing import List, Optional
2
2
 
3
3
  import torch
4
+ from jaxtyping import Float
5
+ from torch import Tensor
4
6
 
5
7
  from ml4gw import gw
8
+ from ml4gw.types import BatchTensor
6
9
 
7
10
 
8
11
  # TODO: should these live in ml4gw.waveforms submodule?
@@ -10,8 +13,8 @@ from ml4gw import gw
10
13
  class WaveformSampler(torch.nn.Module):
11
14
  def __init__(
12
15
  self,
13
- parameters: Optional[torch.Tensor] = None,
14
- **polarizations: torch.Tensor,
16
+ parameters: Optional[Float[Tensor, "batch num_params"]] = None,
17
+ **polarizations: Float[Tensor, "batch time"],
15
18
  ):
16
19
  super().__init__()
17
20
  # make sure we have the same number of waveforms
@@ -29,7 +32,7 @@ class WaveformSampler(torch.nn.Module):
29
32
  elif num_waveforms is None:
30
33
  num_waveforms = tensor.shape[0]
31
34
 
32
- self.polarizations[polarization] = torch.Tensor(tensor)
35
+ self.polarizations[polarization] = Tensor(tensor)
33
36
 
34
37
  if parameters is not None and len(parameters) != num_waveforms:
35
38
  raise ValueError(
@@ -73,10 +76,10 @@ class WaveformProjector(torch.nn.Module):
73
76
 
74
77
  def forward(
75
78
  self,
76
- dec: gw.ScalarTensor,
77
- psi: gw.ScalarTensor,
78
- phi: gw.ScalarTensor,
79
- **polarizations,
79
+ dec: BatchTensor,
80
+ psi: BatchTensor,
81
+ phi: BatchTensor,
82
+ **polarizations: Float[Tensor, "batch time"],
80
83
  ):
81
84
  ifo_responses = gw.compute_observed_strain(
82
85
  dec,
@@ -1,9 +1,15 @@
1
- from typing import Optional
1
+ from typing import Optional, Union
2
2
 
3
3
  import torch
4
4
 
5
5
  from ml4gw import spectral
6
6
  from ml4gw.transforms.transform import FittableSpectralTransform
7
+ from ml4gw.types import (
8
+ FrequencySeries1d,
9
+ FrequencySeries1to3d,
10
+ TimeSeries1d,
11
+ TimeSeries3d,
12
+ )
7
13
 
8
14
 
9
15
  class Whiten(torch.nn.Module):
@@ -58,7 +64,9 @@ class Whiten(torch.nn.Module):
58
64
  window = torch.hann_window(size, dtype=torch.float64)
59
65
  self.register_buffer("window", window)
60
66
 
61
- def forward(self, X: torch.Tensor, psd: torch.Tensor) -> torch.Tensor:
67
+ def forward(
68
+ self, X: TimeSeries3d, psd: FrequencySeries1to3d
69
+ ) -> TimeSeries3d:
62
70
  """
63
71
  Whiten a batch of multichannel timeseries by a
64
72
  background power spectral density.
@@ -142,7 +150,7 @@ class FixedWhiten(FittableSpectralTransform):
142
150
  def fit(
143
151
  self,
144
152
  fduration: float,
145
- *background: torch.Tensor,
153
+ *background: Union[TimeSeries1d, FrequencySeries1d],
146
154
  fftlength: Optional[float] = None,
147
155
  highpass: Optional[float] = None,
148
156
  overlap: Optional[float] = None
@@ -224,7 +232,7 @@ class FixedWhiten(FittableSpectralTransform):
224
232
  fduration = torch.Tensor([fduration])
225
233
  self.build(psd=psd, fduration=fduration)
226
234
 
227
- def forward(self, X: torch.Tensor) -> torch.Tensor:
235
+ def forward(self, X: TimeSeries3d) -> TimeSeries3d:
228
236
  """
229
237
  Whiten the input timeseries tensor using the
230
238
  PSD fit by the `.fit` method, which must be
ml4gw/types.py CHANGED
@@ -1,10 +1,25 @@
1
- from torchtyping import TensorType
2
-
3
- WaveformTensor = TensorType["batch", "num_ifos", "time"]
4
- PSDTensor = TensorType["num_ifos", "frequency"]
5
- ScalarTensor = TensorType["batch"]
6
- VectorGeometry = TensorType["batch", "space"]
7
- TensorGeometry = TensorType["batch", "space", "space"]
8
- NetworkVertices = TensorType["num_ifos", 3]
9
- NetworkDetectorTensors = TensorType["num_ifos", 3, 3]
10
- TimeSeriesTensor = TensorType["num_channels", "time"]
1
+ from typing import Union
2
+
3
+ from jaxtyping import Float
4
+ from torch import Tensor
5
+
6
+ WaveformTensor = Float[Tensor, "batch num_ifos time"]
7
+ PSDTensor = Float[Tensor, "num_ifos frequency"]
8
+ BatchTensor = Float[Tensor, "batch"]
9
+ VectorGeometry = Float[Tensor, "batch space"]
10
+ TensorGeometry = Float[Tensor, "batch space space"]
11
+ NetworkVertices = Float[Tensor, "num_ifos 3"]
12
+ NetworkDetectorTensors = Float[Tensor, "num_ifos 3 3"]
13
+
14
+
15
+ TimeSeries1d = Float[Tensor, "time"]
16
+ TimeSeries2d = Float[TimeSeries1d, "channel"]
17
+ TimeSeries3d = Float[TimeSeries2d, "batch"]
18
+ TimeSeries1to3d = Union[TimeSeries1d, TimeSeries2d, TimeSeries3d]
19
+
20
+ FrequencySeries1d = Float[Tensor, "frequency"]
21
+ FrequencySeries2d = Float[FrequencySeries1d, "channel"]
22
+ FrequencySeries3d = Float[FrequencySeries2d, "batch"]
23
+ FrequencySeries1to3d = Union[
24
+ FrequencySeries1d, FrequencySeries2d, FrequencySeries3d
25
+ ]
@@ -4,7 +4,7 @@ import torch
4
4
  # based on values from
5
5
  # https://lscsoft.docs.ligo.org/lalsuite/lal/_l_a_l_detectors_8h_source.html
6
6
  class InterferometerGeometry:
7
- def __init__(self, name: str):
7
+ def __init__(self, name: str) -> None:
8
8
  if name == "H1":
9
9
  self.x_arm = torch.Tensor(
10
10
  (-0.22389266154, +0.79983062746, +0.55690487831)
ml4gw/utils/slicing.py CHANGED
@@ -1,25 +1,30 @@
1
1
  from typing import Optional, Union
2
2
 
3
3
  import torch
4
+ from jaxtyping import Float, Int64
5
+ from torch import Tensor
4
6
  from torch.nn.functional import unfold
5
- from torchtyping import TensorType
6
7
 
7
- # need to define these for flake8 compatibility
8
- batch = time = channel = None # noqa
8
+ from ml4gw.types import (
9
+ TimeSeries1d,
10
+ TimeSeries1to3d,
11
+ TimeSeries2d,
12
+ TimeSeries3d,
13
+ )
9
14
 
10
- TimeSeriesTensor = Union[TensorType["time"], TensorType["channel", "time"]]
11
-
12
- BatchTimeSeriesTensor = Union[
13
- TensorType["batch", "time"], TensorType["batch", "channel", "time"]
14
- ]
15
+ BatchTimeSeriesTensor = Union[Float[Tensor, "batch time"], TimeSeries3d]
15
16
 
16
17
 
17
18
  def unfold_windows(
18
- x: torch.Tensor,
19
+ x: TimeSeries1to3d,
19
20
  window_size: int,
20
21
  stride: int,
21
22
  drop_last: bool = True,
22
- ):
23
+ ) -> Union[
24
+ Float[TimeSeries1d, " window"],
25
+ Float[TimeSeries2d, " window"],
26
+ Float[TimeSeries3d, " window"],
27
+ ]:
23
28
  """Unfold a timeseries into windows
24
29
 
25
30
  Args:
@@ -83,8 +88,8 @@ def unfold_windows(
83
88
 
84
89
 
85
90
  def slice_kernels(
86
- x: Union[TimeSeriesTensor, TensorType["batch", "channel", "time"]],
87
- idx: TensorType[..., torch.int64],
91
+ x: TimeSeries1to3d,
92
+ idx: Int64[Tensor, "..."],
88
93
  kernel_size: int,
89
94
  ) -> BatchTimeSeriesTensor:
90
95
  """Slice kernels from single or multichannel timeseries
@@ -96,7 +101,8 @@ def slice_kernels(
96
101
  one more dimension than `x`.
97
102
 
98
103
  Args:
99
- x: The timeseries tensor to slice kernels from
104
+ x:
105
+ The timeseries tensor to slice kernels from
100
106
  idx:
101
107
  The indices in `x` of the first sample of each
102
108
  kernel. If `x` is 1D, `idx` must be 1D as well.
@@ -114,6 +120,7 @@ def slice_kernels(
114
120
  coincidentally among the channels.
115
121
  kernel_size:
116
122
  The length of the kernels to slice from the timeseries
123
+
117
124
  Returns:
118
125
  A tensor of shape `(batch_size, kernel_size)` if `x` is
119
126
  1D and `(batch_size, num_channels, kernel_size)` if `x`
@@ -225,7 +232,7 @@ def slice_kernels(
225
232
 
226
233
 
227
234
  def sample_kernels(
228
- X: TimeSeriesTensor,
235
+ X: TimeSeries1to3d,
229
236
  kernel_size: int,
230
237
  N: Optional[int] = None,
231
238
  max_center_offset: Optional[int] = None,
@@ -245,8 +252,9 @@ def sample_kernels(
245
252
  either be `None` or be equal to `len(X)`.
246
253
 
247
254
  Args:
248
- X: The timeseries tensor from which to sample kernels
249
- kernel_size: The size of the kernels to sample
255
+ X:
256
+ The timeseries tensor from which to sample kernels
257
+ kernel_size: The size of the kernels to sample
250
258
  N:
251
259
  The number of kernels to sample. Can be left as
252
260
  `None` if `X` is 3D, otherwise must be specified
@@ -1,24 +1,26 @@
1
- from typing import Callable
1
+ from typing import Callable, Dict, Tuple
2
2
 
3
3
  import torch
4
+ from jaxtyping import Float
5
+ from torch import Tensor
4
6
 
5
7
 
6
8
  class ParameterSampler(torch.nn.Module):
7
- def __init__(self, **parameters: Callable):
9
+ def __init__(self, **parameters: Callable) -> None:
8
10
  super().__init__()
9
11
  self.parameters = parameters
10
12
 
11
13
  def forward(
12
14
  self,
13
15
  N: int,
14
- ):
16
+ ) -> Dict[str, Float[Tensor, " {N}"]]:
15
17
  return {k: v.sample((N,)) for k, v in self.parameters.items()}
16
18
 
17
19
 
18
20
  class WaveformGenerator(torch.nn.Module):
19
21
  def __init__(
20
22
  self, waveform: Callable, parameter_sampler: ParameterSampler
21
- ):
23
+ ) -> None:
22
24
  """
23
25
  A torch module that generates waveforms from a given waveform function
24
26
  and a parameter sampler.
@@ -34,6 +36,8 @@ class WaveformGenerator(torch.nn.Module):
34
36
  self.waveform = waveform
35
37
  self.parameter_sampler = parameter_sampler
36
38
 
37
- def forward(self, N: int):
39
+ def forward(
40
+ self, N: int
41
+ ) -> Tuple[Float[Tensor, "{N} samples"], Dict[str, Float[Tensor, " {N}"]]]:
38
42
  parameters = self.parameter_sampler(N)
39
43
  return self.waveform(**parameters), parameters
@@ -1,7 +1,9 @@
1
1
  import torch
2
- from torchtyping import TensorType
2
+ from jaxtyping import Float
3
+
4
+ from ml4gw.constants import MTSUN_SI, PI
5
+ from ml4gw.types import BatchTensor, FrequencySeries1d
3
6
 
4
- from ..constants import MTSUN_SI, PI
5
7
  from .phenom_d_data import QNMData_a, QNMData_fdamp, QNMData_fring
6
8
  from .taylorf2 import TaylorF2
7
9
 
@@ -15,14 +17,14 @@ class IMRPhenomD(TaylorF2):
15
17
 
16
18
  def forward(
17
19
  self,
18
- f: TensorType,
19
- chirp_mass: TensorType,
20
- mass_ratio: TensorType,
21
- chi1: TensorType,
22
- chi2: TensorType,
23
- distance: TensorType,
24
- phic: TensorType,
25
- inclination: TensorType,
20
+ f: FrequencySeries1d,
21
+ chirp_mass: BatchTensor,
22
+ mass_ratio: BatchTensor,
23
+ chi1: BatchTensor,
24
+ chi2: BatchTensor,
25
+ distance: BatchTensor,
26
+ phic: BatchTensor,
27
+ inclination: BatchTensor,
26
28
  f_ref: float,
27
29
  ):
28
30
  """
@@ -76,15 +78,15 @@ class IMRPhenomD(TaylorF2):
76
78
 
77
79
  def phenom_d_htilde(
78
80
  self,
79
- f: TensorType,
80
- chirp_mass: TensorType,
81
- mass_ratio: TensorType,
82
- chi1: TensorType,
83
- chi2: TensorType,
84
- distance: TensorType,
85
- phic: TensorType,
81
+ f: FrequencySeries1d,
82
+ chirp_mass: BatchTensor,
83
+ mass_ratio: BatchTensor,
84
+ chi1: BatchTensor,
85
+ chi2: BatchTensor,
86
+ distance: BatchTensor,
87
+ phic: BatchTensor,
86
88
  f_ref: float,
87
- ):
89
+ ) -> Float[FrequencySeries1d, " batch"]:
88
90
  total_mass = chirp_mass * (1 + mass_ratio) ** 1.2 / mass_ratio**0.6
89
91
  mass_1 = total_mass / (1 + mass_ratio)
90
92
  mass_2 = mass_1 * mass_ratio