ml4gw 0.7.4__py3-none-any.whl → 0.7.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ml4gw might be problematic. Click here for more details.
ml4gw/distributions.py
CHANGED
|
@@ -6,13 +6,18 @@ from the corresponding distribution.
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
import math
|
|
9
|
-
from typing import Optional
|
|
9
|
+
from typing import Callable, Optional
|
|
10
10
|
|
|
11
11
|
import torch
|
|
12
12
|
import torch.distributions as dist
|
|
13
13
|
from jaxtyping import Float
|
|
14
14
|
from torch import Tensor
|
|
15
15
|
|
|
16
|
+
from ml4gw.constants import C
|
|
17
|
+
|
|
18
|
+
_PLANCK18_H0 = 67.66 # Hubble constant in km/s/Mpc
|
|
19
|
+
_PLANCK18_OMEGA_M = 0.30966 # Matter density parameter
|
|
20
|
+
|
|
16
21
|
|
|
17
22
|
class Cosine(dist.Distribution):
|
|
18
23
|
"""
|
|
@@ -173,3 +178,202 @@ class DeltaFunction(dist.Distribution):
|
|
|
173
178
|
return self.peak * torch.ones(
|
|
174
179
|
sample_shape, device=self.peak.device, dtype=torch.float32
|
|
175
180
|
)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
class UniformComovingVolume(dist.Distribution):
|
|
184
|
+
"""
|
|
185
|
+
Sample either redshift, comoving distance, or luminosity distance
|
|
186
|
+
such that they are uniform in comoving volume, assuming a flat
|
|
187
|
+
lambda-CDM cosmology. Default H0 and Omega_M values match
|
|
188
|
+
astropy.cosmology.Planck18
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
minimum: Minimum distance in the specified distance type
|
|
192
|
+
maximum: Maximum distance in the specified distance type
|
|
193
|
+
distance_type:
|
|
194
|
+
Type of distance to sample from. Can be 'redshift',
|
|
195
|
+
'comoving_distance', or 'luminosity_distance'
|
|
196
|
+
h0: Hubble constant in km/s/Mpc
|
|
197
|
+
omega_m: Matter density parameter
|
|
198
|
+
z_max: Maximum redshift for the grid
|
|
199
|
+
grid_size: Number of points in the grid for interpolation
|
|
200
|
+
validate_args: Whether to validate arguments
|
|
201
|
+
"""
|
|
202
|
+
|
|
203
|
+
arg_constraints = {}
|
|
204
|
+
support = dist.constraints.nonnegative
|
|
205
|
+
|
|
206
|
+
def __init__(
|
|
207
|
+
self,
|
|
208
|
+
minimum: float,
|
|
209
|
+
maximum: float,
|
|
210
|
+
distance_type: str = "redshift",
|
|
211
|
+
h0: float = _PLANCK18_H0,
|
|
212
|
+
omega_m: float = _PLANCK18_OMEGA_M,
|
|
213
|
+
z_grid_max: float = 5,
|
|
214
|
+
grid_size: int = 10000,
|
|
215
|
+
validate_args: bool = None,
|
|
216
|
+
):
|
|
217
|
+
super().__init__(validate_args=validate_args)
|
|
218
|
+
if distance_type not in [
|
|
219
|
+
"redshift",
|
|
220
|
+
"comoving_distance",
|
|
221
|
+
"luminosity_distance",
|
|
222
|
+
]:
|
|
223
|
+
raise ValueError(
|
|
224
|
+
"Distance type must be 'redshift', 'comoving_distance', "
|
|
225
|
+
f"or 'luminosity_distance'; got {distance_type}"
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
self.minimum = minimum
|
|
229
|
+
self.maximum = maximum
|
|
230
|
+
self.distance_type = distance_type
|
|
231
|
+
self.grid_size = grid_size
|
|
232
|
+
self.z_grid_max = z_grid_max
|
|
233
|
+
self.h0 = h0
|
|
234
|
+
self.omega_m = omega_m
|
|
235
|
+
|
|
236
|
+
# Compute redshift range based on the given min and max distances
|
|
237
|
+
z_min, z_max = self._get_z_bounds()
|
|
238
|
+
if z_max > z_grid_max:
|
|
239
|
+
raise ValueError(
|
|
240
|
+
f"Maximum {distance_type} {maximum} "
|
|
241
|
+
f"exceeds given z_max {z_grid_max}."
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
# Restrict distance grids to the specified redshift range
|
|
245
|
+
mask = (self.z_grid >= z_min) & (self.z_grid <= z_max)
|
|
246
|
+
self.distance_grid = self.distance_grid[mask]
|
|
247
|
+
self.z_grid = self.z_grid[mask]
|
|
248
|
+
self.comoving_dist_grid = self.comoving_dist_grid[mask]
|
|
249
|
+
self.luminosity_dist_grid = self.luminosity_dist_grid[mask]
|
|
250
|
+
# Compute probability arrays from those grids
|
|
251
|
+
self._generate_probability_grids()
|
|
252
|
+
|
|
253
|
+
def _hubble_function(self):
|
|
254
|
+
"""
|
|
255
|
+
Compute H(z) assuming a flat lambda-CDM cosmology.
|
|
256
|
+
"""
|
|
257
|
+
omega_l = 1 - self.omega_m
|
|
258
|
+
return self.h0 * torch.sqrt(
|
|
259
|
+
self.omega_m * (1 + self.z_grid) ** 3 + omega_l
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
def _get_z_bounds(self):
|
|
263
|
+
"""
|
|
264
|
+
Compute the bounds on redshift based on the given minimum and maximum
|
|
265
|
+
distances, using the specified distance type.
|
|
266
|
+
"""
|
|
267
|
+
self._generate_distance_grids()
|
|
268
|
+
bounds = torch.tensor([self.minimum, self.maximum])
|
|
269
|
+
z_min, z_max = self._linear_interp_1d(
|
|
270
|
+
self.distance_grid, self.z_grid, bounds
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
return z_min, z_max
|
|
274
|
+
|
|
275
|
+
def _generate_distance_grids(self):
|
|
276
|
+
"""
|
|
277
|
+
Generate distance grids based on the specified redshift range.
|
|
278
|
+
"""
|
|
279
|
+
self.z_grid = torch.linspace(0, self.z_grid_max, self.grid_size)
|
|
280
|
+
self.dz = self.z_grid[1] - self.z_grid[0]
|
|
281
|
+
# C is specfied in m/s, h0 in km/s/Mpc, so divide by 1000 to convert
|
|
282
|
+
comoving_dist_grid = (
|
|
283
|
+
torch.cumulative_trapezoid(
|
|
284
|
+
(C / self._hubble_function()), self.z_grid
|
|
285
|
+
)
|
|
286
|
+
/ 1000
|
|
287
|
+
)
|
|
288
|
+
zero_prefix = torch.zeros(1, dtype=comoving_dist_grid.dtype)
|
|
289
|
+
self.comoving_dist_grid = torch.cat([zero_prefix, comoving_dist_grid])
|
|
290
|
+
self.luminosity_dist_grid = self.comoving_dist_grid * (1 + self.z_grid)
|
|
291
|
+
|
|
292
|
+
if self.distance_type == "redshift":
|
|
293
|
+
self.distance_grid = self.z_grid
|
|
294
|
+
elif self.distance_type == "comoving_distance":
|
|
295
|
+
self.distance_grid = self.comoving_dist_grid
|
|
296
|
+
else: # luminosity_distance
|
|
297
|
+
self.distance_grid = self.luminosity_dist_grid
|
|
298
|
+
|
|
299
|
+
def _p_of_distance(self):
|
|
300
|
+
"""
|
|
301
|
+
Compute the unnormalized probability as a function of distance
|
|
302
|
+
"""
|
|
303
|
+
dV_dz = self.comoving_dist_grid**2 / self._hubble_function()
|
|
304
|
+
# This is a tensor of ones if the distance type is redshift
|
|
305
|
+
jacobian = torch.gradient(self.distance_grid, spacing=self.dz)[0]
|
|
306
|
+
return dV_dz / jacobian
|
|
307
|
+
|
|
308
|
+
def _generate_probability_grids(self):
|
|
309
|
+
"""
|
|
310
|
+
Compute the pdf, cdf, and log pdf based on the
|
|
311
|
+
comoving volume differential and distance grid.
|
|
312
|
+
"""
|
|
313
|
+
p_of_distance = self._p_of_distance()
|
|
314
|
+
self.pdf = p_of_distance / torch.trapz(
|
|
315
|
+
p_of_distance, self.distance_grid
|
|
316
|
+
)
|
|
317
|
+
cdf = torch.cumulative_trapezoid(self.pdf, self.distance_grid)
|
|
318
|
+
zero_prefix = torch.zeros(1, dtype=cdf.dtype)
|
|
319
|
+
self.cdf = torch.cat([zero_prefix, cdf])
|
|
320
|
+
self.log_pdf = torch.log(self.pdf)
|
|
321
|
+
|
|
322
|
+
def _linear_interp_1d(self, x_grid, y_grid, x_query):
|
|
323
|
+
idx = torch.bucketize(x_query, x_grid, right=True)
|
|
324
|
+
idx = idx.clamp(min=1, max=len(x_grid) - 1)
|
|
325
|
+
|
|
326
|
+
x0 = x_grid[idx - 1]
|
|
327
|
+
x1 = x_grid[idx]
|
|
328
|
+
y0 = y_grid[idx - 1]
|
|
329
|
+
y1 = y_grid[idx]
|
|
330
|
+
|
|
331
|
+
t = (x_query - x0) / (x1 - x0)
|
|
332
|
+
return y0 + t * (y1 - y0)
|
|
333
|
+
|
|
334
|
+
def rsample(self, sample_shape: torch.Size = None) -> Tensor:
|
|
335
|
+
sample_shape = sample_shape or torch.Size()
|
|
336
|
+
u = torch.rand(sample_shape)
|
|
337
|
+
return self._linear_interp_1d(self.cdf, self.distance_grid, u)
|
|
338
|
+
|
|
339
|
+
def log_prob(self, value: Tensor) -> Tensor:
|
|
340
|
+
log_prob = self._linear_interp_1d(
|
|
341
|
+
self.distance_grid, self.log_pdf, value
|
|
342
|
+
)
|
|
343
|
+
inside_range = (value >= self.minimum) & (value <= self.maximum)
|
|
344
|
+
log_prob[~inside_range] = float("-inf")
|
|
345
|
+
return log_prob
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
class RateEvolution(UniformComovingVolume):
|
|
349
|
+
"""
|
|
350
|
+
Wrapper around `UniformComovingVolume` to allow for
|
|
351
|
+
arbitrary rate evolution functions. E.g., if
|
|
352
|
+
`rate_function = 1 / (1 + z)`, then the distribution
|
|
353
|
+
will sample values such that they occur uniform in
|
|
354
|
+
source frame time.
|
|
355
|
+
|
|
356
|
+
Args:
|
|
357
|
+
rate_function: Callable that takes redshift as input
|
|
358
|
+
and returns the rate evolution factor.
|
|
359
|
+
*args, **kwargs: Arguments passed to `UniformComovingVolume`
|
|
360
|
+
constructor.
|
|
361
|
+
"""
|
|
362
|
+
|
|
363
|
+
def __init__(
|
|
364
|
+
self,
|
|
365
|
+
rate_function: Callable,
|
|
366
|
+
*args,
|
|
367
|
+
**kwargs,
|
|
368
|
+
):
|
|
369
|
+
self.rate_function = rate_function
|
|
370
|
+
super().__init__(*args, **kwargs)
|
|
371
|
+
|
|
372
|
+
def _p_of_distance(self):
|
|
373
|
+
"""
|
|
374
|
+
Compute the unnormalized probability as a function of distance
|
|
375
|
+
"""
|
|
376
|
+
dV_dz = self.comoving_dist_grid**2 / self._hubble_function()
|
|
377
|
+
# This is a tensor of ones if the distance type is redshift
|
|
378
|
+
jacobian = torch.gradient(self.distance_grid, spacing=self.dz)[0]
|
|
379
|
+
return dV_dz / jacobian * self.rate_function(self.z_grid)
|
|
@@ -1,9 +1,15 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ml4gw
|
|
3
|
-
Version: 0.7.
|
|
3
|
+
Version: 0.7.5
|
|
4
4
|
Summary: Tools for training torch models on gravitational wave data
|
|
5
5
|
Author-email: Ethan Marx <emarx@mit.edu>, Will Benoit <benoi090@umn.edu>, Deep Chatterjee <deep1018@mit.edu>, Alec Gunny <alec.gunny@ligo.org>
|
|
6
6
|
License-File: LICENSE
|
|
7
|
+
Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
|
|
8
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
9
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
7
13
|
Requires-Python: <3.13,>=3.9
|
|
8
14
|
Requires-Dist: jaxtyping<0.3,>=0.2
|
|
9
15
|
Requires-Dist: numpy<2.0.0
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
ml4gw/__init__.py,sha256=81quoggCuIypZjZs3bbf1Ty70KHdva5RGEJxi0oC57E,25
|
|
2
2
|
ml4gw/augmentations.py,sha256=pZH9tjEpXV0AIqvHHDkpUE-BorG02beOz2pmSipw2EY,1232
|
|
3
3
|
ml4gw/constants.py,sha256=RQPXwavlw_cWu3ByltvTejPsi6EWXHDJQ1HaV9iE3Lg,850
|
|
4
|
-
ml4gw/distributions.py,sha256=
|
|
4
|
+
ml4gw/distributions.py,sha256=YbkPqeYBDC91aM59R7-n6NpBMgvMZcZoOAW_U-Jgrdo,12420
|
|
5
5
|
ml4gw/gw.py,sha256=0ovW_HJ3j2b5Yq3mduYtGLSl2RrvFyNNcOsZFf7koHY,19794
|
|
6
6
|
ml4gw/spectral.py,sha256=sao_D0ceeMEatABfiabpqb-xxRfQO8Tz7yk9N7ciOAU,19858
|
|
7
7
|
ml4gw/types.py,sha256=CcctqDcNajR7khGT6BD-WYsfRKpiP0udoSAB0k1qcFw,863
|
|
@@ -49,7 +49,7 @@ ml4gw/waveforms/cbc/phenom_d_data.py,sha256=WA1FBxUp9fo1IQaV_OLJ_5g5gI166mY1FtG9
|
|
|
49
49
|
ml4gw/waveforms/cbc/phenom_p.py,sha256=RZzzKQzqZW3rQuWZ41htTZOwwulYP61ow87HRRrel5A,27612
|
|
50
50
|
ml4gw/waveforms/cbc/taylorf2.py,sha256=cmYrVL29dwX2Icp7I6SXqRIjtPmoljK5DP_ofx2heiM,10505
|
|
51
51
|
ml4gw/waveforms/cbc/utils.py,sha256=LT1ky10_6ZrbwTcxIrWP1O75GUEuU5q2ZE2yYDhadQE,3037
|
|
52
|
-
ml4gw-0.7.
|
|
53
|
-
ml4gw-0.7.
|
|
54
|
-
ml4gw-0.7.
|
|
55
|
-
ml4gw-0.7.
|
|
52
|
+
ml4gw-0.7.5.dist-info/METADATA,sha256=KI_oVRLxKfCFsBESasNrDh2QfMR4lsriRZwCCAYFt4c,3380
|
|
53
|
+
ml4gw-0.7.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
54
|
+
ml4gw-0.7.5.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
|
55
|
+
ml4gw-0.7.5.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|