CosmoDJ 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cosmodj/__init__.py ADDED
@@ -0,0 +1,33 @@
1
+ """Cosmological distance utilities."""
2
+
3
+ from .distances import (
4
+ Cosmology,
5
+ Planck18Cosmology,
6
+ angular_diameter_distance,
7
+ angular_diameter_distance_z1z2,
8
+ angular_diameter_distances,
9
+ comoving_radial_distance,
10
+ dark_energy_scale,
11
+ e_z,
12
+ luminosity_distance,
13
+ nu_relative_density,
14
+ time_delay_distance,
15
+ transverse_comoving_distance,
16
+ )
17
+ from .quadrature import gauss_legendre_integrate
18
+
19
+ __all__ = [
20
+ "Cosmology",
21
+ "Planck18Cosmology",
22
+ "angular_diameter_distance",
23
+ "angular_diameter_distance_z1z2",
24
+ "angular_diameter_distances",
25
+ "comoving_radial_distance",
26
+ "dark_energy_scale",
27
+ "e_z",
28
+ "gauss_legendre_integrate",
29
+ "luminosity_distance",
30
+ "nu_relative_density",
31
+ "time_delay_distance",
32
+ "transverse_comoving_distance",
33
+ ]
cosmodj/distances.py ADDED
@@ -0,0 +1,277 @@
1
+ """JAX cosmological distance calculations.
2
+
3
+ Distances use the CPL dark-energy parameterization,
4
+ ``w(a) = w0 + wa * (1 - a)``. Planck18 is represented only as a
5
+ parameter container; runtime distance calculations are performed here with JAX.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from dataclasses import dataclass
11
+ from typing import Mapping
12
+
13
+ from jax import config
14
+
15
+ config.update("jax_enable_x64", True)
16
+
17
+ from astropy.constants import c as speed_of_light
18
+ import jax.numpy as jnp
19
+
20
+ from .quadrature import gauss_legendre_integrate
21
+
22
+
23
+ C_KM_S = float(speed_of_light.to_value("km/s"))
24
+
25
+
26
+ @dataclass(frozen=True)
27
+ class Cosmology:
28
+ """CPL cosmology container.
29
+
30
+ Parameters are dimensionless except ``H0`` in km/s/Mpc, ``m_nu_eV`` in eV,
31
+ and ``nu_y`` which stores ``m_nu / (k_B T_nu0)`` for massive species.
32
+ """
33
+
34
+ Omegam: float = 0.32
35
+ Omegak: float = 0.0
36
+ w0: float = -1.0
37
+ wa: float = 0.0
38
+ H0: float = 70.0
39
+ Omegade: float | None = None
40
+ Ogamma0: float = 0.0
41
+ Neff: float = 0.0
42
+ neff_per_nu: float | None = None
43
+ nmasslessnu: int = 0
44
+ nu_y: tuple[float, ...] = ()
45
+ m_nu_eV: tuple[float, ...] = ()
46
+ Tcmb0: float = 0.0
47
+
48
+
49
+ Planck18Cosmology = Cosmology(
50
+ Omegam=0.30966,
51
+ Omegak=0.0,
52
+ w0=-1.0,
53
+ wa=0.0,
54
+ H0=67.66,
55
+ Omegade=0.6888463055445441,
56
+ Ogamma0=5.402015137139352e-05,
57
+ Neff=3.046,
58
+ neff_per_nu=1.0153333333333332,
59
+ nmasslessnu=2,
60
+ nu_y=(357.9121209673803,),
61
+ m_nu_eV=(0.0, 0.0, 0.06),
62
+ Tcmb0=2.7255,
63
+ )
64
+ """Planck 2018 parameters; calculations are performed by CosmoDJ/JAX."""
65
+
66
+ def _resolve_cosmology(cosmology=None):
67
+ if cosmology is None:
68
+ return Planck18Cosmology
69
+ return cosmology
70
+
71
+
72
+ def _as_cosmology(cosmology: Cosmology | Mapping[str, float] | None = None) -> Cosmology:
73
+ cosmology = _resolve_cosmology(cosmology)
74
+ if isinstance(cosmology, Cosmology):
75
+ return cosmology
76
+
77
+ h0 = cosmology.get("H0", cosmology.get("h0"))
78
+ if h0 is None:
79
+ raise KeyError("Cosmology mapping must include 'H0' or 'h0'.")
80
+
81
+ return Cosmology(
82
+ Omegam=cosmology["Omegam"],
83
+ Omegak=cosmology.get("Omegak", 0.0),
84
+ w0=cosmology.get("w0", -1.0),
85
+ wa=cosmology.get("wa", 0.0),
86
+ H0=h0,
87
+ Omegade=cosmology.get("Omegade", cosmology.get("Ode0", None)),
88
+ Ogamma0=cosmology.get("Ogamma0", 0.0),
89
+ Neff=cosmology.get("Neff", 0.0),
90
+ neff_per_nu=cosmology.get("neff_per_nu", None),
91
+ nmasslessnu=cosmology.get("nmasslessnu", 0),
92
+ nu_y=tuple(cosmology.get("nu_y", ())),
93
+ m_nu_eV=tuple(cosmology.get("m_nu_eV", cosmology.get("m_nu", ()))),
94
+ Tcmb0=cosmology.get("Tcmb0", 0.0),
95
+ )
96
+
97
+
98
+ def _neff_per_nu(cosmology: Cosmology):
99
+ if cosmology.neff_per_nu is not None:
100
+ return cosmology.neff_per_nu
101
+ n_nu = cosmology.nmasslessnu + len(cosmology.nu_y)
102
+ if n_nu == 0:
103
+ return 0.0
104
+ return cosmology.Neff / n_nu
105
+
106
+
107
+ def nu_relative_density(z, cosmology: Cosmology | Mapping[str, float] | None = None):
108
+ """Return neutrino energy density relative to photon energy density.
109
+
110
+ This follows the Komatsu et al. 2011 fitting formula used by Astropy for
111
+ massive neutrinos.
112
+ """
113
+
114
+ cosmo = _as_cosmology(cosmology)
115
+ z_arr = jnp.asarray(z, dtype=jnp.float64)
116
+ prefac = 0.22710731766
117
+
118
+ if cosmo.Neff == 0:
119
+ return jnp.zeros_like(z_arr)
120
+
121
+ if len(cosmo.nu_y) == 0:
122
+ return prefac * cosmo.Neff * jnp.ones_like(z_arr)
123
+
124
+ p = 1.83
125
+ invp = 0.54644808743
126
+ k = 0.3173
127
+ nu_y = jnp.asarray(cosmo.nu_y, dtype=jnp.float64)
128
+ curr_nu_y = nu_y / (1.0 + jnp.expand_dims(z_arr, axis=-1))
129
+ rel_mass_per = (1.0 + (k * curr_nu_y) ** p) ** invp
130
+ rel_mass = jnp.sum(rel_mass_per, axis=-1) + cosmo.nmasslessnu
131
+ return prefac * _neff_per_nu(cosmo) * rel_mass
132
+
133
+
134
+ def _omega_nu0(cosmology: Cosmology):
135
+ return cosmology.Ogamma0 * nu_relative_density(0.0, cosmology)
136
+
137
+
138
+ def _omega_de0(cosmology: Cosmology):
139
+ if cosmology.Omegade is not None:
140
+ return cosmology.Omegade
141
+ return (
142
+ 1.0
143
+ - cosmology.Omegam
144
+ - cosmology.Omegak
145
+ - cosmology.Ogamma0
146
+ - _omega_nu0(cosmology)
147
+ )
148
+
149
+
150
+ def dark_energy_scale(z, cosmology: Cosmology | Mapping[str, float] | None = None):
151
+ """Return CPL dark-energy density scaling relative to z=0."""
152
+
153
+ cosmo = _as_cosmology(cosmology)
154
+ z_arr = jnp.asarray(z, dtype=jnp.float64)
155
+ zp1 = 1.0 + z_arr
156
+ return zp1 ** (3.0 * (1.0 + cosmo.w0 + cosmo.wa)) * jnp.exp(
157
+ -3.0 * cosmo.wa * z_arr / zp1
158
+ )
159
+
160
+
161
+ def e_z(z, cosmology: Cosmology | Mapping[str, float] | None = None):
162
+ """Dimensionless Hubble parameter ``E(z) = H(z) / H0``."""
163
+
164
+ cosmo = _as_cosmology(cosmology)
165
+ z_arr = jnp.asarray(z, dtype=jnp.float64)
166
+ zp1 = 1.0 + z_arr
167
+ omega_nu_z = cosmo.Ogamma0 * nu_relative_density(z_arr, cosmo)
168
+ ez2 = (
169
+ cosmo.Omegam * zp1**3
170
+ + cosmo.Omegak * zp1**2
171
+ + cosmo.Ogamma0 * zp1**4
172
+ + omega_nu_z * zp1**4
173
+ + _omega_de0(cosmo) * dark_energy_scale(z_arr, cosmo)
174
+ )
175
+ return jnp.sqrt(ez2)
176
+
177
+
178
+ def _transverse_from_radial(chi, Omegak):
179
+ chi_arr = jnp.asarray(chi, dtype=jnp.float64)
180
+ ok = jnp.asarray(Omegak, dtype=jnp.float64)
181
+ sqrt_abs_ok = jnp.sqrt(jnp.maximum(jnp.abs(ok), 1.0e-300))
182
+ d_pos = jnp.sinh(sqrt_abs_ok * chi_arr) / sqrt_abs_ok
183
+ d_neg = jnp.sin(sqrt_abs_ok * chi_arr) / sqrt_abs_ok
184
+ d_curved = jnp.where(ok > 0.0, d_pos, d_neg)
185
+ return jnp.where(jnp.abs(ok) < 1.0e-14, chi_arr, d_curved)
186
+
187
+
188
+ def comoving_radial_distance(
189
+ z,
190
+ cosmology: Cosmology | Mapping[str, float] | None = None,
191
+ n: int = 256,
192
+ ):
193
+ """Line-of-sight comoving distance from observer to redshift ``z`` in Mpc."""
194
+
195
+ cosmo = _as_cosmology(cosmology)
196
+ chi = gauss_legendre_integrate(lambda z_eval: 1.0 / e_z(z_eval, cosmo), 0.0, z, n=n)
197
+ return (C_KM_S / cosmo.H0) * chi
198
+
199
+
200
+ def transverse_comoving_distance(
201
+ z,
202
+ cosmology: Cosmology | Mapping[str, float] | None = None,
203
+ n: int = 256,
204
+ ):
205
+ """Transverse comoving distance from observer to redshift ``z`` in Mpc."""
206
+
207
+ cosmo = _as_cosmology(cosmology)
208
+ chi = gauss_legendre_integrate(lambda z_eval: 1.0 / e_z(z_eval, cosmo), 0.0, z, n=n)
209
+ dm_dimensionless = _transverse_from_radial(chi, cosmo.Omegak)
210
+ return (C_KM_S / cosmo.H0) * dm_dimensionless
211
+
212
+
213
+ def angular_diameter_distance(
214
+ z,
215
+ cosmology: Cosmology | Mapping[str, float] | None = None,
216
+ n: int = 256,
217
+ ):
218
+ """Angular-diameter distance from observer to redshift ``z`` in Mpc."""
219
+
220
+ z_arr = jnp.asarray(z, dtype=jnp.float64)
221
+ return transverse_comoving_distance(z_arr, cosmology, n=n) / (1.0 + z_arr)
222
+
223
+
224
+ def angular_diameter_distance_z1z2(
225
+ z1,
226
+ z2,
227
+ cosmology: Cosmology | Mapping[str, float] | None = None,
228
+ n: int = 256,
229
+ ):
230
+ """Angular-diameter distance between two redshifts in Mpc."""
231
+
232
+ z1_arr = jnp.asarray(z1, dtype=jnp.float64)
233
+ z2_arr = jnp.asarray(z2, dtype=jnp.float64)
234
+
235
+ cosmo = _as_cosmology(cosmology)
236
+ chi1 = gauss_legendre_integrate(lambda z_eval: 1.0 / e_z(z_eval, cosmo), 0.0, z1_arr, n=n)
237
+ chi2 = gauss_legendre_integrate(lambda z_eval: 1.0 / e_z(z_eval, cosmo), 0.0, z2_arr, n=n)
238
+ dm12_dimensionless = _transverse_from_radial(chi2 - chi1, cosmo.Omegak)
239
+ return (C_KM_S / cosmo.H0) * dm12_dimensionless / (1.0 + z2_arr)
240
+
241
+
242
+ def angular_diameter_distances(
243
+ zl,
244
+ zs,
245
+ cosmology: Cosmology | Mapping[str, float] | None = None,
246
+ n: int = 256,
247
+ ):
248
+ """Return ``(D_l, D_s, D_ls)`` angular-diameter distances in Mpc."""
249
+
250
+ return (
251
+ angular_diameter_distance(zl, cosmology, n=n),
252
+ angular_diameter_distance(zs, cosmology, n=n),
253
+ angular_diameter_distance_z1z2(zl, zs, cosmology, n=n),
254
+ )
255
+
256
+
257
+ def luminosity_distance(
258
+ z,
259
+ cosmology: Cosmology | Mapping[str, float] | None = None,
260
+ n: int = 256,
261
+ ):
262
+ """Luminosity distance from observer to redshift ``z`` in Mpc."""
263
+
264
+ z_arr = jnp.asarray(z, dtype=jnp.float64)
265
+ return (1.0 + z_arr) ** 2 * angular_diameter_distance(z_arr, cosmology, n=n)
266
+
267
+
268
+ def time_delay_distance(
269
+ zl,
270
+ zs,
271
+ cosmology: Cosmology | Mapping[str, float] | None = None,
272
+ n: int = 256,
273
+ ):
274
+ """Lensing time-delay distance ``(1 + zl) D_l D_s / D_ls`` in Mpc."""
275
+
276
+ dl, ds, dls = angular_diameter_distances(zl, zs, cosmology, n=n)
277
+ return (1.0 + jnp.asarray(zl, dtype=jnp.float64)) * dl * ds / dls
cosmodj/quadrature.py ADDED
@@ -0,0 +1,52 @@
1
+ """Numerical integration helpers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from functools import lru_cache
6
+ from typing import Callable
7
+
8
+ from jax import config
9
+
10
+ config.update("jax_enable_x64", True)
11
+
12
+ import jax.numpy as jnp
13
+ import numpy as np
14
+
15
+
16
+ @lru_cache(maxsize=None)
17
+ def _legendre_nodes_weights(n: int):
18
+ return np.polynomial.legendre.leggauss(n)
19
+
20
+
21
+ def gauss_legendre_integrate(
22
+ func: Callable,
23
+ a,
24
+ b,
25
+ *args,
26
+ n: int = 64,
27
+ **kwargs,
28
+ ):
29
+ """Integrate ``func(x, *args, **kwargs)`` from ``a`` to ``b``.
30
+
31
+ The implementation follows the same fixed-order Gauss-Legendre rule used
32
+ in the current lensing scripts, but uses JAX arrays and supports scalar or
33
+ broadcast-compatible array limits.
34
+ """
35
+
36
+ a_arr = jnp.asarray(a, dtype=jnp.float64)
37
+ b_arr = jnp.asarray(b, dtype=jnp.float64)
38
+ x_np, w_np = _legendre_nodes_weights(n)
39
+ x = jnp.asarray(x_np, dtype=jnp.float64)
40
+ w = jnp.asarray(w_np, dtype=jnp.float64)
41
+
42
+ shape = jnp.broadcast_shapes(a_arr.shape, b_arr.shape)
43
+ a_b = jnp.broadcast_to(a_arr, shape)
44
+ b_b = jnp.broadcast_to(b_arr, shape)
45
+
46
+ node_shape = (1,) * len(shape) + (x.shape[0],)
47
+ x_eval = 0.5 * (
48
+ (b_b - a_b)[..., None] * x.reshape(node_shape)
49
+ + (b_b + a_b)[..., None]
50
+ )
51
+ values = func(x_eval, *args, **kwargs)
52
+ return 0.5 * (b_b - a_b) * jnp.sum(w.reshape(node_shape) * values, axis=-1)
@@ -0,0 +1,126 @@
1
+ Metadata-Version: 2.4
2
+ Name: CosmoDJ
3
+ Version: 0.0.1
4
+ Summary: Lightweight cosmological distance utilities for lensing forecasts.
5
+ Author: Tian Li
6
+ Project-URL: Homepage, https://github.com/astroskylee/CosmoDJ
7
+ Project-URL: Repository, https://github.com/astroskylee/CosmoDJ
8
+ Project-URL: Issues, https://github.com/astroskylee/CosmoDJ/issues
9
+ Requires-Python: >=3.10
10
+ Description-Content-Type: text/markdown
11
+ Requires-Dist: astropy>=6
12
+ Requires-Dist: jax>=0.4
13
+ Requires-Dist: numpy>=1.23
14
+ Provides-Extra: test
15
+ Requires-Dist: numpyro>=0.15; extra == "test"
16
+ Requires-Dist: pytest>=7; extra == "test"
17
+
18
+ # CosmoDJ
19
+
20
+ CosmoDJ is a lightweight JAX package for cosmological distance calculations, written for strong-lensing cosmology forecasts and NumPyro workflows. The package currently provides angular-diameter distances, transverse and radial comoving distances, luminosity distances, and time-delay distances for CPL dark-energy cosmologies.
21
+
22
+ The default cosmology is `Planck18Cosmology`, implemented as a CosmoDJ parameter object. Runtime distance calculations are performed with JAX rather than Astropy. Astropy is used for physical constants and in tests as a reference implementation.
23
+
24
+ ## Installation
25
+
26
+ For local development:
27
+
28
+ ```bash
29
+ cd /yourpath
30
+ pip install -e .
31
+ ```
32
+
33
+ After PyPI release:
34
+
35
+ ```bash
36
+ pip install CosmoDJ
37
+ ```
38
+
39
+ ## Basic Usage
40
+
41
+ ```python
42
+ from cosmodj import angular_diameter_distance, angular_diameter_distances
43
+
44
+ Da = angular_diameter_distance(1.0) # Mpc, default Planck18Cosmology
45
+ Dl, Ds, Dls = angular_diameter_distances(0.5, 2.0) # Mpc
46
+ ```
47
+
48
+ `angular_diameter_distance` accepts scalar or array-like redshifts:
49
+
50
+ ```python
51
+ import jax.numpy as jnp
52
+ from cosmodj import angular_diameter_distance
53
+
54
+ z = jnp.array([0.5, 1.0, 2.0])
55
+ Da = angular_diameter_distance(z)
56
+ ```
57
+
58
+ ## Custom Cosmology
59
+
60
+ ```python
61
+ from cosmodj import Cosmology, angular_diameter_distance
62
+
63
+ cosmo = Cosmology(
64
+ Omegam=0.32,
65
+ Omegak=0.0,
66
+ w0=-1.0,
67
+ wa=0.0,
68
+ H0=70.0,
69
+ )
70
+
71
+ Da = angular_diameter_distance(1.0, cosmo)
72
+ ```
73
+
74
+ Dictionary inputs are also supported:
75
+
76
+ ```python
77
+ from cosmodj import angular_diameter_distances
78
+
79
+ cosmo = {"Omegam": 0.32, "Omegak": 0.0, "w0": -1.0, "wa": 0.0, "h0": 70.0}
80
+ Dl, Ds, Dls = angular_diameter_distances(0.5, 2.0, cosmo)
81
+ ```
82
+
83
+ ## NumPyro Example
84
+
85
+ ```python
86
+ import jax.numpy as jnp
87
+ import numpyro
88
+ import numpyro.distributions as dist
89
+
90
+ from cosmodj import Cosmology, angular_diameter_distance
91
+
92
+
93
+ def model():
94
+ Omegam = numpyro.sample("Omegam", dist.Uniform(0.2, 0.4))
95
+ H0 = numpyro.sample("H0", dist.Uniform(60.0, 80.0))
96
+
97
+ cosmo = Cosmology(Omegam=Omegam, Omegak=0.0, w0=-1.0, wa=0.0, H0=H0)
98
+ z = jnp.array([0.5, 1.0])
99
+ Da = angular_diameter_distance(z, cosmo)
100
+
101
+ numpyro.sample("Da_obs", dist.Normal(Da, 20.0), obs=jnp.array([1250.0, 1650.0]))
102
+ ```
103
+
104
+ ## Citation
105
+
106
+ If you use this package in a publication, please cite:
107
+
108
+ ```bibtex
109
+ @ARTICLE{2024MNRAS.527.5311L,
110
+ author = {{Li}, Tian and {Collett}, Thomas E. and {Krawczyk}, Coleman M. and {Enzi}, Wolfgang},
111
+ title = "{Cosmology from large populations of galaxy-galaxy strong gravitational lenses}",
112
+ journal = {\mnras},
113
+ keywords = {gravitational lensing: strong, galaxies: structure, cosmological parameters, dark energy, cosmology: observations, Astrophysics - Cosmology and Nongalactic Astrophysics},
114
+ year = 2024,
115
+ month = jan,
116
+ volume = {527},
117
+ number = {3},
118
+ pages = {5311-5323},
119
+ doi = {10.1093/mnras/stad3514},
120
+ archivePrefix = {arXiv},
121
+ eprint = {2307.09271},
122
+ primaryClass = {astro-ph.CO},
123
+ adsurl = {https://ui.adsabs.harvard.edu/abs/2024MNRAS.527.5311L},
124
+ adsnote = {Provided by the SAO/NASA Astrophysics Data System}
125
+ }
126
+ ```
@@ -0,0 +1,7 @@
1
+ cosmodj/__init__.py,sha256=pCZKxKgEPibPSP3WurT8LBK-_egp6KlLme2l7NyjdSk,803
2
+ cosmodj/distances.py,sha256=ilyZsyONa0HGaRARFhfeuWywDaG087hQjWHiTEyF4KI,8353
3
+ cosmodj/quadrature.py,sha256=n0HvfEZc_GJsF0mNAaSdIAfpjseocqK4RkdFwo2uV6Y,1392
4
+ cosmodj-0.0.1.dist-info/METADATA,sha256=AhdIXUxHRDquAb_YDVI1s0jEvDXN29X4f_hqbjStYlk,3692
5
+ cosmodj-0.0.1.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
6
+ cosmodj-0.0.1.dist-info/top_level.txt,sha256=8Nk3T9ZZAP7_s1HPQfHYgpD4bsasocmpxRGQ9sMhxXY,8
7
+ cosmodj-0.0.1.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (82.0.1)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1 @@
1
+ cosmodj