ml4gw 0.4.2__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ml4gw might be problematic. Click here for more details.

ml4gw/augmentations.py CHANGED
@@ -1,4 +1,6 @@
1
1
  import torch
2
+ from jaxtyping import Float
3
+ from torch import Tensor
2
4
 
3
5
 
4
6
  class SignalInverter(torch.nn.Module):
@@ -16,7 +18,9 @@ class SignalInverter(torch.nn.Module):
16
18
  super().__init__()
17
19
  self.prob = prob
18
20
 
19
- def forward(self, X):
21
+ def forward(
22
+ self, X: Float[Tensor, "*batch time"]
23
+ ) -> Float[Tensor, "*batch time"]:
20
24
  mask = torch.rand(size=X.shape[:-1]) < self.prob
21
25
  X[mask] *= -1
22
26
  return X
@@ -37,7 +41,9 @@ class SignalReverser(torch.nn.Module):
37
41
  super().__init__()
38
42
  self.prob = prob
39
43
 
40
- def forward(self, X):
44
+ def forward(
45
+ self, X: Float[Tensor, "*batch time"]
46
+ ) -> Float[Tensor, "*batch time"]:
41
47
  mask = torch.rand(size=X.shape[:-1]) < self.prob
42
48
  X[mask] = X[mask].flip(-1)
43
49
  return X
ml4gw/constants.py ADDED
@@ -0,0 +1,45 @@
1
+ """
2
+ Various constants, all in SI units.
3
+ """
4
+
5
+ EulerGamma = 0.577215664901532860606512090082402431
6
+
7
+ MSUN = 1.988409902147041637325262574352366540e30 # kg
8
+ """Solar mass"""
9
+
10
+ MRSUN = 1.476625038050124729627979840144936351e3
11
+ """Geometrized nominal solar mass, m"""
12
+
13
+ G = 6.67430e-11 # m^3 / kg / s^2
14
+ """Newton's gravitational constant"""
15
+
16
+ C = 299792458.0 # m / s
17
+ """Speed of light"""
18
+
19
+ """Pi"""
20
+ PI = 3.141592653589793238462643383279502884
21
+
22
+ TWO_PI = 6.283185307179586476925286766559005768
23
+
24
+ gt = G * MSUN / (C**3.0)
25
+ """
26
+ G MSUN / C^3 in seconds
27
+ """
28
+
29
+ MTSUN_SI = 4.925490947641266978197229498498379006e-6
30
+ """1 solar mass in seconds. Same value as lal.MTSUN_SI"""
31
+
32
+ m_per_Mpc = 3.085677581491367278913937957796471611e22
33
+ """
34
+ Meters per Mpc.
35
+ """
36
+
37
+ MPC_SEC = m_per_Mpc / C
38
+ """
39
+ 1 Mpc in seconds.
40
+ """
41
+
42
+ clightGpc = C / 3.0856778570831e22
43
+ """
44
+ Speed of light in vacuum (:math:`c`), in gigaparsecs per second
45
+ """
@@ -2,6 +2,8 @@ from collections.abc import Iterable
2
2
 
3
3
  import torch
4
4
 
5
+ from ml4gw.types import WaveformTensor
6
+
5
7
 
6
8
  class ChunkedTimeSeriesDataset(torch.utils.data.IterableDataset):
7
9
  """
@@ -55,10 +57,10 @@ class ChunkedTimeSeriesDataset(torch.utils.data.IterableDataset):
55
57
  self.coincident = coincident
56
58
  self.device = device
57
59
 
58
- def __len__(self):
60
+ def __len__(self) -> int:
59
61
  return len(self.chunk_it) * self.batches_per_chunk
60
62
 
61
- def __iter__(self):
63
+ def __iter__(self) -> WaveformTensor:
62
64
  it = iter(self.chunk_it)
63
65
  chunk = next(it)
64
66
  num_chunks, num_channels, chunk_size = chunk.shape
@@ -161,7 +161,7 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
161
161
  x[b, c] = f[self.channels[c]][i : i + self.kernel_size]
162
162
  return torch.Tensor(x)
163
163
 
164
- def __iter__(self) -> torch.Tensor:
164
+ def __iter__(self) -> WaveformTensor:
165
165
  worker_info = torch.utils.data.get_worker_info()
166
166
  if worker_info is None:
167
167
  num_batches = self.batches_per_epoch
@@ -2,8 +2,9 @@ import itertools
2
2
  from typing import Optional, Tuple, Union
3
3
 
4
4
  import torch
5
+ from jaxtyping import Float
6
+ from torch import Tensor
5
7
 
6
- from ml4gw import types
7
8
  from ml4gw.utils.slicing import slice_kernels
8
9
 
9
10
 
@@ -76,9 +77,9 @@ class InMemoryDataset(torch.utils.data.IterableDataset):
76
77
 
77
78
  def __init__(
78
79
  self,
79
- X: types.TimeSeriesTensor,
80
+ X: Float[Tensor, "channels time"],
80
81
  kernel_size: int,
81
- y: Optional[types.ScalarTensor] = None,
82
+ y: Optional[Float[Tensor, " time"]] = None,
82
83
  batch_size: int = 32,
83
84
  stride: int = 1,
84
85
  batches_per_epoch: Optional[int] = None,
@@ -207,7 +208,10 @@ class InMemoryDataset(torch.utils.data.IterableDataset):
207
208
 
208
209
  def __iter__(
209
210
  self,
210
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
211
+ ) -> Union[
212
+ Float[Tensor, "batch channel time"],
213
+ Tuple[Float[Tensor, "batch channel time"], Float[Tensor, " batch"]],
214
+ ]:
211
215
 
212
216
  indices = self.init_indices()
213
217
  for i in range(len(self)):
ml4gw/distributions.py CHANGED
@@ -4,11 +4,13 @@ from specified distributions. Each callable should map from
4
4
  an integer `N` to a 1D torch `Tensor` containing `N` samples
5
5
  from the corresponding distribution.
6
6
  """
7
-
7
+ import math
8
8
  from typing import Optional
9
9
 
10
10
  import torch
11
11
  import torch.distributions as dist
12
+ from jaxtyping import Float
13
+ from torch import Tensor
12
14
 
13
15
 
14
16
  class Cosine(dist.Distribution):
@@ -21,20 +23,21 @@ class Cosine(dist.Distribution):
21
23
 
22
24
  def __init__(
23
25
  self,
24
- low: float = torch.as_tensor(-torch.pi / 2),
25
- high: float = torch.as_tensor(torch.pi / 2),
26
+ low: float = -math.pi / 2,
27
+ high: float = math.pi / 2,
26
28
  validate_args=None,
27
29
  ):
28
30
  batch_shape = torch.Size()
29
31
  super().__init__(batch_shape, validate_args=validate_args)
30
- self.low = low
31
- self.norm = 1 / (torch.sin(high) - torch.sin(low))
32
+ self.low = torch.as_tensor(low)
33
+ self.high = torch.as_tensor(high)
34
+ self.norm = 1 / (torch.sin(self.high) - torch.sin(self.low))
32
35
 
33
- def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
36
+ def rsample(self, sample_shape: torch.Size = torch.Size()) -> Tensor:
34
37
  u = torch.rand(sample_shape, device=self.low.device)
35
38
  return torch.arcsin(u / self.norm + torch.sin(self.low))
36
39
 
37
- def log_prob(self, value):
40
+ def log_prob(self, value: float) -> Float[Tensor, ""]:
38
41
  value = torch.as_tensor(value)
39
42
  inside_range = (value >= self.low) & (value <= self.high)
40
43
  return value.cos().log() * inside_range
@@ -48,13 +51,16 @@ class Sine(dist.TransformedDistribution):
48
51
 
49
52
  def __init__(
50
53
  self,
51
- low: float = torch.as_tensor(0),
52
- high: float = torch.as_tensor(torch.pi),
54
+ low: float = 0.0,
55
+ high: float = math.pi,
53
56
  validate_args=None,
54
57
  ):
58
+ low = torch.as_tensor(low)
59
+ high = torch.as_tensor(high)
55
60
  base_dist = Cosine(
56
61
  low - torch.pi / 2, high - torch.pi / 2, validate_args
57
62
  )
63
+
58
64
  super().__init__(
59
65
  base_dist,
60
66
  [
@@ -153,14 +159,14 @@ class DeltaFunction(dist.Distribution):
153
159
 
154
160
  def __init__(
155
161
  self,
156
- peak: float = torch.as_tensor(0.0),
162
+ peak: float = 0.0,
157
163
  validate_args=None,
158
164
  ):
159
165
  batch_shape = torch.Size()
160
166
  super().__init__(batch_shape, validate_args=validate_args)
161
- self.peak = peak
167
+ self.peak = torch.as_tensor(peak)
162
168
 
163
- def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
169
+ def rsample(self, sample_shape: torch.Size = torch.Size()) -> Tensor:
164
170
  return self.peak * torch.ones(
165
171
  sample_shape, device=self.peak.device, dtype=torch.float32
166
172
  )
ml4gw/gw.py CHANGED
@@ -13,27 +13,21 @@ https://github.com/lscsoft/bilby/blob/master/bilby/gw/detector/interferometer.py
13
13
  from typing import List, Tuple, Union
14
14
 
15
15
  import torch
16
- from torchtyping import TensorType
16
+ from jaxtyping import Float
17
+ from torch import Tensor
17
18
 
19
+ from ml4gw.constants import C
18
20
  from ml4gw.types import (
21
+ BatchTensor,
19
22
  NetworkDetectorTensors,
20
23
  NetworkVertices,
21
24
  PSDTensor,
22
- ScalarTensor,
23
25
  TensorGeometry,
24
26
  VectorGeometry,
25
27
  WaveformTensor,
26
28
  )
27
29
  from ml4gw.utils.interferometer import InterferometerGeometry
28
30
 
29
- SPEED_OF_LIGHT = 299792458.0 # m/s
30
-
31
-
32
- # define some tensor shapes we'll reuse a bit
33
- # up front. Need to assign these variables so
34
- # that static linters don't give us name errors
35
- batch = num_ifos = polarizations = time = frequency = space = None # noqa
36
-
37
31
 
38
32
  def outer(x: VectorGeometry, y: VectorGeometry) -> TensorGeometry:
39
33
  """
@@ -62,12 +56,12 @@ polarization_funcs = {
62
56
 
63
57
 
64
58
  def compute_antenna_responses(
65
- theta: ScalarTensor,
66
- psi: ScalarTensor,
67
- phi: ScalarTensor,
59
+ theta: BatchTensor,
60
+ psi: BatchTensor,
61
+ phi: BatchTensor,
68
62
  detector_tensors: NetworkDetectorTensors,
69
63
  modes: List[str],
70
- ) -> TensorType["batch", "polarizations", "num_ifos"]:
64
+ ) -> Float[Tensor, "batch polarizations num_ifos"]:
71
65
  """
72
66
  Compute the antenna pattern factors of a batch of
73
67
  waveforms as a function of the sky parameters of
@@ -147,8 +141,8 @@ def compute_antenna_responses(
147
141
 
148
142
  def shift_responses(
149
143
  responses: WaveformTensor,
150
- theta: ScalarTensor,
151
- phi: ScalarTensor,
144
+ theta: BatchTensor,
145
+ phi: BatchTensor,
152
146
  vertices: NetworkVertices,
153
147
  sample_rate: float,
154
148
  ) -> WaveformTensor:
@@ -166,7 +160,7 @@ def shift_responses(
166
160
  # Divide by c in the second line so that we only
167
161
  # need to multiply the array by a single float
168
162
  dt = -(omega * vertices).sum(axis=-1)
169
- dt *= sample_rate / SPEED_OF_LIGHT
163
+ dt *= sample_rate / C
170
164
  dt = torch.trunc(dt).type(torch.int64)
171
165
 
172
166
  # rolling by gathering implementation based on
@@ -191,13 +185,13 @@ def shift_responses(
191
185
 
192
186
 
193
187
  def compute_observed_strain(
194
- dec: ScalarTensor,
195
- psi: ScalarTensor,
196
- phi: ScalarTensor,
188
+ dec: BatchTensor,
189
+ psi: BatchTensor,
190
+ phi: BatchTensor,
197
191
  detector_tensors: NetworkDetectorTensors,
198
192
  detector_vertices: NetworkVertices,
199
193
  sample_rate: float,
200
- **polarizations: TensorType["batch", "time"],
194
+ **polarizations: Float[Tensor, "batch time"],
201
195
  ) -> WaveformTensor:
202
196
  """
203
197
  Compute the strain timeseries $h(t)$ observed by a network
@@ -289,8 +283,8 @@ def compute_ifo_snr(
289
283
  responses: WaveformTensor,
290
284
  psd: PSDTensor,
291
285
  sample_rate: float,
292
- highpass: Union[float, TensorType["frequency"], None] = None,
293
- ) -> TensorType["batch", "num_ifos"]:
286
+ highpass: Union[float, Float[Tensor, " frequency"], None] = None,
287
+ ) -> Float[Tensor, "batch num_ifos"]:
294
288
  r"""Compute the SNRs of a batch of interferometer responses
295
289
 
296
290
  Compute the signal to noise ratio (SNR) of individual
@@ -390,8 +384,8 @@ def compute_network_snr(
390
384
  responses: WaveformTensor,
391
385
  psd: PSDTensor,
392
386
  sample_rate: float,
393
- highpass: Union[float, TensorType["frequency"], None] = None,
394
- ) -> ScalarTensor:
387
+ highpass: Union[float, Float[Tensor, " frequency"], None] = None,
388
+ ) -> BatchTensor:
395
389
  r"""
396
390
  Compute the total SNR from a gravitational waveform
397
391
  from a network of interferometers. The total SNR for
@@ -437,10 +431,10 @@ def compute_network_snr(
437
431
 
438
432
  def reweight_snrs(
439
433
  responses: WaveformTensor,
440
- target_snrs: Union[float, ScalarTensor],
434
+ target_snrs: Union[float, BatchTensor],
441
435
  psd: PSDTensor,
442
436
  sample_rate: float,
443
- highpass: Union[float, TensorType["frequency"], None] = None,
437
+ highpass: Union[float, Float[Tensor, " frequency"], None] = None,
444
438
  ) -> WaveformTensor:
445
439
  """Scale interferometer responses such that they have a desired SNR
446
440
 
@@ -1,7 +1,8 @@
1
1
  from collections.abc import Sequence
2
- from typing import Optional
2
+ from typing import Optional, Tuple, Union
3
3
 
4
4
  import torch
5
+ from torch import Tensor
5
6
 
6
7
  from ml4gw.nn.autoencoder.skip_connection import SkipConnection
7
8
 
@@ -27,12 +28,16 @@ class Autoencoder(torch.nn.Module):
27
28
  and how they operate.
28
29
  """
29
30
 
30
- def __init__(self, skip_connection: Optional[SkipConnection] = None):
31
+ def __init__(
32
+ self, skip_connection: Optional[SkipConnection] = None
33
+ ) -> None:
31
34
  super().__init__()
32
35
  self.skip_connection = skip_connection
33
36
  self.blocks = torch.nn.ModuleList()
34
37
 
35
- def encode(self, *X: torch.Tensor, return_states: bool = False):
38
+ def encode(
39
+ self, *X: Tensor, return_states: bool = False
40
+ ) -> Union[Tensor, Tuple[Tensor, Sequence]]:
36
41
  states = []
37
42
  for block in self.blocks:
38
43
  if isinstance(X, tuple):
@@ -48,7 +53,7 @@ class Autoencoder(torch.nn.Module):
48
53
  return X, states[:-1]
49
54
  return X
50
55
 
51
- def decode(self, *X, states: Optional[Sequence[torch.Tensor]] = None):
56
+ def decode(self, *X, states: Optional[Sequence[Tensor]] = None) -> Tensor:
52
57
  if self.skip_connection is not None and states is None:
53
58
  raise ValueError(
54
59
  "Must pass intermediate states when autoencoder "
@@ -76,7 +81,7 @@ class Autoencoder(torch.nn.Module):
76
81
  X = self.skip_connection(X, state)
77
82
  return X
78
83
 
79
- def forward(self, *X):
84
+ def forward(self, *X: Tensor) -> Tensor:
80
85
  return_states = self.skip_connection is not None
81
86
  X = self.encode(*X, return_states=return_states)
82
87
  if return_states:
@@ -84,6 +89,6 @@ class Autoencoder(torch.nn.Module):
84
89
  else:
85
90
  states = None
86
91
 
87
- if isinstance(X, torch.Tensor):
92
+ if isinstance(X, Tensor):
88
93
  X = (X,)
89
94
  return self.decode(*X, states=states)
@@ -2,6 +2,7 @@ from collections.abc import Callable, Sequence
2
2
  from typing import Optional
3
3
 
4
4
  import torch
5
+ from torch import Tensor
5
6
 
6
7
  from ml4gw.nn.autoencoder.base import Autoencoder
7
8
  from ml4gw.nn.autoencoder.skip_connection import SkipConnection
@@ -64,12 +65,12 @@ class ConvBlock(Autoencoder):
64
65
  self.encode_norm = norm(out_channels)
65
66
  self.decode_norm = norm(decode_channels)
66
67
 
67
- def encode(self, X):
68
+ def encode(self, X: Tensor) -> Tensor:
68
69
  X = self.encode_layer(X)
69
70
  X = self.encode_norm(X)
70
71
  return self.activation(X)
71
72
 
72
- def decode(self, X):
73
+ def decode(self, X: Tensor) -> Tensor:
73
74
  X = self.decode_layer(X)
74
75
  X = self.decode_norm(X)
75
76
  return self.output_activation(X)
@@ -144,13 +145,15 @@ class ConvolutionalAutoencoder(Autoencoder):
144
145
  self.blocks.append(block)
145
146
  in_channels = channels * groups
146
147
 
147
- def decode(self, *X, states=None, input_size: Optional[int] = None):
148
+ def decode(
149
+ self, *X, states=None, input_size: Optional[int] = None
150
+ ) -> Tensor:
148
151
  X = super().decode(*X, states=states)
149
152
  if input_size is not None:
150
153
  return match_size(X, input_size)
151
154
  return X
152
155
 
153
- def forward(self, X):
156
+ def forward(self, X: Tensor) -> Tensor:
154
157
  input_size = X.size(-1)
155
158
  X = super().forward(X)
156
159
  return match_size(X, input_size)
@@ -1,31 +1,32 @@
1
1
  import torch
2
+ from torch import Tensor
2
3
 
3
4
  from ml4gw.nn.autoencoder.utils import match_size
4
5
 
5
6
 
6
7
  class SkipConnection(torch.nn.Module):
7
- def forward(self, X: torch.Tensor, state: torch.Tensor):
8
+ def forward(self, X: Tensor, state: Tensor) -> Tensor:
8
9
  return match_size(X, state.size(-1))
9
10
 
10
- def get_out_channels(self, in_channels):
11
+ def get_out_channels(self, in_channels: int) -> int:
11
12
  return in_channels
12
13
 
13
14
 
14
15
  class AddSkipConnect(SkipConnection):
15
- def forward(self, X, state):
16
+ def forward(self, X: Tensor, state: Tensor) -> Tensor:
16
17
  X = super().forward(X, state)
17
18
  return X + state
18
19
 
19
20
 
20
21
  class ConcatSkipConnect(SkipConnection):
21
- def __init__(self, groups: int = 1):
22
+ def __init__(self, groups: int = 1) -> None:
22
23
  super().__init__()
23
24
  self.groups = groups
24
25
 
25
- def get_out_channels(self, in_channels):
26
+ def get_out_channels(self, in_channels: int) -> int:
26
27
  return 2 * in_channels
27
28
 
28
- def forward(self, X, state):
29
+ def forward(self, X: Tensor, state: Tensor) -> Tensor:
29
30
  X = super().forward(X, state)
30
31
  if self.groups == 1:
31
32
  return torch.cat([X, state], dim=1)
@@ -1,7 +1,8 @@
1
1
  import torch
2
+ from torch import Tensor
2
3
 
3
4
 
4
- def match_size(X: torch.Tensor, target_size: int):
5
+ def match_size(X: Tensor, target_size: int) -> Tensor:
5
6
  diff = target_size - X.size(-1)
6
7
  left = int(diff // 2)
7
8
  right = diff - left
ml4gw/nn/norm.py CHANGED
@@ -1,6 +1,8 @@
1
1
  from typing import Callable, Optional
2
2
 
3
3
  import torch
4
+ from jaxtyping import Float
5
+ from torch import Tensor
4
6
 
5
7
  NormLayer = Callable[[int], torch.nn.Module]
6
8
 
@@ -31,7 +33,15 @@ class GroupNorm1D(torch.nn.Module):
31
33
  self.weight = torch.nn.Parameter(torch.ones(shape))
32
34
  self.bias = torch.nn.Parameter(torch.zeros(shape))
33
35
 
34
- def forward(self, x):
36
+ def forward(
37
+ self, x: Float[Tensor, "batch channel length"]
38
+ ) -> Float[Tensor, "batch channel length"]:
39
+ if len(x.shape) != 3:
40
+ raise ValueError(
41
+ "GroupNorm1D requires 3-dimensional input, "
42
+ f"received {len(x.shape)} dimensional input"
43
+ )
44
+
35
45
  keepdims = self.num_groups == self.num_channels
36
46
 
37
47
  # compute group variance via the E[x**2] - E**2[x] trick
@@ -1,11 +1,11 @@
1
1
  from typing import Optional, Tuple
2
2
 
3
3
  import torch
4
+ from jaxtyping import Float
5
+ from torch import Tensor
4
6
 
5
7
  from ml4gw.utils.slicing import unfold_windows
6
8
 
7
- Tensor = torch.Tensor
8
-
9
9
 
10
10
  class OnlineAverager(torch.nn.Module):
11
11
  """
@@ -70,12 +70,14 @@ class OnlineAverager(torch.nn.Module):
70
70
  weights = unfold_windows(weights, weight_size, update_size)
71
71
  self.register_buffer("weights", weights)
72
72
 
73
- def get_initial_state(self):
73
+ def get_initial_state(self) -> Float[Tensor, "channel time"]:
74
74
  return torch.zeros((self.num_channels, self.state_size))
75
75
 
76
76
  def forward(
77
- self, update: torch.Tensor, state: Optional[torch.Tensor] = None
78
- ) -> Tuple[torch.Tensor, torch.Tensor]:
77
+ self,
78
+ update: Float[Tensor, "batch channel time1"],
79
+ state: Optional[Float[Tensor, "channel time2"]] = None,
80
+ ) -> Tuple[Float[Tensor, "channel time3"], Float[Tensor, "channel time4"]]:
79
81
  if state is None:
80
82
  state = self.get_initial_state()
81
83
 
@@ -1,6 +1,8 @@
1
1
  from typing import Optional, Sequence, Tuple
2
2
 
3
3
  import torch
4
+ from jaxtyping import Float
5
+ from torch import Tensor
4
6
 
5
7
  from ml4gw.utils.slicing import unfold_windows
6
8
 
@@ -82,14 +84,14 @@ class Snapshotter(torch.nn.Module):
82
84
  self.channels_per_snapshot = channels_per_snapshot
83
85
  self.num_channels = num_channels
84
86
 
85
- def get_initial_state(self):
87
+ def get_initial_state(self) -> Float[Tensor, "channel time"]:
86
88
  return torch.zeros((self.num_channels, self.state_size))
87
89
 
88
- # TODO: use torchtyping annotations to make
89
- # clear what the expected shapes are
90
90
  def forward(
91
- self, update: torch.Tensor, snapshot: Optional[torch.Tensor] = None
92
- ) -> Tuple[torch.Tensor, ...]:
91
+ self,
92
+ update: Float[Tensor, "channel time1"],
93
+ snapshot: Optional[Float[Tensor, "channel time2"]] = None,
94
+ ) -> Tuple[Tensor, ...]:
93
95
  if snapshot is None:
94
96
  snapshot = self.get_initial_state()
95
97