classy-szfast 0.0.25.post3__tar.gz → 0.0.25.post5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post5}/PKG-INFO +1 -1
  2. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post5}/classy_szfast/classy_szfast.py +166 -130
  3. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post5}/classy_szfast/cosmopower_jax.py +3 -2
  4. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post5}/classy_szfast/emulators_meta_data.py +19 -7
  5. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post5}/classy_szfast/utils.py +10 -0
  6. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post5}/classy_szfast.egg-info/PKG-INFO +1 -1
  7. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post5}/pyproject.toml +1 -1
  8. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post5}/README.md +0 -0
  9. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post5}/classy_szfast/__init__.py +0 -0
  10. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post5}/classy_szfast/classy_sz.py +0 -0
  11. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post5}/classy_szfast/config.py +0 -0
  12. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post5}/classy_szfast/cosmopower.py +0 -0
  13. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post5}/classy_szfast/cosmosis_classy_szfast_interface.py +0 -0
  14. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post5}/classy_szfast/custom_bias/__init__.py +0 -0
  15. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post5}/classy_szfast/custom_bias/custom_bias.py +0 -0
  16. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post5}/classy_szfast/custom_profiles/__init__.py +0 -0
  17. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post5}/classy_szfast/custom_profiles/custom_profiles.py +0 -0
  18. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post5}/classy_szfast/pks_and_sigmas.py +0 -0
  19. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post5}/classy_szfast/restore_nn.py +0 -0
  20. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post5}/classy_szfast/suppress_warnings.py +0 -0
  21. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post5}/classy_szfast.egg-info/SOURCES.txt +0 -0
  22. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post5}/classy_szfast.egg-info/dependency_links.txt +0 -0
  23. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post5}/classy_szfast.egg-info/requires.txt +0 -0
  24. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post5}/classy_szfast.egg-info/top_level.txt +0 -0
  25. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post5}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: classy_szfast
3
- Version: 0.0.25.post3
3
+ Version: 0.0.25.post5
4
4
  Summary: The accelerator of the class_sz code from https://github.com/CLASS-SZ
5
5
  Maintainer-email: Boris Bolliet <bb667@cam.ac.uk>
6
6
  License: MIT
@@ -1,7 +1,8 @@
1
1
  from .utils import *
2
+ from .utils import Const
2
3
  from .config import *
3
4
  import numpy as np
4
- from .emulators_meta_data import emulator_dict, dofftlog_alphas, cp_l_max_scalars
5
+ from .emulators_meta_data import emulator_dict, dofftlog_alphas, cp_l_max_scalars, cosmo_model_list
5
6
  from .cosmopower import cp_tt_nn, cp_te_nn, cp_ee_nn, cp_ee_nn, cp_pp_nn, cp_pknl_nn, cp_pkl_nn, cp_der_nn, cp_da_nn, cp_h_nn, cp_s8_nn, cp_pkl_fftlog_alphas_real_nn, cp_pkl_fftlog_alphas_imag_nn, cp_pkl_fftlog_alphas_nus
6
7
  from .cosmopower_jax import cp_tt_nn_jax, cp_te_nn_jax, cp_ee_nn_jax, cp_ee_nn_jax, cp_pp_nn_jax, cp_pknl_nn_jax, cp_pkl_nn_jax, cp_der_nn_jax, cp_da_nn_jax, cp_h_nn_jax, cp_s8_nn_jax
7
8
  from .pks_and_sigmas import *
@@ -76,13 +77,10 @@ class Class_szfast(object):
76
77
 
77
78
  self.jax_mode = params_settings["jax"]
78
79
 
80
+
79
81
  # print(f"JAX mode: {self.jax_mode}")
80
82
 
81
83
 
82
-
83
-
84
-
85
-
86
84
  # cosmopower emulators
87
85
  # self.cp_path_to_cosmopower_organization = path_to_cosmopower_organization + '/'
88
86
  self.cp_tt_nn = cp_tt_nn
@@ -90,16 +88,80 @@ class Class_szfast(object):
90
88
  self.cp_ee_nn = cp_ee_nn
91
89
  self.cp_pp_nn = cp_pp_nn
92
90
  self.cp_pknl_nn = cp_pknl_nn
93
- self.cp_pkl_nn = cp_pkl_nn
94
91
  self.cp_der_nn = cp_der_nn
95
92
 
93
+ self.cp_lmax = cp_l_max_scalars
94
+
95
+ cosmo_model_dict = {i: model for i, model in enumerate(cosmo_model_list)}
96
+
97
+ if (cosmo_model_dict[params_settings['cosmo_model']] == 'ede-v2'):
98
+
99
+ self.cszfast_pk_grid_zmax = 20.
100
+ self.cszfast_pk_grid_kmin = 5e-4
101
+ self.cszfast_pk_grid_kmax = 10.
102
+ self.cp_kmax = self.cszfast_pk_grid_kmax
103
+ self.cp_kmin = self.cszfast_pk_grid_kmin
104
+ # self.logger.info(f">>> using kmin = {self.cp_kmin}")
105
+ # self.logger.info(f">>> using kmax = {self.cp_kmax}")
106
+ # self.logger.info(f">>> using zmax = {self.cszfast_pk_grid_zmax}")
107
+
108
+ else:
109
+
110
+ self.cszfast_pk_grid_zmax = 5. # max z of our pk emulators (sept 23)
111
+ self.cszfast_pk_grid_kmin = 1e-4
112
+ self.cszfast_pk_grid_kmax = 50.
113
+ self.cp_kmax = self.cszfast_pk_grid_kmax
114
+ self.cp_kmin = self.cszfast_pk_grid_kmin
115
+ # self.logger.info(f">>> using kmin = {self.cp_kmin}")
116
+ # self.logger.info(f">>> using kmax = {self.cp_kmax}")
117
+ # self.logger.info(f">>> using zmax = {self.cszfast_pk_grid_zmax}")
96
118
 
97
119
  if self.jax_mode:
98
120
  self.cp_h_nn = cp_h_nn_jax
99
121
  self.cp_da_nn = cp_da_nn_jax
122
+ self.cp_pkl_nn = cp_pkl_nn_jax
123
+
124
+ self.pi = jnp.pi
125
+ self.transpose = jnp.transpose
126
+ self.asarray = jnp.asarray
127
+ self.log = jnp.log
128
+ self.pow = jnp.power
129
+
130
+ self.sigma_B = 2. * self.pow(self.pi,5) * self.pow(Const._k_B_,4) / 15. / self.pow(Const._h_P_,3) / self.pow(Const._c_,2)
131
+
132
+ self.linspace = jnp.linspace
133
+ self.geomspace = jnp.geomspace
134
+ self.arange = jnp.arange
135
+ self.zeros = jnp.zeros
136
+ self.gradient = jnp.gradient
137
+
138
+
100
139
  else:
101
140
  self.cp_h_nn = cp_h_nn
102
141
  self.cp_da_nn = cp_da_nn
142
+ self.cp_pkl_nn = cp_pkl_nn
143
+
144
+ self.pi = np.pi
145
+ self.transpose = np.transpose
146
+ self.pow = np.power
147
+
148
+ self.sigma_B = 2. * self.pow(self.pi,5) * self.pow(Const._k_B_,4) / 15. / self.pow(Const._h_P_,3) / self.pow(Const._c_,2)
149
+
150
+
151
+ self.linspace = np.linspace
152
+ self.geomspace = np.geomspace
153
+ self.arange = np.arange
154
+ self.zeros = np.zeros
155
+ self.asarray = np.asarray
156
+ self.log = np.log
157
+ self.gradient = np.gradient
158
+
159
+
160
+ self.cp_ls = self.arange(2,self.cp_lmax+1)
161
+ self.cp_predicted_tt_spectrum =self.zeros(self.cp_lmax)
162
+ self.cp_predicted_te_spectrum =self.zeros(self.cp_lmax)
163
+ self.cp_predicted_ee_spectrum =self.zeros(self.cp_lmax)
164
+ self.cp_predicted_pp_spectrum =self.zeros(self.cp_lmax)
103
165
 
104
166
  self.cp_s8_nn = cp_s8_nn
105
167
 
@@ -114,20 +176,6 @@ class Class_szfast(object):
114
176
  self.use_Amod = 0
115
177
  self.Amod = 0
116
178
 
117
- self.cp_lmax = cp_l_max_scalars
118
- self.cp_ls = np.arange(2,self.cp_lmax+1)
119
-
120
-
121
-
122
-
123
- cosmo_model_dict = {0: 'lcdm',
124
- 1: 'mnu',
125
- 2: 'neff',
126
- 3: 'wcdm',
127
- 4: 'ede',
128
- 5: 'mnu-3states',
129
- 6: 'ede-v2'
130
- }
131
179
 
132
180
 
133
181
  if cosmo_model_dict[params_settings['cosmo_model']] == 'ede-v2':
@@ -140,60 +188,31 @@ class Class_szfast(object):
140
188
  self.cp_ndspl_k = 10
141
189
  self.cp_nk = 5000
142
190
 
143
- self.cp_predicted_tt_spectrum =np.zeros(self.cp_lmax)
144
- self.cp_predicted_te_spectrum =np.zeros(self.cp_lmax)
145
- self.cp_predicted_ee_spectrum =np.zeros(self.cp_lmax)
146
- self.cp_predicted_pp_spectrum =np.zeros(self.cp_lmax)
147
-
148
191
 
149
192
  self.cszfast_ldim = 20000 # used for the cls arrays
150
-
151
193
  self.cszfast_pk_grid_nz = 100 # has to be same as narraySZ, i.e., ndim_redshifts; it is setup hereafter if ndim_redshifts is passed
152
194
 
153
195
 
196
+ self.cszfast_pk_grid_z = self.linspace(0.,self.cszfast_pk_grid_zmax,self.cszfast_pk_grid_nz)
197
+ self.cszfast_pk_grid_ln1pz = self.log(1.+self.cszfast_pk_grid_z)
154
198
 
155
- if (cosmo_model_dict[params_settings['cosmo_model']] == 'ede-v2'):
156
-
157
- self.cszfast_pk_grid_zmax = 20.
158
- self.cszfast_pk_grid_kmin = 5e-4
159
- self.cszfast_pk_grid_kmax = 10.
160
- self.cp_kmax = self.cszfast_pk_grid_kmax
161
- self.cp_kmin = self.cszfast_pk_grid_kmin
162
- # self.logger.info(f">>> using kmin = {self.cp_kmin}")
163
- # self.logger.info(f">>> using kmax = {self.cp_kmax}")
164
- # self.logger.info(f">>> using zmax = {self.cszfast_pk_grid_zmax}")
165
-
166
- else:
167
-
168
- self.cszfast_pk_grid_zmax = 5. # max z of our pk emulators (sept 23)
169
- self.cszfast_pk_grid_kmin = 1e-4
170
- self.cszfast_pk_grid_kmax = 50.
171
- self.cp_kmax = self.cszfast_pk_grid_kmax
172
- self.cp_kmin = self.cszfast_pk_grid_kmin
173
- # self.logger.info(f">>> using kmin = {self.cp_kmin}")
174
- # self.logger.info(f">>> using kmax = {self.cp_kmax}")
175
- # self.logger.info(f">>> using zmax = {self.cszfast_pk_grid_zmax}")
176
199
 
177
- self.cszfast_pk_grid_z = np.linspace(0.,self.cszfast_pk_grid_zmax,self.cszfast_pk_grid_nz)
178
- self.cszfast_pk_grid_ln1pz = np.log(1.+self.cszfast_pk_grid_z)
179
-
180
-
181
- self.cszfast_pk_grid_k = np.geomspace(self.cp_kmin,self.cp_kmax,self.cp_nk)[::self.cp_ndspl_k]
200
+ self.cszfast_pk_grid_k = self.geomspace(self.cp_kmin,self.cp_kmax,self.cp_nk)[::self.cp_ndspl_k]
182
201
 
183
- self.cszfast_pk_grid_lnk = np.log(self.cszfast_pk_grid_k)
202
+ self.cszfast_pk_grid_lnk = self.log(self.cszfast_pk_grid_k)
184
203
 
185
- self.cszfast_pk_grid_nk = len(np.geomspace(self.cp_kmin,self.cp_kmax,self.cp_nk)[::self.cp_ndspl_k]) # has to be same as ndimSZ, and the same as dimension of cosmopower pk emulators
204
+ self.cszfast_pk_grid_nk = len(self.geomspace(self.cp_kmin,self.cp_kmax,self.cp_nk)[::self.cp_ndspl_k]) # has to be same as ndimSZ, and the same as dimension of cosmopower pk emulators
186
205
 
187
206
  for k,v in params_settings.items():
188
207
 
189
208
  if k == 'ndim_redshifts':
190
209
 
191
210
  self.cszfast_pk_grid_nz = v
192
- self.cszfast_pk_grid_z = np.linspace(0.,self.cszfast_pk_grid_zmax,self.cszfast_pk_grid_nz)
193
- self.cszfast_pk_grid_ln1pz = np.log(1.+self.cszfast_pk_grid_z)
211
+ self.cszfast_pk_grid_z = self.linspace(0.,self.cszfast_pk_grid_zmax,self.cszfast_pk_grid_nz)
212
+ self.cszfast_pk_grid_ln1pz = self.log(1.+self.cszfast_pk_grid_z)
194
213
 
195
- self.cszfast_pk_grid_pknl_flat = np.zeros(self.cszfast_pk_grid_nz*self.cszfast_pk_grid_nk)
196
- self.cszfast_pk_grid_pkl_flat = np.zeros(self.cszfast_pk_grid_nz*self.cszfast_pk_grid_nk)
214
+ self.cszfast_pk_grid_pknl_flat = self.zeros(self.cszfast_pk_grid_nz*self.cszfast_pk_grid_nk)
215
+ self.cszfast_pk_grid_pkl_flat = self.zeros(self.cszfast_pk_grid_nz*self.cszfast_pk_grid_nk)
197
216
 
198
217
  if k == 'cosmo_model':
199
218
 
@@ -212,13 +231,12 @@ class Class_szfast(object):
212
231
 
213
232
  else:
214
233
 
215
- ls = np.arange(2,self.cp_nk+2)[::self.cp_ndspl_k] # jan 10 ndspl
216
- dls = ls*(ls+1.)/2./np.pi
234
+ ls = self.arange(2,self.cp_nk+2)[::self.cp_ndspl_k] # jan 10 ndspl
235
+ dls = ls*(ls+1.)/2./self.pi
217
236
  self.pk_power_fac= (dls)**-1
218
237
 
219
238
 
220
- self.cp_z_interp = np.linspace(0.,20.,5000)
221
- self.cp_z_interp_jax = jnp.linspace(0.,20.,5000)
239
+ self.cp_z_interp = self.linspace(0.,20.,5000)
222
240
 
223
241
  self.csz_base = None
224
242
 
@@ -226,7 +244,7 @@ class Class_szfast(object):
226
244
  self.cszfast_zgrid_zmin = 0.
227
245
  self.cszfast_zgrid_zmax = 4.
228
246
  self.cszfast_zgrid_nz = 250
229
- self.cszfast_zgrid = np.linspace(self.cszfast_zgrid_zmin,
247
+ self.cszfast_zgrid = self.linspace(self.cszfast_zgrid_zmin,
230
248
  self.cszfast_zgrid_zmax,
231
249
  self.cszfast_zgrid_nz)
232
250
 
@@ -234,19 +252,41 @@ class Class_szfast(object):
234
252
  self.cszfast_mgrid_mmin = 1e10
235
253
  self.cszfast_mgrid_mmax = 1e15
236
254
  self.cszfast_mgrid_nm = 50
237
- self.cszfast_mgrid = np.geomspace(self.cszfast_mgrid_mmin,
255
+ self.cszfast_mgrid = self.geomspace(self.cszfast_mgrid_mmin,
238
256
  self.cszfast_mgrid_mmax,
239
257
  self.cszfast_mgrid_nm)
240
258
 
241
259
  self.cszfast_gas_pressure_xgrid_xmin = 1e-2
242
260
  self.cszfast_gas_pressure_xgrid_xmax = 1e2
243
261
  self.cszfast_gas_pressure_xgrid_nx = 100
244
- self.cszfast_gas_pressure_xgrid = np.geomspace(self.cszfast_gas_pressure_xgrid_xmin,
262
+ self.cszfast_gas_pressure_xgrid = self.geomspace(self.cszfast_gas_pressure_xgrid_xmin,
245
263
  self.cszfast_gas_pressure_xgrid_xmax,
246
264
  self.cszfast_gas_pressure_xgrid_nx)
247
265
 
248
266
  self.params_for_emulators = {}
249
267
 
268
+
269
+ def get_all_relevant_params(self,params_values_dict=None):
270
+ if params_values_dict:
271
+ params_values = params_values_dict.copy()
272
+ else:
273
+ params_values = self.params_for_emulators
274
+ update_params_with_defaults(params_values, self.emulator_dict[self.cosmo_model]['default'])
275
+ params_values['h'] = params_values['H0']/100.
276
+ params_values['Omega_b'] = params_values['omega_b']/params_values['h']**2.
277
+ params_values['Omega_cdm'] = params_values['omega_cdm']/params_values['h']**2.
278
+ params_values['Omega0_g'] = (4.*self.sigma_B/Const._c_*pow(params_values['T_cmb'],4.)) / (3.*Const._c_*Const._c_*1.e10*params_values['h']*params_values['h']/Const._Mpc_over_m_/Const._Mpc_over_m_/8./self.pi/Const._G_)
279
+ params_values['Omega0_ur'] = params_values['N_ur']*7./8.*self.pow(4./11.,4./3.)*params_values['Omega0_g']
280
+ params_values['Omega0_ncdm'] = params_values['deg_ncdm']*params_values['m_ncdm']/(93.14*params_values['h']*params_values['h']) ## valid only in standard cases, default T_ncdm etc
281
+ params_values['Omega_Lambda'] = 1. - params_values['Omega0_g'] - params_values['Omega_b'] - params_values['Omega_cdm'] - params_values['Omega0_ncdm'] - params_values['Omega0_ur']
282
+ params_values['Omega0_m'] = params_values['Omega_cdm'] + params_values['Omega_b'] + params_values['Omega0_ncdm']
283
+ params_values['Omega0_r'] = params_values['Omega0_ur']+params_values['Omega0_g']
284
+ params_values['Omega0_m_nonu'] = params_values['Omega0_m'] - params_values['Omega0_ncdm']
285
+ params_values['Omega0_cb'] = params_values['Omega0_m_nonu']
286
+ return params_values
287
+
288
+
289
+
250
290
  def find_As(self,params_cp):
251
291
 
252
292
  sigma_8_asked = params_cp["sigma8"]
@@ -312,7 +352,7 @@ class Class_szfast(object):
312
352
  creal = predicted_testing_alphas_creal
313
353
  cimag = predicted_testing_alphas_cimag
314
354
  Nmax = len(self.cszfast_pk_grid_k)
315
- cnew = np.zeros(Nmax+1,dtype=complex)
355
+ cnew = self.zeros(Nmax+1,dtype=complex)
316
356
  for i in range(Nmax+1):
317
357
  if i<int(Nmax/2):
318
358
  cnew[i] = complex(creal[i],cimag[i])
@@ -362,20 +402,20 @@ class Class_szfast(object):
362
402
 
363
403
  nl = len(self.cp_predicted_tt_spectrum)
364
404
  cls = {}
365
- cls['ell'] = np.arange(20000)
366
- cls['tt'] = np.zeros(20000)
367
- cls['te'] = np.zeros(20000)
368
- cls['ee'] = np.zeros(20000)
369
- cls['pp'] = np.zeros(20000)
370
- cls['bb'] = np.zeros(20000)
371
- lcp = np.asarray(cls['ell'][2:nl+2])
405
+ cls['ell'] = self.arange(20000)
406
+ cls['tt'] = self.zeros(20000)
407
+ cls['te'] = self.zeros(20000)
408
+ cls['ee'] = self.zeros(20000)
409
+ cls['pp'] = self.zeros(20000)
410
+ cls['bb'] = self.zeros(20000)
411
+ lcp = self.asarray(cls['ell'][2:nl+2])
372
412
 
373
413
  # print('cosmo_model:',self.cosmo_model,nl)
374
414
  if self.cosmo_model == 'ede-v2':
375
415
  factor_ttteee = 1./lcp**2
376
416
  factor_pp = 1./lcp**3
377
417
  else:
378
- factor_ttteee = 1./(lcp*(lcp+1.)/2./np.pi)
418
+ factor_ttteee = 1./(lcp*(lcp+1.)/2./self.pi)
379
419
  factor_pp = 1./(lcp*(lcp+1.))**2.
380
420
 
381
421
  self.cp_predicted_tt_spectrum *= factor_ttteee
@@ -402,7 +442,7 @@ class Class_szfast(object):
402
442
  cmb_cls_loaded = pickle.load(handle)
403
443
  nl_cls_file = len(cmb_cls_loaded['ell'])
404
444
  cls_ls = cmb_cls_loaded['ell']
405
- dlfac = cls_ls*(cls_ls+1.)/2./np.pi
445
+ dlfac = cls_ls*(cls_ls+1.)/2./self.pi
406
446
  cls_tt = cmb_cls_loaded['tt']*dlfac
407
447
  cls_te = cmb_cls_loaded['te']*dlfac
408
448
  cls_ee = cmb_cls_loaded['ee']*dlfac
@@ -421,8 +461,6 @@ class Class_szfast(object):
421
461
  **params_values_dict):
422
462
 
423
463
  z_arr = self.cszfast_pk_grid_z
424
-
425
-
426
464
  k_arr = self.cszfast_pk_grid_k
427
465
 
428
466
  # print(">>> z_arr:",z_arr)
@@ -464,26 +502,32 @@ class Class_szfast(object):
464
502
 
465
503
  params_dict_pp = params_dict.copy()
466
504
  params_dict_pp['z_pk_save_nonclass'] = [zp]
467
- predicted_pk_spectrum_z.append(self.cp_pkl_nn[self.cosmo_model].predictions_np(params_dict_pp)[0])
505
+ if self.jax_mode:
506
+ predicted_pk_spectrum_z.append(self.cp_pkl_nn[self.cosmo_model].predict(params_dict_pp))
507
+ else:
508
+ predicted_pk_spectrum_z.append(self.cp_pkl_nn[self.cosmo_model].predictions_np(params_dict_pp)[0])
468
509
 
469
510
  # if abs(zp-0.5) < 0.01:
470
511
  # print(">>> predicted_pk_spectrum_z:",predicted_pk_spectrum_z[-1])
471
512
  # import pprint
472
513
  # pprint.pprint(params_dict_pp)
473
514
 
474
- predicted_pk_spectrum = np.asarray(predicted_pk_spectrum_z)
515
+ predicted_pk_spectrum = self.asarray(predicted_pk_spectrum_z)
475
516
 
476
517
 
477
518
  pk = 10.**predicted_pk_spectrum
478
519
 
479
520
  pk_re = pk*self.pk_power_fac
480
- pk_re = np.transpose(pk_re)
521
+ pk_re = self.transpose(pk_re)
481
522
 
482
523
  # print(">>> pk_re:",pk_re)
483
524
  # import sys
484
525
  # sys.exit(0)
485
526
 
486
- self.pkl_interp = PowerSpectrumInterpolator(z_arr,k_arr,np.log(pk_re).T,logP=True)
527
+ if self.jax_mode:
528
+ self.pkl_interp = None
529
+ else:
530
+ self.pkl_interp = PowerSpectrumInterpolator(z_arr,k_arr,self.log(pk_re).T,logP=True)
487
531
 
488
532
  self.cszfast_pk_grid_pk = pk_re
489
533
  self.cszfast_pk_grid_pkl_flat = pk_re.flatten()
@@ -491,9 +535,7 @@ class Class_szfast(object):
491
535
  return pk_re, k_arr, z_arr
492
536
 
493
537
 
494
- def calculate_sigma(self,
495
-
496
- **params_values_dict):
538
+ def calculate_sigma(self,**params_values_dict):
497
539
 
498
540
  params_values = params_values_dict.copy()
499
541
 
@@ -509,7 +551,7 @@ class Class_szfast(object):
509
551
 
510
552
  R, var[:,iz] = TophatVar(k, lowring=True)(P[:,iz], extrap=True)
511
553
 
512
- dvar[:,iz] = np.gradient(var[:,iz], R)
554
+ dvar[:,iz] = self.gradient(var[:,iz], R)
513
555
 
514
556
  # print(k)
515
557
  # print(R)
@@ -517,11 +559,11 @@ class Class_szfast(object):
517
559
  # exit(0)
518
560
 
519
561
 
520
- self.cszfast_pk_grid_lnr = np.log(R)
562
+ self.cszfast_pk_grid_lnr = self.log(R)
521
563
  self.cszfast_pk_grid_sigma2 = var
522
564
 
523
565
  self.cszfast_pk_grid_sigma2_flat = var.flatten()
524
- self.cszfast_pk_grid_lnsigma2_flat = 0.5*np.log(var.flatten())
566
+ self.cszfast_pk_grid_lnsigma2_flat = 0.5*self.log(var.flatten())
525
567
 
526
568
  self.cszfast_pk_grid_dsigma2 = dvar
527
569
  self.cszfast_pk_grid_dsigma2_flat = dvar.flatten()
@@ -570,7 +612,7 @@ class Class_szfast(object):
570
612
  s8z = self.cp_s8_nn[self.cosmo_model].predictions_np(params_dict)
571
613
  # print(self.s8z)
572
614
  self.s8z_interp = scipy.interpolate.interp1d(
573
- np.linspace(0.,20.,5000),
615
+ self.linspace(0.,20.,5000),
574
616
  s8z[0],
575
617
  kind='linear',
576
618
  axis=-1,
@@ -609,16 +651,16 @@ class Class_szfast(object):
609
651
  params_dict_pp['z_pk_save_nonclass'] = [zp]
610
652
  predicted_pk_spectrum_z.append(self.cp_pknl_nn[self.cosmo_model].predictions_np(params_dict_pp)[0])
611
653
 
612
- predicted_pk_spectrum = np.asarray(predicted_pk_spectrum_z)
654
+ predicted_pk_spectrum = self.asarray(predicted_pk_spectrum_z)
613
655
 
614
656
 
615
657
  pk = 10.**predicted_pk_spectrum
616
658
 
617
659
  pk_re = pk*self.pk_power_fac
618
- pk_re = np.transpose(pk_re)
660
+ pk_re = self.transpose(pk_re)
619
661
 
620
662
 
621
- self.pknl_interp = PowerSpectrumInterpolator(z_arr,k_arr,np.log(pk_re).T,logP=True)
663
+ self.pknl_interp = PowerSpectrumInterpolator(z_arr,k_arr,self.log(pk_re).T,logP=True)
622
664
 
623
665
 
624
666
  self.cszfast_pk_grid_pknl = pk_re
@@ -627,20 +669,18 @@ class Class_szfast(object):
627
669
  return pk_re, k_arr, z_arr
628
670
 
629
671
 
672
+
673
+
630
674
  def calculate_pkl_at_z(self,
631
675
  z_asked,
632
676
  params_values_dict=None):
633
677
 
634
- z_arr = self.cszfast_pk_grid_z
635
678
 
636
679
  k_arr = self.cszfast_pk_grid_k
637
680
 
638
681
  if params_values_dict:
639
-
640
682
  params_values = params_values_dict.copy()
641
-
642
683
  else:
643
-
644
684
  params_values = self.params_for_emulators
645
685
 
646
686
  update_params_with_defaults(params_values, self.emulator_dict[self.cosmo_model]['default'])
@@ -657,19 +697,21 @@ class Class_szfast(object):
657
697
 
658
698
  predicted_pk_spectrum_z = []
659
699
 
660
- z_asked = z_asked
661
700
  params_dict_pp = params_dict.copy()
662
701
  update_params_with_defaults(params_dict_pp, self.emulator_dict[self.cosmo_model]['default'])
663
702
 
664
703
  params_dict_pp['z_pk_save_nonclass'] = [z_asked]
665
- predicted_pk_spectrum_z.append(self.cp_pkl_nn[self.cosmo_model].predictions_np(params_dict_pp)[0])
704
+ if self.jax_mode:
705
+ predicted_pk_spectrum_z.append(self.cp_pkl_nn[self.cosmo_model].predict(params_dict_pp))
706
+ else:
707
+ predicted_pk_spectrum_z.append(self.cp_pkl_nn[self.cosmo_model].predictions_np(params_dict_pp)[0])
666
708
 
667
- predicted_pk_spectrum = np.asarray(predicted_pk_spectrum_z)
709
+ predicted_pk_spectrum = self.asarray(predicted_pk_spectrum_z)
668
710
 
669
711
 
670
712
  pk = 10.**predicted_pk_spectrum
671
713
  pk_re = pk*self.pk_power_fac
672
- pk_re = np.transpose(pk_re)
714
+ pk_re = self.transpose(pk_re)
673
715
 
674
716
 
675
717
  return pk_re, k_arr
@@ -712,12 +754,12 @@ class Class_szfast(object):
712
754
  params_dict_pp['z_pk_save_nonclass'] = [z_asked]
713
755
  predicted_pk_spectrum_z.append(self.cp_pknl_nn[self.cosmo_model].predictions_np(params_dict_pp)[0])
714
756
 
715
- predicted_pk_spectrum = np.asarray(predicted_pk_spectrum_z)
757
+ predicted_pk_spectrum = self.asarray(predicted_pk_spectrum_z)
716
758
 
717
759
 
718
760
  pk = 10.**predicted_pk_spectrum
719
761
  pk_re = pk*self.pk_power_fac
720
- pk_re = np.transpose(pk_re)
762
+ pk_re = self.transpose(pk_re)
721
763
 
722
764
 
723
765
  return pk_re, k_arr
@@ -748,19 +790,10 @@ class Class_szfast(object):
748
790
  # print("self.cp_predicted_hubble type:", type(self.cp_predicted_hubble))
749
791
  # print("self.cp_predicted_hubble",self.cp_predicted_hubble)
750
792
 
751
- # self.hz_interp = jscipy.interpolate.interp1d(
752
- # self.cp_z_interp_jax,
753
- # self.cp_predicted_hubble,
754
- # kind='linear',
755
- # axis=-1,
756
- # copy=True,
757
- # bounds_error=None,
758
- # fill_value=np.nan,
759
- # assume_sorted=False)
760
793
 
761
794
  # Assuming `cp_z_interp` and `cp_predicted_hubble` are JAX arrays
762
795
  def hz_interp(x):
763
- return jnp.interp(x, self.cp_z_interp_jax, self.cp_predicted_hubble, left=jnp.nan, right=jnp.nan)
796
+ return jnp.interp(x, self.cp_z_interp, self.cp_predicted_hubble, left=jnp.nan, right=jnp.nan)
764
797
 
765
798
  self.hz_interp = hz_interp
766
799
  # exit()
@@ -803,10 +836,10 @@ class Class_szfast(object):
803
836
  if self.cosmo_model == 'ede-v2':
804
837
  # print('ede-v2 case')
805
838
  self.cp_predicted_da = jnp.insert(self.cp_predicted_da, 0, 0)
806
- self.cp_predicted_da *= (1.+self.cp_z_interp_jax)
839
+ self.cp_predicted_da *= (1.+self.cp_z_interp)
807
840
 
808
841
  def chi_interp(x):
809
- return jnp.interp(x, self.cp_z_interp_jax, self.cp_predicted_da, left=jnp.nan, right=jnp.nan)
842
+ return jnp.interp(x, self.cp_z_interp, self.cp_predicted_da, left=jnp.nan, right=jnp.nan)
810
843
 
811
844
  self.chi_interp = chi_interp
812
845
 
@@ -835,12 +868,12 @@ class Class_szfast(object):
835
868
  def get_cmb_cls(self,ell_factor=True,Tcmb_uk = Tcmb_uk):
836
869
 
837
870
  cls = {}
838
- cls['ell'] = np.arange(self.cszfast_ldim)
839
- cls['tt'] = np.zeros(self.cszfast_ldim)
840
- cls['te'] = np.zeros(self.cszfast_ldim)
841
- cls['ee'] = np.zeros(self.cszfast_ldim)
842
- cls['pp'] = np.zeros(self.cszfast_ldim)
843
- cls['bb'] = np.zeros(self.cszfast_ldim)
871
+ cls['ell'] = self.arange(self.cszfast_ldim)
872
+ cls['tt'] = self.zeros(self.cszfast_ldim)
873
+ cls['te'] = self.zeros(self.cszfast_ldim)
874
+ cls['ee'] = self.zeros(self.cszfast_ldim)
875
+ cls['pp'] = self.zeros(self.cszfast_ldim)
876
+ cls['bb'] = self.zeros(self.cszfast_ldim)
844
877
  cls['tt'][2:self.cp_lmax+1] = (Tcmb_uk)**2.*self.cp_predicted_tt_spectrum.copy()
845
878
  cls['te'][2:self.cp_lmax+1] = (Tcmb_uk)**2.*self.cp_predicted_te_spectrum.copy()
846
879
  cls['ee'][2:self.cp_lmax+1] = (Tcmb_uk)**2.*self.cp_predicted_ee_spectrum.copy()
@@ -848,21 +881,21 @@ class Class_szfast(object):
848
881
 
849
882
 
850
883
  if ell_factor==False:
851
- fac_l = np.zeros(self.cszfast_ldim)
852
- fac_l[2:self.cp_lmax+1] = 1./(cls['ell'][2:self.cp_lmax+1]*(cls['ell'][2:self.cp_lmax+1]+1.)/2./np.pi)
884
+ fac_l = self.zeros(self.cszfast_ldim)
885
+ fac_l[2:self.cp_lmax+1] = 1./(cls['ell'][2:self.cp_lmax+1]*(cls['ell'][2:self.cp_lmax+1]+1.)/2./self.pi)
853
886
  cls['tt'][2:self.cp_lmax+1] *= fac_l[2:self.cp_lmax+1]
854
887
  cls['te'][2:self.cp_lmax+1] *= fac_l[2:self.cp_lmax+1]
855
888
  cls['ee'][2:self.cp_lmax+1] *= fac_l[2:self.cp_lmax+1]
856
- # cls['bb'] = np.zeros(self.cszfast_ldim)
889
+ # cls['bb'] = self.zeros(self.cszfast_ldim)
857
890
  return cls
858
891
 
859
892
 
860
893
  def get_pknl_at_k_and_z(self,k_asked,z_asked):
861
894
  # def get_pkl_at_k_and_z(self,k_asked,z_asked,method = 'cloughtocher'):
862
895
  # if method == 'linear':
863
- # pk = self.pkl_linearnd_interp(z_asked,np.log(k_asked))
896
+ # pk = self.pkl_linearnd_interp(z_asked,self.log(k_asked))
864
897
  # elif method == 'cloughtocher':
865
- # pk = self.pkl_cloughtocher_interp(z_asked,np.log(k_asked))
898
+ # pk = self.pkl_cloughtocher_interp(z_asked,self.log(k_asked))
866
899
  # return np.exp(pk)
867
900
  return self.pknl_interp.P(z_asked,k_asked)
868
901
 
@@ -870,18 +903,18 @@ class Class_szfast(object):
870
903
  # def get_pkl_at_k_and_z(self,k_asked,z_asked,method = 'cloughtocher'):
871
904
  def get_pkl_at_k_and_z(self,k_asked,z_asked):
872
905
  # if method == 'linear':
873
- # pk = self.pknl_linearnd_interp(z_asked,np.log(k_asked))
906
+ # pk = self.pknl_linearnd_interp(z_asked,self.log(k_asked))
874
907
  # elif method == 'cloughtocher':
875
- # pk = self.pknl_cloughtocher_interp(z_asked,np.log(k_asked))
908
+ # pk = self.pknl_cloughtocher_interp(z_asked,self.log(k_asked))
876
909
  # return np.exp(pk)
877
910
  return self.pkl_interp.P(z_asked,k_asked)
878
911
 
879
912
  # function used to overwrite the classy function in fast mode.
880
913
  def get_sigma_at_r_and_z(self,r_asked,z_asked):
881
914
  # if method == 'linear':
882
- # pk = self.pknl_linearnd_interp(z_asked,np.log(k_asked))
915
+ # pk = self.pknl_linearnd_interp(z_asked,self.log(k_asked))
883
916
  # elif method == 'cloughtocher':
884
- # pk = self.pknl_cloughtocher_interp(z_asked,np.log(k_asked))
917
+ # pk = self.pknl_cloughtocher_interp(z_asked,self.log(k_asked))
885
918
  # return np.exp(pk)
886
919
  k = self.cszfast_pk_grid_k
887
920
  P_at_z = self.get_pkl_at_k_and_z(k,z_asked)
@@ -899,7 +932,10 @@ class Class_szfast(object):
899
932
  return np.array(self.hz_interp(z)*H_units_conv_factor[units])
900
933
 
901
934
  def get_chi(self, z):
902
- return np.array(self.chi_interp(z))
935
+ if self.jax_mode:
936
+ return jnp.array(self.chi_interp(z))
937
+ else:
938
+ return np.array(self.chi_interp(z))
903
939
 
904
940
  def get_gas_pressure_profile_x(self,z,m,x):
905
941
  return 0#np.vectorize(self.csz_base.get_pressure_P_over_P_delta_at_x_M_z_b12_200c)(x,m,z)
@@ -913,7 +949,7 @@ class Class_szfast(object):
913
949
 
914
950
 
915
951
  def tabulate_gas_pressure_profile_k(self):
916
- z_asked,m_asked,x_asked = 0.2,3e14,np.geomspace(1e-3,1e2,500)
952
+ z_asked,m_asked,x_asked = 0.2,3e14,self.geomspace(1e-3,1e2,500)
917
953
  start = time.time()
918
954
  px = self.get_gas_pressure_profile_x(z_asked,m_asked,x_asked)
919
955
  end = time.time()
@@ -105,8 +105,9 @@ for mp in cosmo_model_list:
105
105
 
106
106
  cp_pknl_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'PK/' + emulator_dict[mp]['PKNL'])
107
107
 
108
- cp_pkl_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'PK/' + emulator_dict[mp]['PKL'])
109
-
108
+ cp_pkl_nn_jax[mp] = CosmoPowerJAX_custom(probe='custom_log',filepath=path_to_emulators +'PK/' + emulator_dict[mp]['PKL'] + '.npz')
109
+ cp_pkl_nn_jax[mp].ten_to_predictions = False
110
+
110
111
  cp_der_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'derived-parameters/' + emulator_dict[mp]['DER'])
111
112
 
112
113
  # cp_da_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['DAZ'])
@@ -33,7 +33,7 @@ cosmopower_derived_params_names = ['100*theta_s',
33
33
  'ra_star',
34
34
  'rs_drag']
35
35
 
36
- cp_l_max_scalars = 11000 # max multipole of train ing data
36
+ cp_l_max_scalars = 11000 # max multipole of training data
37
37
 
38
38
  cosmo_model_list = [
39
39
  'lcdm',
@@ -87,6 +87,8 @@ emulator_dict['lcdm']['default']['n_s'] = 0.9665
87
87
  emulator_dict['lcdm']['default']['N_ur'] = 2.0328
88
88
  emulator_dict['lcdm']['default']['N_ncdm'] = 1
89
89
  emulator_dict['lcdm']['default']['m_ncdm'] = 0.06
90
+ emulator_dict['lcdm']['default']['deg_ncdm'] = 1
91
+ emulator_dict['lcdm']['default']['T_cmb'] = 2.7255
90
92
 
91
93
  emulator_dict['mnu']['TT'] = 'TT_mnu_v1'
92
94
  emulator_dict['mnu']['TE'] = 'TE_mnu_v1'
@@ -107,7 +109,9 @@ emulator_dict['mnu']['default']['omega_cdm'] = 0.11933
107
109
  emulator_dict['mnu']['default']['n_s'] = 0.9665
108
110
  emulator_dict['mnu']['default']['N_ur'] = 2.0328
109
111
  emulator_dict['mnu']['default']['N_ncdm'] = 1
112
+ emulator_dict['mnu']['default']['deg_ncdm'] = 1
110
113
  emulator_dict['mnu']['default']['m_ncdm'] = 0.06
114
+ emulator_dict['mnu']['default']['T_cmb'] = 2.7255
111
115
 
112
116
  emulator_dict['neff']['TT'] = 'TT_neff_v1'
113
117
  emulator_dict['neff']['TE'] = 'TE_neff_v1'
@@ -128,8 +132,9 @@ emulator_dict['neff']['default']['omega_cdm'] = 0.11933
128
132
  emulator_dict['neff']['default']['n_s'] = 0.9665
129
133
  emulator_dict['neff']['default']['N_ur'] = 2.0328 # this is the default value in class v2 to get Neff = 3.046
130
134
  emulator_dict['neff']['default']['N_ncdm'] = 1
135
+ emulator_dict['neff']['default']['deg_ncdm'] = 1
131
136
  emulator_dict['neff']['default']['m_ncdm'] = 0.06
132
-
137
+ emulator_dict['neff']['default']['T_cmb'] = 2.7255
133
138
 
134
139
  emulator_dict['wcdm']['TT'] = 'TT_w_v1'
135
140
  emulator_dict['wcdm']['TE'] = 'TE_w_v1'
@@ -151,6 +156,9 @@ emulator_dict['wcdm']['default']['n_s'] = 0.9665
151
156
  emulator_dict['wcdm']['default']['N_ur'] = 2.0328 # this is the default value in class v2 to get Neff = 3.046
152
157
  emulator_dict['wcdm']['default']['N_ncdm'] = 1
153
158
  emulator_dict['wcdm']['default']['m_ncdm'] = 0.06
159
+ emulator_dict['wcdm']['default']['deg_ncdm'] = 1
160
+ emulator_dict['wcdm']['default']['T_cmb'] = 2.7255
161
+
154
162
 
155
163
  emulator_dict['ede']['TT'] = 'TT_v1'
156
164
  emulator_dict['ede']['TE'] = 'TE_v1'
@@ -174,9 +182,10 @@ emulator_dict['ede']['default']['log10z_c'] = 3.562 # e.g. from https://github.c
174
182
  emulator_dict['ede']['default']['thetai_scf'] = 2.83 # e.g. from https://github.com/mwt5345/class_ede/blob/master/class/notebooks-ede/2-CMB-Comparison.ipynb
175
183
  emulator_dict['ede']['default']['r'] = 0.
176
184
  emulator_dict['ede']['default']['N_ur'] = 0.00641 # this is the default value in class v2 to get Neff = 3.046
177
- emulator_dict['ede']['default']['N_ncdm'] = 3
185
+ emulator_dict['ede']['default']['N_ncdm'] = 1 ### use equivalence with deg_ncdm = 3
186
+ emulator_dict['ede']['default']['deg_ncdm'] = 3 ### use equivalence with deg_ncdm = 3 for faster computation
178
187
  emulator_dict['ede']['default']['m_ncdm'] = 0.02
179
-
188
+ emulator_dict['ede']['default']['T_cmb'] = 2.7255
180
189
 
181
190
  emulator_dict['mnu-3states']['TT'] = 'TT_v1'
182
191
  emulator_dict['mnu-3states']['TE'] = 'TE_v1'
@@ -196,8 +205,10 @@ emulator_dict['mnu-3states']['default']['omega_b'] = 0.02242
196
205
  emulator_dict['mnu-3states']['default']['omega_cdm'] = 0.11933
197
206
  emulator_dict['mnu-3states']['default']['n_s'] = 0.9665
198
207
  emulator_dict['mnu-3states']['default']['N_ur'] = 0.00641 # this is the default value in class v2 to get Neff = 3.046
199
- emulator_dict['mnu-3states']['default']['N_ncdm'] = 3
208
+ emulator_dict['mnu-3states']['default']['N_ncdm'] = 1 ### use equivalence with deg_ncdm = 3
209
+ emulator_dict['mnu-3states']['default']['deg_ncdm'] = 3 ### use equivalence with deg_ncdm = 3 for faster computation
200
210
  emulator_dict['mnu-3states']['default']['m_ncdm'] = 0.02
211
+ emulator_dict['mnu-3states']['default']['T_cmb'] = 2.7255
201
212
 
202
213
  emulator_dict['ede-v2']['TT'] = 'TT_v2'
203
214
  emulator_dict['ede-v2']['TE'] = 'TE_v2'
@@ -222,9 +233,10 @@ emulator_dict['ede-v2']['default']['log10z_c'] = 3.562 # e.g. from https://githu
222
233
  emulator_dict['ede-v2']['default']['thetai_scf'] = 2.83 # e.g. from https://github.com/mwt5345/class_ede/blob/master/class/notebooks-ede/2-CMB-Comparison.ipynb
223
234
  emulator_dict['ede-v2']['default']['r'] = 0.
224
235
  emulator_dict['ede-v2']['default']['N_ur'] = 0.00441 # this is the default value in class v3 to get Neff = 3.044
225
- emulator_dict['ede-v2']['default']['N_ncdm'] = 3
236
+ emulator_dict['ede-v2']['default']['N_ncdm'] = 1 ### use equivalence with deg_ncdm = 3
237
+ emulator_dict['ede-v2']['default']['deg_ncdm'] = 3 ### use equivalence with deg_ncdm = 3 for faster computation
226
238
  emulator_dict['ede-v2']['default']['m_ncdm'] = 0.02
227
-
239
+ emulator_dict['ede-v2']['default']['T_cmb'] = 2.7255
228
240
 
229
241
 
230
242
 
@@ -60,3 +60,13 @@ class Const:
60
60
  c_km_s = 299792.458 # speed of light
61
61
  h_J_s = 6.626070040e-34 # Planck's constant
62
62
  kB_J_K = 1.38064852e-23 # Boltzmann constant
63
+
64
+ _c_ = 2.99792458e8 # c in m/s
65
+ _Mpc_over_m_ = 3.085677581282e22 # conversion factor from meters to megaparsecs
66
+ _Gyr_over_Mpc_ = 3.06601394e2 # conversion factor from megaparsecs to gigayears
67
+ _G_ = 6.67428e-11 # Newton constant in m^3/Kg/s^2
68
+ _eV_ = 1.602176487e-19 # 1 eV expressed in J
69
+
70
+ # parameters entering in Stefan-Boltzmann constant sigma_B
71
+ _k_B_ = 1.3806504e-23
72
+ _h_P_ = 6.62606896e-34
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: classy_szfast
3
- Version: 0.0.25.post3
3
+ Version: 0.0.25.post5
4
4
  Summary: The accelerator of the class_sz code from https://github.com/CLASS-SZ
5
5
  Maintainer-email: Boris Bolliet <bb667@cam.ac.uk>
6
6
  License: MIT
@@ -3,7 +3,7 @@ requires = ["setuptools", "wheel"]
3
3
  build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
- version = "0.0.25.post3"
6
+ version = "0.0.25.post5"
7
7
  license = { text = "MIT" }
8
8
  name = "classy_szfast"
9
9
  maintainers = [{name = "Boris Bolliet",email="bb667@cam.ac.uk"}]