ml4gw 0.7.4__py3-none-any.whl → 0.7.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ml4gw might be problematic. Click here for more details.

ml4gw/distributions.py CHANGED
@@ -6,13 +6,18 @@ from the corresponding distribution.
6
6
  """
7
7
 
8
8
  import math
9
- from typing import Optional
9
+ from typing import Callable, Optional
10
10
 
11
11
  import torch
12
12
  import torch.distributions as dist
13
13
  from jaxtyping import Float
14
14
  from torch import Tensor
15
15
 
16
+ from ml4gw.constants import C
17
+
18
+ _PLANCK18_H0 = 67.66 # Hubble constant in km/s/Mpc
19
+ _PLANCK18_OMEGA_M = 0.30966 # Matter density parameter
20
+
16
21
 
17
22
  class Cosine(dist.Distribution):
18
23
  """
@@ -173,3 +178,202 @@ class DeltaFunction(dist.Distribution):
173
178
  return self.peak * torch.ones(
174
179
  sample_shape, device=self.peak.device, dtype=torch.float32
175
180
  )
181
+
182
+
183
+ class UniformComovingVolume(dist.Distribution):
184
+ """
185
+ Sample either redshift, comoving distance, or luminosity distance
186
+ such that they are uniform in comoving volume, assuming a flat
187
+ lambda-CDM cosmology. Default H0 and Omega_M values match
188
+ astropy.cosmology.Planck18
189
+
190
+ Args:
191
+ minimum: Minimum distance in the specified distance type
192
+ maximum: Maximum distance in the specified distance type
193
+ distance_type:
194
+ Type of distance to sample from. Can be 'redshift',
195
+ 'comoving_distance', or 'luminosity_distance'
196
+ h0: Hubble constant in km/s/Mpc
197
+ omega_m: Matter density parameter
198
+ z_max: Maximum redshift for the grid
199
+ grid_size: Number of points in the grid for interpolation
200
+ validate_args: Whether to validate arguments
201
+ """
202
+
203
+ arg_constraints = {}
204
+ support = dist.constraints.nonnegative
205
+
206
+ def __init__(
207
+ self,
208
+ minimum: float,
209
+ maximum: float,
210
+ distance_type: str = "redshift",
211
+ h0: float = _PLANCK18_H0,
212
+ omega_m: float = _PLANCK18_OMEGA_M,
213
+ z_grid_max: float = 5,
214
+ grid_size: int = 10000,
215
+ validate_args: bool = None,
216
+ ):
217
+ super().__init__(validate_args=validate_args)
218
+ if distance_type not in [
219
+ "redshift",
220
+ "comoving_distance",
221
+ "luminosity_distance",
222
+ ]:
223
+ raise ValueError(
224
+ "Distance type must be 'redshift', 'comoving_distance', "
225
+ f"or 'luminosity_distance'; got {distance_type}"
226
+ )
227
+
228
+ self.minimum = minimum
229
+ self.maximum = maximum
230
+ self.distance_type = distance_type
231
+ self.grid_size = grid_size
232
+ self.z_grid_max = z_grid_max
233
+ self.h0 = h0
234
+ self.omega_m = omega_m
235
+
236
+ # Compute redshift range based on the given min and max distances
237
+ z_min, z_max = self._get_z_bounds()
238
+ if z_max > z_grid_max:
239
+ raise ValueError(
240
+ f"Maximum {distance_type} {maximum} "
241
+ f"exceeds given z_max {z_grid_max}."
242
+ )
243
+
244
+ # Restrict distance grids to the specified redshift range
245
+ mask = (self.z_grid >= z_min) & (self.z_grid <= z_max)
246
+ self.distance_grid = self.distance_grid[mask]
247
+ self.z_grid = self.z_grid[mask]
248
+ self.comoving_dist_grid = self.comoving_dist_grid[mask]
249
+ self.luminosity_dist_grid = self.luminosity_dist_grid[mask]
250
+ # Compute probability arrays from those grids
251
+ self._generate_probability_grids()
252
+
253
+ def _hubble_function(self):
254
+ """
255
+ Compute H(z) assuming a flat lambda-CDM cosmology.
256
+ """
257
+ omega_l = 1 - self.omega_m
258
+ return self.h0 * torch.sqrt(
259
+ self.omega_m * (1 + self.z_grid) ** 3 + omega_l
260
+ )
261
+
262
+ def _get_z_bounds(self):
263
+ """
264
+ Compute the bounds on redshift based on the given minimum and maximum
265
+ distances, using the specified distance type.
266
+ """
267
+ self._generate_distance_grids()
268
+ bounds = torch.tensor([self.minimum, self.maximum])
269
+ z_min, z_max = self._linear_interp_1d(
270
+ self.distance_grid, self.z_grid, bounds
271
+ )
272
+
273
+ return z_min, z_max
274
+
275
+ def _generate_distance_grids(self):
276
+ """
277
+ Generate distance grids based on the specified redshift range.
278
+ """
279
+ self.z_grid = torch.linspace(0, self.z_grid_max, self.grid_size)
280
+ self.dz = self.z_grid[1] - self.z_grid[0]
281
+ # C is specfied in m/s, h0 in km/s/Mpc, so divide by 1000 to convert
282
+ comoving_dist_grid = (
283
+ torch.cumulative_trapezoid(
284
+ (C / self._hubble_function()), self.z_grid
285
+ )
286
+ / 1000
287
+ )
288
+ zero_prefix = torch.zeros(1, dtype=comoving_dist_grid.dtype)
289
+ self.comoving_dist_grid = torch.cat([zero_prefix, comoving_dist_grid])
290
+ self.luminosity_dist_grid = self.comoving_dist_grid * (1 + self.z_grid)
291
+
292
+ if self.distance_type == "redshift":
293
+ self.distance_grid = self.z_grid
294
+ elif self.distance_type == "comoving_distance":
295
+ self.distance_grid = self.comoving_dist_grid
296
+ else: # luminosity_distance
297
+ self.distance_grid = self.luminosity_dist_grid
298
+
299
+ def _p_of_distance(self):
300
+ """
301
+ Compute the unnormalized probability as a function of distance
302
+ """
303
+ dV_dz = self.comoving_dist_grid**2 / self._hubble_function()
304
+ # This is a tensor of ones if the distance type is redshift
305
+ jacobian = torch.gradient(self.distance_grid, spacing=self.dz)[0]
306
+ return dV_dz / jacobian
307
+
308
+ def _generate_probability_grids(self):
309
+ """
310
+ Compute the pdf, cdf, and log pdf based on the
311
+ comoving volume differential and distance grid.
312
+ """
313
+ p_of_distance = self._p_of_distance()
314
+ self.pdf = p_of_distance / torch.trapz(
315
+ p_of_distance, self.distance_grid
316
+ )
317
+ cdf = torch.cumulative_trapezoid(self.pdf, self.distance_grid)
318
+ zero_prefix = torch.zeros(1, dtype=cdf.dtype)
319
+ self.cdf = torch.cat([zero_prefix, cdf])
320
+ self.log_pdf = torch.log(self.pdf)
321
+
322
+ def _linear_interp_1d(self, x_grid, y_grid, x_query):
323
+ idx = torch.bucketize(x_query, x_grid, right=True)
324
+ idx = idx.clamp(min=1, max=len(x_grid) - 1)
325
+
326
+ x0 = x_grid[idx - 1]
327
+ x1 = x_grid[idx]
328
+ y0 = y_grid[idx - 1]
329
+ y1 = y_grid[idx]
330
+
331
+ t = (x_query - x0) / (x1 - x0)
332
+ return y0 + t * (y1 - y0)
333
+
334
+ def rsample(self, sample_shape: torch.Size = None) -> Tensor:
335
+ sample_shape = sample_shape or torch.Size()
336
+ u = torch.rand(sample_shape)
337
+ return self._linear_interp_1d(self.cdf, self.distance_grid, u)
338
+
339
+ def log_prob(self, value: Tensor) -> Tensor:
340
+ log_prob = self._linear_interp_1d(
341
+ self.distance_grid, self.log_pdf, value
342
+ )
343
+ inside_range = (value >= self.minimum) & (value <= self.maximum)
344
+ log_prob[~inside_range] = float("-inf")
345
+ return log_prob
346
+
347
+
348
+ class RateEvolution(UniformComovingVolume):
349
+ """
350
+ Wrapper around `UniformComovingVolume` to allow for
351
+ arbitrary rate evolution functions. E.g., if
352
+ `rate_function = 1 / (1 + z)`, then the distribution
353
+ will sample values such that they occur uniform in
354
+ source frame time.
355
+
356
+ Args:
357
+ rate_function: Callable that takes redshift as input
358
+ and returns the rate evolution factor.
359
+ *args, **kwargs: Arguments passed to `UniformComovingVolume`
360
+ constructor.
361
+ """
362
+
363
+ def __init__(
364
+ self,
365
+ rate_function: Callable,
366
+ *args,
367
+ **kwargs,
368
+ ):
369
+ self.rate_function = rate_function
370
+ super().__init__(*args, **kwargs)
371
+
372
+ def _p_of_distance(self):
373
+ """
374
+ Compute the unnormalized probability as a function of distance
375
+ """
376
+ dV_dz = self.comoving_dist_grid**2 / self._hubble_function()
377
+ # This is a tensor of ones if the distance type is redshift
378
+ jacobian = torch.gradient(self.distance_grid, spacing=self.dz)[0]
379
+ return dV_dz / jacobian * self.rate_function(self.z_grid)
@@ -1,9 +1,15 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ml4gw
3
- Version: 0.7.4
3
+ Version: 0.7.5
4
4
  Summary: Tools for training torch models on gravitational wave data
5
5
  Author-email: Ethan Marx <emarx@mit.edu>, Will Benoit <benoi090@umn.edu>, Deep Chatterjee <deep1018@mit.edu>, Alec Gunny <alec.gunny@ligo.org>
6
6
  License-File: LICENSE
7
+ Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
8
+ Classifier: Programming Language :: Python :: 3.9
9
+ Classifier: Programming Language :: Python :: 3.10
10
+ Classifier: Programming Language :: Python :: 3.11
11
+ Classifier: Programming Language :: Python :: 3.12
12
+ Classifier: Programming Language :: Python :: 3.13
7
13
  Requires-Python: <3.13,>=3.9
8
14
  Requires-Dist: jaxtyping<0.3,>=0.2
9
15
  Requires-Dist: numpy<2.0.0
@@ -1,7 +1,7 @@
1
1
  ml4gw/__init__.py,sha256=81quoggCuIypZjZs3bbf1Ty70KHdva5RGEJxi0oC57E,25
2
2
  ml4gw/augmentations.py,sha256=pZH9tjEpXV0AIqvHHDkpUE-BorG02beOz2pmSipw2EY,1232
3
3
  ml4gw/constants.py,sha256=RQPXwavlw_cWu3ByltvTejPsi6EWXHDJQ1HaV9iE3Lg,850
4
- ml4gw/distributions.py,sha256=6UOgq8W-Bs-9170Jor_0hyeRnmC74zwbUrwAcJEz1jI,5082
4
+ ml4gw/distributions.py,sha256=YbkPqeYBDC91aM59R7-n6NpBMgvMZcZoOAW_U-Jgrdo,12420
5
5
  ml4gw/gw.py,sha256=0ovW_HJ3j2b5Yq3mduYtGLSl2RrvFyNNcOsZFf7koHY,19794
6
6
  ml4gw/spectral.py,sha256=sao_D0ceeMEatABfiabpqb-xxRfQO8Tz7yk9N7ciOAU,19858
7
7
  ml4gw/types.py,sha256=CcctqDcNajR7khGT6BD-WYsfRKpiP0udoSAB0k1qcFw,863
@@ -49,7 +49,7 @@ ml4gw/waveforms/cbc/phenom_d_data.py,sha256=WA1FBxUp9fo1IQaV_OLJ_5g5gI166mY1FtG9
49
49
  ml4gw/waveforms/cbc/phenom_p.py,sha256=RZzzKQzqZW3rQuWZ41htTZOwwulYP61ow87HRRrel5A,27612
50
50
  ml4gw/waveforms/cbc/taylorf2.py,sha256=cmYrVL29dwX2Icp7I6SXqRIjtPmoljK5DP_ofx2heiM,10505
51
51
  ml4gw/waveforms/cbc/utils.py,sha256=LT1ky10_6ZrbwTcxIrWP1O75GUEuU5q2ZE2yYDhadQE,3037
52
- ml4gw-0.7.4.dist-info/METADATA,sha256=5nM8sBFDpqrKHQNqtssbhFYFkwze3IE_HLM8Zb8qXQU,3049
53
- ml4gw-0.7.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
54
- ml4gw-0.7.4.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
55
- ml4gw-0.7.4.dist-info/RECORD,,
52
+ ml4gw-0.7.5.dist-info/METADATA,sha256=KI_oVRLxKfCFsBESasNrDh2QfMR4lsriRZwCCAYFt4c,3380
53
+ ml4gw-0.7.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
54
+ ml4gw-0.7.5.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
55
+ ml4gw-0.7.5.dist-info/RECORD,,
File without changes