classy-szfast 0.0.25.post3__tar.gz → 0.0.25.post4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post4}/PKG-INFO +1 -1
  2. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post4}/classy_szfast/classy_szfast.py +117 -114
  3. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post4}/classy_szfast/cosmopower_jax.py +3 -2
  4. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post4}/classy_szfast/emulators_meta_data.py +1 -1
  5. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post4}/classy_szfast.egg-info/PKG-INFO +1 -1
  6. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post4}/pyproject.toml +1 -1
  7. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post4}/README.md +0 -0
  8. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post4}/classy_szfast/__init__.py +0 -0
  9. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post4}/classy_szfast/classy_sz.py +0 -0
  10. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post4}/classy_szfast/config.py +0 -0
  11. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post4}/classy_szfast/cosmopower.py +0 -0
  12. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post4}/classy_szfast/cosmosis_classy_szfast_interface.py +0 -0
  13. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post4}/classy_szfast/custom_bias/__init__.py +0 -0
  14. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post4}/classy_szfast/custom_bias/custom_bias.py +0 -0
  15. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post4}/classy_szfast/custom_profiles/__init__.py +0 -0
  16. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post4}/classy_szfast/custom_profiles/custom_profiles.py +0 -0
  17. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post4}/classy_szfast/pks_and_sigmas.py +0 -0
  18. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post4}/classy_szfast/restore_nn.py +0 -0
  19. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post4}/classy_szfast/suppress_warnings.py +0 -0
  20. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post4}/classy_szfast/utils.py +0 -0
  21. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post4}/classy_szfast.egg-info/SOURCES.txt +0 -0
  22. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post4}/classy_szfast.egg-info/dependency_links.txt +0 -0
  23. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post4}/classy_szfast.egg-info/requires.txt +0 -0
  24. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post4}/classy_szfast.egg-info/top_level.txt +0 -0
  25. {classy_szfast-0.0.25.post3 → classy_szfast-0.0.25.post4}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: classy_szfast
3
- Version: 0.0.25.post3
3
+ Version: 0.0.25.post4
4
4
  Summary: The accelerator of the class_sz code from https://github.com/CLASS-SZ
5
5
  Maintainer-email: Boris Bolliet <bb667@cam.ac.uk>
6
6
  License: MIT
@@ -1,7 +1,7 @@
1
1
  from .utils import *
2
2
  from .config import *
3
3
  import numpy as np
4
- from .emulators_meta_data import emulator_dict, dofftlog_alphas, cp_l_max_scalars
4
+ from .emulators_meta_data import emulator_dict, dofftlog_alphas, cp_l_max_scalars, cosmo_model_list
5
5
  from .cosmopower import cp_tt_nn, cp_te_nn, cp_ee_nn, cp_ee_nn, cp_pp_nn, cp_pknl_nn, cp_pkl_nn, cp_der_nn, cp_da_nn, cp_h_nn, cp_s8_nn, cp_pkl_fftlog_alphas_real_nn, cp_pkl_fftlog_alphas_imag_nn, cp_pkl_fftlog_alphas_nus
6
6
  from .cosmopower_jax import cp_tt_nn_jax, cp_te_nn_jax, cp_ee_nn_jax, cp_ee_nn_jax, cp_pp_nn_jax, cp_pknl_nn_jax, cp_pkl_nn_jax, cp_der_nn_jax, cp_da_nn_jax, cp_h_nn_jax, cp_s8_nn_jax
7
7
  from .pks_and_sigmas import *
@@ -76,13 +76,10 @@ class Class_szfast(object):
76
76
 
77
77
  self.jax_mode = params_settings["jax"]
78
78
 
79
+
79
80
  # print(f"JAX mode: {self.jax_mode}")
80
81
 
81
82
 
82
-
83
-
84
-
85
-
86
83
  # cosmopower emulators
87
84
  # self.cp_path_to_cosmopower_organization = path_to_cosmopower_organization + '/'
88
85
  self.cp_tt_nn = cp_tt_nn
@@ -90,16 +87,70 @@ class Class_szfast(object):
90
87
  self.cp_ee_nn = cp_ee_nn
91
88
  self.cp_pp_nn = cp_pp_nn
92
89
  self.cp_pknl_nn = cp_pknl_nn
93
- self.cp_pkl_nn = cp_pkl_nn
94
90
  self.cp_der_nn = cp_der_nn
95
91
 
92
+ self.cp_lmax = cp_l_max_scalars
93
+
94
+ cosmo_model_dict = {i: model for i, model in enumerate(cosmo_model_list)}
95
+
96
+ if (cosmo_model_dict[params_settings['cosmo_model']] == 'ede-v2'):
97
+
98
+ self.cszfast_pk_grid_zmax = 20.
99
+ self.cszfast_pk_grid_kmin = 5e-4
100
+ self.cszfast_pk_grid_kmax = 10.
101
+ self.cp_kmax = self.cszfast_pk_grid_kmax
102
+ self.cp_kmin = self.cszfast_pk_grid_kmin
103
+ # self.logger.info(f">>> using kmin = {self.cp_kmin}")
104
+ # self.logger.info(f">>> using kmax = {self.cp_kmax}")
105
+ # self.logger.info(f">>> using zmax = {self.cszfast_pk_grid_zmax}")
106
+
107
+ else:
108
+
109
+ self.cszfast_pk_grid_zmax = 5. # max z of our pk emulators (sept 23)
110
+ self.cszfast_pk_grid_kmin = 1e-4
111
+ self.cszfast_pk_grid_kmax = 50.
112
+ self.cp_kmax = self.cszfast_pk_grid_kmax
113
+ self.cp_kmin = self.cszfast_pk_grid_kmin
114
+ # self.logger.info(f">>> using kmin = {self.cp_kmin}")
115
+ # self.logger.info(f">>> using kmax = {self.cp_kmax}")
116
+ # self.logger.info(f">>> using zmax = {self.cszfast_pk_grid_zmax}")
96
117
 
97
118
  if self.jax_mode:
98
119
  self.cp_h_nn = cp_h_nn_jax
99
120
  self.cp_da_nn = cp_da_nn_jax
121
+ self.cp_pkl_nn = cp_pkl_nn_jax
122
+
123
+ self.pi = jnp.pi
124
+ self.transpose = jnp.transpose
125
+ self.asarray = jnp.asarray
126
+
127
+ self.linspace = jnp.linspace
128
+ self.geomspace = jnp.geomspace
129
+ self.arange = jnp.arange
130
+ self.zeros = jnp.zeros
131
+
132
+
100
133
  else:
101
134
  self.cp_h_nn = cp_h_nn
102
135
  self.cp_da_nn = cp_da_nn
136
+ self.cp_pkl_nn = cp_pkl_nn
137
+
138
+ self.pi = np.pi
139
+ self.transpose = np.transpose
140
+
141
+ self.linspace = np.linspace
142
+ self.geomspace = np.geomspace
143
+ self.arange = np.arange
144
+ self.zeros = np.zeros
145
+ self.asarray = np.asarray
146
+
147
+
148
+
149
+ self.cp_ls = self.arange(2,self.cp_lmax+1)
150
+ self.cp_predicted_tt_spectrum =self.zeros(self.cp_lmax)
151
+ self.cp_predicted_te_spectrum =self.zeros(self.cp_lmax)
152
+ self.cp_predicted_ee_spectrum =self.zeros(self.cp_lmax)
153
+ self.cp_predicted_pp_spectrum =self.zeros(self.cp_lmax)
103
154
 
104
155
  self.cp_s8_nn = cp_s8_nn
105
156
 
@@ -114,20 +165,6 @@ class Class_szfast(object):
114
165
  self.use_Amod = 0
115
166
  self.Amod = 0
116
167
 
117
- self.cp_lmax = cp_l_max_scalars
118
- self.cp_ls = np.arange(2,self.cp_lmax+1)
119
-
120
-
121
-
122
-
123
- cosmo_model_dict = {0: 'lcdm',
124
- 1: 'mnu',
125
- 2: 'neff',
126
- 3: 'wcdm',
127
- 4: 'ede',
128
- 5: 'mnu-3states',
129
- 6: 'ede-v2'
130
- }
131
168
 
132
169
 
133
170
  if cosmo_model_dict[params_settings['cosmo_model']] == 'ede-v2':
@@ -140,60 +177,31 @@ class Class_szfast(object):
140
177
  self.cp_ndspl_k = 10
141
178
  self.cp_nk = 5000
142
179
 
143
- self.cp_predicted_tt_spectrum =np.zeros(self.cp_lmax)
144
- self.cp_predicted_te_spectrum =np.zeros(self.cp_lmax)
145
- self.cp_predicted_ee_spectrum =np.zeros(self.cp_lmax)
146
- self.cp_predicted_pp_spectrum =np.zeros(self.cp_lmax)
147
-
148
180
 
149
181
  self.cszfast_ldim = 20000 # used for the cls arrays
150
-
151
182
  self.cszfast_pk_grid_nz = 100 # has to be same as narraySZ, i.e., ndim_redshifts; it is setup hereafter if ndim_redshifts is passed
152
183
 
153
184
 
154
-
155
- if (cosmo_model_dict[params_settings['cosmo_model']] == 'ede-v2'):
156
-
157
- self.cszfast_pk_grid_zmax = 20.
158
- self.cszfast_pk_grid_kmin = 5e-4
159
- self.cszfast_pk_grid_kmax = 10.
160
- self.cp_kmax = self.cszfast_pk_grid_kmax
161
- self.cp_kmin = self.cszfast_pk_grid_kmin
162
- # self.logger.info(f">>> using kmin = {self.cp_kmin}")
163
- # self.logger.info(f">>> using kmax = {self.cp_kmax}")
164
- # self.logger.info(f">>> using zmax = {self.cszfast_pk_grid_zmax}")
165
-
166
- else:
167
-
168
- self.cszfast_pk_grid_zmax = 5. # max z of our pk emulators (sept 23)
169
- self.cszfast_pk_grid_kmin = 1e-4
170
- self.cszfast_pk_grid_kmax = 50.
171
- self.cp_kmax = self.cszfast_pk_grid_kmax
172
- self.cp_kmin = self.cszfast_pk_grid_kmin
173
- # self.logger.info(f">>> using kmin = {self.cp_kmin}")
174
- # self.logger.info(f">>> using kmax = {self.cp_kmax}")
175
- # self.logger.info(f">>> using zmax = {self.cszfast_pk_grid_zmax}")
176
-
177
- self.cszfast_pk_grid_z = np.linspace(0.,self.cszfast_pk_grid_zmax,self.cszfast_pk_grid_nz)
185
+ self.cszfast_pk_grid_z = self.linspace(0.,self.cszfast_pk_grid_zmax,self.cszfast_pk_grid_nz)
178
186
  self.cszfast_pk_grid_ln1pz = np.log(1.+self.cszfast_pk_grid_z)
179
187
 
180
188
 
181
- self.cszfast_pk_grid_k = np.geomspace(self.cp_kmin,self.cp_kmax,self.cp_nk)[::self.cp_ndspl_k]
189
+ self.cszfast_pk_grid_k = self.geomspace(self.cp_kmin,self.cp_kmax,self.cp_nk)[::self.cp_ndspl_k]
182
190
 
183
191
  self.cszfast_pk_grid_lnk = np.log(self.cszfast_pk_grid_k)
184
192
 
185
- self.cszfast_pk_grid_nk = len(np.geomspace(self.cp_kmin,self.cp_kmax,self.cp_nk)[::self.cp_ndspl_k]) # has to be same as ndimSZ, and the same as dimension of cosmopower pk emulators
193
+ self.cszfast_pk_grid_nk = len(self.geomspace(self.cp_kmin,self.cp_kmax,self.cp_nk)[::self.cp_ndspl_k]) # has to be same as ndimSZ, and the same as dimension of cosmopower pk emulators
186
194
 
187
195
  for k,v in params_settings.items():
188
196
 
189
197
  if k == 'ndim_redshifts':
190
198
 
191
199
  self.cszfast_pk_grid_nz = v
192
- self.cszfast_pk_grid_z = np.linspace(0.,self.cszfast_pk_grid_zmax,self.cszfast_pk_grid_nz)
200
+ self.cszfast_pk_grid_z = self.linspace(0.,self.cszfast_pk_grid_zmax,self.cszfast_pk_grid_nz)
193
201
  self.cszfast_pk_grid_ln1pz = np.log(1.+self.cszfast_pk_grid_z)
194
202
 
195
- self.cszfast_pk_grid_pknl_flat = np.zeros(self.cszfast_pk_grid_nz*self.cszfast_pk_grid_nk)
196
- self.cszfast_pk_grid_pkl_flat = np.zeros(self.cszfast_pk_grid_nz*self.cszfast_pk_grid_nk)
203
+ self.cszfast_pk_grid_pknl_flat = self.zeros(self.cszfast_pk_grid_nz*self.cszfast_pk_grid_nk)
204
+ self.cszfast_pk_grid_pkl_flat = self.zeros(self.cszfast_pk_grid_nz*self.cszfast_pk_grid_nk)
197
205
 
198
206
  if k == 'cosmo_model':
199
207
 
@@ -212,13 +220,12 @@ class Class_szfast(object):
212
220
 
213
221
  else:
214
222
 
215
- ls = np.arange(2,self.cp_nk+2)[::self.cp_ndspl_k] # jan 10 ndspl
216
- dls = ls*(ls+1.)/2./np.pi
223
+ ls = self.arange(2,self.cp_nk+2)[::self.cp_ndspl_k] # jan 10 ndspl
224
+ dls = ls*(ls+1.)/2./self.pi
217
225
  self.pk_power_fac= (dls)**-1
218
226
 
219
227
 
220
- self.cp_z_interp = np.linspace(0.,20.,5000)
221
- self.cp_z_interp_jax = jnp.linspace(0.,20.,5000)
228
+ self.cp_z_interp = self.linspace(0.,20.,5000)
222
229
 
223
230
  self.csz_base = None
224
231
 
@@ -226,7 +233,7 @@ class Class_szfast(object):
226
233
  self.cszfast_zgrid_zmin = 0.
227
234
  self.cszfast_zgrid_zmax = 4.
228
235
  self.cszfast_zgrid_nz = 250
229
- self.cszfast_zgrid = np.linspace(self.cszfast_zgrid_zmin,
236
+ self.cszfast_zgrid = self.linspace(self.cszfast_zgrid_zmin,
230
237
  self.cszfast_zgrid_zmax,
231
238
  self.cszfast_zgrid_nz)
232
239
 
@@ -234,14 +241,14 @@ class Class_szfast(object):
234
241
  self.cszfast_mgrid_mmin = 1e10
235
242
  self.cszfast_mgrid_mmax = 1e15
236
243
  self.cszfast_mgrid_nm = 50
237
- self.cszfast_mgrid = np.geomspace(self.cszfast_mgrid_mmin,
244
+ self.cszfast_mgrid = self.geomspace(self.cszfast_mgrid_mmin,
238
245
  self.cszfast_mgrid_mmax,
239
246
  self.cszfast_mgrid_nm)
240
247
 
241
248
  self.cszfast_gas_pressure_xgrid_xmin = 1e-2
242
249
  self.cszfast_gas_pressure_xgrid_xmax = 1e2
243
250
  self.cszfast_gas_pressure_xgrid_nx = 100
244
- self.cszfast_gas_pressure_xgrid = np.geomspace(self.cszfast_gas_pressure_xgrid_xmin,
251
+ self.cszfast_gas_pressure_xgrid = self.geomspace(self.cszfast_gas_pressure_xgrid_xmin,
245
252
  self.cszfast_gas_pressure_xgrid_xmax,
246
253
  self.cszfast_gas_pressure_xgrid_nx)
247
254
 
@@ -312,7 +319,7 @@ class Class_szfast(object):
312
319
  creal = predicted_testing_alphas_creal
313
320
  cimag = predicted_testing_alphas_cimag
314
321
  Nmax = len(self.cszfast_pk_grid_k)
315
- cnew = np.zeros(Nmax+1,dtype=complex)
322
+ cnew = self.zeros(Nmax+1,dtype=complex)
316
323
  for i in range(Nmax+1):
317
324
  if i<int(Nmax/2):
318
325
  cnew[i] = complex(creal[i],cimag[i])
@@ -362,20 +369,20 @@ class Class_szfast(object):
362
369
 
363
370
  nl = len(self.cp_predicted_tt_spectrum)
364
371
  cls = {}
365
- cls['ell'] = np.arange(20000)
366
- cls['tt'] = np.zeros(20000)
367
- cls['te'] = np.zeros(20000)
368
- cls['ee'] = np.zeros(20000)
369
- cls['pp'] = np.zeros(20000)
370
- cls['bb'] = np.zeros(20000)
371
- lcp = np.asarray(cls['ell'][2:nl+2])
372
+ cls['ell'] = self.arange(20000)
373
+ cls['tt'] = self.zeros(20000)
374
+ cls['te'] = self.zeros(20000)
375
+ cls['ee'] = self.zeros(20000)
376
+ cls['pp'] = self.zeros(20000)
377
+ cls['bb'] = self.zeros(20000)
378
+ lcp = self.asarray(cls['ell'][2:nl+2])
372
379
 
373
380
  # print('cosmo_model:',self.cosmo_model,nl)
374
381
  if self.cosmo_model == 'ede-v2':
375
382
  factor_ttteee = 1./lcp**2
376
383
  factor_pp = 1./lcp**3
377
384
  else:
378
- factor_ttteee = 1./(lcp*(lcp+1.)/2./np.pi)
385
+ factor_ttteee = 1./(lcp*(lcp+1.)/2./self.pi)
379
386
  factor_pp = 1./(lcp*(lcp+1.))**2.
380
387
 
381
388
  self.cp_predicted_tt_spectrum *= factor_ttteee
@@ -402,7 +409,7 @@ class Class_szfast(object):
402
409
  cmb_cls_loaded = pickle.load(handle)
403
410
  nl_cls_file = len(cmb_cls_loaded['ell'])
404
411
  cls_ls = cmb_cls_loaded['ell']
405
- dlfac = cls_ls*(cls_ls+1.)/2./np.pi
412
+ dlfac = cls_ls*(cls_ls+1.)/2./self.pi
406
413
  cls_tt = cmb_cls_loaded['tt']*dlfac
407
414
  cls_te = cmb_cls_loaded['te']*dlfac
408
415
  cls_ee = cmb_cls_loaded['ee']*dlfac
@@ -421,8 +428,6 @@ class Class_szfast(object):
421
428
  **params_values_dict):
422
429
 
423
430
  z_arr = self.cszfast_pk_grid_z
424
-
425
-
426
431
  k_arr = self.cszfast_pk_grid_k
427
432
 
428
433
  # print(">>> z_arr:",z_arr)
@@ -464,26 +469,32 @@ class Class_szfast(object):
464
469
 
465
470
  params_dict_pp = params_dict.copy()
466
471
  params_dict_pp['z_pk_save_nonclass'] = [zp]
467
- predicted_pk_spectrum_z.append(self.cp_pkl_nn[self.cosmo_model].predictions_np(params_dict_pp)[0])
472
+ if self.jax_mode:
473
+ predicted_pk_spectrum_z.append(self.cp_pkl_nn[self.cosmo_model].predict(params_dict_pp))
474
+ else:
475
+ predicted_pk_spectrum_z.append(self.cp_pkl_nn[self.cosmo_model].predictions_np(params_dict_pp)[0])
468
476
 
469
477
  # if abs(zp-0.5) < 0.01:
470
478
  # print(">>> predicted_pk_spectrum_z:",predicted_pk_spectrum_z[-1])
471
479
  # import pprint
472
480
  # pprint.pprint(params_dict_pp)
473
481
 
474
- predicted_pk_spectrum = np.asarray(predicted_pk_spectrum_z)
482
+ predicted_pk_spectrum = self.asarray(predicted_pk_spectrum_z)
475
483
 
476
484
 
477
485
  pk = 10.**predicted_pk_spectrum
478
486
 
479
487
  pk_re = pk*self.pk_power_fac
480
- pk_re = np.transpose(pk_re)
488
+ pk_re = self.transpose(pk_re)
481
489
 
482
490
  # print(">>> pk_re:",pk_re)
483
491
  # import sys
484
492
  # sys.exit(0)
485
493
 
486
- self.pkl_interp = PowerSpectrumInterpolator(z_arr,k_arr,np.log(pk_re).T,logP=True)
494
+ if self.jax_mode:
495
+ self.pkl_interp = None
496
+ else:
497
+ self.pkl_interp = PowerSpectrumInterpolator(z_arr,k_arr,np.log(pk_re).T,logP=True)
487
498
 
488
499
  self.cszfast_pk_grid_pk = pk_re
489
500
  self.cszfast_pk_grid_pkl_flat = pk_re.flatten()
@@ -570,7 +581,7 @@ class Class_szfast(object):
570
581
  s8z = self.cp_s8_nn[self.cosmo_model].predictions_np(params_dict)
571
582
  # print(self.s8z)
572
583
  self.s8z_interp = scipy.interpolate.interp1d(
573
- np.linspace(0.,20.,5000),
584
+ self.linspace(0.,20.,5000),
574
585
  s8z[0],
575
586
  kind='linear',
576
587
  axis=-1,
@@ -609,13 +620,13 @@ class Class_szfast(object):
609
620
  params_dict_pp['z_pk_save_nonclass'] = [zp]
610
621
  predicted_pk_spectrum_z.append(self.cp_pknl_nn[self.cosmo_model].predictions_np(params_dict_pp)[0])
611
622
 
612
- predicted_pk_spectrum = np.asarray(predicted_pk_spectrum_z)
623
+ predicted_pk_spectrum = self.asarray(predicted_pk_spectrum_z)
613
624
 
614
625
 
615
626
  pk = 10.**predicted_pk_spectrum
616
627
 
617
628
  pk_re = pk*self.pk_power_fac
618
- pk_re = np.transpose(pk_re)
629
+ pk_re = self.transpose(pk_re)
619
630
 
620
631
 
621
632
  self.pknl_interp = PowerSpectrumInterpolator(z_arr,k_arr,np.log(pk_re).T,logP=True)
@@ -631,16 +642,12 @@ class Class_szfast(object):
631
642
  z_asked,
632
643
  params_values_dict=None):
633
644
 
634
- z_arr = self.cszfast_pk_grid_z
635
645
 
636
646
  k_arr = self.cszfast_pk_grid_k
637
647
 
638
648
  if params_values_dict:
639
-
640
649
  params_values = params_values_dict.copy()
641
-
642
650
  else:
643
-
644
651
  params_values = self.params_for_emulators
645
652
 
646
653
  update_params_with_defaults(params_values, self.emulator_dict[self.cosmo_model]['default'])
@@ -657,19 +664,21 @@ class Class_szfast(object):
657
664
 
658
665
  predicted_pk_spectrum_z = []
659
666
 
660
- z_asked = z_asked
661
667
  params_dict_pp = params_dict.copy()
662
668
  update_params_with_defaults(params_dict_pp, self.emulator_dict[self.cosmo_model]['default'])
663
669
 
664
670
  params_dict_pp['z_pk_save_nonclass'] = [z_asked]
665
- predicted_pk_spectrum_z.append(self.cp_pkl_nn[self.cosmo_model].predictions_np(params_dict_pp)[0])
671
+ if self.jax_mode:
672
+ predicted_pk_spectrum_z.append(self.cp_pkl_nn[self.cosmo_model].predict(params_dict_pp))
673
+ else:
674
+ predicted_pk_spectrum_z.append(self.cp_pkl_nn[self.cosmo_model].predictions_np(params_dict_pp)[0])
666
675
 
667
- predicted_pk_spectrum = np.asarray(predicted_pk_spectrum_z)
676
+ predicted_pk_spectrum = self.asarray(predicted_pk_spectrum_z)
668
677
 
669
678
 
670
679
  pk = 10.**predicted_pk_spectrum
671
680
  pk_re = pk*self.pk_power_fac
672
- pk_re = np.transpose(pk_re)
681
+ pk_re = self.transpose(pk_re)
673
682
 
674
683
 
675
684
  return pk_re, k_arr
@@ -712,12 +721,12 @@ class Class_szfast(object):
712
721
  params_dict_pp['z_pk_save_nonclass'] = [z_asked]
713
722
  predicted_pk_spectrum_z.append(self.cp_pknl_nn[self.cosmo_model].predictions_np(params_dict_pp)[0])
714
723
 
715
- predicted_pk_spectrum = np.asarray(predicted_pk_spectrum_z)
724
+ predicted_pk_spectrum = self.asarray(predicted_pk_spectrum_z)
716
725
 
717
726
 
718
727
  pk = 10.**predicted_pk_spectrum
719
728
  pk_re = pk*self.pk_power_fac
720
- pk_re = np.transpose(pk_re)
729
+ pk_re = self.transpose(pk_re)
721
730
 
722
731
 
723
732
  return pk_re, k_arr
@@ -748,19 +757,10 @@ class Class_szfast(object):
748
757
  # print("self.cp_predicted_hubble type:", type(self.cp_predicted_hubble))
749
758
  # print("self.cp_predicted_hubble",self.cp_predicted_hubble)
750
759
 
751
- # self.hz_interp = jscipy.interpolate.interp1d(
752
- # self.cp_z_interp_jax,
753
- # self.cp_predicted_hubble,
754
- # kind='linear',
755
- # axis=-1,
756
- # copy=True,
757
- # bounds_error=None,
758
- # fill_value=np.nan,
759
- # assume_sorted=False)
760
760
 
761
761
  # Assuming `cp_z_interp` and `cp_predicted_hubble` are JAX arrays
762
762
  def hz_interp(x):
763
- return jnp.interp(x, self.cp_z_interp_jax, self.cp_predicted_hubble, left=jnp.nan, right=jnp.nan)
763
+ return jnp.interp(x, self.cp_z_interp, self.cp_predicted_hubble, left=jnp.nan, right=jnp.nan)
764
764
 
765
765
  self.hz_interp = hz_interp
766
766
  # exit()
@@ -803,10 +803,10 @@ class Class_szfast(object):
803
803
  if self.cosmo_model == 'ede-v2':
804
804
  # print('ede-v2 case')
805
805
  self.cp_predicted_da = jnp.insert(self.cp_predicted_da, 0, 0)
806
- self.cp_predicted_da *= (1.+self.cp_z_interp_jax)
806
+ self.cp_predicted_da *= (1.+self.cp_z_interp)
807
807
 
808
808
  def chi_interp(x):
809
- return jnp.interp(x, self.cp_z_interp_jax, self.cp_predicted_da, left=jnp.nan, right=jnp.nan)
809
+ return jnp.interp(x, self.cp_z_interp, self.cp_predicted_da, left=jnp.nan, right=jnp.nan)
810
810
 
811
811
  self.chi_interp = chi_interp
812
812
 
@@ -835,12 +835,12 @@ class Class_szfast(object):
835
835
  def get_cmb_cls(self,ell_factor=True,Tcmb_uk = Tcmb_uk):
836
836
 
837
837
  cls = {}
838
- cls['ell'] = np.arange(self.cszfast_ldim)
839
- cls['tt'] = np.zeros(self.cszfast_ldim)
840
- cls['te'] = np.zeros(self.cszfast_ldim)
841
- cls['ee'] = np.zeros(self.cszfast_ldim)
842
- cls['pp'] = np.zeros(self.cszfast_ldim)
843
- cls['bb'] = np.zeros(self.cszfast_ldim)
838
+ cls['ell'] = self.arange(self.cszfast_ldim)
839
+ cls['tt'] = self.zeros(self.cszfast_ldim)
840
+ cls['te'] = self.zeros(self.cszfast_ldim)
841
+ cls['ee'] = self.zeros(self.cszfast_ldim)
842
+ cls['pp'] = self.zeros(self.cszfast_ldim)
843
+ cls['bb'] = self.zeros(self.cszfast_ldim)
844
844
  cls['tt'][2:self.cp_lmax+1] = (Tcmb_uk)**2.*self.cp_predicted_tt_spectrum.copy()
845
845
  cls['te'][2:self.cp_lmax+1] = (Tcmb_uk)**2.*self.cp_predicted_te_spectrum.copy()
846
846
  cls['ee'][2:self.cp_lmax+1] = (Tcmb_uk)**2.*self.cp_predicted_ee_spectrum.copy()
@@ -848,12 +848,12 @@ class Class_szfast(object):
848
848
 
849
849
 
850
850
  if ell_factor==False:
851
- fac_l = np.zeros(self.cszfast_ldim)
852
- fac_l[2:self.cp_lmax+1] = 1./(cls['ell'][2:self.cp_lmax+1]*(cls['ell'][2:self.cp_lmax+1]+1.)/2./np.pi)
851
+ fac_l = self.zeros(self.cszfast_ldim)
852
+ fac_l[2:self.cp_lmax+1] = 1./(cls['ell'][2:self.cp_lmax+1]*(cls['ell'][2:self.cp_lmax+1]+1.)/2./self.pi)
853
853
  cls['tt'][2:self.cp_lmax+1] *= fac_l[2:self.cp_lmax+1]
854
854
  cls['te'][2:self.cp_lmax+1] *= fac_l[2:self.cp_lmax+1]
855
855
  cls['ee'][2:self.cp_lmax+1] *= fac_l[2:self.cp_lmax+1]
856
- # cls['bb'] = np.zeros(self.cszfast_ldim)
856
+ # cls['bb'] = self.zeros(self.cszfast_ldim)
857
857
  return cls
858
858
 
859
859
 
@@ -899,7 +899,10 @@ class Class_szfast(object):
899
899
  return np.array(self.hz_interp(z)*H_units_conv_factor[units])
900
900
 
901
901
  def get_chi(self, z):
902
- return np.array(self.chi_interp(z))
902
+ if self.jax_mode:
903
+ return jnp.array(self.chi_interp(z))
904
+ else:
905
+ return np.array(self.chi_interp(z))
903
906
 
904
907
  def get_gas_pressure_profile_x(self,z,m,x):
905
908
  return 0#np.vectorize(self.csz_base.get_pressure_P_over_P_delta_at_x_M_z_b12_200c)(x,m,z)
@@ -913,7 +916,7 @@ class Class_szfast(object):
913
916
 
914
917
 
915
918
  def tabulate_gas_pressure_profile_k(self):
916
- z_asked,m_asked,x_asked = 0.2,3e14,np.geomspace(1e-3,1e2,500)
919
+ z_asked,m_asked,x_asked = 0.2,3e14,self.geomspace(1e-3,1e2,500)
917
920
  start = time.time()
918
921
  px = self.get_gas_pressure_profile_x(z_asked,m_asked,x_asked)
919
922
  end = time.time()
@@ -105,8 +105,9 @@ for mp in cosmo_model_list:
105
105
 
106
106
  cp_pknl_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'PK/' + emulator_dict[mp]['PKNL'])
107
107
 
108
- cp_pkl_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'PK/' + emulator_dict[mp]['PKL'])
109
-
108
+ cp_pkl_nn_jax[mp] = CosmoPowerJAX_custom(probe='custom_log',filepath=path_to_emulators +'PK/' + emulator_dict[mp]['PKL'] + '.npz')
109
+ cp_pkl_nn_jax[mp].ten_to_predictions = False
110
+
110
111
  cp_der_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'derived-parameters/' + emulator_dict[mp]['DER'])
111
112
 
112
113
  # cp_da_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['DAZ'])
@@ -33,7 +33,7 @@ cosmopower_derived_params_names = ['100*theta_s',
33
33
  'ra_star',
34
34
  'rs_drag']
35
35
 
36
- cp_l_max_scalars = 11000 # max multipole of train ing data
36
+ cp_l_max_scalars = 11000 # max multipole of training data
37
37
 
38
38
  cosmo_model_list = [
39
39
  'lcdm',
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: classy_szfast
3
- Version: 0.0.25.post3
3
+ Version: 0.0.25.post4
4
4
  Summary: The accelerator of the class_sz code from https://github.com/CLASS-SZ
5
5
  Maintainer-email: Boris Bolliet <bb667@cam.ac.uk>
6
6
  License: MIT
@@ -3,7 +3,7 @@ requires = ["setuptools", "wheel"]
3
3
  build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
- version = "0.0.25.post3"
6
+ version = "0.0.25.post4"
7
7
  license = { text = "MIT" }
8
8
  name = "classy_szfast"
9
9
  maintainers = [{name = "Boris Bolliet",email="bb667@cam.ac.uk"}]