ABCMB 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,33 @@
1
+ Metadata-Version: 2.4
2
+ Name: ABCMB
3
+ Version: 0.1.0
4
+ Summary: A fast, differentiable, and extensible CMB code
5
+ Home-page: https://github.com/TonyZhou729/ABCMB
6
+ Author: Zilu Zhou, Cara Giovanetti, Hongwan Liu
7
+ Author-email: cgiovanetti@lbl.gov
8
+ License: MIT
9
+ Classifier: Development Status :: 1 - Planning
10
+ Classifier: Intended Audience :: Science/Research
11
+ Classifier: Operating System :: POSIX :: Linux
12
+ Classifier: Programming Language :: Python :: 3.8
13
+ Classifier: Programming Language :: Python :: 3.9
14
+ Classifier: Programming Language :: Python :: 3.10
15
+ Classifier: Programming Language :: Python :: 3.11
16
+ Classifier: Programming Language :: Python :: 3.12
17
+ Classifier: Programming Language :: Python :: 3.13
18
+ Classifier: Programming Language :: Python :: 3.14
19
+ Requires-Dist: numpy
20
+ Requires-Dist: scipy
21
+ Requires-Dist: matplotlib
22
+ Requires-Dist: diffrax
23
+ Requires-Dist: equinox
24
+ Requires-Dist: interpax
25
+ Requires-Dist: jax
26
+ Requires-Dist: pytest
27
+ Dynamic: author
28
+ Dynamic: author-email
29
+ Dynamic: classifier
30
+ Dynamic: home-page
31
+ Dynamic: license
32
+ Dynamic: requires-dist
33
+ Dynamic: summary
@@ -0,0 +1,17 @@
1
+ README.md
2
+ setup.py
3
+ ABCMB.egg-info/PKG-INFO
4
+ ABCMB.egg-info/SOURCES.txt
5
+ ABCMB.egg-info/dependency_links.txt
6
+ ABCMB.egg-info/requires.txt
7
+ ABCMB.egg-info/top_level.txt
8
+ abcmb/ABCMBTools.py
9
+ abcmb/__init__.py
10
+ abcmb/background.py
11
+ abcmb/constants.py
12
+ abcmb/main.py
13
+ abcmb/model_specs.py
14
+ abcmb/perturbations.py
15
+ abcmb/setup.py
16
+ abcmb/species.py
17
+ abcmb/spectrum.py
@@ -0,0 +1,8 @@
1
+ numpy
2
+ scipy
3
+ matplotlib
4
+ diffrax
5
+ equinox
6
+ interpax
7
+ jax
8
+ pytest
@@ -0,0 +1 @@
1
+ abcmb
abcmb-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,33 @@
1
+ Metadata-Version: 2.4
2
+ Name: ABCMB
3
+ Version: 0.1.0
4
+ Summary: A fast, differentiable, and extensible CMB code
5
+ Home-page: https://github.com/TonyZhou729/ABCMB
6
+ Author: Zilu Zhou, Cara Giovanetti, Hongwan Liu
7
+ Author-email: cgiovanetti@lbl.gov
8
+ License: MIT
9
+ Classifier: Development Status :: 1 - Planning
10
+ Classifier: Intended Audience :: Science/Research
11
+ Classifier: Operating System :: POSIX :: Linux
12
+ Classifier: Programming Language :: Python :: 3.8
13
+ Classifier: Programming Language :: Python :: 3.9
14
+ Classifier: Programming Language :: Python :: 3.10
15
+ Classifier: Programming Language :: Python :: 3.11
16
+ Classifier: Programming Language :: Python :: 3.12
17
+ Classifier: Programming Language :: Python :: 3.13
18
+ Classifier: Programming Language :: Python :: 3.14
19
+ Requires-Dist: numpy
20
+ Requires-Dist: scipy
21
+ Requires-Dist: matplotlib
22
+ Requires-Dist: diffrax
23
+ Requires-Dist: equinox
24
+ Requires-Dist: interpax
25
+ Requires-Dist: jax
26
+ Requires-Dist: pytest
27
+ Dynamic: author
28
+ Dynamic: author-email
29
+ Dynamic: classifier
30
+ Dynamic: home-page
31
+ Dynamic: license
32
+ Dynamic: requires-dist
33
+ Dynamic: summary
abcmb-0.1.0/README.md ADDED
@@ -0,0 +1,31 @@
1
+ <h1 align="center">
2
+ ABCMB<!-- omit from toc -->
3
+ </h1>
4
+ <h4 align="center">
5
+
6
+ [![License: MIT](https://img.shields.io/badge/License-MIT-red.svg)](https://opensource.org/licenses/MIT)
7
+ [![Run Tests](https://github.com/TonyZhou729/ABCMB/actions/workflows/accuracy.yml/badge.svg)](https://github.com/TonyZhou729/ABCMB/actions/workflows/accuracy.yml)
8
+ <!--[![arXiv](https://img.shields.io/badge/arXiv-2408.14538%20-green.svg)](https://arxiv.org/abs/2408.14538) -->
9
+
10
+ </h4>
11
+
12
+ Autodifferentiable Boltzmann solver for the CMB (ABCMB) is a Python+JAX package for differentiable computation of the Cosmic Microwave Background. ABCMB is **complete to linear order** in $\Lambda\rm{CDM}$ cosmology. It computes the matter and CMB power spectra and includes effects like lensing, massive neutrinos, and a state-of-the-art treatment of the physics of recombination through the companion code [HyRex](https://github.com/TonyZhou729/HyRex).
13
+
14
+ ## Installation
15
+ We recommend installing ABCMB in a clean conda environment. After downloading and unpacking the code, in the code directory run
16
+ ```
17
+ conda create -n ABCMB
18
+ conda activate ABCMB
19
+ pip install -U -r requirements.txt
20
+
21
+ ```
22
+ optionally specifying your preferred python version after the environment name. Note that this will automatically attempt to install JAX for CUDA12; if you need a different CUDA version or are attempting to run on a CPU-only platform, refer to the [JAX documentation](https://docs.jax.dev/en/latest/installation.html) for a quick JAX installation guide.
23
+
24
+ ## Examples
25
+ We have included several pedagogical jupyter notebooks to walk you through how to get started with ABCMB in our [example_notebooks](https://github.com/TonyZhou729/ABCMB/tree/main/example_notebooks) folder. We suggest you start with [ABCMB_basics](https://github.com/TonyZhou729/ABCMB/blob/main/example_notebooks/ABCMB_basics.ipynb) to get a sense of how to run the code. If you'd like to add new physics to ABCMB, check out [ABCMB_Fluids](https://github.com/TonyZhou729/ABCMB/blob/main/example_notebooks/ABCMB_Fluids.ipynb). If you'd like to run ABCMB with the Big Bang Nucleosynthesis (BBN) code [LINX](https://github.com/cgiovanetti/LINX/tree/main) to do BBN+CMB joint analyses, check out [ABCMB_with_LINX](https://github.com/TonyZhou729/ABCMB/blob/main/example_notebooks/ABCMB_with_LINX.ipynb).
26
+
27
+ ## Issues
28
+ Please feel free to open an issue if something is amiss in ABCMB!
29
+
30
+
31
+
@@ -0,0 +1,286 @@
1
+ """
2
+ Script for helper numerical tools
3
+ """
4
+ import jax
5
+ from jax import grad, lax, config, jit, vmap
6
+ from jax.scipy.special import gamma, factorial
7
+ from functools import partial
8
+ import numpy as np
9
+ import jax.numpy as jnp
10
+ import equinox as eqx
11
+
12
+ config.update("jax_enable_x64", True)
13
+
14
+ ### BEGINNING OF WIGNER ROTATION FOR LENSING ###
15
+
16
+ def wigner_d_matrix(mu, ells, m, n):
17
+ """
18
+ Compute Wigner d-matrix elements for rotation.
19
+
20
+ Recursively computes reduced Wigner d-matrix elements d^ell_{mn}(beta)
21
+ for CMB lensing calculations using three-term recurrence relation.
22
+
23
+ Parameters:
24
+ -----------
25
+ mu : array
26
+ Cosine of rotation angle beta
27
+ ells : array
28
+ Multipole values [m, m+1, m+2, ..., ellmax]
29
+ m : int
30
+ First index (must be positive and >= |n|)
31
+ n : int
32
+ Second index (must satisfy |n| <= m)
33
+
34
+ Returns:
35
+ --------
36
+ array
37
+ Wigner d-matrix elements, shape (len(mu), len(ells))
38
+ """
39
+
40
+ # base case: ell = m
41
+ def base_val(mu):
42
+ beta = jnp.arccos(mu)
43
+ norm = jnp.sqrt((2*m+1)/2) * jnp.sqrt(factorial(2*m)/(factorial(m+n)*factorial(m-n)))
44
+ return norm * jnp.cos(beta/2.)**(m+n)*(-jnp.sin(beta/2.))**(m-n)
45
+ #return norm * jnp.sqrt((1+mu)/2)**(m+n) * jnp.sqrt((1-mu)/2)**(m-n)
46
+
47
+ normA = jnp.sqrt((2*ells+3)/(2*ells+1))
48
+ normC = jnp.sqrt((2*ells+3)/(2*ells-1))
49
+ denom = jnp.sqrt((ells+1)**2-m**2) * jnp.sqrt((ells+1)**2-n**2)
50
+ A = jnp.nan_to_num(normA * (ells+1)*(2*ells+1) / denom, 0)
51
+ B = jnp.nan_to_num(-A * m * n / ells / (ells+1), 0)
52
+ C = jnp.nan_to_num(-normC * jnp.sqrt(ells**2-m**2) * jnp.sqrt(ells**2-n**2) / denom * (ells+1)/ells, 0)
53
+
54
+ def one_mu(mu):
55
+ d_start = base_val(mu) # Corresponds to ellmin = m
56
+
57
+ def recursive_dlp1(carry, inputs):
58
+ # For the first iteration, will take d^m_{mn} and d^m_{mn}=0., compute d^{m+1}_{mn}.
59
+ dl, dlm1 = carry
60
+ a, b, c = inputs
61
+
62
+ # Compute dlp1
63
+ dlp1 = a*mu*dl + b*dl + c*dlm1
64
+
65
+ # Save dl, then make dl->dlm1, dlp1->dl
66
+ return (dlp1, dl), dl
67
+
68
+ # run scan for l = 2..lmax-1
69
+ (_, _), res = lax.scan(recursive_dlp1, (d_start, 0.), (A, B, C))
70
+ return res * jnp.sqrt(2./(2.*ells+1))
71
+
72
+ return vmap(one_mu)(mu)
73
+
74
+ def d00(mu, ells):
75
+ """
76
+ Compute Wigner d-matrix elements d^ell_{00}.
77
+
78
+ Parameters:
79
+ -----------
80
+ mu : array
81
+ Cosine of rotation angle
82
+ ells : array
83
+ Multipole values starting from ell=2
84
+
85
+ Returns:
86
+ --------
87
+ array
88
+ d^ell_{00} elements for ells >= 2
89
+ """
90
+ # ells go from (2, 3, 4, ..., ellmax)
91
+ ells_patched = jnp.concatenate((jnp.array([0, 1]), ells))
92
+ res = wigner_d_matrix(mu, ells_patched, 0, 0)
93
+ return res[:, 2:] # Return only the ells >= 2
94
+
95
+ def d1n(mu, ells, n):
96
+ """
97
+ Compute Wigner d-matrix elements d^ell_{1n}.
98
+
99
+ Parameters:
100
+ -----------
101
+ mu : array
102
+ Cosine of rotation angle
103
+ ells : array
104
+ Multipole values
105
+ n : int
106
+ Second index (|n| <= 1)
107
+
108
+ Returns:
109
+ --------
110
+ array
111
+ d^ell_{1n} elements
112
+ """
113
+ # Wigner matrices where m=1, and |n|<=m.
114
+ ells_patched = jnp.concatenate((jnp.array([1]), ells))
115
+ res = wigner_d_matrix(mu, ells_patched, 1, n)
116
+ return res[:, 1:]
117
+
118
+ def d2n(mu, ells, n):
119
+ """
120
+ Compute Wigner d-matrix elements d^ell_{2n}.
121
+
122
+ Parameters:
123
+ -----------
124
+ mu : array
125
+ Cosine of rotation angle
126
+ ells : array
127
+ Multipole values
128
+ n : int
129
+ Second index (|n| <= 2)
130
+
131
+ Returns:
132
+ --------
133
+ array
134
+ d^ell_{2n} elements
135
+ """
136
+ # Wigner matrices where m=2, and |n|<=m.
137
+ res = wigner_d_matrix(mu, ells, 2, n)
138
+ return res
139
+
140
+ def d3n(mu, ells, n):
141
+ """
142
+ Compute Wigner d-matrix elements d^ell_{3n}.
143
+
144
+ Parameters:
145
+ -----------
146
+ mu : array
147
+ Cosine of rotation angle
148
+ ells : array
149
+ Multipole values
150
+ n : int
151
+ Second index (|n| <= 3)
152
+
153
+ Returns:
154
+ --------
155
+ array
156
+ d^ell_{3n} elements, zero-padded for ell < 3
157
+ """
158
+ # Wigner matrices where m=3, and |n|<=m.
159
+ ells_sliced = ells[1:] # Compute starting at ell=3
160
+ res = wigner_d_matrix(mu, ells_sliced, 3, n)
161
+ res_patched = jnp.concatenate((jnp.zeros((mu.size, 1)), res), axis=1) # Pad zeros for ell<3.
162
+ return res_patched
163
+
164
+ def d4n(mu, ells, n):
165
+ """
166
+ Compute Wigner d-matrix elements d^ell_{4n}.
167
+
168
+ Parameters:
169
+ -----------
170
+ mu : array
171
+ Cosine of rotation angle
172
+ ells : array
173
+ Multipole values
174
+ n : int
175
+ Second index (|n| <= 4)
176
+
177
+ Returns:
178
+ --------
179
+ array
180
+ d^ell_{4n} elements, zero-padded for ell < 4
181
+ """
182
+ # Wigner matrices where m=4, and |n|<=m.
183
+ ells_sliced = ells[2:] # Compute starting at ell=4
184
+ res = wigner_d_matrix(mu, ells_sliced, 4, n)
185
+ res_patched = jnp.concatenate((jnp.zeros((mu.size, 2)), res), axis=1) # Pad zeros for ell<4.
186
+ return res_patched
187
+
188
+ ### END OF WIGNER ROTATION FOR LENSING ###
189
+
190
+
191
+
192
+ def fast_interp(x, xp_min, xp_max, fp):
193
+ """
194
+ Fast 1D linear interpolation for uniformly-spaced grids.
195
+
196
+ Optimized interpolation that avoids searchsorted by exploiting
197
+ uniform grid spacing. Significantly faster than jnp.interp for
198
+ large arrays.
199
+
200
+ Parameters:
201
+ -----------
202
+ x : float or array
203
+ Query points for interpolation
204
+ xp_min : float
205
+ Minimum value of interpolation grid
206
+ xp_max : float
207
+ Maximum value of interpolation grid
208
+ fp : array
209
+ Function values on uniform grid
210
+
211
+ Returns:
212
+ --------
213
+ float or array
214
+ Interpolated values at query points
215
+
216
+ Notes:
217
+ ------
218
+ Credit: JAX issue #16182 (https://github.com/jax-ml/jax/issues/16182)
219
+ Assumes fp is uniformly spaced between xp_min and xp_max.
220
+ """
221
+ # The official jnp.interp is very slow becuase it uses searchsorted.
222
+ # Therefore, we leverage the fact that the fp is linearly increasing, evenly spaced, and has a known range
223
+ # to make this operation much faster.
224
+ eps = 1.e-6
225
+ n = fp.shape[-1]
226
+ i = (x - xp_min) / (xp_max - xp_min) * n
227
+ i = jnp.clip(i, eps, n - 1.0 - eps) # Avoid index out of range
228
+ i_lower = jnp.floor(i).astype(jnp.int32)
229
+ i_upper = jnp.minimum(i_lower + 1, n - 1)
230
+ w_upper = i - i_lower
231
+ w_lower = 1.0 - w_upper
232
+ return w_lower * fp[i_lower] + w_upper * fp[i_upper]
233
+
234
+
235
+ def bilinear_interp(x, y, z, xq, yq):
236
+ """
237
+ Bilinear interpolation on 2D regular grid.
238
+
239
+ Performs bilinear interpolation to evaluate function at query point
240
+ (xq, yq) given values on a regular 2D grid.
241
+
242
+ Parameters:
243
+ -----------
244
+ x : array
245
+ 1D array of x-coordinates (must be sorted)
246
+ y : array
247
+ 1D array of y-coordinates (must be sorted)
248
+ z : array
249
+ 2D array of function values, shape (len(y), len(x))
250
+ xq : float
251
+ Query x-coordinate
252
+ yq : float
253
+ Query y-coordinate
254
+
255
+ Returns:
256
+ --------
257
+ float
258
+ Interpolated value at (xq, yq)
259
+
260
+ Notes:
261
+ ------
262
+ Uses standard bilinear interpolation formula with four nearest
263
+ grid points.
264
+ """
265
+ # find indices for x and y
266
+ ix = jnp.clip(jnp.searchsorted(x, xq) - 1, 0, x.size - 2)
267
+ iy = jnp.clip(jnp.searchsorted(y, yq) - 1, 0, y.size - 2)
268
+
269
+ # grid corner points
270
+ x0, x1 = x[ix], x[ix + 1]
271
+ y0, y1 = y[iy], y[iy + 1]
272
+
273
+ # fractional positions
274
+ tx = (xq - x0) / (x1 - x0)
275
+ ty = (yq - y0) / (y1 - y0)
276
+
277
+ # get z values
278
+ z00 = z[iy, ix]
279
+ z01 = z[iy, ix + 1]
280
+ z10 = z[iy + 1, ix]
281
+ z11 = z[iy + 1, ix + 1]
282
+
283
+ # bilinear interpolation
284
+ return (1 - tx) * (1 - ty) * z00 + tx * (1 - ty) * z01 + (1 - tx) * ty * z10 + tx * ty * z11
285
+
286
+
File without changes