ml4gw 0.7.4__py3-none-any.whl → 0.7.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ml4gw might be problematic. Click here for more details.
- ml4gw/augmentations.py +4 -4
- ml4gw/dataloading/chunked_dataset.py +3 -3
- ml4gw/dataloading/hdf5_dataset.py +7 -10
- ml4gw/dataloading/in_memory_dataset.py +21 -21
- ml4gw/distributions.py +216 -10
- ml4gw/gw.py +60 -53
- ml4gw/nn/autoencoder/base.py +9 -9
- ml4gw/nn/autoencoder/convolutional.py +4 -4
- ml4gw/nn/resnet/resnet_1d.py +13 -13
- ml4gw/nn/resnet/resnet_2d.py +12 -12
- ml4gw/nn/streaming/online_average.py +1 -1
- ml4gw/nn/streaming/snapshotter.py +14 -14
- ml4gw/spectral.py +48 -48
- ml4gw/transforms/iirfilter.py +3 -3
- ml4gw/transforms/pearson.py +7 -8
- ml4gw/transforms/qtransform.py +19 -19
- ml4gw/transforms/scaler.py +4 -4
- ml4gw/transforms/spectral.py +10 -10
- ml4gw/transforms/spectrogram.py +12 -11
- ml4gw/transforms/spline_interpolation.py +8 -15
- ml4gw/transforms/transform.py +1 -1
- ml4gw/transforms/whitening.py +36 -36
- ml4gw/utils/slicing.py +40 -40
- ml4gw/waveforms/cbc/phenom_d.py +22 -66
- ml4gw/waveforms/cbc/phenom_p.py +9 -5
- ml4gw/waveforms/cbc/taylorf2.py +8 -7
- ml4gw/waveforms/conversion.py +2 -1
- ml4gw/waveforms/generator.py +33 -32
- {ml4gw-0.7.4.dist-info → ml4gw-0.7.6.dist-info}/METADATA +7 -1
- ml4gw-0.7.6.dist-info/RECORD +55 -0
- ml4gw-0.7.4.dist-info/RECORD +0 -55
- {ml4gw-0.7.4.dist-info → ml4gw-0.7.6.dist-info}/WHEEL +0 -0
- {ml4gw-0.7.4.dist-info → ml4gw-0.7.6.dist-info}/licenses/LICENSE +0 -0
ml4gw/augmentations.py
CHANGED
|
@@ -6,8 +6,8 @@ from torch import Tensor
|
|
|
6
6
|
class SignalInverter(torch.nn.Module):
|
|
7
7
|
"""
|
|
8
8
|
Takes a tensor of timeseries of arbitrary dimension
|
|
9
|
-
and randomly inverts
|
|
10
|
-
each timeseries with probability
|
|
9
|
+
and randomly inverts i.e. :math:`h(t) \\rightarrow -h(t)`
|
|
10
|
+
each timeseries with probability ``prob``.
|
|
11
11
|
|
|
12
12
|
Args:
|
|
13
13
|
prob:
|
|
@@ -29,8 +29,8 @@ class SignalInverter(torch.nn.Module):
|
|
|
29
29
|
class SignalReverser(torch.nn.Module):
|
|
30
30
|
"""
|
|
31
31
|
Takes a tensor of timeseries of arbitrary dimension
|
|
32
|
-
and randomly reverses
|
|
33
|
-
each timeseries with probability
|
|
32
|
+
and randomly reverses i.e., :math:`h(t) \\rightarrow h(-t)`.
|
|
33
|
+
each timeseries with probability ``prob``.
|
|
34
34
|
|
|
35
35
|
Args:
|
|
36
36
|
prob:
|
|
@@ -15,9 +15,9 @@ class ChunkedTimeSeriesDataset(torch.utils.data.IterableDataset):
|
|
|
15
15
|
chunk_it:
|
|
16
16
|
Iterator which will produce chunks of timeseries
|
|
17
17
|
data to sample windows from. Should have shape
|
|
18
|
-
|
|
19
|
-
to sample from,
|
|
20
|
-
and
|
|
18
|
+
``(N, C, T)``, where ``N`` is the number of chunks
|
|
19
|
+
to sample from, ``C`` is the number of channels,
|
|
20
|
+
and ``T`` is the number of samples along the
|
|
21
21
|
time dimension for each chunk.
|
|
22
22
|
kernel_size:
|
|
23
23
|
Size of windows to be sampled from each chunk.
|
|
@@ -17,8 +17,7 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
|
|
|
17
17
|
Iterable dataset that samples and loads windows of
|
|
18
18
|
timeseries data uniformly from a set of HDF5 files.
|
|
19
19
|
It is _strongly_ recommended that these files have been
|
|
20
|
-
written using
|
|
21
|
-
(https://docs.h5py.org/en/stable/high/dataset.html#chunked-storage).
|
|
20
|
+
written using `chunked storage <https://docs.h5py.org/en/stable/high/dataset.html#chunked-storage>`_.
|
|
22
21
|
This has shown to produce increases in read-time speeds
|
|
23
22
|
of over an order of magnitude.
|
|
24
23
|
|
|
@@ -37,27 +36,25 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
|
|
|
37
36
|
Number of windows to sample at each iteration.
|
|
38
37
|
batches_per_epoch:
|
|
39
38
|
Number of batches to generate during each call
|
|
40
|
-
to
|
|
39
|
+
to ``__iter__``.
|
|
41
40
|
coincident:
|
|
42
41
|
Whether windows for each channel in a given batch
|
|
43
42
|
element should be sampled coincidentally, i.e.
|
|
44
43
|
corresponding to the same time indices from the
|
|
45
44
|
same files, or should be sampled independently.
|
|
46
45
|
For the latter case, users can either specify
|
|
47
|
-
|
|
48
|
-
for each channel, or
|
|
46
|
+
``False``, which will sample filenames independently
|
|
47
|
+
for each channel, or ``"files"``, which will sample
|
|
49
48
|
windows independently within a given file for each
|
|
50
49
|
channel. The latter setting limits the amount of
|
|
51
50
|
entropy in the effective dataset, but can provide
|
|
52
51
|
over 2x improvement in total throughput.
|
|
53
52
|
num_files_per_batch:
|
|
54
53
|
The number of unique files from which to sample
|
|
55
|
-
batch elements each epoch. If left as
|
|
54
|
+
batch elements each epoch. If left as ``None``,
|
|
56
55
|
will use all available files. Useful when reading
|
|
57
56
|
from many files is bottlenecking dataloading.
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
"""
|
|
57
|
+
""" # noqa E501
|
|
61
58
|
|
|
62
59
|
def __init__(
|
|
63
60
|
self,
|
|
@@ -117,7 +114,7 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
|
|
|
117
114
|
return self.batches_per_epoch
|
|
118
115
|
|
|
119
116
|
def sample_fnames(self, size) -> np.ndarray:
|
|
120
|
-
# first, randomly select
|
|
117
|
+
# first, randomly select ``self.num_files_per_batch``
|
|
121
118
|
# file indices based on their probabilities
|
|
122
119
|
fname_indices = np.arange(len(self.fnames))
|
|
123
120
|
fname_indices = np.random.choice(
|
|
@@ -20,56 +20,56 @@ class InMemoryDataset(torch.utils.data.IterableDataset):
|
|
|
20
20
|
Args:
|
|
21
21
|
X:
|
|
22
22
|
Timeseries data to be iterated through. Should have
|
|
23
|
-
shape
|
|
23
|
+
shape ``(num_channels, length * sample_rate)``. Windows
|
|
24
24
|
will be sampled from the time (1st) dimension for all
|
|
25
25
|
channels along the channel (0th) dimension.
|
|
26
26
|
kernel_size:
|
|
27
|
-
The length of the windows to sample from
|
|
27
|
+
The length of the windows to sample from ``X`` in units
|
|
28
28
|
of samples.
|
|
29
29
|
y:
|
|
30
30
|
Target timeseries to be iterated through. If specified,
|
|
31
31
|
should be a single channel and have shape
|
|
32
|
-
|
|
33
|
-
sampled from
|
|
32
|
+
``(length * sample_rate,)``. If left as ``None``, only windows
|
|
33
|
+
sampled from ``X`` will be returned during iteration.
|
|
34
34
|
Otherwise, windows sampled from both arrays will be
|
|
35
35
|
returned. Note that if sampling is performed non-coincidentally,
|
|
36
36
|
there's no sensible way to align windows sampled from this
|
|
37
|
-
array with the windows sampled from
|
|
37
|
+
array with the windows sampled from ``X``, so this combination
|
|
38
38
|
of arguments is not permitted.
|
|
39
39
|
batch_size:
|
|
40
40
|
Maximum number of windows to return at each iteration. Will
|
|
41
41
|
be the length of the 0th dimension of the returned array(s).
|
|
42
|
-
If
|
|
43
|
-
of
|
|
42
|
+
If ``batches_per_epoch`` is specified, this will be the length
|
|
43
|
+
of **every** array returned during iteration. Otherwise, it's
|
|
44
44
|
possible that the last array will be shorter due to the number
|
|
45
45
|
of windows in the timeseries being a non-integer multiple of
|
|
46
|
-
|
|
46
|
+
``batch_size``.
|
|
47
47
|
stride:
|
|
48
48
|
The resolution at which windows will be sampled from the
|
|
49
49
|
specified timeseries, in units of samples. E.g. if
|
|
50
|
-
|
|
51
|
-
from an index of
|
|
50
|
+
``stride=2``, the first sample of each window can only be
|
|
51
|
+
from an index of ``X`` which is a multiple of 2. Obviously,
|
|
52
52
|
this reduces the number of windows which can be iterated
|
|
53
|
-
through by a factor of
|
|
53
|
+
through by a factor of ``stride``.
|
|
54
54
|
batches_per_epoch:
|
|
55
55
|
Number of batches of window to produce during iteration
|
|
56
|
-
before raising a
|
|
56
|
+
before raising a ``StopIteration``. Must be specified if
|
|
57
57
|
performing non-coincident sampling. Otherwise, if left
|
|
58
|
-
as
|
|
58
|
+
as ``None``, windows will be sampled until the entire
|
|
59
59
|
timeseries has been exhausted. Note that
|
|
60
|
-
|
|
60
|
+
``batch_size * batches_per_epoch`` must be be small
|
|
61
61
|
enough to be able to be fulfilled by the number of
|
|
62
|
-
windows in the timeseries, otherise a
|
|
62
|
+
windows in the timeseries, otherise a ``ValueError``
|
|
63
63
|
will be raised.
|
|
64
64
|
coincident:
|
|
65
|
-
Whether to sample windows from the channels of
|
|
65
|
+
Whether to sample windows from the channels of ``X``
|
|
66
66
|
using the same indices or independently. Can't be
|
|
67
|
-
|
|
68
|
-
|
|
67
|
+
``True`` if ``batches_per_epoch`` is ``None`` or ``y`` is
|
|
68
|
+
**not** ``None``.
|
|
69
69
|
shuffle:
|
|
70
70
|
Whether to sample windows from timeseries randomly
|
|
71
|
-
or in order along the time axis. If
|
|
72
|
-
and
|
|
71
|
+
or in order along the time axis. If ``coincident=False``
|
|
72
|
+
and ``shuffle=False``, channels will be iterated through
|
|
73
73
|
with the index along the last channel moving fastest.
|
|
74
74
|
device:
|
|
75
75
|
Which device to host the timeseries arrays on
|
|
@@ -91,7 +91,7 @@ class InMemoryDataset(torch.utils.data.IterableDataset):
|
|
|
91
91
|
|
|
92
92
|
# make sure if we specified a target array that all other
|
|
93
93
|
# other necessary conditions are met (it has the same
|
|
94
|
-
# length as
|
|
94
|
+
# length as ``X`` and we're sampling coincidentally)
|
|
95
95
|
if y is not None and y.shape[-1] != X.shape[-1]:
|
|
96
96
|
raise ValueError(
|
|
97
97
|
"Target timeseries must have same length as input"
|
ml4gw/distributions.py
CHANGED
|
@@ -1,24 +1,30 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Module containing callables classes for generating samples
|
|
3
3
|
from specified distributions. Each callable should map from
|
|
4
|
-
an integer
|
|
4
|
+
an integer ``N`` to a 1D torch ``Tensor`` containing ``N`` samples
|
|
5
5
|
from the corresponding distribution.
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
import math
|
|
9
|
-
from typing import Optional
|
|
9
|
+
from typing import Callable, Optional
|
|
10
10
|
|
|
11
11
|
import torch
|
|
12
12
|
import torch.distributions as dist
|
|
13
13
|
from jaxtyping import Float
|
|
14
14
|
from torch import Tensor
|
|
15
15
|
|
|
16
|
+
from ml4gw.constants import C
|
|
17
|
+
|
|
18
|
+
_PLANCK18_H0 = 67.66 # Hubble constant in km/s/Mpc
|
|
19
|
+
_PLANCK18_OMEGA_M = 0.30966 # Matter density parameter
|
|
20
|
+
|
|
16
21
|
|
|
17
22
|
class Cosine(dist.Distribution):
|
|
18
23
|
"""
|
|
19
24
|
Cosine distribution based on
|
|
20
|
-
``torch.distributions.TransformedDistribution
|
|
21
|
-
|
|
25
|
+
``torch.distributions.TransformedDistribution``
|
|
26
|
+
(see `documentation <https://docs.pytorch.org/docs/stable/distributions.html#transformeddistribution>`_).
|
|
27
|
+
""" # noqa E501
|
|
22
28
|
|
|
23
29
|
arg_constraints = {}
|
|
24
30
|
|
|
@@ -112,18 +118,17 @@ class LogNormal(dist.LogNormal):
|
|
|
112
118
|
class PowerLaw(dist.TransformedDistribution):
|
|
113
119
|
"""
|
|
114
120
|
Sample from a power law distribution,
|
|
115
|
-
|
|
116
|
-
|
|
121
|
+
|
|
122
|
+
.. math:: p(x) \\approx x^{\\alpha}.
|
|
117
123
|
|
|
118
124
|
Index alpha cannot be 0, since it is equivalent to a Uniform distribution.
|
|
119
125
|
This could be used, for example, as a universal distribution of
|
|
120
126
|
signal-to-noise ratios (SNRs) from uniformly volume distributed
|
|
121
127
|
sources
|
|
122
|
-
.. math::
|
|
123
128
|
|
|
124
|
-
|
|
129
|
+
.. math:: p(\\rho) = 3\;\\rho_0^3 / \\rho^4
|
|
125
130
|
|
|
126
|
-
where :math
|
|
131
|
+
where :math:`\\rho_0` is a representative minimum SNR
|
|
127
132
|
considered for detection. See, for example,
|
|
128
133
|
`Schutz (2011) <https://arxiv.org/abs/1102.5421>`_.
|
|
129
134
|
Or, for example, ``index=2`` for uniform in Euclidean volume.
|
|
@@ -135,7 +140,7 @@ class PowerLaw(dist.TransformedDistribution):
|
|
|
135
140
|
self, minimum: float, maximum: float, index: int, validate_args=None
|
|
136
141
|
):
|
|
137
142
|
if index == 0:
|
|
138
|
-
raise
|
|
143
|
+
raise ValueError("Index of 0 is the same as Uniform")
|
|
139
144
|
elif index == -1:
|
|
140
145
|
base_min = torch.as_tensor(minimum).log()
|
|
141
146
|
base_max = torch.as_tensor(maximum).log()
|
|
@@ -173,3 +178,204 @@ class DeltaFunction(dist.Distribution):
|
|
|
173
178
|
return self.peak * torch.ones(
|
|
174
179
|
sample_shape, device=self.peak.device, dtype=torch.float32
|
|
175
180
|
)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
class UniformComovingVolume(dist.Distribution):
|
|
184
|
+
"""
|
|
185
|
+
Sample either redshift, comoving distance, or luminosity distance
|
|
186
|
+
such that they are uniform in comoving volume, assuming a flat
|
|
187
|
+
lambda-CDM cosmology. Default H0 and Omega_M values match
|
|
188
|
+
`Planck18 parameters in Astropy <https://docs.astropy.org/en/latest/api/astropy.cosmology.realizations.Planck18.html>`_.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
minimum: Minimum distance in the specified distance type
|
|
192
|
+
maximum: Maximum distance in the specified distance type
|
|
193
|
+
distance_type:
|
|
194
|
+
Type of distance to sample from. Can be ``redshift``,
|
|
195
|
+
``comoving_distance``, or ``luminosity_distance``
|
|
196
|
+
h0: Hubble constant in km/s/Mpc
|
|
197
|
+
omega_m: Matter density parameter
|
|
198
|
+
z_max: Maximum redshift for the grid
|
|
199
|
+
grid_size: Number of points in the grid for interpolation
|
|
200
|
+
validate_args: Whether to validate arguments
|
|
201
|
+
"""
|
|
202
|
+
|
|
203
|
+
arg_constraints = {}
|
|
204
|
+
support = dist.constraints.nonnegative
|
|
205
|
+
|
|
206
|
+
def __init__(
|
|
207
|
+
self,
|
|
208
|
+
minimum: float,
|
|
209
|
+
maximum: float,
|
|
210
|
+
distance_type: str = "redshift",
|
|
211
|
+
h0: float = _PLANCK18_H0,
|
|
212
|
+
omega_m: float = _PLANCK18_OMEGA_M,
|
|
213
|
+
z_grid_max: float = 5,
|
|
214
|
+
grid_size: int = 10000,
|
|
215
|
+
validate_args: bool = None,
|
|
216
|
+
):
|
|
217
|
+
super().__init__(validate_args=validate_args)
|
|
218
|
+
if distance_type not in [
|
|
219
|
+
"redshift",
|
|
220
|
+
"comoving_distance",
|
|
221
|
+
"luminosity_distance",
|
|
222
|
+
]:
|
|
223
|
+
raise ValueError(
|
|
224
|
+
"Distance type must be 'redshift', 'comoving_distance', "
|
|
225
|
+
f"or 'luminosity_distance'; got {distance_type}"
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
self.minimum = minimum
|
|
229
|
+
self.maximum = maximum
|
|
230
|
+
self.distance_type = distance_type
|
|
231
|
+
self.grid_size = grid_size
|
|
232
|
+
self.z_grid_max = z_grid_max
|
|
233
|
+
self.h0 = h0
|
|
234
|
+
self.omega_m = omega_m
|
|
235
|
+
|
|
236
|
+
# Compute redshift range based on the given min and max distances
|
|
237
|
+
z_min, z_max = self._get_z_bounds()
|
|
238
|
+
if z_max > z_grid_max:
|
|
239
|
+
raise ValueError(
|
|
240
|
+
f"Maximum {distance_type} {maximum} "
|
|
241
|
+
f"exceeds given z_max {z_grid_max}."
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
# Restrict distance grids to the specified redshift range
|
|
245
|
+
mask = (self.z_grid >= z_min) & (self.z_grid <= z_max)
|
|
246
|
+
self.distance_grid = self.distance_grid[mask]
|
|
247
|
+
self.z_grid = self.z_grid[mask]
|
|
248
|
+
self.comoving_dist_grid = self.comoving_dist_grid[mask]
|
|
249
|
+
self.luminosity_dist_grid = self.luminosity_dist_grid[mask]
|
|
250
|
+
# Compute probability arrays from those grids
|
|
251
|
+
self._generate_probability_grids()
|
|
252
|
+
|
|
253
|
+
def _hubble_function(self):
|
|
254
|
+
"""
|
|
255
|
+
Compute H(z) assuming a flat lambda-CDM cosmology.
|
|
256
|
+
"""
|
|
257
|
+
omega_l = 1 - self.omega_m
|
|
258
|
+
return self.h0 * torch.sqrt(
|
|
259
|
+
self.omega_m * (1 + self.z_grid) ** 3 + omega_l
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
def _get_z_bounds(self):
|
|
263
|
+
"""
|
|
264
|
+
Compute the bounds on redshift based on the given minimum and maximum
|
|
265
|
+
distances, using the specified distance type.
|
|
266
|
+
"""
|
|
267
|
+
self._generate_distance_grids()
|
|
268
|
+
bounds = torch.tensor([self.minimum, self.maximum])
|
|
269
|
+
z_min, z_max = self._linear_interp_1d(
|
|
270
|
+
self.distance_grid, self.z_grid, bounds
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
return z_min, z_max
|
|
274
|
+
|
|
275
|
+
def _generate_distance_grids(self):
|
|
276
|
+
"""
|
|
277
|
+
Generate distance grids based on the specified redshift range.
|
|
278
|
+
"""
|
|
279
|
+
self.z_grid = torch.linspace(0, self.z_grid_max, self.grid_size)
|
|
280
|
+
self.dz = self.z_grid[1] - self.z_grid[0]
|
|
281
|
+
# C is specfied in m/s, h0 in km/s/Mpc, so divide by 1000 to convert
|
|
282
|
+
comoving_dist_grid = (
|
|
283
|
+
torch.cumulative_trapezoid(
|
|
284
|
+
(C / self._hubble_function()), self.z_grid
|
|
285
|
+
)
|
|
286
|
+
/ 1000
|
|
287
|
+
)
|
|
288
|
+
zero_prefix = torch.zeros(1, dtype=comoving_dist_grid.dtype)
|
|
289
|
+
self.comoving_dist_grid = torch.cat([zero_prefix, comoving_dist_grid])
|
|
290
|
+
self.luminosity_dist_grid = self.comoving_dist_grid * (1 + self.z_grid)
|
|
291
|
+
|
|
292
|
+
if self.distance_type == "redshift":
|
|
293
|
+
self.distance_grid = self.z_grid
|
|
294
|
+
elif self.distance_type == "comoving_distance":
|
|
295
|
+
self.distance_grid = self.comoving_dist_grid
|
|
296
|
+
else: # luminosity_distance
|
|
297
|
+
self.distance_grid = self.luminosity_dist_grid
|
|
298
|
+
|
|
299
|
+
def _p_of_distance(self):
|
|
300
|
+
"""
|
|
301
|
+
Compute the unnormalized probability as a function of distance
|
|
302
|
+
"""
|
|
303
|
+
dV_dz = self.comoving_dist_grid**2 / self._hubble_function()
|
|
304
|
+
# This is a tensor of ones if the distance type is redshift
|
|
305
|
+
jacobian = torch.gradient(self.distance_grid, spacing=self.dz)[0]
|
|
306
|
+
return dV_dz / jacobian
|
|
307
|
+
|
|
308
|
+
def _generate_probability_grids(self):
|
|
309
|
+
"""
|
|
310
|
+
Compute the pdf, cdf, and log pdf based on the
|
|
311
|
+
comoving volume differential and distance grid.
|
|
312
|
+
"""
|
|
313
|
+
p_of_distance = self._p_of_distance()
|
|
314
|
+
self.pdf = p_of_distance / torch.trapz(
|
|
315
|
+
p_of_distance, self.distance_grid
|
|
316
|
+
)
|
|
317
|
+
cdf = torch.cumulative_trapezoid(self.pdf, self.distance_grid)
|
|
318
|
+
zero_prefix = torch.zeros(1, dtype=cdf.dtype)
|
|
319
|
+
self.cdf = torch.cat([zero_prefix, cdf])
|
|
320
|
+
self.log_pdf = torch.log(self.pdf)
|
|
321
|
+
|
|
322
|
+
def _linear_interp_1d(self, x_grid, y_grid, x_query):
|
|
323
|
+
idx = torch.bucketize(x_query, x_grid, right=True)
|
|
324
|
+
idx = idx.clamp(min=1, max=len(x_grid) - 1)
|
|
325
|
+
|
|
326
|
+
x0 = x_grid[idx - 1]
|
|
327
|
+
x1 = x_grid[idx]
|
|
328
|
+
y0 = y_grid[idx - 1]
|
|
329
|
+
y1 = y_grid[idx]
|
|
330
|
+
|
|
331
|
+
t = (x_query - x0) / (x1 - x0)
|
|
332
|
+
return y0 + t * (y1 - y0)
|
|
333
|
+
|
|
334
|
+
def rsample(self, sample_shape: torch.Size = None) -> Tensor:
|
|
335
|
+
sample_shape = sample_shape or torch.Size()
|
|
336
|
+
u = torch.rand(sample_shape)
|
|
337
|
+
return self._linear_interp_1d(self.cdf, self.distance_grid, u)
|
|
338
|
+
|
|
339
|
+
def log_prob(self, value: Tensor) -> Tensor:
|
|
340
|
+
log_prob = self._linear_interp_1d(
|
|
341
|
+
self.distance_grid, self.log_pdf, value
|
|
342
|
+
)
|
|
343
|
+
inside_range = (value >= self.minimum) & (value <= self.maximum)
|
|
344
|
+
log_prob[~inside_range] = float("-inf")
|
|
345
|
+
return log_prob
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
class RateEvolution(UniformComovingVolume):
|
|
349
|
+
"""
|
|
350
|
+
Wrapper around :meth:`~ml4gw.distributions.UniformComovingVolume` to allow for
|
|
351
|
+
arbitrary rate evolution functions. E.g., if
|
|
352
|
+
``rate_function = lambda z: 1 / (1 + z)``, then the distribution
|
|
353
|
+
will sample values such that they occur uniform in
|
|
354
|
+
source frame time.
|
|
355
|
+
|
|
356
|
+
Args:
|
|
357
|
+
rate_function: Callable that takes redshift as input
|
|
358
|
+
and returns the rate evolution factor.
|
|
359
|
+
*args: Arguments passed to
|
|
360
|
+
:meth:`~ml4gw.distributions.UniformComovingVolume` constructor.
|
|
361
|
+
**kwargs: Keyword arguments passed to
|
|
362
|
+
:meth:`~ml4gw.distributions.UniformComovingVolume` constructor.
|
|
363
|
+
""" # noqa E501
|
|
364
|
+
|
|
365
|
+
def __init__(
|
|
366
|
+
self,
|
|
367
|
+
rate_function: Callable,
|
|
368
|
+
*args,
|
|
369
|
+
**kwargs,
|
|
370
|
+
):
|
|
371
|
+
self.rate_function = rate_function
|
|
372
|
+
super().__init__(*args, **kwargs)
|
|
373
|
+
|
|
374
|
+
def _p_of_distance(self):
|
|
375
|
+
"""
|
|
376
|
+
Compute the unnormalized probability as a function of distance
|
|
377
|
+
"""
|
|
378
|
+
dV_dz = self.comoving_dist_grid**2 / self._hubble_function()
|
|
379
|
+
# This is a tensor of ones if the distance type is redshift
|
|
380
|
+
jacobian = torch.gradient(self.distance_grid, spacing=self.dz)[0]
|
|
381
|
+
return dV_dz / jacobian * self.rate_function(self.z_grid)
|