avoidcorr 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,44 @@
1
+ Metadata-Version: 2.4
2
+ Name: avoidcorr
3
+ Version: 0.1.0
4
+ Summary: Predict avoidance biased 21-cm power spectrum from full k power spectrum.
5
+ Author: satyapan
6
+ Author-email: satyapan.iiserm@gmail.com
7
+ Classifier: Programming Language :: Python :: 2
8
+ Classifier: Programming Language :: Python :: 2.7
9
+ Classifier: Programming Language :: Python :: 3
10
+ Classifier: Programming Language :: Python :: 3.4
11
+ Classifier: Programming Language :: Python :: 3.5
12
+ Classifier: Programming Language :: Python :: 3.6
13
+ Classifier: Programming Language :: Python :: 3.7
14
+ Classifier: Programming Language :: Python :: 3.8
15
+ Classifier: Programming Language :: Python :: 3.9
16
+ Classifier: Programming Language :: Python :: 3.10
17
+ Classifier: Programming Language :: Python :: 3.11
18
+ Classifier: Programming Language :: Python :: 3.12
19
+ Classifier: Programming Language :: Python :: 3.13
20
+ Classifier: Programming Language :: Python :: 3.14
21
+ Description-Content-Type: text/markdown
22
+
23
+ # Avoidcorr: Predict avoidance power spectrum from full k power spectrum
24
+
25
+ avoidcorr is a Python package for predicting the avoidance biased 21-cm power spectrum from a full k power spectrum. It can use shape difference degeneracies from a database to estimate biased power spectra corresponding to an input power spectrum and an input arbitrary mask in k space.
26
+
27
+ # Dependencies
28
+
29
+ avoidcorr requires the following Python libraries:
30
+ - numpy
31
+ - matplotlib
32
+ - scipy
33
+ - astropy
34
+ - tqdm
35
+ - multiprocessing
36
+
37
+ # Installation
38
+ avoidcorr can be installed via pip:
39
+ ```
40
+ pip install avoidcorr
41
+ ```
42
+
43
+ # Documentation
44
+ A step-by-step guide is presented in the wiki page.
@@ -0,0 +1,22 @@
1
+ # Avoidcorr: Predict avoidance power spectrum from full k power spectrum
2
+
3
+ avoidcorr is a Python package for predicting the avoidance biased 21-cm power spectrum from a full k power spectrum. It can use shape difference degeneracies from a database to estimate biased power spectra corresponding to an input power spectrum and an input arbitrary mask in k space.
4
+
5
+ # Dependencies
6
+
7
+ avoidcorr requires the following Python libraries:
8
+ - numpy
9
+ - matplotlib
10
+ - scipy
11
+ - astropy
12
+ - tqdm
13
+ - multiprocessing
14
+
15
+ # Installation
16
+ avoidcorr can be installed via pip:
17
+ ```
18
+ pip install avoidcorr
19
+ ```
20
+
21
+ # Documentation
22
+ A step-by-step guide is presented in the wiki page.
@@ -0,0 +1 @@
1
+ from .predictor import *
@@ -0,0 +1,55 @@
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ from numpy.fft import fftn, fftshift, fftfreq
4
+ from astropy.cosmology import Planck15 as cosmo
5
+ import astropy.constants as const
6
+
7
+ def legendre_P(ell, x):
8
+ coeff = np.zeros(ell+1)
9
+ coeff[ell] = 1.0
10
+ return np.polynomial.legendre.legval(x, coeff)
11
+
12
+ def estimate_multipoles(box, boxlen, kb, k_c, los_axis=2, n_multipoles=6):
13
+ nx, ny, nz = box.shape
14
+ fx = np.fft.fftfreq(nx, d=boxlen[0]/nx)
15
+ fy = np.fft.fftfreq(ny, d=boxlen[1]/ny)
16
+ fz = np.fft.fftfreq(nz, d=boxlen[2]/nz)
17
+ kx, ky, kz = np.meshgrid(fx*2*np.pi, fy*2*np.pi, fz*2*np.pi, indexing='ij')
18
+ kx, ky, kz = np.fft.fftshift(kx), np.fft.fftshift(ky), np.fft.fftshift(kz)
19
+ k = np.sqrt(kx**2 + ky**2 + kz**2)
20
+ kpar = [kx, ky, kz][los_axis]
21
+ mu = np.zeros_like(k)
22
+ m0 = k > 0
23
+ mu[m0] = np.abs(kpar[m0]) / k[m0]
24
+ V = boxlen[0]*boxlen[1]*boxlen[2]
25
+ dk = (box - np.mean(box)) * V/(nx*ny*nz)
26
+ fk = np.fft.fftshift(np.fft.fftn(dk))
27
+ pk = np.abs(fk)**2 / V
28
+ good = (k > 0) & np.isfinite(pk)
29
+ kvals, muvals, pkvals = k[good].ravel(), mu[good].ravel(), pk[good].ravel()
30
+ which = np.digitize(kvals, kb) - 1
31
+ nb = len(kb) - 1
32
+ ok = (which >= 0) & (which < nb)
33
+ which, muvals, pkvals = which[ok], muvals[ok], pkvals[ok]
34
+ nm = np.zeros(nb, int)
35
+ fac = k_c**3 / (2*np.pi**2)
36
+ multipoles = []
37
+ for m in range(n_multipoles):
38
+ multipoles.append(np.full(nb, np.nan))
39
+ for i in range(nb):
40
+ sel = which == i
41
+ nm[i] = sel.sum()
42
+ if nm[i] == 0:
43
+ continue
44
+ pk_i = pkvals[sel]
45
+ mu_i = muvals[sel]
46
+ for j in range(n_multipoles):
47
+ ell = 2*(j+1)
48
+ Pl = legendre_P(ell, mu_i)
49
+ Pkell = (2*ell+1) * np.mean(pk_i * Pl)
50
+ multipoles[j][i] = fac[i] * Pkell
51
+ return multipoles
52
+
53
+ def mu_theta(z, theta=np.pi/2):
54
+ factor_cosmo = (cosmo.comoving_transverse_distance(z)*cosmo.H(z)/(const.c*(1+z))).decompose().value
55
+ return np.sin(theta)*factor_cosmo/np.sqrt(1 + (np.sin(theta)*factor_cosmo)**2)
@@ -0,0 +1,345 @@
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ from numpy.fft import fftn, fftshift, fftfreq
4
+ from astropy.cosmology import Planck15 as cosmo
5
+ from tqdm import tqdm
6
+ import astropy.units as u
7
+ import astropy.constants as const
8
+ from matplotlib.colors import Normalize, LogNorm
9
+ import multiprocessing as mp
10
+ from functools import partial
11
+ from scipy.interpolate import interp1d
12
+
13
+ from .funcs import *
14
+ from .ps_estimator import *
15
+
16
+ class Database:
17
+ """
18
+ Store a training database of power spectra and multipoles across multiple redshifts.
19
+
20
+ Arguments:
21
+ ps_path (str): Path to the directory where power spectra and multipole files are stored or loaded from.
22
+ z_vals (array-like): Redshift values corresponding to each snapshot in the database.
23
+ k_c (numpy.ndarray): k-bin centres. Required when loading pre-computed power spectra (compute_ps_kwargs=None).
24
+ delta2s_str (str): Filename prefix for the Delta^2 arrays. Default 'delta2s'.
25
+ multipoles_str (str): Filename prefix for the multipole arrays. Default 'multipoles_list'.
26
+ box_shape (tuple): Shape of the simulation boxes. Default (128, 128, 128).
27
+ boxlen (tuple): Physical side lengths of the boxes in Mpc. Default (250, 250, 250).
28
+ compute_ps_kwargs (dict): If provided, compute power spectra on the fly rather than loading from disk.
29
+ Required keys: 'ids' (list of simulation IDs), 'simres_path' (path to simulation results).
30
+ Optional keys: 'n_multipoles' (int, default 6), 'parallel' (bool, default True), 'n_threads' (int, default 12), and any kwargs accepted by PS1D.
31
+ """
32
+ def __init__(self, ps_path, z_vals, k_c=None, delta2s_str='delta2s', multipoles_str='multipoles_list', box_shape=(128,128,128), boxlen=(250,250,250), compute_ps_kwargs=None):
33
+ self.ps_path = ps_path
34
+ self.z_vals = z_vals
35
+ self.delta2s_str = delta2s_str
36
+ self.multipoles_str = multipoles_str
37
+ self.box_shape = box_shape
38
+ self.boxlen = boxlen
39
+ if compute_ps_kwargs is not None:
40
+ ids = compute_ps_kwargs.get('ids')
41
+ simres_path = compute_ps_kwargs.get('simres_path')
42
+ if ids is None or simres_path is None or boxlen is None:
43
+ raise ValueError("ids, simres_path, and boxlen must be provided in compute_ps_kwargs")
44
+ n_multipoles = compute_ps_kwargs.get('n_multipoles', 6)
45
+ parallel = compute_ps_kwargs.get('parallel', True)
46
+ n_threads = compute_ps_kwargs.get('n_threads', 12)
47
+ ps_est_kwargs = {k: v for k, v in compute_ps_kwargs.items()
48
+ if k not in ('ids', 'simres_path', 'n_multipoles', 'parallel', 'n_threads')}
49
+ self.compute_ps_allz(ids, simres_path, boxlen, n_multipoles=n_multipoles, parallel=parallel, n_threads=n_threads, **ps_est_kwargs)
50
+ else:
51
+ self.delta2s_allz = [np.load(ps_path + f'/{self.delta2s_str}_z{z_val}.npy') for z_val in z_vals]
52
+ self.multipoles_allz = [np.load(ps_path + f'/{self.multipoles_str}_z{z_val}.npy') for z_val in z_vals]
53
+ if k_c is None:
54
+ raise ValueError("k_c must be provided if not computing PS")
55
+ else:
56
+ self.k_c = k_c
57
+ self.N_z = len(z_vals)
58
+ self.N_k = len(self.k_c)
59
+ self.filter_ps()
60
+ self.norm_ps()
61
+
62
+ def get_zero_mask(self, delta2s):
63
+ is_all_zero = np.all(delta2s == 0, axis=1)
64
+ is_any_nan = np.any(np.isnan(delta2s), axis=1)
65
+ zero_mask = is_all_zero | is_any_nan
66
+ return zero_mask
67
+
68
+ def filter_ps(self):
69
+ self.zero_mask_allz = [self.get_zero_mask(self.delta2s_allz[i]) for i in range(self.N_z)]
70
+ self.delta2s_filt_allz = [self.delta2s_allz[i][~self.zero_mask_allz[i]] for i in range(self.N_z)]
71
+ self.multipoles_filt_allz = [self.multipoles_allz[i][~self.zero_mask_allz[i]] for i in range(self.N_z)]
72
+ self.N_filts = [len(self.delta2s_filt_allz[i]) for i in range(self.N_z)]
73
+
74
+ def norm_ps(self):
75
+ self.norm_factors_allz = [np.nansum(self.delta2s_filt_allz[i], axis=1) for i in range(self.N_z)]
76
+ self.delta2s_norm_allz = [self.delta2s_filt_allz[i] / self.norm_factors_allz[i][:, None] for i in range(self.N_z)]
77
+ self.multipoles_norm_allz = [self.multipoles_filt_allz[i] / self.norm_factors_allz[i][:, None, None] for i in range(self.N_z)]
78
+
79
+ def process_z(self, z_val, ids, simres_path, boxlen, n_multipoles, **kwargs):
80
+ delta2s = []
81
+ multipoles = []
82
+ for id in ids:
83
+ bt = np.load(simres_path+f'/simres_id{id}_z{z_val}.npy')
84
+ ps1d = PS1D(bt, boxlen, z_val, **kwargs)
85
+ multipole = estimate_multipoles(bt, boxlen, ps1d.kb, ps1d.k_c, n_multipoles=n_multipoles)
86
+ delta2s.append(ps1d.delta2)
87
+ multipoles.append(multipole)
88
+ np.save(self.ps_path + f'/{self.delta2s_str}_z{z_val}.npy', np.array(delta2s))
89
+ np.save(self.ps_path + f'/{self.multipoles_str}_z{z_val}.npy', np.array(multipoles))
90
+ np.save(self.ps_path + f'/k_c_z{z_val}.npy', ps1d.k_c)
91
+
92
+ def compute_ps_allz(self, ids, simres_path, boxlen, n_multipoles=6, parallel=True, n_threads=12, **kwargs):
93
+ if parallel:
94
+ fn = partial(self.process_z, ids=ids, simres_path=simres_path, boxlen=boxlen, n_multipoles=n_multipoles, **kwargs)
95
+ with mp.Pool(n_threads) as pool:
96
+ pool.map(fn, self.z_vals)
97
+ else:
98
+ for z_val in self.z_vals:
99
+ self.process_z(z_val, ids, simres_path, boxlen, n_multipoles, **kwargs)
100
+ self.delta2s_allz = [np.load(self.ps_path + f'/{self.delta2s_str}_z{z_val}.npy') for z_val in self.z_vals]
101
+ self.multipoles_allz = [np.load(self.ps_path + f'/{self.multipoles_str}_z{z_val}.npy') for z_val in self.z_vals]
102
+ self.k_c = np.load(self.ps_path + f'/k_c_z{self.z_vals[0]}.npy')
103
+
104
+ def plot_ps(self, z_idxs=None, normed=True, **kwargs):
105
+ if z_idxs is None:
106
+ z_idxs = list(range(self.N_z))
107
+ z_vals_plot = [self.z_vals[i] for i in z_idxs]
108
+ norm = Normalize(vmin=min(z_vals_plot), vmax=max(z_vals_plot))
109
+ colors = plt.cm.viridis(np.linspace(0, 1, len(z_idxs)))
110
+ if normed:
111
+ delta2s_plot = self.delta2s_norm_allz
112
+ else:
113
+ delta2s_plot = self.delta2s_filt_allz
114
+ fig, ax = plt.subplots(figsize=(8, 6))
115
+ for i in range(len(z_idxs)):
116
+ for j in range(self.N_filts[z_idxs[i]]):
117
+ ax.plot(self.k_c, delta2s_plot[z_idxs[i]][j], color=colors[i], **kwargs)
118
+ ax.set_xlabel(r'$k$ [Mpc$^{-1}$]')
119
+ ax.set_ylabel(r'$\Delta^2(k)$ [mK$^2$]')
120
+ ax.set_xscale('log')
121
+ ax.set_yscale('log')
122
+ sm = plt.cm.ScalarMappable(cmap='viridis', norm=norm)
123
+ cb = fig.colorbar(ax=ax, mappable=sm, label='Redshift')
124
+ fig.tight_layout()
125
+
126
+
127
+ class BiasPredictor:
128
+ """
129
+ Predicts the avoidance biased power spectrum from a full k power spectrum.
130
+
131
+ Two modes:
132
+ - Analytic mode (no database): When multipoles are known: Directly uses known multipole power spectra to predict the biased power spectrum.
133
+ - Database mode: When multipoles are unknown. Use database of pre-computed power spectra and multipoles to predict biased power spectrum using shape difference degeneracies.
134
+
135
+ Arguments:
136
+ z_val (float): Redshift of the input power spectrum.
137
+ box_shape (tuple): Shape of the simulation box from which input power spectrum is derived, e.g. (128, 128, 128). Required for estimating allowed k range when not using a database.
138
+ boxlen (tuple): Physical side lengths of the box in Mpc from which input power spectrum is derived. Required for estimating allowed k range when not using a database.
139
+ database (Database): Training database instance. Default None.
140
+ mu_min (float or 'horizon'): Minimum mu value. Default 'horizon': computes the horizon wedge cut at z_val.
141
+ kpar_range (tuple or list of tuples): One or more (kpar_min, kpar_max) ranges defining rectangular avoidance boxes in (kperp, kpar) space. Default None (no kpar cut).
142
+ kperp_range (tuple or list of tuples): One or more (kperp_min, kperp_max) ranges. Must have the same number of entries as kpar_range. Default None.
143
+ combz (bool): If True and using a database, combine all redshifts into one training set. Default False.
144
+ k_c_inp (numpy.ndarray): k-bin centres of the input power spectrum. Must match the length of the arrays passed to estimate_avoidance_ps or predict. Default None.
145
+ interpolate (bool): If True and using a database, interpolate between the database k-grid and k_c_inp. Requires k_c_inp. Default False.
146
+ """
147
+ def __init__(self, z_val, box_shape=None, boxlen=None, database=None, mu_min='horizon', kpar_range=None, kperp_range=None, combz=False, k_c_inp=None, interpolate=False):
148
+ self.z_val = z_val
149
+ self.box_shape = box_shape
150
+ self.boxlen = boxlen
151
+ self.database = database
152
+ self.interpolate = interpolate
153
+ self.k_c = k_c_inp
154
+ if mu_min == 'horizon':
155
+ self.mu_min = mu_theta(self.z_val, theta=np.pi/2)
156
+ else:
157
+ self.mu_min = mu_min
158
+ self.kpar_ranges = self._normalize_range(kpar_range)
159
+ self.kperp_ranges = self._normalize_range(kperp_range)
160
+ if len(self.kpar_ranges) != len(self.kperp_ranges):
161
+ raise ValueError("kpar_range and kperp_range must have the same number of boxes")
162
+ self.N_boxes = len(self.kpar_ranges)
163
+ if box_shape is not None and boxlen is not None:
164
+ self.mask_avoid = self.accessible_kmask(self.k_c)
165
+ self.k_c_masked = self.k_c[self.mask_avoid]
166
+ if self.database is not None:
167
+ self.box_shape = self.database.box_shape
168
+ self.boxlen = self.database.boxlen
169
+ self.z_idx = np.argmin(np.abs(np.array(self.database.z_vals) - self.z_val))
170
+ self.combz = combz
171
+ self.k_c = self.database.k_c
172
+ if self.combz:
173
+ self.delta2s_norm = np.concatenate(self.database.delta2s_norm_allz, axis=0).copy()
174
+ self.multipoles_norm = np.concatenate(self.database.multipoles_norm_allz, axis=0).copy()
175
+ self.norm_factors = np.concatenate(self.database.norm_factors_allz, axis=0).copy()
176
+ else:
177
+ self.delta2s_norm = self.database.delta2s_norm_allz[self.z_idx].copy()
178
+ self.multipoles_norm = self.database.multipoles_norm_allz[self.z_idx].copy()
179
+ self.norm_factors = self.database.norm_factors_allz[self.z_idx].copy()
180
+ if self.interpolate:
181
+ if k_c_inp is None:
182
+ raise ValueError("k_c_inp must be provided when interpolate=True")
183
+ self.k_c_inp = k_c_inp
184
+ self.mask_avoid = self.accessible_kmask(self.k_c_inp)
185
+ self.k_c_masked = self.k_c_inp[self.mask_avoid]
186
+ else:
187
+ self.mask_avoid = self.accessible_kmask(self.k_c)
188
+ self.k_c_masked = self.k_c[self.mask_avoid]
189
+ self.mu_bounds()
190
+
191
+ @staticmethod
192
+ def _normalize_range(r):
193
+ if r is None:
194
+ return [None]
195
+ if isinstance(r, (list, tuple)) and len(r) > 0 and isinstance(r[0], (list, tuple)):
196
+ return list(r)
197
+ return [r]
198
+
199
+ def mu_bounds(self):
200
+ k = np.asarray(self.k_c, float)
201
+ self.mu_lo_boxes = []
202
+ self.mu_hi_boxes = []
203
+ for kpar_r, kperp_r in zip(self.kpar_ranges, self.kperp_ranges):
204
+ mu_lo = np.zeros_like(k)
205
+ mu_hi = np.ones_like(k)
206
+ if self.mu_min is not None: mu_lo = np.maximum(mu_lo, self.mu_min * np.ones_like(k))
207
+ if kpar_r is not None:
208
+ kpar_min, kpar_max = kpar_r
209
+ if kpar_min is not None: mu_lo = np.maximum(mu_lo, kpar_min / k)
210
+ if kpar_max is not None: mu_hi = np.minimum(mu_hi, kpar_max / k)
211
+ if kperp_r is not None:
212
+ kperp_min, kperp_max = kperp_r
213
+ if kperp_max is not None: mu_lo = np.maximum(mu_lo, np.sqrt(np.maximum(0.0, 1.0 - (kperp_max / k)**2)))
214
+ if kperp_min is not None: mu_hi = np.minimum(mu_hi, np.sqrt(np.maximum(0.0, 1.0 - (kperp_min / k)**2)))
215
+ self.mu_lo_boxes.append(np.clip(mu_lo, 0.0, 1.0))
216
+ self.mu_hi_boxes.append(np.clip(mu_hi, 0.0, 1.0))
217
+
218
+ def coeff_range(self, ell):
219
+ if ell == 0:
220
+ return np.ones_like(self.mu_lo_boxes[0])
221
+ num = np.zeros_like(self.mu_lo_boxes[0])
222
+ denom = np.zeros_like(self.mu_lo_boxes[0])
223
+ for mu_lo, mu_hi in zip(self.mu_lo_boxes, self.mu_hi_boxes):
224
+ valid = mu_hi > mu_lo
225
+ Ihi = (legendre_P(ell+1, mu_hi) - legendre_P(ell-1, mu_hi)) / (2*ell+1)
226
+ Ilo = (legendre_P(ell+1, mu_lo) - legendre_P(ell-1, mu_lo)) / (2*ell+1)
227
+ num += np.where(valid, Ihi - Ilo, 0.0)
228
+ denom += np.where(valid, mu_hi - mu_lo, 0.0)
229
+ return num / denom
230
+
231
+ def estimate_avoidance_ps(self, ps1d_delta2, multipoles, masked=False):
232
+ """
233
+ Predict avoidance biased power spectrum.
234
+
235
+ Arguments:
236
+ ps1d_delta2 (numpy.ndarray): Input full k power spectrum.
237
+ multipoles (numpy.ndarray or list): Input even multipoles, with shape (n_multipoles, n_k). Multipole j corresponds to ell = 2*(j+1).
238
+ masked (bool): If True, return only the accessible k values. Default False.
239
+ Returns:
240
+ numpy.ndarray: Predicted biased power spectrum. Full array if masked=False, corresponding k values are in self.k_c_masked.
241
+ """
242
+ out = ps1d_delta2.copy()
243
+ for j in range(len(multipoles)):
244
+ ell = 2*(j+1)
245
+ aell = self.coeff_range(ell)
246
+ out += aell * multipoles[j]
247
+ if masked:
248
+ out = out[self.mask_avoid]
249
+ self.k_c_masked = self.k_c[self.mask_avoid]
250
+ return out
251
+
252
+ def accessible_kmask(self, k_c, los_axis=2):
253
+ Nlos = self.box_shape[los_axis]
254
+ Llos = self.boxlen[los_axis]
255
+ kpar_max_grid = np.pi * Nlos / Llos
256
+ axes = [0, 1, 2]
257
+ ax_perp = [a for a in axes if a != los_axis]
258
+ Np0, Np1 = self.box_shape[ax_perp[0]], self.box_shape[ax_perp[1]]
259
+ Lp0, Lp1 = self.boxlen[ax_perp[0]], self.boxlen[ax_perp[1]]
260
+ kperp_max_grid = np.sqrt((np.pi*Np0/Lp0)**2 + (np.pi*Np1/Lp1)**2)
261
+ mask = np.zeros_like(k_c, dtype=bool)
262
+ for kpar_r, kperp_r in zip(self.kpar_ranges, self.kperp_ranges):
263
+ mu_lo = np.zeros_like(k_c)
264
+ mu_hi = np.ones_like(k_c)
265
+ if self.mu_min is not None: mu_lo = np.maximum(mu_lo, self.mu_min * np.ones_like(k_c))
266
+ if kpar_r is not None:
267
+ kpar_min, kpar_max_user = kpar_r
268
+ if kpar_min is not None: mu_lo = np.maximum(mu_lo, kpar_min / k_c)
269
+ if kpar_max_user is not None: mu_hi = np.minimum(mu_hi, kpar_max_user / k_c)
270
+ if kperp_r is not None:
271
+ kperp_min, kperp_max_user = kperp_r
272
+ if kperp_max_user is not None: mu_lo = np.maximum(mu_lo, np.sqrt(np.maximum(0.0, 1.0 - (kperp_max_user / k_c)**2)))
273
+ if kperp_min is not None: mu_hi = np.minimum(mu_hi, np.sqrt(np.maximum(0.0, 1.0 - (kperp_min / k_c)**2)))
274
+ mu_lo = np.clip(mu_lo, 0.0, 1.0)
275
+ mu_hi = np.clip(mu_hi, 0.0, 1.0)
276
+ ok_mu = mu_hi > mu_lo
277
+ ok_kpar = (mu_lo * k_c) <= kpar_max_grid
278
+ if kpar_r is not None and kpar_r[0] is not None: ok_kpar &= (kpar_r[0] <= kpar_max_grid)
279
+ ok_kperp = np.ones_like(k_c, dtype=bool)
280
+ if kperp_r is not None and kperp_r[0] is not None: ok_kperp &= (kperp_r[0] <= kperp_max_grid)
281
+ mask |= (ok_mu & ok_kpar & ok_kperp)
282
+ return mask
283
+
284
+ def min_rmse(self, delta2_inp):
285
+ norm_factor = np.nansum(delta2_inp)
286
+ if norm_factor == 0:
287
+ return None, 0
288
+ target_shape = delta2_inp / norm_factor
289
+ mse = np.nanmean((self.delta2s_norm - target_shape)**2, axis=1)
290
+ best_idx = np.argmin(mse)
291
+ return best_idx, norm_factor
292
+
293
+ def interpolate_ps(self, k_old, k_new, delta2, log=True):
294
+ if log:
295
+ f = interp1d(np.log(k_old), np.log(np.where(delta2 > 0, delta2, np.nan)),
296
+ bounds_error=False, fill_value="extrapolate")
297
+ return np.exp(f(np.log(k_new)))
298
+ else:
299
+ f = interp1d(k_old, delta2, bounds_error=False, fill_value="extrapolate")
300
+ return f(k_new)
301
+
302
+ def predict(self, delta2_inp, plot=True, log_interp=True):
303
+ """
304
+ Predict avoidance biased power spectrum using the training database.
305
+
306
+ Arguments:
307
+ delta2_inp (numpy.ndarray): Input full k power spectrum.
308
+ plot (bool): If True, plot the input and predicted power spectra. Default True.
309
+ log_interp (bool): If True, interpolate in log-log space when interpolate=True. Default True.
310
+
311
+ Returns:
312
+ numpy.ndarray: Predicted biased power spectrum at the accessible k values (self.k_c_masked).
313
+ """
314
+ if self.interpolate:
315
+ delta2_inp = self.interpolate_ps(self.k_c_inp, self.k_c, delta2_inp, log=log_interp)
316
+ best_idx, norm_factor = self.min_rmse(delta2_inp)
317
+ if norm_factor == 0:
318
+ raise ValueError("Input PS has zero norm.")
319
+ delta2_norm = delta2_inp/norm_factor
320
+ delta2_pred = self.estimate_avoidance_ps(delta2_norm, self.multipoles_norm[best_idx])*norm_factor
321
+ if self.interpolate:
322
+ delta2_pred = self.interpolate_ps(self.k_c, self.k_c_inp, delta2_pred, log=log_interp)
323
+ self.delta2_pred = delta2_pred
324
+ self.delta2_pred_masked = delta2_pred[self.mask_avoid]
325
+ if plot:
326
+ self.plot(delta2_inp, self.delta2_pred_masked)
327
+ return self.delta2_pred_masked
328
+
329
+ def plot(self, delta2_inp, delta2_pred_masked, ax=None):
330
+ """
331
+ Plot the input full k power spectrum alongside the predicted biased power spectrum.
332
+
333
+ Arguments:
334
+ delta2_inp (numpy.ndarray): Input full-sky power spectrum on the k_c grid.
335
+ delta2_pred_masked (numpy.ndarray): Predicted biased power spectrum at accessible k values (k_c_masked).
336
+ ax (matplotlib.axes.Axes): Axes to plot on. If None, a new figure is created.
337
+ """
338
+ if ax is None: fig, ax = plt.subplots()
339
+ ax.plot(self.k_c, delta2_inp, '-o', label='Input', color='red')
340
+ ax.plot(self.k_c_masked, delta2_pred_masked, '-s', label='Predicted', color='blue')
341
+ ax.set_xlabel(r'$k$ [Mpc$^{-1}$]')
342
+ ax.set_ylabel(r'$\Delta^2(k)$ [mK$^2$]')
343
+ ax.set_xscale('log')
344
+ ax.set_yscale('log')
345
+ ax.legend()
@@ -0,0 +1,149 @@
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ from numpy.fft import fftn, fftshift, fftfreq
4
+ from astropy.cosmology import Planck15 as cosmo
5
+ import astropy.units as u
6
+ import astropy.constants as const
7
+ import pickle
8
+ from matplotlib.colors import Normalize, LogNorm
9
+
10
+ class PS1D:
11
+ """
12
+ Computes the spherically-averaged 1D dimensionless power spectrum, with optional foreground avoidance masking and/or rectangular cuts in (kperp, kpar) space.
13
+
14
+ Arguments:
15
+ box (numpy.ndarray): 3D brightness temperature box.
16
+ boxlen (tuple): Physical side lengths of the box in Mpc, e.g. (Lx, Ly, Lz).
17
+ z_val (float): Redshift of the box, used for the horizon avoidance cut.
18
+ n_bins (int): Number of k bins. Default 50.
19
+ avoidance (bool): If True, mask modes inside the foreground wedge. Default False.
20
+ keep_pk (bool): If True, store the full 3D power spectrum array as self.pk. Default False.
21
+ log_bins (bool): If True, use logarithmically-spaced k bins. Default False.
22
+ keep_box (bool): If True, keep the input box in memory after computing. Default False.
23
+ kperp_range (tuple or list of tuples): One or more (kperp_min, kperp_max) ranges to include.
24
+ Pass a list of tuples to union multiple rectangular boxes in (kperp, kpar) space.
25
+ kpar_range (tuple or list of tuples): One or more (kpar_min, kpar_max) ranges to include.
26
+ Must have the same number of entries as kperp_range.
27
+ kmin (float): Minimum k for binning (overrides auto-detected minimum).
28
+ kmax (float): Maximum k for binning (overrides auto-detected maximum).
29
+ k_edges (array-like): Explicit bin edges. If provided, overrides n_bins, kmin, kmax.
30
+ apply_window (bool): If True, apply a Blackman-Harris window along the LOS axis. Default False.
31
+ """
32
+ def __init__(self, box, boxlen, z_val, n_bins=50, avoidance=False, keep_pk=False, log_bins=False, keep_box=False, kperp_range=None, kpar_range=None, kmin=None, kmax=None, k_edges=None, apply_window=False):
33
+ self.box, self.boxlen, self.n_bins = box, boxlen, n_bins
34
+ self.avoidance, self.z_val, self.keep_pk = avoidance, z_val, keep_pk
35
+ self.log_bins, self.keep_box = log_bins, keep_box
36
+ self.kmin, self.kmax = kmin, kmax
37
+ self.k_edges = k_edges
38
+ self.apply_window = apply_window
39
+ self.kperp_ranges = self._normalize_range(kperp_range)
40
+ self.kpar_ranges = self._normalize_range(kpar_range)
41
+ if len(self.kperp_ranges) != len(self.kpar_ranges):
42
+ raise ValueError("kperp_range and kpar_range must have the same number of boxes")
43
+ self.compute_ps()
44
+
45
+ @staticmethod
46
+ def _normalize_range(r):
47
+ if r is None:
48
+ return [None]
49
+ if isinstance(r, (list, tuple)) and len(r) > 0 and isinstance(r[0], (list, tuple)):
50
+ return list(r)
51
+ return [r]
52
+
53
+ def compute_ps(self):
54
+ log_bins = self.log_bins
55
+ box, boxlen, n_bins = self.box, self.boxlen, self.n_bins
56
+ nx, ny, nz = box.shape
57
+ fx = fftfreq(nx, d=boxlen[0]/nx)
58
+ fy = fftfreq(ny, d=boxlen[1]/ny)
59
+ fz = fftfreq(nz, d=boxlen[2]/nz)
60
+ kx, ky, kz = np.meshgrid(fx*2*np.pi, fy*2*np.pi, fz*2*np.pi, indexing='ij')
61
+ kx, ky, kz = fftshift(kx), fftshift(ky), fftshift(kz)
62
+ kperp, kpar = np.sqrt(kx**2 + ky**2), np.abs(kz)
63
+ k = np.sqrt(kperp**2 + kpar**2)
64
+ if self.avoidance:
65
+ factor_cosmo = (cosmo.comoving_transverse_distance(self.z_val) * cosmo.H(self.z_val) / (const.c * (1 + self.z_val))).decompose().value
66
+ mask = kpar < factor_cosmo * kperp
67
+ else:
68
+ mask = np.zeros_like(k, dtype=bool)
69
+ has_ranges = any(kp is not None or kz is not None for kp, kz in zip(self.kperp_ranges, self.kpar_ranges))
70
+ if has_ranges:
71
+ in_any_box = np.zeros_like(k, dtype=bool)
72
+ for kperp_r, kpar_r in zip(self.kperp_ranges, self.kpar_ranges):
73
+ in_box = np.ones_like(k, dtype=bool)
74
+ if kperp_r is not None:
75
+ kp_min, kp_max = kperp_r
76
+ if kp_min is not None: in_box &= (kperp >= kp_min)
77
+ if kp_max is not None: in_box &= (kperp <= kp_max)
78
+ if kpar_r is not None:
79
+ kz_min, kz_max = kpar_r
80
+ if kz_min is not None: in_box &= (kpar >= kz_min)
81
+ if kz_max is not None: in_box &= (kpar <= kz_max)
82
+ in_any_box |= in_box
83
+ mask |= ~in_any_box
84
+ V = boxlen[0] * boxlen[1] * boxlen[2]
85
+ dk = (box - np.mean(box)) * V / (nx * ny * nz)
86
+ if self.apply_window:
87
+ freq_window = blackmanharris(nz)
88
+ dk *= freq_window[None, None, :]
89
+ fk = fftshift(fftn(dk))
90
+ pk = np.abs(fk)**2 / V
91
+ if self.apply_window:
92
+ window_correction = np.sum(freq_window**2) / nz
93
+ pk /= window_correction
94
+ if self.keep_pk: self.pk = pk
95
+ pk[mask] = np.nan
96
+ if self.k_edges is not None:
97
+ kb = np.array(self.k_edges)
98
+ self.n_bins = len(kb) - 1
99
+ if log_bins:
100
+ self.k_c = np.sqrt(kb[1:] * kb[:-1])
101
+ else:
102
+ self.k_c = 0.5 * (kb[1:] + kb[:-1])
103
+ else:
104
+ if log_bins:
105
+ kpos = k[(k > 0) & (~np.isnan(pk))]
106
+ kmin_val = self.kmin if self.kmin is not None else kpos.min()
107
+ kmax_val = self.kmax if self.kmax is not None else kpos.max()
108
+ kb = np.logspace(np.log10(kmin_val), np.log10(kmax_val), n_bins + 1)
109
+ self.k_c = np.sqrt(kb[1:] * kb[:-1])
110
+ else:
111
+ kmin_val = self.kmin if self.kmin is not None else k.min()
112
+ kmax_val = self.kmax if self.kmax is not None else k.max()
113
+ kb = np.linspace(kmin_val, kmax_val, n_bins + 1)
114
+ self.k_c = 0.5 * (kb[1:] + kb[:-1])
115
+ self.kb = kb
116
+ self.ps1d = np.zeros(len(self.k_c))
117
+ self.err = np.zeros(len(self.k_c))
118
+ for i in range(len(self.k_c)):
119
+ sel = (k >= kb[i]) & (k < kb[i+1]) & (~np.isnan(pk))
120
+ Nm = np.sum(sel)
121
+ if Nm > 0:
122
+ Pk_val = np.nanmean(pk[sel])
123
+ self.ps1d[i] = Pk_val
124
+ self.err[i] = Pk_val * np.sqrt(2 / Nm)
125
+ else:
126
+ self.ps1d[i], self.err[i] = np.nan, np.nan
127
+ self.delta2 = self.ps1d * self.k_c**3 / (2 * np.pi**2)
128
+ self.err_delta2 = self.delta2 * (self.err / self.ps1d)
129
+ if not self.keep_box: self.box = None
130
+
131
+ def plot(self, ax=None, **kwargs):
132
+ """
133
+ Plot the dimensionless power spectrum.
134
+
135
+ Arguments:
136
+ ax: Axis to plot on. If None, a new figure is created.
137
+ **kwargs: Additional keyword arguments passed to ax.errorbar.
138
+ """
139
+ if ax is None: fig, ax = plt.subplots()
140
+ ax.errorbar(self.k_c, self.delta2, yerr=self.err_delta2, fmt='-o', capsize=2, **kwargs)
141
+ ax.set_xlabel(r'$k$ [Mpc$^{-1}$]')
142
+ ax.set_ylabel(r'$\Delta^2(k)$ [$\mathrm{mK}^2$]')
143
+ if self.log_bins:
144
+ ax.set_xscale('log')
145
+ ax.set_yscale('log')
146
+
147
+ def save(self, filename):
148
+ with open(filename, 'wb') as outp:
149
+ pickle.dump(self, outp, pickle.HIGHEST_PROTOCOL)
@@ -0,0 +1,10 @@
1
+ [project]
2
+ name = "avoidcorr"
3
+ version = "0.1.0"
4
+ description = "Predict avoidance biased 21-cm power spectrum from full k power spectrum."
5
+ authors = [{name = "satyapan",email = "satyapan.iiserm@gmail.com"}]
6
+ readme = "README.md"
7
+
8
+ [build-system]
9
+ requires = ["poetry-core>=2.0.0,<3.0.0"]
10
+ build-backend = "poetry.core.masonry.api"