classy-szfast 0.0.25.post3__py3-none-any.whl → 0.0.25.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- classy_szfast/classy_szfast.py +117 -114
- classy_szfast/cosmopower_jax.py +3 -2
- classy_szfast/emulators_meta_data.py +1 -1
- {classy_szfast-0.0.25.post3.dist-info → classy_szfast-0.0.25.post4.dist-info}/METADATA +1 -1
- {classy_szfast-0.0.25.post3.dist-info → classy_szfast-0.0.25.post4.dist-info}/RECORD +7 -7
- {classy_szfast-0.0.25.post3.dist-info → classy_szfast-0.0.25.post4.dist-info}/WHEEL +0 -0
- {classy_szfast-0.0.25.post3.dist-info → classy_szfast-0.0.25.post4.dist-info}/top_level.txt +0 -0
classy_szfast/classy_szfast.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
from .utils import *
|
2
2
|
from .config import *
|
3
3
|
import numpy as np
|
4
|
-
from .emulators_meta_data import emulator_dict, dofftlog_alphas, cp_l_max_scalars
|
4
|
+
from .emulators_meta_data import emulator_dict, dofftlog_alphas, cp_l_max_scalars, cosmo_model_list
|
5
5
|
from .cosmopower import cp_tt_nn, cp_te_nn, cp_ee_nn, cp_ee_nn, cp_pp_nn, cp_pknl_nn, cp_pkl_nn, cp_der_nn, cp_da_nn, cp_h_nn, cp_s8_nn, cp_pkl_fftlog_alphas_real_nn, cp_pkl_fftlog_alphas_imag_nn, cp_pkl_fftlog_alphas_nus
|
6
6
|
from .cosmopower_jax import cp_tt_nn_jax, cp_te_nn_jax, cp_ee_nn_jax, cp_ee_nn_jax, cp_pp_nn_jax, cp_pknl_nn_jax, cp_pkl_nn_jax, cp_der_nn_jax, cp_da_nn_jax, cp_h_nn_jax, cp_s8_nn_jax
|
7
7
|
from .pks_and_sigmas import *
|
@@ -76,13 +76,10 @@ class Class_szfast(object):
|
|
76
76
|
|
77
77
|
self.jax_mode = params_settings["jax"]
|
78
78
|
|
79
|
+
|
79
80
|
# print(f"JAX mode: {self.jax_mode}")
|
80
81
|
|
81
82
|
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
83
|
# cosmopower emulators
|
87
84
|
# self.cp_path_to_cosmopower_organization = path_to_cosmopower_organization + '/'
|
88
85
|
self.cp_tt_nn = cp_tt_nn
|
@@ -90,16 +87,70 @@ class Class_szfast(object):
|
|
90
87
|
self.cp_ee_nn = cp_ee_nn
|
91
88
|
self.cp_pp_nn = cp_pp_nn
|
92
89
|
self.cp_pknl_nn = cp_pknl_nn
|
93
|
-
self.cp_pkl_nn = cp_pkl_nn
|
94
90
|
self.cp_der_nn = cp_der_nn
|
95
91
|
|
92
|
+
self.cp_lmax = cp_l_max_scalars
|
93
|
+
|
94
|
+
cosmo_model_dict = {i: model for i, model in enumerate(cosmo_model_list)}
|
95
|
+
|
96
|
+
if (cosmo_model_dict[params_settings['cosmo_model']] == 'ede-v2'):
|
97
|
+
|
98
|
+
self.cszfast_pk_grid_zmax = 20.
|
99
|
+
self.cszfast_pk_grid_kmin = 5e-4
|
100
|
+
self.cszfast_pk_grid_kmax = 10.
|
101
|
+
self.cp_kmax = self.cszfast_pk_grid_kmax
|
102
|
+
self.cp_kmin = self.cszfast_pk_grid_kmin
|
103
|
+
# self.logger.info(f">>> using kmin = {self.cp_kmin}")
|
104
|
+
# self.logger.info(f">>> using kmax = {self.cp_kmax}")
|
105
|
+
# self.logger.info(f">>> using zmax = {self.cszfast_pk_grid_zmax}")
|
106
|
+
|
107
|
+
else:
|
108
|
+
|
109
|
+
self.cszfast_pk_grid_zmax = 5. # max z of our pk emulators (sept 23)
|
110
|
+
self.cszfast_pk_grid_kmin = 1e-4
|
111
|
+
self.cszfast_pk_grid_kmax = 50.
|
112
|
+
self.cp_kmax = self.cszfast_pk_grid_kmax
|
113
|
+
self.cp_kmin = self.cszfast_pk_grid_kmin
|
114
|
+
# self.logger.info(f">>> using kmin = {self.cp_kmin}")
|
115
|
+
# self.logger.info(f">>> using kmax = {self.cp_kmax}")
|
116
|
+
# self.logger.info(f">>> using zmax = {self.cszfast_pk_grid_zmax}")
|
96
117
|
|
97
118
|
if self.jax_mode:
|
98
119
|
self.cp_h_nn = cp_h_nn_jax
|
99
120
|
self.cp_da_nn = cp_da_nn_jax
|
121
|
+
self.cp_pkl_nn = cp_pkl_nn_jax
|
122
|
+
|
123
|
+
self.pi = jnp.pi
|
124
|
+
self.transpose = jnp.transpose
|
125
|
+
self.asarray = jnp.asarray
|
126
|
+
|
127
|
+
self.linspace = jnp.linspace
|
128
|
+
self.geomspace = jnp.geomspace
|
129
|
+
self.arange = jnp.arange
|
130
|
+
self.zeros = jnp.zeros
|
131
|
+
|
132
|
+
|
100
133
|
else:
|
101
134
|
self.cp_h_nn = cp_h_nn
|
102
135
|
self.cp_da_nn = cp_da_nn
|
136
|
+
self.cp_pkl_nn = cp_pkl_nn
|
137
|
+
|
138
|
+
self.pi = np.pi
|
139
|
+
self.transpose = np.transpose
|
140
|
+
|
141
|
+
self.linspace = np.linspace
|
142
|
+
self.geomspace = np.geomspace
|
143
|
+
self.arange = np.arange
|
144
|
+
self.zeros = np.zeros
|
145
|
+
self.asarray = np.asarray
|
146
|
+
|
147
|
+
|
148
|
+
|
149
|
+
self.cp_ls = self.arange(2,self.cp_lmax+1)
|
150
|
+
self.cp_predicted_tt_spectrum =self.zeros(self.cp_lmax)
|
151
|
+
self.cp_predicted_te_spectrum =self.zeros(self.cp_lmax)
|
152
|
+
self.cp_predicted_ee_spectrum =self.zeros(self.cp_lmax)
|
153
|
+
self.cp_predicted_pp_spectrum =self.zeros(self.cp_lmax)
|
103
154
|
|
104
155
|
self.cp_s8_nn = cp_s8_nn
|
105
156
|
|
@@ -114,20 +165,6 @@ class Class_szfast(object):
|
|
114
165
|
self.use_Amod = 0
|
115
166
|
self.Amod = 0
|
116
167
|
|
117
|
-
self.cp_lmax = cp_l_max_scalars
|
118
|
-
self.cp_ls = np.arange(2,self.cp_lmax+1)
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
cosmo_model_dict = {0: 'lcdm',
|
124
|
-
1: 'mnu',
|
125
|
-
2: 'neff',
|
126
|
-
3: 'wcdm',
|
127
|
-
4: 'ede',
|
128
|
-
5: 'mnu-3states',
|
129
|
-
6: 'ede-v2'
|
130
|
-
}
|
131
168
|
|
132
169
|
|
133
170
|
if cosmo_model_dict[params_settings['cosmo_model']] == 'ede-v2':
|
@@ -140,60 +177,31 @@ class Class_szfast(object):
|
|
140
177
|
self.cp_ndspl_k = 10
|
141
178
|
self.cp_nk = 5000
|
142
179
|
|
143
|
-
self.cp_predicted_tt_spectrum =np.zeros(self.cp_lmax)
|
144
|
-
self.cp_predicted_te_spectrum =np.zeros(self.cp_lmax)
|
145
|
-
self.cp_predicted_ee_spectrum =np.zeros(self.cp_lmax)
|
146
|
-
self.cp_predicted_pp_spectrum =np.zeros(self.cp_lmax)
|
147
|
-
|
148
180
|
|
149
181
|
self.cszfast_ldim = 20000 # used for the cls arrays
|
150
|
-
|
151
182
|
self.cszfast_pk_grid_nz = 100 # has to be same as narraySZ, i.e., ndim_redshifts; it is setup hereafter if ndim_redshifts is passed
|
152
183
|
|
153
184
|
|
154
|
-
|
155
|
-
if (cosmo_model_dict[params_settings['cosmo_model']] == 'ede-v2'):
|
156
|
-
|
157
|
-
self.cszfast_pk_grid_zmax = 20.
|
158
|
-
self.cszfast_pk_grid_kmin = 5e-4
|
159
|
-
self.cszfast_pk_grid_kmax = 10.
|
160
|
-
self.cp_kmax = self.cszfast_pk_grid_kmax
|
161
|
-
self.cp_kmin = self.cszfast_pk_grid_kmin
|
162
|
-
# self.logger.info(f">>> using kmin = {self.cp_kmin}")
|
163
|
-
# self.logger.info(f">>> using kmax = {self.cp_kmax}")
|
164
|
-
# self.logger.info(f">>> using zmax = {self.cszfast_pk_grid_zmax}")
|
165
|
-
|
166
|
-
else:
|
167
|
-
|
168
|
-
self.cszfast_pk_grid_zmax = 5. # max z of our pk emulators (sept 23)
|
169
|
-
self.cszfast_pk_grid_kmin = 1e-4
|
170
|
-
self.cszfast_pk_grid_kmax = 50.
|
171
|
-
self.cp_kmax = self.cszfast_pk_grid_kmax
|
172
|
-
self.cp_kmin = self.cszfast_pk_grid_kmin
|
173
|
-
# self.logger.info(f">>> using kmin = {self.cp_kmin}")
|
174
|
-
# self.logger.info(f">>> using kmax = {self.cp_kmax}")
|
175
|
-
# self.logger.info(f">>> using zmax = {self.cszfast_pk_grid_zmax}")
|
176
|
-
|
177
|
-
self.cszfast_pk_grid_z = np.linspace(0.,self.cszfast_pk_grid_zmax,self.cszfast_pk_grid_nz)
|
185
|
+
self.cszfast_pk_grid_z = self.linspace(0.,self.cszfast_pk_grid_zmax,self.cszfast_pk_grid_nz)
|
178
186
|
self.cszfast_pk_grid_ln1pz = np.log(1.+self.cszfast_pk_grid_z)
|
179
187
|
|
180
188
|
|
181
|
-
self.cszfast_pk_grid_k =
|
189
|
+
self.cszfast_pk_grid_k = self.geomspace(self.cp_kmin,self.cp_kmax,self.cp_nk)[::self.cp_ndspl_k]
|
182
190
|
|
183
191
|
self.cszfast_pk_grid_lnk = np.log(self.cszfast_pk_grid_k)
|
184
192
|
|
185
|
-
self.cszfast_pk_grid_nk = len(
|
193
|
+
self.cszfast_pk_grid_nk = len(self.geomspace(self.cp_kmin,self.cp_kmax,self.cp_nk)[::self.cp_ndspl_k]) # has to be same as ndimSZ, and the same as dimension of cosmopower pk emulators
|
186
194
|
|
187
195
|
for k,v in params_settings.items():
|
188
196
|
|
189
197
|
if k == 'ndim_redshifts':
|
190
198
|
|
191
199
|
self.cszfast_pk_grid_nz = v
|
192
|
-
self.cszfast_pk_grid_z =
|
200
|
+
self.cszfast_pk_grid_z = self.linspace(0.,self.cszfast_pk_grid_zmax,self.cszfast_pk_grid_nz)
|
193
201
|
self.cszfast_pk_grid_ln1pz = np.log(1.+self.cszfast_pk_grid_z)
|
194
202
|
|
195
|
-
self.cszfast_pk_grid_pknl_flat =
|
196
|
-
self.cszfast_pk_grid_pkl_flat =
|
203
|
+
self.cszfast_pk_grid_pknl_flat = self.zeros(self.cszfast_pk_grid_nz*self.cszfast_pk_grid_nk)
|
204
|
+
self.cszfast_pk_grid_pkl_flat = self.zeros(self.cszfast_pk_grid_nz*self.cszfast_pk_grid_nk)
|
197
205
|
|
198
206
|
if k == 'cosmo_model':
|
199
207
|
|
@@ -212,13 +220,12 @@ class Class_szfast(object):
|
|
212
220
|
|
213
221
|
else:
|
214
222
|
|
215
|
-
ls =
|
216
|
-
dls = ls*(ls+1.)/2./
|
223
|
+
ls = self.arange(2,self.cp_nk+2)[::self.cp_ndspl_k] # jan 10 ndspl
|
224
|
+
dls = ls*(ls+1.)/2./self.pi
|
217
225
|
self.pk_power_fac= (dls)**-1
|
218
226
|
|
219
227
|
|
220
|
-
self.cp_z_interp =
|
221
|
-
self.cp_z_interp_jax = jnp.linspace(0.,20.,5000)
|
228
|
+
self.cp_z_interp = self.linspace(0.,20.,5000)
|
222
229
|
|
223
230
|
self.csz_base = None
|
224
231
|
|
@@ -226,7 +233,7 @@ class Class_szfast(object):
|
|
226
233
|
self.cszfast_zgrid_zmin = 0.
|
227
234
|
self.cszfast_zgrid_zmax = 4.
|
228
235
|
self.cszfast_zgrid_nz = 250
|
229
|
-
self.cszfast_zgrid =
|
236
|
+
self.cszfast_zgrid = self.linspace(self.cszfast_zgrid_zmin,
|
230
237
|
self.cszfast_zgrid_zmax,
|
231
238
|
self.cszfast_zgrid_nz)
|
232
239
|
|
@@ -234,14 +241,14 @@ class Class_szfast(object):
|
|
234
241
|
self.cszfast_mgrid_mmin = 1e10
|
235
242
|
self.cszfast_mgrid_mmax = 1e15
|
236
243
|
self.cszfast_mgrid_nm = 50
|
237
|
-
self.cszfast_mgrid =
|
244
|
+
self.cszfast_mgrid = self.geomspace(self.cszfast_mgrid_mmin,
|
238
245
|
self.cszfast_mgrid_mmax,
|
239
246
|
self.cszfast_mgrid_nm)
|
240
247
|
|
241
248
|
self.cszfast_gas_pressure_xgrid_xmin = 1e-2
|
242
249
|
self.cszfast_gas_pressure_xgrid_xmax = 1e2
|
243
250
|
self.cszfast_gas_pressure_xgrid_nx = 100
|
244
|
-
self.cszfast_gas_pressure_xgrid =
|
251
|
+
self.cszfast_gas_pressure_xgrid = self.geomspace(self.cszfast_gas_pressure_xgrid_xmin,
|
245
252
|
self.cszfast_gas_pressure_xgrid_xmax,
|
246
253
|
self.cszfast_gas_pressure_xgrid_nx)
|
247
254
|
|
@@ -312,7 +319,7 @@ class Class_szfast(object):
|
|
312
319
|
creal = predicted_testing_alphas_creal
|
313
320
|
cimag = predicted_testing_alphas_cimag
|
314
321
|
Nmax = len(self.cszfast_pk_grid_k)
|
315
|
-
cnew =
|
322
|
+
cnew = self.zeros(Nmax+1,dtype=complex)
|
316
323
|
for i in range(Nmax+1):
|
317
324
|
if i<int(Nmax/2):
|
318
325
|
cnew[i] = complex(creal[i],cimag[i])
|
@@ -362,20 +369,20 @@ class Class_szfast(object):
|
|
362
369
|
|
363
370
|
nl = len(self.cp_predicted_tt_spectrum)
|
364
371
|
cls = {}
|
365
|
-
cls['ell'] =
|
366
|
-
cls['tt'] =
|
367
|
-
cls['te'] =
|
368
|
-
cls['ee'] =
|
369
|
-
cls['pp'] =
|
370
|
-
cls['bb'] =
|
371
|
-
lcp =
|
372
|
+
cls['ell'] = self.arange(20000)
|
373
|
+
cls['tt'] = self.zeros(20000)
|
374
|
+
cls['te'] = self.zeros(20000)
|
375
|
+
cls['ee'] = self.zeros(20000)
|
376
|
+
cls['pp'] = self.zeros(20000)
|
377
|
+
cls['bb'] = self.zeros(20000)
|
378
|
+
lcp = self.asarray(cls['ell'][2:nl+2])
|
372
379
|
|
373
380
|
# print('cosmo_model:',self.cosmo_model,nl)
|
374
381
|
if self.cosmo_model == 'ede-v2':
|
375
382
|
factor_ttteee = 1./lcp**2
|
376
383
|
factor_pp = 1./lcp**3
|
377
384
|
else:
|
378
|
-
factor_ttteee = 1./(lcp*(lcp+1.)/2./
|
385
|
+
factor_ttteee = 1./(lcp*(lcp+1.)/2./self.pi)
|
379
386
|
factor_pp = 1./(lcp*(lcp+1.))**2.
|
380
387
|
|
381
388
|
self.cp_predicted_tt_spectrum *= factor_ttteee
|
@@ -402,7 +409,7 @@ class Class_szfast(object):
|
|
402
409
|
cmb_cls_loaded = pickle.load(handle)
|
403
410
|
nl_cls_file = len(cmb_cls_loaded['ell'])
|
404
411
|
cls_ls = cmb_cls_loaded['ell']
|
405
|
-
dlfac = cls_ls*(cls_ls+1.)/2./
|
412
|
+
dlfac = cls_ls*(cls_ls+1.)/2./self.pi
|
406
413
|
cls_tt = cmb_cls_loaded['tt']*dlfac
|
407
414
|
cls_te = cmb_cls_loaded['te']*dlfac
|
408
415
|
cls_ee = cmb_cls_loaded['ee']*dlfac
|
@@ -421,8 +428,6 @@ class Class_szfast(object):
|
|
421
428
|
**params_values_dict):
|
422
429
|
|
423
430
|
z_arr = self.cszfast_pk_grid_z
|
424
|
-
|
425
|
-
|
426
431
|
k_arr = self.cszfast_pk_grid_k
|
427
432
|
|
428
433
|
# print(">>> z_arr:",z_arr)
|
@@ -464,26 +469,32 @@ class Class_szfast(object):
|
|
464
469
|
|
465
470
|
params_dict_pp = params_dict.copy()
|
466
471
|
params_dict_pp['z_pk_save_nonclass'] = [zp]
|
467
|
-
|
472
|
+
if self.jax_mode:
|
473
|
+
predicted_pk_spectrum_z.append(self.cp_pkl_nn[self.cosmo_model].predict(params_dict_pp))
|
474
|
+
else:
|
475
|
+
predicted_pk_spectrum_z.append(self.cp_pkl_nn[self.cosmo_model].predictions_np(params_dict_pp)[0])
|
468
476
|
|
469
477
|
# if abs(zp-0.5) < 0.01:
|
470
478
|
# print(">>> predicted_pk_spectrum_z:",predicted_pk_spectrum_z[-1])
|
471
479
|
# import pprint
|
472
480
|
# pprint.pprint(params_dict_pp)
|
473
481
|
|
474
|
-
predicted_pk_spectrum =
|
482
|
+
predicted_pk_spectrum = self.asarray(predicted_pk_spectrum_z)
|
475
483
|
|
476
484
|
|
477
485
|
pk = 10.**predicted_pk_spectrum
|
478
486
|
|
479
487
|
pk_re = pk*self.pk_power_fac
|
480
|
-
pk_re =
|
488
|
+
pk_re = self.transpose(pk_re)
|
481
489
|
|
482
490
|
# print(">>> pk_re:",pk_re)
|
483
491
|
# import sys
|
484
492
|
# sys.exit(0)
|
485
493
|
|
486
|
-
self.
|
494
|
+
if self.jax_mode:
|
495
|
+
self.pkl_interp = None
|
496
|
+
else:
|
497
|
+
self.pkl_interp = PowerSpectrumInterpolator(z_arr,k_arr,np.log(pk_re).T,logP=True)
|
487
498
|
|
488
499
|
self.cszfast_pk_grid_pk = pk_re
|
489
500
|
self.cszfast_pk_grid_pkl_flat = pk_re.flatten()
|
@@ -570,7 +581,7 @@ class Class_szfast(object):
|
|
570
581
|
s8z = self.cp_s8_nn[self.cosmo_model].predictions_np(params_dict)
|
571
582
|
# print(self.s8z)
|
572
583
|
self.s8z_interp = scipy.interpolate.interp1d(
|
573
|
-
|
584
|
+
self.linspace(0.,20.,5000),
|
574
585
|
s8z[0],
|
575
586
|
kind='linear',
|
576
587
|
axis=-1,
|
@@ -609,13 +620,13 @@ class Class_szfast(object):
|
|
609
620
|
params_dict_pp['z_pk_save_nonclass'] = [zp]
|
610
621
|
predicted_pk_spectrum_z.append(self.cp_pknl_nn[self.cosmo_model].predictions_np(params_dict_pp)[0])
|
611
622
|
|
612
|
-
predicted_pk_spectrum =
|
623
|
+
predicted_pk_spectrum = self.asarray(predicted_pk_spectrum_z)
|
613
624
|
|
614
625
|
|
615
626
|
pk = 10.**predicted_pk_spectrum
|
616
627
|
|
617
628
|
pk_re = pk*self.pk_power_fac
|
618
|
-
pk_re =
|
629
|
+
pk_re = self.transpose(pk_re)
|
619
630
|
|
620
631
|
|
621
632
|
self.pknl_interp = PowerSpectrumInterpolator(z_arr,k_arr,np.log(pk_re).T,logP=True)
|
@@ -631,16 +642,12 @@ class Class_szfast(object):
|
|
631
642
|
z_asked,
|
632
643
|
params_values_dict=None):
|
633
644
|
|
634
|
-
z_arr = self.cszfast_pk_grid_z
|
635
645
|
|
636
646
|
k_arr = self.cszfast_pk_grid_k
|
637
647
|
|
638
648
|
if params_values_dict:
|
639
|
-
|
640
649
|
params_values = params_values_dict.copy()
|
641
|
-
|
642
650
|
else:
|
643
|
-
|
644
651
|
params_values = self.params_for_emulators
|
645
652
|
|
646
653
|
update_params_with_defaults(params_values, self.emulator_dict[self.cosmo_model]['default'])
|
@@ -657,19 +664,21 @@ class Class_szfast(object):
|
|
657
664
|
|
658
665
|
predicted_pk_spectrum_z = []
|
659
666
|
|
660
|
-
z_asked = z_asked
|
661
667
|
params_dict_pp = params_dict.copy()
|
662
668
|
update_params_with_defaults(params_dict_pp, self.emulator_dict[self.cosmo_model]['default'])
|
663
669
|
|
664
670
|
params_dict_pp['z_pk_save_nonclass'] = [z_asked]
|
665
|
-
|
671
|
+
if self.jax_mode:
|
672
|
+
predicted_pk_spectrum_z.append(self.cp_pkl_nn[self.cosmo_model].predict(params_dict_pp))
|
673
|
+
else:
|
674
|
+
predicted_pk_spectrum_z.append(self.cp_pkl_nn[self.cosmo_model].predictions_np(params_dict_pp)[0])
|
666
675
|
|
667
|
-
predicted_pk_spectrum =
|
676
|
+
predicted_pk_spectrum = self.asarray(predicted_pk_spectrum_z)
|
668
677
|
|
669
678
|
|
670
679
|
pk = 10.**predicted_pk_spectrum
|
671
680
|
pk_re = pk*self.pk_power_fac
|
672
|
-
pk_re =
|
681
|
+
pk_re = self.transpose(pk_re)
|
673
682
|
|
674
683
|
|
675
684
|
return pk_re, k_arr
|
@@ -712,12 +721,12 @@ class Class_szfast(object):
|
|
712
721
|
params_dict_pp['z_pk_save_nonclass'] = [z_asked]
|
713
722
|
predicted_pk_spectrum_z.append(self.cp_pknl_nn[self.cosmo_model].predictions_np(params_dict_pp)[0])
|
714
723
|
|
715
|
-
predicted_pk_spectrum =
|
724
|
+
predicted_pk_spectrum = self.asarray(predicted_pk_spectrum_z)
|
716
725
|
|
717
726
|
|
718
727
|
pk = 10.**predicted_pk_spectrum
|
719
728
|
pk_re = pk*self.pk_power_fac
|
720
|
-
pk_re =
|
729
|
+
pk_re = self.transpose(pk_re)
|
721
730
|
|
722
731
|
|
723
732
|
return pk_re, k_arr
|
@@ -748,19 +757,10 @@ class Class_szfast(object):
|
|
748
757
|
# print("self.cp_predicted_hubble type:", type(self.cp_predicted_hubble))
|
749
758
|
# print("self.cp_predicted_hubble",self.cp_predicted_hubble)
|
750
759
|
|
751
|
-
# self.hz_interp = jscipy.interpolate.interp1d(
|
752
|
-
# self.cp_z_interp_jax,
|
753
|
-
# self.cp_predicted_hubble,
|
754
|
-
# kind='linear',
|
755
|
-
# axis=-1,
|
756
|
-
# copy=True,
|
757
|
-
# bounds_error=None,
|
758
|
-
# fill_value=np.nan,
|
759
|
-
# assume_sorted=False)
|
760
760
|
|
761
761
|
# Assuming `cp_z_interp` and `cp_predicted_hubble` are JAX arrays
|
762
762
|
def hz_interp(x):
|
763
|
-
return jnp.interp(x, self.
|
763
|
+
return jnp.interp(x, self.cp_z_interp, self.cp_predicted_hubble, left=jnp.nan, right=jnp.nan)
|
764
764
|
|
765
765
|
self.hz_interp = hz_interp
|
766
766
|
# exit()
|
@@ -803,10 +803,10 @@ class Class_szfast(object):
|
|
803
803
|
if self.cosmo_model == 'ede-v2':
|
804
804
|
# print('ede-v2 case')
|
805
805
|
self.cp_predicted_da = jnp.insert(self.cp_predicted_da, 0, 0)
|
806
|
-
self.cp_predicted_da *= (1.+self.
|
806
|
+
self.cp_predicted_da *= (1.+self.cp_z_interp)
|
807
807
|
|
808
808
|
def chi_interp(x):
|
809
|
-
return jnp.interp(x, self.
|
809
|
+
return jnp.interp(x, self.cp_z_interp, self.cp_predicted_da, left=jnp.nan, right=jnp.nan)
|
810
810
|
|
811
811
|
self.chi_interp = chi_interp
|
812
812
|
|
@@ -835,12 +835,12 @@ class Class_szfast(object):
|
|
835
835
|
def get_cmb_cls(self,ell_factor=True,Tcmb_uk = Tcmb_uk):
|
836
836
|
|
837
837
|
cls = {}
|
838
|
-
cls['ell'] =
|
839
|
-
cls['tt'] =
|
840
|
-
cls['te'] =
|
841
|
-
cls['ee'] =
|
842
|
-
cls['pp'] =
|
843
|
-
cls['bb'] =
|
838
|
+
cls['ell'] = self.arange(self.cszfast_ldim)
|
839
|
+
cls['tt'] = self.zeros(self.cszfast_ldim)
|
840
|
+
cls['te'] = self.zeros(self.cszfast_ldim)
|
841
|
+
cls['ee'] = self.zeros(self.cszfast_ldim)
|
842
|
+
cls['pp'] = self.zeros(self.cszfast_ldim)
|
843
|
+
cls['bb'] = self.zeros(self.cszfast_ldim)
|
844
844
|
cls['tt'][2:self.cp_lmax+1] = (Tcmb_uk)**2.*self.cp_predicted_tt_spectrum.copy()
|
845
845
|
cls['te'][2:self.cp_lmax+1] = (Tcmb_uk)**2.*self.cp_predicted_te_spectrum.copy()
|
846
846
|
cls['ee'][2:self.cp_lmax+1] = (Tcmb_uk)**2.*self.cp_predicted_ee_spectrum.copy()
|
@@ -848,12 +848,12 @@ class Class_szfast(object):
|
|
848
848
|
|
849
849
|
|
850
850
|
if ell_factor==False:
|
851
|
-
fac_l =
|
852
|
-
fac_l[2:self.cp_lmax+1] = 1./(cls['ell'][2:self.cp_lmax+1]*(cls['ell'][2:self.cp_lmax+1]+1.)/2./
|
851
|
+
fac_l = self.zeros(self.cszfast_ldim)
|
852
|
+
fac_l[2:self.cp_lmax+1] = 1./(cls['ell'][2:self.cp_lmax+1]*(cls['ell'][2:self.cp_lmax+1]+1.)/2./self.pi)
|
853
853
|
cls['tt'][2:self.cp_lmax+1] *= fac_l[2:self.cp_lmax+1]
|
854
854
|
cls['te'][2:self.cp_lmax+1] *= fac_l[2:self.cp_lmax+1]
|
855
855
|
cls['ee'][2:self.cp_lmax+1] *= fac_l[2:self.cp_lmax+1]
|
856
|
-
# cls['bb'] =
|
856
|
+
# cls['bb'] = self.zeros(self.cszfast_ldim)
|
857
857
|
return cls
|
858
858
|
|
859
859
|
|
@@ -899,7 +899,10 @@ class Class_szfast(object):
|
|
899
899
|
return np.array(self.hz_interp(z)*H_units_conv_factor[units])
|
900
900
|
|
901
901
|
def get_chi(self, z):
|
902
|
-
|
902
|
+
if self.jax_mode:
|
903
|
+
return jnp.array(self.chi_interp(z))
|
904
|
+
else:
|
905
|
+
return np.array(self.chi_interp(z))
|
903
906
|
|
904
907
|
def get_gas_pressure_profile_x(self,z,m,x):
|
905
908
|
return 0#np.vectorize(self.csz_base.get_pressure_P_over_P_delta_at_x_M_z_b12_200c)(x,m,z)
|
@@ -913,7 +916,7 @@ class Class_szfast(object):
|
|
913
916
|
|
914
917
|
|
915
918
|
def tabulate_gas_pressure_profile_k(self):
|
916
|
-
z_asked,m_asked,x_asked = 0.2,3e14,
|
919
|
+
z_asked,m_asked,x_asked = 0.2,3e14,self.geomspace(1e-3,1e2,500)
|
917
920
|
start = time.time()
|
918
921
|
px = self.get_gas_pressure_profile_x(z_asked,m_asked,x_asked)
|
919
922
|
end = time.time()
|
classy_szfast/cosmopower_jax.py
CHANGED
@@ -105,8 +105,9 @@ for mp in cosmo_model_list:
|
|
105
105
|
|
106
106
|
cp_pknl_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'PK/' + emulator_dict[mp]['PKNL'])
|
107
107
|
|
108
|
-
cp_pkl_nn_jax[mp] =
|
109
|
-
|
108
|
+
cp_pkl_nn_jax[mp] = CosmoPowerJAX_custom(probe='custom_log',filepath=path_to_emulators +'PK/' + emulator_dict[mp]['PKL'] + '.npz')
|
109
|
+
cp_pkl_nn_jax[mp].ten_to_predictions = False
|
110
|
+
|
110
111
|
cp_der_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'derived-parameters/' + emulator_dict[mp]['DER'])
|
111
112
|
|
112
113
|
# cp_da_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['DAZ'])
|
@@ -1,11 +1,11 @@
|
|
1
1
|
classy_szfast/__init__.py,sha256=E2thrL0Z9oXFfdzwcsu-xbOytudLFTlRlPqVFGlPPPg,279
|
2
2
|
classy_szfast/classy_sz.py,sha256=QmbwrSXInQLMvCDqsr7KPmtaU0KOiOt1Rb-cTKuulZw,22240
|
3
|
-
classy_szfast/classy_szfast.py,sha256=
|
3
|
+
classy_szfast/classy_szfast.py,sha256=H3DPD4_ZvVJ_FCVbEzc4rW8DYQHPZCCaGQZtQ1f-7oQ,37646
|
4
4
|
classy_szfast/config.py,sha256=cd7Z62-qnX_4FJWfUNqcyJVh-AdBiXrF8DcQGpyAUZM,274
|
5
5
|
classy_szfast/cosmopower.py,sha256=ooYK2BDOZSo3XtGHfPtjXHxr5UW-yVngLPkb5gpvTx8,2351
|
6
|
-
classy_szfast/cosmopower_jax.py,sha256=
|
6
|
+
classy_szfast/cosmopower_jax.py,sha256=NvwY1a9YoizOV0zKp0SEtiHQkhWRO-seGFzO3FIzLl0,5436
|
7
7
|
classy_szfast/cosmosis_classy_szfast_interface.py,sha256=zAnxvFtn73a5yS7jgs59zpWFEYKCIQyraYPs5hQ4Le8,11483
|
8
|
-
classy_szfast/emulators_meta_data.py,sha256
|
8
|
+
classy_szfast/emulators_meta_data.py,sha256=faA0iQqzfC5lOj1hu7wSEZUFfhBGLJqX9-xovMfTbr0,9723
|
9
9
|
classy_szfast/pks_and_sigmas.py,sha256=drtuujE1HhlrYY1hY92DyY5lXlYS1uE15MSuVI4uo6k,6625
|
10
10
|
classy_szfast/restore_nn.py,sha256=DqA9thhTRiGBDVb9zjhqcbF2W4V0AU0vrjJFhnLboU4,21075
|
11
11
|
classy_szfast/suppress_warnings.py,sha256=6wIBml2Sj9DyRGZlZWhuA9hqvpxqrNyYjuz6BPK_a6E,202
|
@@ -14,7 +14,7 @@ classy_szfast/custom_bias/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJW
|
|
14
14
|
classy_szfast/custom_bias/custom_bias.py,sha256=aR2t5RTIwv7P0m2bsEU0Eq6BTkj4pG10AebH6QpG4qM,486
|
15
15
|
classy_szfast/custom_profiles/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
16
16
|
classy_szfast/custom_profiles/custom_profiles.py,sha256=4LZwb2XoqwCyWNmW2s24Z7AJdmgVdaRG7yYaBYe-d9Q,1188
|
17
|
-
classy_szfast-0.0.25.
|
18
|
-
classy_szfast-0.0.25.
|
19
|
-
classy_szfast-0.0.25.
|
20
|
-
classy_szfast-0.0.25.
|
17
|
+
classy_szfast-0.0.25.post4.dist-info/METADATA,sha256=oNj9IJ8zELtcPCEOOi4wQzx9SfxVaMmdIl6QjyAFdkY,548
|
18
|
+
classy_szfast-0.0.25.post4.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
19
|
+
classy_szfast-0.0.25.post4.dist-info/top_level.txt,sha256=hRgqpilUck4lx2KkaWI2y9aCDKqF6pFfGHfNaoPFxv0,14
|
20
|
+
classy_szfast-0.0.25.post4.dist-info/RECORD,,
|
File without changes
|
File without changes
|