classy-szfast 0.0.25.post2__py3-none-any.whl → 0.0.25.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- classy_szfast/classy_szfast.py +152 -132
- classy_szfast/cosmopower_jax.py +80 -5
- classy_szfast/emulators_meta_data.py +1 -1
- {classy_szfast-0.0.25.post2.dist-info → classy_szfast-0.0.25.post4.dist-info}/METADATA +1 -1
- {classy_szfast-0.0.25.post2.dist-info → classy_szfast-0.0.25.post4.dist-info}/RECORD +7 -7
- {classy_szfast-0.0.25.post2.dist-info → classy_szfast-0.0.25.post4.dist-info}/WHEEL +0 -0
- {classy_szfast-0.0.25.post2.dist-info → classy_szfast-0.0.25.post4.dist-info}/top_level.txt +0 -0
classy_szfast/classy_szfast.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
from .utils import *
|
2
2
|
from .config import *
|
3
3
|
import numpy as np
|
4
|
-
from .emulators_meta_data import emulator_dict, dofftlog_alphas, cp_l_max_scalars
|
4
|
+
from .emulators_meta_data import emulator_dict, dofftlog_alphas, cp_l_max_scalars, cosmo_model_list
|
5
5
|
from .cosmopower import cp_tt_nn, cp_te_nn, cp_ee_nn, cp_ee_nn, cp_pp_nn, cp_pknl_nn, cp_pkl_nn, cp_der_nn, cp_da_nn, cp_h_nn, cp_s8_nn, cp_pkl_fftlog_alphas_real_nn, cp_pkl_fftlog_alphas_imag_nn, cp_pkl_fftlog_alphas_nus
|
6
6
|
from .cosmopower_jax import cp_tt_nn_jax, cp_te_nn_jax, cp_ee_nn_jax, cp_ee_nn_jax, cp_pp_nn_jax, cp_pknl_nn_jax, cp_pkl_nn_jax, cp_der_nn_jax, cp_da_nn_jax, cp_h_nn_jax, cp_s8_nn_jax
|
7
7
|
from .pks_and_sigmas import *
|
@@ -76,13 +76,10 @@ class Class_szfast(object):
|
|
76
76
|
|
77
77
|
self.jax_mode = params_settings["jax"]
|
78
78
|
|
79
|
+
|
79
80
|
# print(f"JAX mode: {self.jax_mode}")
|
80
81
|
|
81
82
|
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
83
|
# cosmopower emulators
|
87
84
|
# self.cp_path_to_cosmopower_organization = path_to_cosmopower_organization + '/'
|
88
85
|
self.cp_tt_nn = cp_tt_nn
|
@@ -90,14 +87,70 @@ class Class_szfast(object):
|
|
90
87
|
self.cp_ee_nn = cp_ee_nn
|
91
88
|
self.cp_pp_nn = cp_pp_nn
|
92
89
|
self.cp_pknl_nn = cp_pknl_nn
|
93
|
-
self.cp_pkl_nn = cp_pkl_nn
|
94
90
|
self.cp_der_nn = cp_der_nn
|
95
|
-
|
91
|
+
|
92
|
+
self.cp_lmax = cp_l_max_scalars
|
93
|
+
|
94
|
+
cosmo_model_dict = {i: model for i, model in enumerate(cosmo_model_list)}
|
95
|
+
|
96
|
+
if (cosmo_model_dict[params_settings['cosmo_model']] == 'ede-v2'):
|
97
|
+
|
98
|
+
self.cszfast_pk_grid_zmax = 20.
|
99
|
+
self.cszfast_pk_grid_kmin = 5e-4
|
100
|
+
self.cszfast_pk_grid_kmax = 10.
|
101
|
+
self.cp_kmax = self.cszfast_pk_grid_kmax
|
102
|
+
self.cp_kmin = self.cszfast_pk_grid_kmin
|
103
|
+
# self.logger.info(f">>> using kmin = {self.cp_kmin}")
|
104
|
+
# self.logger.info(f">>> using kmax = {self.cp_kmax}")
|
105
|
+
# self.logger.info(f">>> using zmax = {self.cszfast_pk_grid_zmax}")
|
106
|
+
|
107
|
+
else:
|
108
|
+
|
109
|
+
self.cszfast_pk_grid_zmax = 5. # max z of our pk emulators (sept 23)
|
110
|
+
self.cszfast_pk_grid_kmin = 1e-4
|
111
|
+
self.cszfast_pk_grid_kmax = 50.
|
112
|
+
self.cp_kmax = self.cszfast_pk_grid_kmax
|
113
|
+
self.cp_kmin = self.cszfast_pk_grid_kmin
|
114
|
+
# self.logger.info(f">>> using kmin = {self.cp_kmin}")
|
115
|
+
# self.logger.info(f">>> using kmax = {self.cp_kmax}")
|
116
|
+
# self.logger.info(f">>> using zmax = {self.cszfast_pk_grid_zmax}")
|
96
117
|
|
97
118
|
if self.jax_mode:
|
98
119
|
self.cp_h_nn = cp_h_nn_jax
|
120
|
+
self.cp_da_nn = cp_da_nn_jax
|
121
|
+
self.cp_pkl_nn = cp_pkl_nn_jax
|
122
|
+
|
123
|
+
self.pi = jnp.pi
|
124
|
+
self.transpose = jnp.transpose
|
125
|
+
self.asarray = jnp.asarray
|
126
|
+
|
127
|
+
self.linspace = jnp.linspace
|
128
|
+
self.geomspace = jnp.geomspace
|
129
|
+
self.arange = jnp.arange
|
130
|
+
self.zeros = jnp.zeros
|
131
|
+
|
132
|
+
|
99
133
|
else:
|
100
134
|
self.cp_h_nn = cp_h_nn
|
135
|
+
self.cp_da_nn = cp_da_nn
|
136
|
+
self.cp_pkl_nn = cp_pkl_nn
|
137
|
+
|
138
|
+
self.pi = np.pi
|
139
|
+
self.transpose = np.transpose
|
140
|
+
|
141
|
+
self.linspace = np.linspace
|
142
|
+
self.geomspace = np.geomspace
|
143
|
+
self.arange = np.arange
|
144
|
+
self.zeros = np.zeros
|
145
|
+
self.asarray = np.asarray
|
146
|
+
|
147
|
+
|
148
|
+
|
149
|
+
self.cp_ls = self.arange(2,self.cp_lmax+1)
|
150
|
+
self.cp_predicted_tt_spectrum =self.zeros(self.cp_lmax)
|
151
|
+
self.cp_predicted_te_spectrum =self.zeros(self.cp_lmax)
|
152
|
+
self.cp_predicted_ee_spectrum =self.zeros(self.cp_lmax)
|
153
|
+
self.cp_predicted_pp_spectrum =self.zeros(self.cp_lmax)
|
101
154
|
|
102
155
|
self.cp_s8_nn = cp_s8_nn
|
103
156
|
|
@@ -112,20 +165,6 @@ class Class_szfast(object):
|
|
112
165
|
self.use_Amod = 0
|
113
166
|
self.Amod = 0
|
114
167
|
|
115
|
-
self.cp_lmax = cp_l_max_scalars
|
116
|
-
self.cp_ls = np.arange(2,self.cp_lmax+1)
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
cosmo_model_dict = {0: 'lcdm',
|
122
|
-
1: 'mnu',
|
123
|
-
2: 'neff',
|
124
|
-
3: 'wcdm',
|
125
|
-
4: 'ede',
|
126
|
-
5: 'mnu-3states',
|
127
|
-
6: 'ede-v2'
|
128
|
-
}
|
129
168
|
|
130
169
|
|
131
170
|
if cosmo_model_dict[params_settings['cosmo_model']] == 'ede-v2':
|
@@ -138,60 +177,31 @@ class Class_szfast(object):
|
|
138
177
|
self.cp_ndspl_k = 10
|
139
178
|
self.cp_nk = 5000
|
140
179
|
|
141
|
-
self.cp_predicted_tt_spectrum =np.zeros(self.cp_lmax)
|
142
|
-
self.cp_predicted_te_spectrum =np.zeros(self.cp_lmax)
|
143
|
-
self.cp_predicted_ee_spectrum =np.zeros(self.cp_lmax)
|
144
|
-
self.cp_predicted_pp_spectrum =np.zeros(self.cp_lmax)
|
145
|
-
|
146
180
|
|
147
181
|
self.cszfast_ldim = 20000 # used for the cls arrays
|
148
|
-
|
149
182
|
self.cszfast_pk_grid_nz = 100 # has to be same as narraySZ, i.e., ndim_redshifts; it is setup hereafter if ndim_redshifts is passed
|
150
183
|
|
151
184
|
|
152
|
-
|
153
|
-
if (cosmo_model_dict[params_settings['cosmo_model']] == 'ede-v2'):
|
154
|
-
|
155
|
-
self.cszfast_pk_grid_zmax = 20.
|
156
|
-
self.cszfast_pk_grid_kmin = 5e-4
|
157
|
-
self.cszfast_pk_grid_kmax = 10.
|
158
|
-
self.cp_kmax = self.cszfast_pk_grid_kmax
|
159
|
-
self.cp_kmin = self.cszfast_pk_grid_kmin
|
160
|
-
# self.logger.info(f">>> using kmin = {self.cp_kmin}")
|
161
|
-
# self.logger.info(f">>> using kmax = {self.cp_kmax}")
|
162
|
-
# self.logger.info(f">>> using zmax = {self.cszfast_pk_grid_zmax}")
|
163
|
-
|
164
|
-
else:
|
165
|
-
|
166
|
-
self.cszfast_pk_grid_zmax = 5. # max z of our pk emulators (sept 23)
|
167
|
-
self.cszfast_pk_grid_kmin = 1e-4
|
168
|
-
self.cszfast_pk_grid_kmax = 50.
|
169
|
-
self.cp_kmax = self.cszfast_pk_grid_kmax
|
170
|
-
self.cp_kmin = self.cszfast_pk_grid_kmin
|
171
|
-
# self.logger.info(f">>> using kmin = {self.cp_kmin}")
|
172
|
-
# self.logger.info(f">>> using kmax = {self.cp_kmax}")
|
173
|
-
# self.logger.info(f">>> using zmax = {self.cszfast_pk_grid_zmax}")
|
174
|
-
|
175
|
-
self.cszfast_pk_grid_z = np.linspace(0.,self.cszfast_pk_grid_zmax,self.cszfast_pk_grid_nz)
|
185
|
+
self.cszfast_pk_grid_z = self.linspace(0.,self.cszfast_pk_grid_zmax,self.cszfast_pk_grid_nz)
|
176
186
|
self.cszfast_pk_grid_ln1pz = np.log(1.+self.cszfast_pk_grid_z)
|
177
187
|
|
178
188
|
|
179
|
-
self.cszfast_pk_grid_k =
|
189
|
+
self.cszfast_pk_grid_k = self.geomspace(self.cp_kmin,self.cp_kmax,self.cp_nk)[::self.cp_ndspl_k]
|
180
190
|
|
181
191
|
self.cszfast_pk_grid_lnk = np.log(self.cszfast_pk_grid_k)
|
182
192
|
|
183
|
-
self.cszfast_pk_grid_nk = len(
|
193
|
+
self.cszfast_pk_grid_nk = len(self.geomspace(self.cp_kmin,self.cp_kmax,self.cp_nk)[::self.cp_ndspl_k]) # has to be same as ndimSZ, and the same as dimension of cosmopower pk emulators
|
184
194
|
|
185
195
|
for k,v in params_settings.items():
|
186
196
|
|
187
197
|
if k == 'ndim_redshifts':
|
188
198
|
|
189
199
|
self.cszfast_pk_grid_nz = v
|
190
|
-
self.cszfast_pk_grid_z =
|
200
|
+
self.cszfast_pk_grid_z = self.linspace(0.,self.cszfast_pk_grid_zmax,self.cszfast_pk_grid_nz)
|
191
201
|
self.cszfast_pk_grid_ln1pz = np.log(1.+self.cszfast_pk_grid_z)
|
192
202
|
|
193
|
-
self.cszfast_pk_grid_pknl_flat =
|
194
|
-
self.cszfast_pk_grid_pkl_flat =
|
203
|
+
self.cszfast_pk_grid_pknl_flat = self.zeros(self.cszfast_pk_grid_nz*self.cszfast_pk_grid_nk)
|
204
|
+
self.cszfast_pk_grid_pkl_flat = self.zeros(self.cszfast_pk_grid_nz*self.cszfast_pk_grid_nk)
|
195
205
|
|
196
206
|
if k == 'cosmo_model':
|
197
207
|
|
@@ -210,13 +220,12 @@ class Class_szfast(object):
|
|
210
220
|
|
211
221
|
else:
|
212
222
|
|
213
|
-
ls =
|
214
|
-
dls = ls*(ls+1.)/2./
|
223
|
+
ls = self.arange(2,self.cp_nk+2)[::self.cp_ndspl_k] # jan 10 ndspl
|
224
|
+
dls = ls*(ls+1.)/2./self.pi
|
215
225
|
self.pk_power_fac= (dls)**-1
|
216
226
|
|
217
227
|
|
218
|
-
self.cp_z_interp =
|
219
|
-
self.cp_z_interp_jax = jnp.linspace(0.,20.,5000)
|
228
|
+
self.cp_z_interp = self.linspace(0.,20.,5000)
|
220
229
|
|
221
230
|
self.csz_base = None
|
222
231
|
|
@@ -224,7 +233,7 @@ class Class_szfast(object):
|
|
224
233
|
self.cszfast_zgrid_zmin = 0.
|
225
234
|
self.cszfast_zgrid_zmax = 4.
|
226
235
|
self.cszfast_zgrid_nz = 250
|
227
|
-
self.cszfast_zgrid =
|
236
|
+
self.cszfast_zgrid = self.linspace(self.cszfast_zgrid_zmin,
|
228
237
|
self.cszfast_zgrid_zmax,
|
229
238
|
self.cszfast_zgrid_nz)
|
230
239
|
|
@@ -232,14 +241,14 @@ class Class_szfast(object):
|
|
232
241
|
self.cszfast_mgrid_mmin = 1e10
|
233
242
|
self.cszfast_mgrid_mmax = 1e15
|
234
243
|
self.cszfast_mgrid_nm = 50
|
235
|
-
self.cszfast_mgrid =
|
244
|
+
self.cszfast_mgrid = self.geomspace(self.cszfast_mgrid_mmin,
|
236
245
|
self.cszfast_mgrid_mmax,
|
237
246
|
self.cszfast_mgrid_nm)
|
238
247
|
|
239
248
|
self.cszfast_gas_pressure_xgrid_xmin = 1e-2
|
240
249
|
self.cszfast_gas_pressure_xgrid_xmax = 1e2
|
241
250
|
self.cszfast_gas_pressure_xgrid_nx = 100
|
242
|
-
self.cszfast_gas_pressure_xgrid =
|
251
|
+
self.cszfast_gas_pressure_xgrid = self.geomspace(self.cszfast_gas_pressure_xgrid_xmin,
|
243
252
|
self.cszfast_gas_pressure_xgrid_xmax,
|
244
253
|
self.cszfast_gas_pressure_xgrid_nx)
|
245
254
|
|
@@ -310,7 +319,7 @@ class Class_szfast(object):
|
|
310
319
|
creal = predicted_testing_alphas_creal
|
311
320
|
cimag = predicted_testing_alphas_cimag
|
312
321
|
Nmax = len(self.cszfast_pk_grid_k)
|
313
|
-
cnew =
|
322
|
+
cnew = self.zeros(Nmax+1,dtype=complex)
|
314
323
|
for i in range(Nmax+1):
|
315
324
|
if i<int(Nmax/2):
|
316
325
|
cnew[i] = complex(creal[i],cimag[i])
|
@@ -360,20 +369,20 @@ class Class_szfast(object):
|
|
360
369
|
|
361
370
|
nl = len(self.cp_predicted_tt_spectrum)
|
362
371
|
cls = {}
|
363
|
-
cls['ell'] =
|
364
|
-
cls['tt'] =
|
365
|
-
cls['te'] =
|
366
|
-
cls['ee'] =
|
367
|
-
cls['pp'] =
|
368
|
-
cls['bb'] =
|
369
|
-
lcp =
|
372
|
+
cls['ell'] = self.arange(20000)
|
373
|
+
cls['tt'] = self.zeros(20000)
|
374
|
+
cls['te'] = self.zeros(20000)
|
375
|
+
cls['ee'] = self.zeros(20000)
|
376
|
+
cls['pp'] = self.zeros(20000)
|
377
|
+
cls['bb'] = self.zeros(20000)
|
378
|
+
lcp = self.asarray(cls['ell'][2:nl+2])
|
370
379
|
|
371
380
|
# print('cosmo_model:',self.cosmo_model,nl)
|
372
381
|
if self.cosmo_model == 'ede-v2':
|
373
382
|
factor_ttteee = 1./lcp**2
|
374
383
|
factor_pp = 1./lcp**3
|
375
384
|
else:
|
376
|
-
factor_ttteee = 1./(lcp*(lcp+1.)/2./
|
385
|
+
factor_ttteee = 1./(lcp*(lcp+1.)/2./self.pi)
|
377
386
|
factor_pp = 1./(lcp*(lcp+1.))**2.
|
378
387
|
|
379
388
|
self.cp_predicted_tt_spectrum *= factor_ttteee
|
@@ -400,7 +409,7 @@ class Class_szfast(object):
|
|
400
409
|
cmb_cls_loaded = pickle.load(handle)
|
401
410
|
nl_cls_file = len(cmb_cls_loaded['ell'])
|
402
411
|
cls_ls = cmb_cls_loaded['ell']
|
403
|
-
dlfac = cls_ls*(cls_ls+1.)/2./
|
412
|
+
dlfac = cls_ls*(cls_ls+1.)/2./self.pi
|
404
413
|
cls_tt = cmb_cls_loaded['tt']*dlfac
|
405
414
|
cls_te = cmb_cls_loaded['te']*dlfac
|
406
415
|
cls_ee = cmb_cls_loaded['ee']*dlfac
|
@@ -419,8 +428,6 @@ class Class_szfast(object):
|
|
419
428
|
**params_values_dict):
|
420
429
|
|
421
430
|
z_arr = self.cszfast_pk_grid_z
|
422
|
-
|
423
|
-
|
424
431
|
k_arr = self.cszfast_pk_grid_k
|
425
432
|
|
426
433
|
# print(">>> z_arr:",z_arr)
|
@@ -462,26 +469,32 @@ class Class_szfast(object):
|
|
462
469
|
|
463
470
|
params_dict_pp = params_dict.copy()
|
464
471
|
params_dict_pp['z_pk_save_nonclass'] = [zp]
|
465
|
-
|
472
|
+
if self.jax_mode:
|
473
|
+
predicted_pk_spectrum_z.append(self.cp_pkl_nn[self.cosmo_model].predict(params_dict_pp))
|
474
|
+
else:
|
475
|
+
predicted_pk_spectrum_z.append(self.cp_pkl_nn[self.cosmo_model].predictions_np(params_dict_pp)[0])
|
466
476
|
|
467
477
|
# if abs(zp-0.5) < 0.01:
|
468
478
|
# print(">>> predicted_pk_spectrum_z:",predicted_pk_spectrum_z[-1])
|
469
479
|
# import pprint
|
470
480
|
# pprint.pprint(params_dict_pp)
|
471
481
|
|
472
|
-
predicted_pk_spectrum =
|
482
|
+
predicted_pk_spectrum = self.asarray(predicted_pk_spectrum_z)
|
473
483
|
|
474
484
|
|
475
485
|
pk = 10.**predicted_pk_spectrum
|
476
486
|
|
477
487
|
pk_re = pk*self.pk_power_fac
|
478
|
-
pk_re =
|
488
|
+
pk_re = self.transpose(pk_re)
|
479
489
|
|
480
490
|
# print(">>> pk_re:",pk_re)
|
481
491
|
# import sys
|
482
492
|
# sys.exit(0)
|
483
493
|
|
484
|
-
self.
|
494
|
+
if self.jax_mode:
|
495
|
+
self.pkl_interp = None
|
496
|
+
else:
|
497
|
+
self.pkl_interp = PowerSpectrumInterpolator(z_arr,k_arr,np.log(pk_re).T,logP=True)
|
485
498
|
|
486
499
|
self.cszfast_pk_grid_pk = pk_re
|
487
500
|
self.cszfast_pk_grid_pkl_flat = pk_re.flatten()
|
@@ -568,7 +581,7 @@ class Class_szfast(object):
|
|
568
581
|
s8z = self.cp_s8_nn[self.cosmo_model].predictions_np(params_dict)
|
569
582
|
# print(self.s8z)
|
570
583
|
self.s8z_interp = scipy.interpolate.interp1d(
|
571
|
-
|
584
|
+
self.linspace(0.,20.,5000),
|
572
585
|
s8z[0],
|
573
586
|
kind='linear',
|
574
587
|
axis=-1,
|
@@ -607,13 +620,13 @@ class Class_szfast(object):
|
|
607
620
|
params_dict_pp['z_pk_save_nonclass'] = [zp]
|
608
621
|
predicted_pk_spectrum_z.append(self.cp_pknl_nn[self.cosmo_model].predictions_np(params_dict_pp)[0])
|
609
622
|
|
610
|
-
predicted_pk_spectrum =
|
623
|
+
predicted_pk_spectrum = self.asarray(predicted_pk_spectrum_z)
|
611
624
|
|
612
625
|
|
613
626
|
pk = 10.**predicted_pk_spectrum
|
614
627
|
|
615
628
|
pk_re = pk*self.pk_power_fac
|
616
|
-
pk_re =
|
629
|
+
pk_re = self.transpose(pk_re)
|
617
630
|
|
618
631
|
|
619
632
|
self.pknl_interp = PowerSpectrumInterpolator(z_arr,k_arr,np.log(pk_re).T,logP=True)
|
@@ -629,16 +642,12 @@ class Class_szfast(object):
|
|
629
642
|
z_asked,
|
630
643
|
params_values_dict=None):
|
631
644
|
|
632
|
-
z_arr = self.cszfast_pk_grid_z
|
633
645
|
|
634
646
|
k_arr = self.cszfast_pk_grid_k
|
635
647
|
|
636
648
|
if params_values_dict:
|
637
|
-
|
638
649
|
params_values = params_values_dict.copy()
|
639
|
-
|
640
650
|
else:
|
641
|
-
|
642
651
|
params_values = self.params_for_emulators
|
643
652
|
|
644
653
|
update_params_with_defaults(params_values, self.emulator_dict[self.cosmo_model]['default'])
|
@@ -655,19 +664,21 @@ class Class_szfast(object):
|
|
655
664
|
|
656
665
|
predicted_pk_spectrum_z = []
|
657
666
|
|
658
|
-
z_asked = z_asked
|
659
667
|
params_dict_pp = params_dict.copy()
|
660
668
|
update_params_with_defaults(params_dict_pp, self.emulator_dict[self.cosmo_model]['default'])
|
661
669
|
|
662
670
|
params_dict_pp['z_pk_save_nonclass'] = [z_asked]
|
663
|
-
|
671
|
+
if self.jax_mode:
|
672
|
+
predicted_pk_spectrum_z.append(self.cp_pkl_nn[self.cosmo_model].predict(params_dict_pp))
|
673
|
+
else:
|
674
|
+
predicted_pk_spectrum_z.append(self.cp_pkl_nn[self.cosmo_model].predictions_np(params_dict_pp)[0])
|
664
675
|
|
665
|
-
predicted_pk_spectrum =
|
676
|
+
predicted_pk_spectrum = self.asarray(predicted_pk_spectrum_z)
|
666
677
|
|
667
678
|
|
668
679
|
pk = 10.**predicted_pk_spectrum
|
669
680
|
pk_re = pk*self.pk_power_fac
|
670
|
-
pk_re =
|
681
|
+
pk_re = self.transpose(pk_re)
|
671
682
|
|
672
683
|
|
673
684
|
return pk_re, k_arr
|
@@ -710,12 +721,12 @@ class Class_szfast(object):
|
|
710
721
|
params_dict_pp['z_pk_save_nonclass'] = [z_asked]
|
711
722
|
predicted_pk_spectrum_z.append(self.cp_pknl_nn[self.cosmo_model].predictions_np(params_dict_pp)[0])
|
712
723
|
|
713
|
-
predicted_pk_spectrum =
|
724
|
+
predicted_pk_spectrum = self.asarray(predicted_pk_spectrum_z)
|
714
725
|
|
715
726
|
|
716
727
|
pk = 10.**predicted_pk_spectrum
|
717
728
|
pk_re = pk*self.pk_power_fac
|
718
|
-
pk_re =
|
729
|
+
pk_re = self.transpose(pk_re)
|
719
730
|
|
720
731
|
|
721
732
|
return pk_re, k_arr
|
@@ -746,19 +757,10 @@ class Class_szfast(object):
|
|
746
757
|
# print("self.cp_predicted_hubble type:", type(self.cp_predicted_hubble))
|
747
758
|
# print("self.cp_predicted_hubble",self.cp_predicted_hubble)
|
748
759
|
|
749
|
-
# self.hz_interp = jscipy.interpolate.interp1d(
|
750
|
-
# self.cp_z_interp_jax,
|
751
|
-
# self.cp_predicted_hubble,
|
752
|
-
# kind='linear',
|
753
|
-
# axis=-1,
|
754
|
-
# copy=True,
|
755
|
-
# bounds_error=None,
|
756
|
-
# fill_value=np.nan,
|
757
|
-
# assume_sorted=False)
|
758
760
|
|
759
761
|
# Assuming `cp_z_interp` and `cp_predicted_hubble` are JAX arrays
|
760
762
|
def hz_interp(x):
|
761
|
-
return jnp.interp(x, self.
|
763
|
+
return jnp.interp(x, self.cp_z_interp, self.cp_predicted_hubble, left=jnp.nan, right=jnp.nan)
|
762
764
|
|
763
765
|
self.hz_interp = hz_interp
|
764
766
|
# exit()
|
@@ -785,45 +787,60 @@ class Class_szfast(object):
|
|
785
787
|
update_params_with_defaults(params_values, self.emulator_dict[self.cosmo_model]['default'])
|
786
788
|
|
787
789
|
params_dict = {}
|
788
|
-
|
789
790
|
for k,v in zip(params_values.keys(),params_values.values()):
|
790
|
-
|
791
791
|
params_dict[k]=[v]
|
792
792
|
|
793
793
|
if 'm_ncdm' in params_dict.keys():
|
794
794
|
if isinstance(params_dict['m_ncdm'][0],str):
|
795
795
|
params_dict['m_ncdm'] = [float(params_dict['m_ncdm'][0].split(',')[0])]
|
796
796
|
|
797
|
-
|
798
|
-
|
797
|
+
if self.jax_mode:
|
798
|
+
# print("JAX MODE in chi")
|
799
|
+
# self.cp_da_nn[self.cosmo_model].log = False
|
800
|
+
self.cp_predicted_da = self.cp_da_nn[self.cosmo_model].predict(params_dict)
|
801
|
+
# print("self.cp_predicted_da",self.cp_predicted_da)
|
802
|
+
|
803
|
+
if self.cosmo_model == 'ede-v2':
|
804
|
+
# print('ede-v2 case')
|
805
|
+
self.cp_predicted_da = jnp.insert(self.cp_predicted_da, 0, 0)
|
806
|
+
self.cp_predicted_da *= (1.+self.cp_z_interp)
|
807
|
+
|
808
|
+
def chi_interp(x):
|
809
|
+
return jnp.interp(x, self.cp_z_interp, self.cp_predicted_da, left=jnp.nan, right=jnp.nan)
|
810
|
+
|
811
|
+
self.chi_interp = chi_interp
|
799
812
|
|
800
|
-
self.cp_predicted_da = self.cp_da_nn[self.cosmo_model].ten_to_predictions_np(params_dict)[0]
|
801
|
-
self.cp_predicted_da = np.insert(self.cp_predicted_da, 0, 0)
|
802
|
-
|
803
813
|
else:
|
804
|
-
|
805
|
-
|
806
|
-
|
814
|
+
# deal with different scaling of DA in different model from emulator training
|
815
|
+
if self.cosmo_model == 'ede-v2':
|
807
816
|
|
808
|
-
|
809
|
-
|
810
|
-
|
811
|
-
|
812
|
-
|
813
|
-
|
814
|
-
|
815
|
-
|
816
|
-
|
817
|
+
self.cp_predicted_da = self.cp_da_nn[self.cosmo_model].ten_to_predictions_np(params_dict)[0]
|
818
|
+
self.cp_predicted_da = np.insert(self.cp_predicted_da, 0, 0)
|
819
|
+
|
820
|
+
else:
|
821
|
+
|
822
|
+
self.cp_predicted_da = self.cp_da_nn[self.cosmo_model].predictions_np(params_dict)[0]
|
823
|
+
|
824
|
+
|
825
|
+
self.chi_interp = scipy.interpolate.interp1d(
|
826
|
+
self.cp_z_interp,
|
827
|
+
self.cp_predicted_da*(1.+self.cp_z_interp),
|
828
|
+
kind='linear',
|
829
|
+
axis=-1,
|
830
|
+
copy=True,
|
831
|
+
bounds_error=None,
|
832
|
+
fill_value=np.nan,
|
833
|
+
assume_sorted=False)
|
817
834
|
|
818
835
|
def get_cmb_cls(self,ell_factor=True,Tcmb_uk = Tcmb_uk):
|
819
836
|
|
820
837
|
cls = {}
|
821
|
-
cls['ell'] =
|
822
|
-
cls['tt'] =
|
823
|
-
cls['te'] =
|
824
|
-
cls['ee'] =
|
825
|
-
cls['pp'] =
|
826
|
-
cls['bb'] =
|
838
|
+
cls['ell'] = self.arange(self.cszfast_ldim)
|
839
|
+
cls['tt'] = self.zeros(self.cszfast_ldim)
|
840
|
+
cls['te'] = self.zeros(self.cszfast_ldim)
|
841
|
+
cls['ee'] = self.zeros(self.cszfast_ldim)
|
842
|
+
cls['pp'] = self.zeros(self.cszfast_ldim)
|
843
|
+
cls['bb'] = self.zeros(self.cszfast_ldim)
|
827
844
|
cls['tt'][2:self.cp_lmax+1] = (Tcmb_uk)**2.*self.cp_predicted_tt_spectrum.copy()
|
828
845
|
cls['te'][2:self.cp_lmax+1] = (Tcmb_uk)**2.*self.cp_predicted_te_spectrum.copy()
|
829
846
|
cls['ee'][2:self.cp_lmax+1] = (Tcmb_uk)**2.*self.cp_predicted_ee_spectrum.copy()
|
@@ -831,12 +848,12 @@ class Class_szfast(object):
|
|
831
848
|
|
832
849
|
|
833
850
|
if ell_factor==False:
|
834
|
-
fac_l =
|
835
|
-
fac_l[2:self.cp_lmax+1] = 1./(cls['ell'][2:self.cp_lmax+1]*(cls['ell'][2:self.cp_lmax+1]+1.)/2./
|
851
|
+
fac_l = self.zeros(self.cszfast_ldim)
|
852
|
+
fac_l[2:self.cp_lmax+1] = 1./(cls['ell'][2:self.cp_lmax+1]*(cls['ell'][2:self.cp_lmax+1]+1.)/2./self.pi)
|
836
853
|
cls['tt'][2:self.cp_lmax+1] *= fac_l[2:self.cp_lmax+1]
|
837
854
|
cls['te'][2:self.cp_lmax+1] *= fac_l[2:self.cp_lmax+1]
|
838
855
|
cls['ee'][2:self.cp_lmax+1] *= fac_l[2:self.cp_lmax+1]
|
839
|
-
# cls['bb'] =
|
856
|
+
# cls['bb'] = self.zeros(self.cszfast_ldim)
|
840
857
|
return cls
|
841
858
|
|
842
859
|
|
@@ -882,7 +899,10 @@ class Class_szfast(object):
|
|
882
899
|
return np.array(self.hz_interp(z)*H_units_conv_factor[units])
|
883
900
|
|
884
901
|
def get_chi(self, z):
|
885
|
-
|
902
|
+
if self.jax_mode:
|
903
|
+
return jnp.array(self.chi_interp(z))
|
904
|
+
else:
|
905
|
+
return np.array(self.chi_interp(z))
|
886
906
|
|
887
907
|
def get_gas_pressure_profile_x(self,z,m,x):
|
888
908
|
return 0#np.vectorize(self.csz_base.get_pressure_P_over_P_delta_at_x_M_z_b12_200c)(x,m,z)
|
@@ -896,7 +916,7 @@ class Class_szfast(object):
|
|
896
916
|
|
897
917
|
|
898
918
|
def tabulate_gas_pressure_profile_k(self):
|
899
|
-
z_asked,m_asked,x_asked = 0.2,3e14,
|
919
|
+
z_asked,m_asked,x_asked = 0.2,3e14,self.geomspace(1e-3,1e2,500)
|
900
920
|
start = time.time()
|
901
921
|
px = self.get_gas_pressure_profile_x(z_asked,m_asked,x_asked)
|
902
922
|
end = time.time()
|
classy_szfast/cosmopower_jax.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
from .config import path_to_class_sz_data
|
2
2
|
import numpy as np
|
3
|
+
import jax.numpy as jnp
|
3
4
|
from .restore_nn import Restore_NN
|
4
5
|
from .restore_nn import Restore_PCAplusNN
|
5
6
|
from .suppress_warnings import suppress_warnings
|
@@ -20,6 +21,74 @@ cp_h_nn_jax = {}
|
|
20
21
|
cp_s8_nn_jax = {}
|
21
22
|
|
22
23
|
|
24
|
+
class CosmoPowerJAX_custom(CPJ):
|
25
|
+
def __init__(self, *args, **kwargs):
|
26
|
+
super().__init__(*args, **kwargs)
|
27
|
+
self.ten_to_predictions = True
|
28
|
+
if 'ten_to_predictions' in kwargs.keys():
|
29
|
+
self.ten_to_predictions = kwargs['ten_to_predictions']
|
30
|
+
|
31
|
+
def _predict(self, weights, hyper_params, param_train_mean, param_train_std,
|
32
|
+
feature_train_mean, feature_train_std, input_vec):
|
33
|
+
""" Forward pass through pre-trained network.
|
34
|
+
In its current form, it does not make use of high-level frameworks like
|
35
|
+
FLAX et similia; rather, it simply loops over the network layers.
|
36
|
+
In future work this can be improved, especially if speed is a problem.
|
37
|
+
|
38
|
+
Parameters
|
39
|
+
----------
|
40
|
+
weights : array
|
41
|
+
The stored weights of the neural network.
|
42
|
+
hyper_params : array
|
43
|
+
The stored hyperparameters of the activation function for each layer.
|
44
|
+
param_train_mean : array
|
45
|
+
The stored mean of the training cosmological parameters.
|
46
|
+
param_train_std : array
|
47
|
+
The stored standard deviation of the training cosmological parameters.
|
48
|
+
feature_train_mean : array
|
49
|
+
The stored mean of the training features.
|
50
|
+
feature_train_std : array
|
51
|
+
The stored standard deviation of the training features.
|
52
|
+
input_vec : array of shape (n_samples, n_parameters) or (n_parameters)
|
53
|
+
The cosmological parameters given as input to the network.
|
54
|
+
|
55
|
+
Returns
|
56
|
+
-------
|
57
|
+
predictions : array
|
58
|
+
The prediction of the trained neural network.
|
59
|
+
"""
|
60
|
+
act = []
|
61
|
+
# Standardise
|
62
|
+
layer_out = [(input_vec - param_train_mean)/param_train_std]
|
63
|
+
|
64
|
+
# Loop over layers
|
65
|
+
for i in range(len(weights[:-1])):
|
66
|
+
w, b = weights[i]
|
67
|
+
alpha, beta = hyper_params[i]
|
68
|
+
act.append(jnp.dot(layer_out[-1], w.T) + b)
|
69
|
+
layer_out.append(self._activation(act[-1], alpha, beta))
|
70
|
+
|
71
|
+
# Final layer prediction (no activations)
|
72
|
+
w, b = weights[-1]
|
73
|
+
if self.probe == 'custom_log' or self.probe == 'custom_pca':
|
74
|
+
# in original CP models, we assumed a full final bias vector...
|
75
|
+
preds = jnp.dot(layer_out[-1], w.T) + b
|
76
|
+
else:
|
77
|
+
# ... unlike in cpjax, where we used only a single bias vector
|
78
|
+
preds = jnp.dot(layer_out[-1], w.T) + b[-1]
|
79
|
+
|
80
|
+
# Undo the standardisation
|
81
|
+
preds = preds * feature_train_std + feature_train_mean
|
82
|
+
if self.log == True:
|
83
|
+
if self.ten_to_predictions:
|
84
|
+
preds = 10**preds
|
85
|
+
else:
|
86
|
+
preds = (preds@self.pca_matrix)*self.training_std + self.training_mean
|
87
|
+
if self.probe == 'cmb_pp':
|
88
|
+
preds = 10**preds
|
89
|
+
predictions = preds.squeeze()
|
90
|
+
return predictions
|
91
|
+
|
23
92
|
for mp in cosmo_model_list:
|
24
93
|
folder, version = split_emulator_string(mp)
|
25
94
|
# print(folder, version)
|
@@ -36,18 +105,24 @@ for mp in cosmo_model_list:
|
|
36
105
|
|
37
106
|
cp_pknl_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'PK/' + emulator_dict[mp]['PKNL'])
|
38
107
|
|
39
|
-
cp_pkl_nn_jax[mp] =
|
40
|
-
|
108
|
+
cp_pkl_nn_jax[mp] = CosmoPowerJAX_custom(probe='custom_log',filepath=path_to_emulators +'PK/' + emulator_dict[mp]['PKL'] + '.npz')
|
109
|
+
cp_pkl_nn_jax[mp].ten_to_predictions = False
|
110
|
+
|
41
111
|
cp_der_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'derived-parameters/' + emulator_dict[mp]['DER'])
|
42
112
|
|
43
|
-
cp_da_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['DAZ'])
|
113
|
+
# cp_da_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['DAZ'])
|
114
|
+
|
44
115
|
|
116
|
+
cp_da_nn_jax[mp] = CosmoPowerJAX_custom(probe='custom_log',filepath=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['DAZ'] + '.npz')
|
117
|
+
if mp != 'ede-v2':
|
118
|
+
cp_da_nn_jax[mp].ten_to_predictions = False
|
119
|
+
# print(cp_da_nn_jax[mp].parameters)
|
45
120
|
# print(path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'])
|
46
|
-
emulator_custom = CPJ(probe='custom_log',filepath=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'] + '.npz')
|
121
|
+
# emulator_custom = CPJ(probe='custom_log',filepath=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'] + '.npz')
|
47
122
|
# print(emulator_custom.parameters)
|
48
123
|
# exit()
|
49
124
|
|
50
|
-
cp_h_nn_jax[mp] =
|
125
|
+
cp_h_nn_jax[mp] = CosmoPowerJAX_custom(probe='custom_log',filepath=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'] + '.npz')
|
51
126
|
|
52
127
|
cp_s8_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['S8Z'])
|
53
128
|
|
@@ -1,11 +1,11 @@
|
|
1
1
|
classy_szfast/__init__.py,sha256=E2thrL0Z9oXFfdzwcsu-xbOytudLFTlRlPqVFGlPPPg,279
|
2
2
|
classy_szfast/classy_sz.py,sha256=QmbwrSXInQLMvCDqsr7KPmtaU0KOiOt1Rb-cTKuulZw,22240
|
3
|
-
classy_szfast/classy_szfast.py,sha256=
|
3
|
+
classy_szfast/classy_szfast.py,sha256=H3DPD4_ZvVJ_FCVbEzc4rW8DYQHPZCCaGQZtQ1f-7oQ,37646
|
4
4
|
classy_szfast/config.py,sha256=cd7Z62-qnX_4FJWfUNqcyJVh-AdBiXrF8DcQGpyAUZM,274
|
5
5
|
classy_szfast/cosmopower.py,sha256=ooYK2BDOZSo3XtGHfPtjXHxr5UW-yVngLPkb5gpvTx8,2351
|
6
|
-
classy_szfast/cosmopower_jax.py,sha256=
|
6
|
+
classy_szfast/cosmopower_jax.py,sha256=NvwY1a9YoizOV0zKp0SEtiHQkhWRO-seGFzO3FIzLl0,5436
|
7
7
|
classy_szfast/cosmosis_classy_szfast_interface.py,sha256=zAnxvFtn73a5yS7jgs59zpWFEYKCIQyraYPs5hQ4Le8,11483
|
8
|
-
classy_szfast/emulators_meta_data.py,sha256
|
8
|
+
classy_szfast/emulators_meta_data.py,sha256=faA0iQqzfC5lOj1hu7wSEZUFfhBGLJqX9-xovMfTbr0,9723
|
9
9
|
classy_szfast/pks_and_sigmas.py,sha256=drtuujE1HhlrYY1hY92DyY5lXlYS1uE15MSuVI4uo6k,6625
|
10
10
|
classy_szfast/restore_nn.py,sha256=DqA9thhTRiGBDVb9zjhqcbF2W4V0AU0vrjJFhnLboU4,21075
|
11
11
|
classy_szfast/suppress_warnings.py,sha256=6wIBml2Sj9DyRGZlZWhuA9hqvpxqrNyYjuz6BPK_a6E,202
|
@@ -14,7 +14,7 @@ classy_szfast/custom_bias/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJW
|
|
14
14
|
classy_szfast/custom_bias/custom_bias.py,sha256=aR2t5RTIwv7P0m2bsEU0Eq6BTkj4pG10AebH6QpG4qM,486
|
15
15
|
classy_szfast/custom_profiles/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
16
16
|
classy_szfast/custom_profiles/custom_profiles.py,sha256=4LZwb2XoqwCyWNmW2s24Z7AJdmgVdaRG7yYaBYe-d9Q,1188
|
17
|
-
classy_szfast-0.0.25.
|
18
|
-
classy_szfast-0.0.25.
|
19
|
-
classy_szfast-0.0.25.
|
20
|
-
classy_szfast-0.0.25.
|
17
|
+
classy_szfast-0.0.25.post4.dist-info/METADATA,sha256=oNj9IJ8zELtcPCEOOi4wQzx9SfxVaMmdIl6QjyAFdkY,548
|
18
|
+
classy_szfast-0.0.25.post4.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
19
|
+
classy_szfast-0.0.25.post4.dist-info/top_level.txt,sha256=hRgqpilUck4lx2KkaWI2y9aCDKqF6pFfGHfNaoPFxv0,14
|
20
|
+
classy_szfast-0.0.25.post4.dist-info/RECORD,,
|
File without changes
|
File without changes
|