ml4gw 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ml4gw might be problematic. Click here for more details.

ml4gw/constants.py ADDED
@@ -0,0 +1,45 @@
1
+ """
2
+ Various constants, all in SI units.
3
+ """
4
+
5
+ EulerGamma = 0.577215664901532860606512090082402431
6
+
7
+ MSUN = 1.988409902147041637325262574352366540e30 # kg
8
+ """Solar mass"""
9
+
10
+ MRSUN = 1.476625038050124729627979840144936351e3
11
+ """Geometrized nominal solar mass, m"""
12
+
13
+ G = 6.67430e-11 # m^3 / kg / s^2
14
+ """Newton's gravitational constant"""
15
+
16
+ C = 299792458.0 # m / s
17
+ """Speed of light"""
18
+
19
+ """Pi"""
20
+ PI = 3.141592653589793238462643383279502884
21
+
22
+ TWO_PI = 6.283185307179586476925286766559005768
23
+
24
+ gt = G * MSUN / (C**3.0)
25
+ """
26
+ G MSUN / C^3 in seconds
27
+ """
28
+
29
+ MTSUN_SI = 4.925490947641266978197229498498379006e-6
30
+ """1 solar mass in seconds. Same value as lal.MTSUN_SI"""
31
+
32
+ m_per_Mpc = 3.085677581491367278913937957796471611e22
33
+ """
34
+ Meters per Mpc.
35
+ """
36
+
37
+ MPC_SEC = m_per_Mpc / C
38
+ """
39
+ 1 Mpc in seconds.
40
+ """
41
+
42
+ clightGpc = C / 3.0856778570831e22
43
+ """
44
+ Speed of light in vacuum (:math:`c`), in gigaparsecs per second
45
+ """
@@ -7,7 +7,7 @@ from ml4gw import types
7
7
  from ml4gw.utils.slicing import slice_kernels
8
8
 
9
9
 
10
- class InMemoryDataset:
10
+ class InMemoryDataset(torch.utils.data.IterableDataset):
11
11
  """Dataset for iterating through in-memory multi-channel timeseries
12
12
 
13
13
  Dataset for arrays of timeseries data which can be stored
@@ -131,7 +131,6 @@ class InMemoryDataset:
131
131
  self.batches_per_epoch = batches_per_epoch
132
132
  self.shuffle = shuffle
133
133
  self.coincident = coincident
134
- self._i = self._idx = None
135
134
 
136
135
  @property
137
136
  def num_kernels(self) -> int:
@@ -157,7 +156,7 @@ class InMemoryDataset:
157
156
  num_kernels = self.num_kernels ** len(self.X)
158
157
  return (num_kernels - 1) // self.batch_size + 1
159
158
 
160
- def __iter__(self):
159
+ def init_indices(self):
161
160
  """
162
161
  Initialize arrays of indices we'll use to slice
163
162
  through X and y at iteration time. This helps by
@@ -204,36 +203,23 @@ class InMemoryDataset:
204
203
  # the simplest case: deteriminstic and coincident
205
204
  idx = torch.arange(num_kernels, device=device)
206
205
 
207
- self._idx = idx
208
- self._i = 0
209
- return self
206
+ return idx
210
207
 
211
- def __next__(
208
+ def __iter__(
212
209
  self,
213
210
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
214
- if self._i is None or self._idx is None:
215
- raise TypeError(
216
- "Must initialize InMemoryDataset iteration "
217
- "before calling __next__"
218
- )
219
211
 
220
- # check if we're out of batches, and if so
221
- # make sure to reset before stopping
222
- if self._i >= len(self):
223
- self._i = self._idx = None
224
- raise StopIteration
225
-
226
- # slice the array of _indices_ we'll be using to
227
- # slice our timeseries, and scale them by the stride
228
- slc = slice(self._i * self.batch_size, (self._i + 1) * self.batch_size)
229
- idx = self._idx[slc] * self.stride
230
-
231
- # slice our timeseries
232
- X = slice_kernels(self.X, idx, self.kernel_size)
233
- if self.y is not None:
234
- y = slice_kernels(self.y, idx, self.kernel_size)
235
-
236
- self._i += 1
237
- if self.y is not None:
238
- return X, y
239
- return X
212
+ indices = self.init_indices()
213
+ for i in range(len(self)):
214
+ # slice the array of _indices_ we'll be using to
215
+ # slice our timeseries, and scale them by the stride
216
+ slc = slice(i * self.batch_size, (i + 1) * self.batch_size)
217
+ idx = indices[slc] * self.stride
218
+
219
+ # slice our timeseries
220
+ X = slice_kernels(self.X, idx, self.kernel_size)
221
+ if self.y is not None:
222
+ y = slice_kernels(self.y, idx, self.kernel_size)
223
+ yield X, y
224
+ else:
225
+ yield X
ml4gw/distributions.py CHANGED
@@ -4,94 +4,114 @@ from specified distributions. Each callable should map from
4
4
  an integer `N` to a 1D torch `Tensor` containing `N` samples
5
5
  from the corresponding distribution.
6
6
  """
7
-
8
7
  import math
9
8
  from typing import Optional
10
9
 
11
10
  import torch
11
+ import torch.distributions as dist
12
12
 
13
13
 
14
- class Uniform:
14
+ class Cosine(dist.Distribution):
15
15
  """
16
- Sample uniformly between `low` and `high`.
16
+ Cosine distribution based on
17
+ ``torch.distributions.TransformedDistribution``.
17
18
  """
18
19
 
19
- def __init__(self, low: float = 0, high: float = 1) -> None:
20
- self.low = low
21
- self.high = high
22
-
23
- def __call__(self, N: int) -> torch.Tensor:
24
- return self.low + torch.rand(size=(N,)) * (self.high - self.low)
20
+ arg_constraints = {}
25
21
 
26
-
27
- class Cosine:
22
+ def __init__(
23
+ self,
24
+ low: float = -math.pi / 2,
25
+ high: float = math.pi / 2,
26
+ validate_args=None,
27
+ ):
28
+ batch_shape = torch.Size()
29
+ super().__init__(batch_shape, validate_args=validate_args)
30
+ self.low = torch.as_tensor(low)
31
+ self.high = torch.as_tensor(high)
32
+ self.norm = 1 / (torch.sin(self.high) - torch.sin(self.low))
33
+
34
+ def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
35
+ u = torch.rand(sample_shape, device=self.low.device)
36
+ return torch.arcsin(u / self.norm + torch.sin(self.low))
37
+
38
+ def log_prob(self, value):
39
+ value = torch.as_tensor(value)
40
+ inside_range = (value >= self.low) & (value <= self.high)
41
+ return value.cos().log() * inside_range
42
+
43
+
44
+ class Sine(dist.TransformedDistribution):
28
45
  """
29
- Sample from a raised Cosine distribution between
30
- `low` and `high`. Based on the implementation from
31
- bilby documented here:
32
- https://lscsoft.docs.ligo.org/bilby/api/bilby.core.prior.analytical.Cosine.html # noqa
46
+ Sine distribution based on
47
+ ``torch.distributions.TransformedDistribution``.
33
48
  """
34
49
 
35
50
  def __init__(
36
- self, low: float = -math.pi / 2, high: float = math.pi / 2
37
- ) -> None:
38
- self.low = low
39
- self.norm = 1 / (math.sin(high) - math.sin(low))
40
-
41
- def __call__(self, N: int) -> torch.Tensor:
42
- """
43
- Implementation lifted from
44
- https://lscsoft.docs.ligo.org/bilby/_modules/bilby/core/prior/analytical.html#Cosine # noqa
45
- """
46
- u = torch.rand(size=(N,))
47
- return torch.arcsin(u / self.norm + math.sin(self.low))
48
-
49
-
50
- class LogNormal:
51
+ self,
52
+ low: float = 0.0,
53
+ high: float = math.pi,
54
+ validate_args=None,
55
+ ):
56
+ low = torch.as_tensor(low)
57
+ high = torch.as_tensor(high)
58
+ base_dist = Cosine(
59
+ low - torch.pi / 2, high - torch.pi / 2, validate_args
60
+ )
61
+
62
+ super().__init__(
63
+ base_dist,
64
+ [
65
+ dist.AffineTransform(
66
+ loc=torch.pi / 2,
67
+ scale=1,
68
+ )
69
+ ],
70
+ validate_args=validate_args,
71
+ )
72
+
73
+
74
+ class LogUniform(dist.TransformedDistribution):
51
75
  """
52
- Sample from a log normal distribution with the
53
- specified `mean` and standard deviation `std`.
54
- If a `low` value is specified, values sampled
55
- lower than this will be clipped to `low`.
76
+ Sample from a log uniform distribution
56
77
  """
57
78
 
58
- def __init__(
59
- self, mean: float, std: float, low: Optional[float] = None
60
- ) -> None:
61
- self.sigma = math.log((std / mean) ** 2 + 1) ** 0.5
62
- self.mu = 2 * math.log(mean / (mean**2 + std**2) ** 0.25)
63
- self.low = low
79
+ def __init__(self, low: float, high: float, validate_args=None):
80
+ base_dist = dist.Uniform(
81
+ torch.as_tensor(low).log(),
82
+ torch.as_tensor(high).log(),
83
+ validate_args,
84
+ )
85
+ super().__init__(
86
+ base_dist,
87
+ [dist.ExpTransform()],
88
+ validate_args=validate_args,
89
+ )
64
90
 
65
- def __call__(self, N: int) -> torch.Tensor:
66
91
 
67
- u = self.mu + torch.randn(N) * self.sigma
68
- x = torch.exp(u)
92
+ class LogNormal(dist.LogNormal):
93
+ def __init__(
94
+ self,
95
+ mean: float,
96
+ std: float,
97
+ low: Optional[float] = None,
98
+ validate_args=None,
99
+ ):
100
+ self.low = low
101
+ super().__init__(loc=mean, scale=std, validate_args=validate_args)
69
102
 
103
+ def support(self):
70
104
  if self.low is not None:
71
- x = torch.clip(x, self.low)
72
- return x
73
-
74
-
75
- class LogUniform(Uniform):
76
- """
77
- Sample from a log uniform distribution
78
- """
79
-
80
- def __init__(self, low: float, high: float) -> None:
81
- super().__init__(math.log(low), math.log(high))
82
-
83
- def __call__(self, N: int) -> torch.Tensor:
84
- u = super().__call__(N)
85
- return torch.exp(u)
105
+ return dist.constraints.greater_than(self.low)
86
106
 
87
107
 
88
- class PowerLaw:
108
+ class PowerLaw(dist.TransformedDistribution):
89
109
  """
90
110
  Sample from a power law distribution,
91
111
  .. math::
92
- p(x) \approx x^{-\alpha}.
112
+ p(x) \approx x^{\alpha}.
93
113
 
94
- Index alpha must be greater than 1.
114
+ Index alpha cannot be 0, since it is equivalent to a Uniform distribution.
95
115
  This could be used, for example, as a universal distribution of
96
116
  signal-to-noise ratios (SNRs) from uniformly volume distributed
97
117
  sources
@@ -102,21 +122,49 @@ class PowerLaw:
102
122
  where :math:`\rho_0` is a representative minimum SNR
103
123
  considered for detection. See, for example,
104
124
  `Schutz (2011) <https://arxiv.org/abs/1102.5421>`_.
125
+ Or, for example, ``index=2`` for uniform in Euclidean volume.
105
126
  """
106
127
 
128
+ support = dist.constraints.nonnegative
129
+
130
+ def __init__(
131
+ self, minimum: float, maximum: float, index: int, validate_args=None
132
+ ):
133
+ if index == 0:
134
+ raise RuntimeError("Index of 0 is the same as Uniform")
135
+ elif index == -1:
136
+ base_min = torch.as_tensor(minimum).log()
137
+ base_max = torch.as_tensor(maximum).log()
138
+ transforms = [dist.ExpTransform()]
139
+ else:
140
+ index_plus = index + 1
141
+ base_min = minimum**index_plus / index_plus
142
+ base_max = maximum**index_plus / index_plus
143
+ transforms = [
144
+ dist.AffineTransform(loc=0, scale=index_plus),
145
+ dist.PowerTransform(1 / index_plus),
146
+ ]
147
+ base_dist = dist.Uniform(base_min, base_max, validate_args=False)
148
+ super().__init__(
149
+ base_dist,
150
+ transforms,
151
+ validate_args=validate_args,
152
+ )
153
+
154
+
155
+ class DeltaFunction(dist.Distribution):
156
+ arg_constraints = {}
157
+
107
158
  def __init__(
108
- self, x_min: float, x_max: float = float("inf"), alpha: float = 2
109
- ) -> None:
110
- self.x_min = x_min
111
- self.x_max = x_max
112
- self.alpha = alpha
113
-
114
- self.normalization = x_min ** (-self.alpha + 1)
115
- self.normalization -= x_max ** (-self.alpha + 1)
116
-
117
- def __call__(self, N: int) -> torch.Tensor:
118
- u = torch.rand(N)
119
- u *= self.normalization
120
- samples = self.x_min ** (-self.alpha + 1) - u
121
- samples = torch.pow(samples, -1.0 / (self.alpha - 1))
122
- return samples
159
+ self,
160
+ peak: float = 0.0,
161
+ validate_args=None,
162
+ ):
163
+ batch_shape = torch.Size()
164
+ super().__init__(batch_shape, validate_args=validate_args)
165
+ self.peak = torch.as_tensor(peak)
166
+
167
+ def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
168
+ return self.peak * torch.ones(
169
+ sample_shape, device=self.peak.device, dtype=torch.float32
170
+ )
ml4gw/nn/norm.py CHANGED
@@ -32,6 +32,12 @@ class GroupNorm1D(torch.nn.Module):
32
32
  self.bias = torch.nn.Parameter(torch.zeros(shape))
33
33
 
34
34
  def forward(self, x):
35
+ if len(x.shape) != 3:
36
+ raise ValueError(
37
+ "GroupNorm1D requires 3-dimensional input, "
38
+ f"received {len(x.shape)} dimensional input"
39
+ )
40
+
35
41
  keepdims = self.num_groups == self.num_channels
36
42
 
37
43
  # compute group variance via the E[x**2] - E**2[x] trick
@@ -1,4 +1,5 @@
1
1
  from .pearson import ShiftedPearsonCorrelation
2
+ from .qtransform import QScan, SingleQTransform
2
3
  from .scaler import ChannelWiseScaler
3
4
  from .snr_rescaler import SnrRescaler
4
5
  from .spectral import SpectralDensity