classy-szfast 0.0.25.post2__py3-none-any.whl → 0.0.25.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,7 @@
1
1
  from .utils import *
2
2
  from .config import *
3
3
  import numpy as np
4
- from .emulators_meta_data import emulator_dict, dofftlog_alphas, cp_l_max_scalars
4
+ from .emulators_meta_data import emulator_dict, dofftlog_alphas, cp_l_max_scalars, cosmo_model_list
5
5
  from .cosmopower import cp_tt_nn, cp_te_nn, cp_ee_nn, cp_ee_nn, cp_pp_nn, cp_pknl_nn, cp_pkl_nn, cp_der_nn, cp_da_nn, cp_h_nn, cp_s8_nn, cp_pkl_fftlog_alphas_real_nn, cp_pkl_fftlog_alphas_imag_nn, cp_pkl_fftlog_alphas_nus
6
6
  from .cosmopower_jax import cp_tt_nn_jax, cp_te_nn_jax, cp_ee_nn_jax, cp_ee_nn_jax, cp_pp_nn_jax, cp_pknl_nn_jax, cp_pkl_nn_jax, cp_der_nn_jax, cp_da_nn_jax, cp_h_nn_jax, cp_s8_nn_jax
7
7
  from .pks_and_sigmas import *
@@ -76,13 +76,10 @@ class Class_szfast(object):
76
76
 
77
77
  self.jax_mode = params_settings["jax"]
78
78
 
79
+
79
80
  # print(f"JAX mode: {self.jax_mode}")
80
81
 
81
82
 
82
-
83
-
84
-
85
-
86
83
  # cosmopower emulators
87
84
  # self.cp_path_to_cosmopower_organization = path_to_cosmopower_organization + '/'
88
85
  self.cp_tt_nn = cp_tt_nn
@@ -90,14 +87,70 @@ class Class_szfast(object):
90
87
  self.cp_ee_nn = cp_ee_nn
91
88
  self.cp_pp_nn = cp_pp_nn
92
89
  self.cp_pknl_nn = cp_pknl_nn
93
- self.cp_pkl_nn = cp_pkl_nn
94
90
  self.cp_der_nn = cp_der_nn
95
- self.cp_da_nn = cp_da_nn
91
+
92
+ self.cp_lmax = cp_l_max_scalars
93
+
94
+ cosmo_model_dict = {i: model for i, model in enumerate(cosmo_model_list)}
95
+
96
+ if (cosmo_model_dict[params_settings['cosmo_model']] == 'ede-v2'):
97
+
98
+ self.cszfast_pk_grid_zmax = 20.
99
+ self.cszfast_pk_grid_kmin = 5e-4
100
+ self.cszfast_pk_grid_kmax = 10.
101
+ self.cp_kmax = self.cszfast_pk_grid_kmax
102
+ self.cp_kmin = self.cszfast_pk_grid_kmin
103
+ # self.logger.info(f">>> using kmin = {self.cp_kmin}")
104
+ # self.logger.info(f">>> using kmax = {self.cp_kmax}")
105
+ # self.logger.info(f">>> using zmax = {self.cszfast_pk_grid_zmax}")
106
+
107
+ else:
108
+
109
+ self.cszfast_pk_grid_zmax = 5. # max z of our pk emulators (sept 23)
110
+ self.cszfast_pk_grid_kmin = 1e-4
111
+ self.cszfast_pk_grid_kmax = 50.
112
+ self.cp_kmax = self.cszfast_pk_grid_kmax
113
+ self.cp_kmin = self.cszfast_pk_grid_kmin
114
+ # self.logger.info(f">>> using kmin = {self.cp_kmin}")
115
+ # self.logger.info(f">>> using kmax = {self.cp_kmax}")
116
+ # self.logger.info(f">>> using zmax = {self.cszfast_pk_grid_zmax}")
96
117
 
97
118
  if self.jax_mode:
98
119
  self.cp_h_nn = cp_h_nn_jax
120
+ self.cp_da_nn = cp_da_nn_jax
121
+ self.cp_pkl_nn = cp_pkl_nn_jax
122
+
123
+ self.pi = jnp.pi
124
+ self.transpose = jnp.transpose
125
+ self.asarray = jnp.asarray
126
+
127
+ self.linspace = jnp.linspace
128
+ self.geomspace = jnp.geomspace
129
+ self.arange = jnp.arange
130
+ self.zeros = jnp.zeros
131
+
132
+
99
133
  else:
100
134
  self.cp_h_nn = cp_h_nn
135
+ self.cp_da_nn = cp_da_nn
136
+ self.cp_pkl_nn = cp_pkl_nn
137
+
138
+ self.pi = np.pi
139
+ self.transpose = np.transpose
140
+
141
+ self.linspace = np.linspace
142
+ self.geomspace = np.geomspace
143
+ self.arange = np.arange
144
+ self.zeros = np.zeros
145
+ self.asarray = np.asarray
146
+
147
+
148
+
149
+ self.cp_ls = self.arange(2,self.cp_lmax+1)
150
+ self.cp_predicted_tt_spectrum =self.zeros(self.cp_lmax)
151
+ self.cp_predicted_te_spectrum =self.zeros(self.cp_lmax)
152
+ self.cp_predicted_ee_spectrum =self.zeros(self.cp_lmax)
153
+ self.cp_predicted_pp_spectrum =self.zeros(self.cp_lmax)
101
154
 
102
155
  self.cp_s8_nn = cp_s8_nn
103
156
 
@@ -112,20 +165,6 @@ class Class_szfast(object):
112
165
  self.use_Amod = 0
113
166
  self.Amod = 0
114
167
 
115
- self.cp_lmax = cp_l_max_scalars
116
- self.cp_ls = np.arange(2,self.cp_lmax+1)
117
-
118
-
119
-
120
-
121
- cosmo_model_dict = {0: 'lcdm',
122
- 1: 'mnu',
123
- 2: 'neff',
124
- 3: 'wcdm',
125
- 4: 'ede',
126
- 5: 'mnu-3states',
127
- 6: 'ede-v2'
128
- }
129
168
 
130
169
 
131
170
  if cosmo_model_dict[params_settings['cosmo_model']] == 'ede-v2':
@@ -138,60 +177,31 @@ class Class_szfast(object):
138
177
  self.cp_ndspl_k = 10
139
178
  self.cp_nk = 5000
140
179
 
141
- self.cp_predicted_tt_spectrum =np.zeros(self.cp_lmax)
142
- self.cp_predicted_te_spectrum =np.zeros(self.cp_lmax)
143
- self.cp_predicted_ee_spectrum =np.zeros(self.cp_lmax)
144
- self.cp_predicted_pp_spectrum =np.zeros(self.cp_lmax)
145
-
146
180
 
147
181
  self.cszfast_ldim = 20000 # used for the cls arrays
148
-
149
182
  self.cszfast_pk_grid_nz = 100 # has to be same as narraySZ, i.e., ndim_redshifts; it is setup hereafter if ndim_redshifts is passed
150
183
 
151
184
 
152
-
153
- if (cosmo_model_dict[params_settings['cosmo_model']] == 'ede-v2'):
154
-
155
- self.cszfast_pk_grid_zmax = 20.
156
- self.cszfast_pk_grid_kmin = 5e-4
157
- self.cszfast_pk_grid_kmax = 10.
158
- self.cp_kmax = self.cszfast_pk_grid_kmax
159
- self.cp_kmin = self.cszfast_pk_grid_kmin
160
- # self.logger.info(f">>> using kmin = {self.cp_kmin}")
161
- # self.logger.info(f">>> using kmax = {self.cp_kmax}")
162
- # self.logger.info(f">>> using zmax = {self.cszfast_pk_grid_zmax}")
163
-
164
- else:
165
-
166
- self.cszfast_pk_grid_zmax = 5. # max z of our pk emulators (sept 23)
167
- self.cszfast_pk_grid_kmin = 1e-4
168
- self.cszfast_pk_grid_kmax = 50.
169
- self.cp_kmax = self.cszfast_pk_grid_kmax
170
- self.cp_kmin = self.cszfast_pk_grid_kmin
171
- # self.logger.info(f">>> using kmin = {self.cp_kmin}")
172
- # self.logger.info(f">>> using kmax = {self.cp_kmax}")
173
- # self.logger.info(f">>> using zmax = {self.cszfast_pk_grid_zmax}")
174
-
175
- self.cszfast_pk_grid_z = np.linspace(0.,self.cszfast_pk_grid_zmax,self.cszfast_pk_grid_nz)
185
+ self.cszfast_pk_grid_z = self.linspace(0.,self.cszfast_pk_grid_zmax,self.cszfast_pk_grid_nz)
176
186
  self.cszfast_pk_grid_ln1pz = np.log(1.+self.cszfast_pk_grid_z)
177
187
 
178
188
 
179
- self.cszfast_pk_grid_k = np.geomspace(self.cp_kmin,self.cp_kmax,self.cp_nk)[::self.cp_ndspl_k]
189
+ self.cszfast_pk_grid_k = self.geomspace(self.cp_kmin,self.cp_kmax,self.cp_nk)[::self.cp_ndspl_k]
180
190
 
181
191
  self.cszfast_pk_grid_lnk = np.log(self.cszfast_pk_grid_k)
182
192
 
183
- self.cszfast_pk_grid_nk = len(np.geomspace(self.cp_kmin,self.cp_kmax,self.cp_nk)[::self.cp_ndspl_k]) # has to be same as ndimSZ, and the same as dimension of cosmopower pk emulators
193
+ self.cszfast_pk_grid_nk = len(self.geomspace(self.cp_kmin,self.cp_kmax,self.cp_nk)[::self.cp_ndspl_k]) # has to be same as ndimSZ, and the same as dimension of cosmopower pk emulators
184
194
 
185
195
  for k,v in params_settings.items():
186
196
 
187
197
  if k == 'ndim_redshifts':
188
198
 
189
199
  self.cszfast_pk_grid_nz = v
190
- self.cszfast_pk_grid_z = np.linspace(0.,self.cszfast_pk_grid_zmax,self.cszfast_pk_grid_nz)
200
+ self.cszfast_pk_grid_z = self.linspace(0.,self.cszfast_pk_grid_zmax,self.cszfast_pk_grid_nz)
191
201
  self.cszfast_pk_grid_ln1pz = np.log(1.+self.cszfast_pk_grid_z)
192
202
 
193
- self.cszfast_pk_grid_pknl_flat = np.zeros(self.cszfast_pk_grid_nz*self.cszfast_pk_grid_nk)
194
- self.cszfast_pk_grid_pkl_flat = np.zeros(self.cszfast_pk_grid_nz*self.cszfast_pk_grid_nk)
203
+ self.cszfast_pk_grid_pknl_flat = self.zeros(self.cszfast_pk_grid_nz*self.cszfast_pk_grid_nk)
204
+ self.cszfast_pk_grid_pkl_flat = self.zeros(self.cszfast_pk_grid_nz*self.cszfast_pk_grid_nk)
195
205
 
196
206
  if k == 'cosmo_model':
197
207
 
@@ -210,13 +220,12 @@ class Class_szfast(object):
210
220
 
211
221
  else:
212
222
 
213
- ls = np.arange(2,self.cp_nk+2)[::self.cp_ndspl_k] # jan 10 ndspl
214
- dls = ls*(ls+1.)/2./np.pi
223
+ ls = self.arange(2,self.cp_nk+2)[::self.cp_ndspl_k] # jan 10 ndspl
224
+ dls = ls*(ls+1.)/2./self.pi
215
225
  self.pk_power_fac= (dls)**-1
216
226
 
217
227
 
218
- self.cp_z_interp = np.linspace(0.,20.,5000)
219
- self.cp_z_interp_jax = jnp.linspace(0.,20.,5000)
228
+ self.cp_z_interp = self.linspace(0.,20.,5000)
220
229
 
221
230
  self.csz_base = None
222
231
 
@@ -224,7 +233,7 @@ class Class_szfast(object):
224
233
  self.cszfast_zgrid_zmin = 0.
225
234
  self.cszfast_zgrid_zmax = 4.
226
235
  self.cszfast_zgrid_nz = 250
227
- self.cszfast_zgrid = np.linspace(self.cszfast_zgrid_zmin,
236
+ self.cszfast_zgrid = self.linspace(self.cszfast_zgrid_zmin,
228
237
  self.cszfast_zgrid_zmax,
229
238
  self.cszfast_zgrid_nz)
230
239
 
@@ -232,14 +241,14 @@ class Class_szfast(object):
232
241
  self.cszfast_mgrid_mmin = 1e10
233
242
  self.cszfast_mgrid_mmax = 1e15
234
243
  self.cszfast_mgrid_nm = 50
235
- self.cszfast_mgrid = np.geomspace(self.cszfast_mgrid_mmin,
244
+ self.cszfast_mgrid = self.geomspace(self.cszfast_mgrid_mmin,
236
245
  self.cszfast_mgrid_mmax,
237
246
  self.cszfast_mgrid_nm)
238
247
 
239
248
  self.cszfast_gas_pressure_xgrid_xmin = 1e-2
240
249
  self.cszfast_gas_pressure_xgrid_xmax = 1e2
241
250
  self.cszfast_gas_pressure_xgrid_nx = 100
242
- self.cszfast_gas_pressure_xgrid = np.geomspace(self.cszfast_gas_pressure_xgrid_xmin,
251
+ self.cszfast_gas_pressure_xgrid = self.geomspace(self.cszfast_gas_pressure_xgrid_xmin,
243
252
  self.cszfast_gas_pressure_xgrid_xmax,
244
253
  self.cszfast_gas_pressure_xgrid_nx)
245
254
 
@@ -310,7 +319,7 @@ class Class_szfast(object):
310
319
  creal = predicted_testing_alphas_creal
311
320
  cimag = predicted_testing_alphas_cimag
312
321
  Nmax = len(self.cszfast_pk_grid_k)
313
- cnew = np.zeros(Nmax+1,dtype=complex)
322
+ cnew = self.zeros(Nmax+1,dtype=complex)
314
323
  for i in range(Nmax+1):
315
324
  if i<int(Nmax/2):
316
325
  cnew[i] = complex(creal[i],cimag[i])
@@ -360,20 +369,20 @@ class Class_szfast(object):
360
369
 
361
370
  nl = len(self.cp_predicted_tt_spectrum)
362
371
  cls = {}
363
- cls['ell'] = np.arange(20000)
364
- cls['tt'] = np.zeros(20000)
365
- cls['te'] = np.zeros(20000)
366
- cls['ee'] = np.zeros(20000)
367
- cls['pp'] = np.zeros(20000)
368
- cls['bb'] = np.zeros(20000)
369
- lcp = np.asarray(cls['ell'][2:nl+2])
372
+ cls['ell'] = self.arange(20000)
373
+ cls['tt'] = self.zeros(20000)
374
+ cls['te'] = self.zeros(20000)
375
+ cls['ee'] = self.zeros(20000)
376
+ cls['pp'] = self.zeros(20000)
377
+ cls['bb'] = self.zeros(20000)
378
+ lcp = self.asarray(cls['ell'][2:nl+2])
370
379
 
371
380
  # print('cosmo_model:',self.cosmo_model,nl)
372
381
  if self.cosmo_model == 'ede-v2':
373
382
  factor_ttteee = 1./lcp**2
374
383
  factor_pp = 1./lcp**3
375
384
  else:
376
- factor_ttteee = 1./(lcp*(lcp+1.)/2./np.pi)
385
+ factor_ttteee = 1./(lcp*(lcp+1.)/2./self.pi)
377
386
  factor_pp = 1./(lcp*(lcp+1.))**2.
378
387
 
379
388
  self.cp_predicted_tt_spectrum *= factor_ttteee
@@ -400,7 +409,7 @@ class Class_szfast(object):
400
409
  cmb_cls_loaded = pickle.load(handle)
401
410
  nl_cls_file = len(cmb_cls_loaded['ell'])
402
411
  cls_ls = cmb_cls_loaded['ell']
403
- dlfac = cls_ls*(cls_ls+1.)/2./np.pi
412
+ dlfac = cls_ls*(cls_ls+1.)/2./self.pi
404
413
  cls_tt = cmb_cls_loaded['tt']*dlfac
405
414
  cls_te = cmb_cls_loaded['te']*dlfac
406
415
  cls_ee = cmb_cls_loaded['ee']*dlfac
@@ -419,8 +428,6 @@ class Class_szfast(object):
419
428
  **params_values_dict):
420
429
 
421
430
  z_arr = self.cszfast_pk_grid_z
422
-
423
-
424
431
  k_arr = self.cszfast_pk_grid_k
425
432
 
426
433
  # print(">>> z_arr:",z_arr)
@@ -462,26 +469,32 @@ class Class_szfast(object):
462
469
 
463
470
  params_dict_pp = params_dict.copy()
464
471
  params_dict_pp['z_pk_save_nonclass'] = [zp]
465
- predicted_pk_spectrum_z.append(self.cp_pkl_nn[self.cosmo_model].predictions_np(params_dict_pp)[0])
472
+ if self.jax_mode:
473
+ predicted_pk_spectrum_z.append(self.cp_pkl_nn[self.cosmo_model].predict(params_dict_pp))
474
+ else:
475
+ predicted_pk_spectrum_z.append(self.cp_pkl_nn[self.cosmo_model].predictions_np(params_dict_pp)[0])
466
476
 
467
477
  # if abs(zp-0.5) < 0.01:
468
478
  # print(">>> predicted_pk_spectrum_z:",predicted_pk_spectrum_z[-1])
469
479
  # import pprint
470
480
  # pprint.pprint(params_dict_pp)
471
481
 
472
- predicted_pk_spectrum = np.asarray(predicted_pk_spectrum_z)
482
+ predicted_pk_spectrum = self.asarray(predicted_pk_spectrum_z)
473
483
 
474
484
 
475
485
  pk = 10.**predicted_pk_spectrum
476
486
 
477
487
  pk_re = pk*self.pk_power_fac
478
- pk_re = np.transpose(pk_re)
488
+ pk_re = self.transpose(pk_re)
479
489
 
480
490
  # print(">>> pk_re:",pk_re)
481
491
  # import sys
482
492
  # sys.exit(0)
483
493
 
484
- self.pkl_interp = PowerSpectrumInterpolator(z_arr,k_arr,np.log(pk_re).T,logP=True)
494
+ if self.jax_mode:
495
+ self.pkl_interp = None
496
+ else:
497
+ self.pkl_interp = PowerSpectrumInterpolator(z_arr,k_arr,np.log(pk_re).T,logP=True)
485
498
 
486
499
  self.cszfast_pk_grid_pk = pk_re
487
500
  self.cszfast_pk_grid_pkl_flat = pk_re.flatten()
@@ -568,7 +581,7 @@ class Class_szfast(object):
568
581
  s8z = self.cp_s8_nn[self.cosmo_model].predictions_np(params_dict)
569
582
  # print(self.s8z)
570
583
  self.s8z_interp = scipy.interpolate.interp1d(
571
- np.linspace(0.,20.,5000),
584
+ self.linspace(0.,20.,5000),
572
585
  s8z[0],
573
586
  kind='linear',
574
587
  axis=-1,
@@ -607,13 +620,13 @@ class Class_szfast(object):
607
620
  params_dict_pp['z_pk_save_nonclass'] = [zp]
608
621
  predicted_pk_spectrum_z.append(self.cp_pknl_nn[self.cosmo_model].predictions_np(params_dict_pp)[0])
609
622
 
610
- predicted_pk_spectrum = np.asarray(predicted_pk_spectrum_z)
623
+ predicted_pk_spectrum = self.asarray(predicted_pk_spectrum_z)
611
624
 
612
625
 
613
626
  pk = 10.**predicted_pk_spectrum
614
627
 
615
628
  pk_re = pk*self.pk_power_fac
616
- pk_re = np.transpose(pk_re)
629
+ pk_re = self.transpose(pk_re)
617
630
 
618
631
 
619
632
  self.pknl_interp = PowerSpectrumInterpolator(z_arr,k_arr,np.log(pk_re).T,logP=True)
@@ -629,16 +642,12 @@ class Class_szfast(object):
629
642
  z_asked,
630
643
  params_values_dict=None):
631
644
 
632
- z_arr = self.cszfast_pk_grid_z
633
645
 
634
646
  k_arr = self.cszfast_pk_grid_k
635
647
 
636
648
  if params_values_dict:
637
-
638
649
  params_values = params_values_dict.copy()
639
-
640
650
  else:
641
-
642
651
  params_values = self.params_for_emulators
643
652
 
644
653
  update_params_with_defaults(params_values, self.emulator_dict[self.cosmo_model]['default'])
@@ -655,19 +664,21 @@ class Class_szfast(object):
655
664
 
656
665
  predicted_pk_spectrum_z = []
657
666
 
658
- z_asked = z_asked
659
667
  params_dict_pp = params_dict.copy()
660
668
  update_params_with_defaults(params_dict_pp, self.emulator_dict[self.cosmo_model]['default'])
661
669
 
662
670
  params_dict_pp['z_pk_save_nonclass'] = [z_asked]
663
- predicted_pk_spectrum_z.append(self.cp_pkl_nn[self.cosmo_model].predictions_np(params_dict_pp)[0])
671
+ if self.jax_mode:
672
+ predicted_pk_spectrum_z.append(self.cp_pkl_nn[self.cosmo_model].predict(params_dict_pp))
673
+ else:
674
+ predicted_pk_spectrum_z.append(self.cp_pkl_nn[self.cosmo_model].predictions_np(params_dict_pp)[0])
664
675
 
665
- predicted_pk_spectrum = np.asarray(predicted_pk_spectrum_z)
676
+ predicted_pk_spectrum = self.asarray(predicted_pk_spectrum_z)
666
677
 
667
678
 
668
679
  pk = 10.**predicted_pk_spectrum
669
680
  pk_re = pk*self.pk_power_fac
670
- pk_re = np.transpose(pk_re)
681
+ pk_re = self.transpose(pk_re)
671
682
 
672
683
 
673
684
  return pk_re, k_arr
@@ -710,12 +721,12 @@ class Class_szfast(object):
710
721
  params_dict_pp['z_pk_save_nonclass'] = [z_asked]
711
722
  predicted_pk_spectrum_z.append(self.cp_pknl_nn[self.cosmo_model].predictions_np(params_dict_pp)[0])
712
723
 
713
- predicted_pk_spectrum = np.asarray(predicted_pk_spectrum_z)
724
+ predicted_pk_spectrum = self.asarray(predicted_pk_spectrum_z)
714
725
 
715
726
 
716
727
  pk = 10.**predicted_pk_spectrum
717
728
  pk_re = pk*self.pk_power_fac
718
- pk_re = np.transpose(pk_re)
729
+ pk_re = self.transpose(pk_re)
719
730
 
720
731
 
721
732
  return pk_re, k_arr
@@ -746,19 +757,10 @@ class Class_szfast(object):
746
757
  # print("self.cp_predicted_hubble type:", type(self.cp_predicted_hubble))
747
758
  # print("self.cp_predicted_hubble",self.cp_predicted_hubble)
748
759
 
749
- # self.hz_interp = jscipy.interpolate.interp1d(
750
- # self.cp_z_interp_jax,
751
- # self.cp_predicted_hubble,
752
- # kind='linear',
753
- # axis=-1,
754
- # copy=True,
755
- # bounds_error=None,
756
- # fill_value=np.nan,
757
- # assume_sorted=False)
758
760
 
759
761
  # Assuming `cp_z_interp` and `cp_predicted_hubble` are JAX arrays
760
762
  def hz_interp(x):
761
- return jnp.interp(x, self.cp_z_interp_jax, self.cp_predicted_hubble, left=jnp.nan, right=jnp.nan)
763
+ return jnp.interp(x, self.cp_z_interp, self.cp_predicted_hubble, left=jnp.nan, right=jnp.nan)
762
764
 
763
765
  self.hz_interp = hz_interp
764
766
  # exit()
@@ -785,45 +787,60 @@ class Class_szfast(object):
785
787
  update_params_with_defaults(params_values, self.emulator_dict[self.cosmo_model]['default'])
786
788
 
787
789
  params_dict = {}
788
-
789
790
  for k,v in zip(params_values.keys(),params_values.values()):
790
-
791
791
  params_dict[k]=[v]
792
792
 
793
793
  if 'm_ncdm' in params_dict.keys():
794
794
  if isinstance(params_dict['m_ncdm'][0],str):
795
795
  params_dict['m_ncdm'] = [float(params_dict['m_ncdm'][0].split(',')[0])]
796
796
 
797
- # deal with different scaling of DA in different model from emulator training
798
- if self.cosmo_model == 'ede-v2':
797
+ if self.jax_mode:
798
+ # print("JAX MODE in chi")
799
+ # self.cp_da_nn[self.cosmo_model].log = False
800
+ self.cp_predicted_da = self.cp_da_nn[self.cosmo_model].predict(params_dict)
801
+ # print("self.cp_predicted_da",self.cp_predicted_da)
802
+
803
+ if self.cosmo_model == 'ede-v2':
804
+ # print('ede-v2 case')
805
+ self.cp_predicted_da = jnp.insert(self.cp_predicted_da, 0, 0)
806
+ self.cp_predicted_da *= (1.+self.cp_z_interp)
807
+
808
+ def chi_interp(x):
809
+ return jnp.interp(x, self.cp_z_interp, self.cp_predicted_da, left=jnp.nan, right=jnp.nan)
810
+
811
+ self.chi_interp = chi_interp
799
812
 
800
- self.cp_predicted_da = self.cp_da_nn[self.cosmo_model].ten_to_predictions_np(params_dict)[0]
801
- self.cp_predicted_da = np.insert(self.cp_predicted_da, 0, 0)
802
-
803
813
  else:
804
-
805
- self.cp_predicted_da = self.cp_da_nn[self.cosmo_model].predictions_np(params_dict)[0]
806
-
814
+ # deal with different scaling of DA in different model from emulator training
815
+ if self.cosmo_model == 'ede-v2':
807
816
 
808
- self.chi_interp = scipy.interpolate.interp1d(
809
- self.cp_z_interp,
810
- self.cp_predicted_da*(1.+self.cp_z_interp),
811
- kind='linear',
812
- axis=-1,
813
- copy=True,
814
- bounds_error=None,
815
- fill_value=np.nan,
816
- assume_sorted=False)
817
+ self.cp_predicted_da = self.cp_da_nn[self.cosmo_model].ten_to_predictions_np(params_dict)[0]
818
+ self.cp_predicted_da = np.insert(self.cp_predicted_da, 0, 0)
819
+
820
+ else:
821
+
822
+ self.cp_predicted_da = self.cp_da_nn[self.cosmo_model].predictions_np(params_dict)[0]
823
+
824
+
825
+ self.chi_interp = scipy.interpolate.interp1d(
826
+ self.cp_z_interp,
827
+ self.cp_predicted_da*(1.+self.cp_z_interp),
828
+ kind='linear',
829
+ axis=-1,
830
+ copy=True,
831
+ bounds_error=None,
832
+ fill_value=np.nan,
833
+ assume_sorted=False)
817
834
 
818
835
  def get_cmb_cls(self,ell_factor=True,Tcmb_uk = Tcmb_uk):
819
836
 
820
837
  cls = {}
821
- cls['ell'] = np.arange(self.cszfast_ldim)
822
- cls['tt'] = np.zeros(self.cszfast_ldim)
823
- cls['te'] = np.zeros(self.cszfast_ldim)
824
- cls['ee'] = np.zeros(self.cszfast_ldim)
825
- cls['pp'] = np.zeros(self.cszfast_ldim)
826
- cls['bb'] = np.zeros(self.cszfast_ldim)
838
+ cls['ell'] = self.arange(self.cszfast_ldim)
839
+ cls['tt'] = self.zeros(self.cszfast_ldim)
840
+ cls['te'] = self.zeros(self.cszfast_ldim)
841
+ cls['ee'] = self.zeros(self.cszfast_ldim)
842
+ cls['pp'] = self.zeros(self.cszfast_ldim)
843
+ cls['bb'] = self.zeros(self.cszfast_ldim)
827
844
  cls['tt'][2:self.cp_lmax+1] = (Tcmb_uk)**2.*self.cp_predicted_tt_spectrum.copy()
828
845
  cls['te'][2:self.cp_lmax+1] = (Tcmb_uk)**2.*self.cp_predicted_te_spectrum.copy()
829
846
  cls['ee'][2:self.cp_lmax+1] = (Tcmb_uk)**2.*self.cp_predicted_ee_spectrum.copy()
@@ -831,12 +848,12 @@ class Class_szfast(object):
831
848
 
832
849
 
833
850
  if ell_factor==False:
834
- fac_l = np.zeros(self.cszfast_ldim)
835
- fac_l[2:self.cp_lmax+1] = 1./(cls['ell'][2:self.cp_lmax+1]*(cls['ell'][2:self.cp_lmax+1]+1.)/2./np.pi)
851
+ fac_l = self.zeros(self.cszfast_ldim)
852
+ fac_l[2:self.cp_lmax+1] = 1./(cls['ell'][2:self.cp_lmax+1]*(cls['ell'][2:self.cp_lmax+1]+1.)/2./self.pi)
836
853
  cls['tt'][2:self.cp_lmax+1] *= fac_l[2:self.cp_lmax+1]
837
854
  cls['te'][2:self.cp_lmax+1] *= fac_l[2:self.cp_lmax+1]
838
855
  cls['ee'][2:self.cp_lmax+1] *= fac_l[2:self.cp_lmax+1]
839
- # cls['bb'] = np.zeros(self.cszfast_ldim)
856
+ # cls['bb'] = self.zeros(self.cszfast_ldim)
840
857
  return cls
841
858
 
842
859
 
@@ -882,7 +899,10 @@ class Class_szfast(object):
882
899
  return np.array(self.hz_interp(z)*H_units_conv_factor[units])
883
900
 
884
901
  def get_chi(self, z):
885
- return np.array(self.chi_interp(z))
902
+ if self.jax_mode:
903
+ return jnp.array(self.chi_interp(z))
904
+ else:
905
+ return np.array(self.chi_interp(z))
886
906
 
887
907
  def get_gas_pressure_profile_x(self,z,m,x):
888
908
  return 0#np.vectorize(self.csz_base.get_pressure_P_over_P_delta_at_x_M_z_b12_200c)(x,m,z)
@@ -896,7 +916,7 @@ class Class_szfast(object):
896
916
 
897
917
 
898
918
  def tabulate_gas_pressure_profile_k(self):
899
- z_asked,m_asked,x_asked = 0.2,3e14,np.geomspace(1e-3,1e2,500)
919
+ z_asked,m_asked,x_asked = 0.2,3e14,self.geomspace(1e-3,1e2,500)
900
920
  start = time.time()
901
921
  px = self.get_gas_pressure_profile_x(z_asked,m_asked,x_asked)
902
922
  end = time.time()
@@ -1,5 +1,6 @@
1
1
  from .config import path_to_class_sz_data
2
2
  import numpy as np
3
+ import jax.numpy as jnp
3
4
  from .restore_nn import Restore_NN
4
5
  from .restore_nn import Restore_PCAplusNN
5
6
  from .suppress_warnings import suppress_warnings
@@ -20,6 +21,74 @@ cp_h_nn_jax = {}
20
21
  cp_s8_nn_jax = {}
21
22
 
22
23
 
24
+ class CosmoPowerJAX_custom(CPJ):
25
+ def __init__(self, *args, **kwargs):
26
+ super().__init__(*args, **kwargs)
27
+ self.ten_to_predictions = True
28
+ if 'ten_to_predictions' in kwargs.keys():
29
+ self.ten_to_predictions = kwargs['ten_to_predictions']
30
+
31
+ def _predict(self, weights, hyper_params, param_train_mean, param_train_std,
32
+ feature_train_mean, feature_train_std, input_vec):
33
+ """ Forward pass through pre-trained network.
34
+ In its current form, it does not make use of high-level frameworks like
35
+ FLAX et similia; rather, it simply loops over the network layers.
36
+ In future work this can be improved, especially if speed is a problem.
37
+
38
+ Parameters
39
+ ----------
40
+ weights : array
41
+ The stored weights of the neural network.
42
+ hyper_params : array
43
+ The stored hyperparameters of the activation function for each layer.
44
+ param_train_mean : array
45
+ The stored mean of the training cosmological parameters.
46
+ param_train_std : array
47
+ The stored standard deviation of the training cosmological parameters.
48
+ feature_train_mean : array
49
+ The stored mean of the training features.
50
+ feature_train_std : array
51
+ The stored standard deviation of the training features.
52
+ input_vec : array of shape (n_samples, n_parameters) or (n_parameters)
53
+ The cosmological parameters given as input to the network.
54
+
55
+ Returns
56
+ -------
57
+ predictions : array
58
+ The prediction of the trained neural network.
59
+ """
60
+ act = []
61
+ # Standardise
62
+ layer_out = [(input_vec - param_train_mean)/param_train_std]
63
+
64
+ # Loop over layers
65
+ for i in range(len(weights[:-1])):
66
+ w, b = weights[i]
67
+ alpha, beta = hyper_params[i]
68
+ act.append(jnp.dot(layer_out[-1], w.T) + b)
69
+ layer_out.append(self._activation(act[-1], alpha, beta))
70
+
71
+ # Final layer prediction (no activations)
72
+ w, b = weights[-1]
73
+ if self.probe == 'custom_log' or self.probe == 'custom_pca':
74
+ # in original CP models, we assumed a full final bias vector...
75
+ preds = jnp.dot(layer_out[-1], w.T) + b
76
+ else:
77
+ # ... unlike in cpjax, where we used only a single bias vector
78
+ preds = jnp.dot(layer_out[-1], w.T) + b[-1]
79
+
80
+ # Undo the standardisation
81
+ preds = preds * feature_train_std + feature_train_mean
82
+ if self.log == True:
83
+ if self.ten_to_predictions:
84
+ preds = 10**preds
85
+ else:
86
+ preds = (preds@self.pca_matrix)*self.training_std + self.training_mean
87
+ if self.probe == 'cmb_pp':
88
+ preds = 10**preds
89
+ predictions = preds.squeeze()
90
+ return predictions
91
+
23
92
  for mp in cosmo_model_list:
24
93
  folder, version = split_emulator_string(mp)
25
94
  # print(folder, version)
@@ -36,18 +105,24 @@ for mp in cosmo_model_list:
36
105
 
37
106
  cp_pknl_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'PK/' + emulator_dict[mp]['PKNL'])
38
107
 
39
- cp_pkl_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'PK/' + emulator_dict[mp]['PKL'])
40
-
108
+ cp_pkl_nn_jax[mp] = CosmoPowerJAX_custom(probe='custom_log',filepath=path_to_emulators +'PK/' + emulator_dict[mp]['PKL'] + '.npz')
109
+ cp_pkl_nn_jax[mp].ten_to_predictions = False
110
+
41
111
  cp_der_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'derived-parameters/' + emulator_dict[mp]['DER'])
42
112
 
43
- cp_da_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['DAZ'])
113
+ # cp_da_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['DAZ'])
114
+
44
115
 
116
+ cp_da_nn_jax[mp] = CosmoPowerJAX_custom(probe='custom_log',filepath=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['DAZ'] + '.npz')
117
+ if mp != 'ede-v2':
118
+ cp_da_nn_jax[mp].ten_to_predictions = False
119
+ # print(cp_da_nn_jax[mp].parameters)
45
120
  # print(path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'])
46
- emulator_custom = CPJ(probe='custom_log',filepath=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'] + '.npz')
121
+ # emulator_custom = CPJ(probe='custom_log',filepath=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'] + '.npz')
47
122
  # print(emulator_custom.parameters)
48
123
  # exit()
49
124
 
50
- cp_h_nn_jax[mp] = CPJ(probe='custom_log',filepath=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'] + '.npz')
125
+ cp_h_nn_jax[mp] = CosmoPowerJAX_custom(probe='custom_log',filepath=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['HZ'] + '.npz')
51
126
 
52
127
  cp_s8_nn_jax[mp] = Restore_NN(restore_filename=path_to_emulators + 'growth-and-distances/' + emulator_dict[mp]['S8Z'])
53
128
 
@@ -33,7 +33,7 @@ cosmopower_derived_params_names = ['100*theta_s',
33
33
  'ra_star',
34
34
  'rs_drag']
35
35
 
36
- cp_l_max_scalars = 11000 # max multipole of train ing data
36
+ cp_l_max_scalars = 11000 # max multipole of training data
37
37
 
38
38
  cosmo_model_list = [
39
39
  'lcdm',
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: classy_szfast
3
- Version: 0.0.25.post2
3
+ Version: 0.0.25.post4
4
4
  Summary: The accelerator of the class_sz code from https://github.com/CLASS-SZ
5
5
  Maintainer-email: Boris Bolliet <bb667@cam.ac.uk>
6
6
  License: MIT
@@ -1,11 +1,11 @@
1
1
  classy_szfast/__init__.py,sha256=E2thrL0Z9oXFfdzwcsu-xbOytudLFTlRlPqVFGlPPPg,279
2
2
  classy_szfast/classy_sz.py,sha256=QmbwrSXInQLMvCDqsr7KPmtaU0KOiOt1Rb-cTKuulZw,22240
3
- classy_szfast/classy_szfast.py,sha256=2eRVEjpTY9xkbHtIUZjZjKQJEcT_0-5ku0n63VYlXLo,36639
3
+ classy_szfast/classy_szfast.py,sha256=H3DPD4_ZvVJ_FCVbEzc4rW8DYQHPZCCaGQZtQ1f-7oQ,37646
4
4
  classy_szfast/config.py,sha256=cd7Z62-qnX_4FJWfUNqcyJVh-AdBiXrF8DcQGpyAUZM,274
5
5
  classy_szfast/cosmopower.py,sha256=ooYK2BDOZSo3XtGHfPtjXHxr5UW-yVngLPkb5gpvTx8,2351
6
- classy_szfast/cosmopower_jax.py,sha256=C7NzfMFs9sL8rKuDdXdmwxk0UzHqNJnVjZENak-EPQA,2151
6
+ classy_szfast/cosmopower_jax.py,sha256=NvwY1a9YoizOV0zKp0SEtiHQkhWRO-seGFzO3FIzLl0,5436
7
7
  classy_szfast/cosmosis_classy_szfast_interface.py,sha256=zAnxvFtn73a5yS7jgs59zpWFEYKCIQyraYPs5hQ4Le8,11483
8
- classy_szfast/emulators_meta_data.py,sha256=-lHneGhSJ2481S48viz_bNeCyAGu1Ogee0jFEB8B618,9724
8
+ classy_szfast/emulators_meta_data.py,sha256=faA0iQqzfC5lOj1hu7wSEZUFfhBGLJqX9-xovMfTbr0,9723
9
9
  classy_szfast/pks_and_sigmas.py,sha256=drtuujE1HhlrYY1hY92DyY5lXlYS1uE15MSuVI4uo6k,6625
10
10
  classy_szfast/restore_nn.py,sha256=DqA9thhTRiGBDVb9zjhqcbF2W4V0AU0vrjJFhnLboU4,21075
11
11
  classy_szfast/suppress_warnings.py,sha256=6wIBml2Sj9DyRGZlZWhuA9hqvpxqrNyYjuz6BPK_a6E,202
@@ -14,7 +14,7 @@ classy_szfast/custom_bias/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJW
14
14
  classy_szfast/custom_bias/custom_bias.py,sha256=aR2t5RTIwv7P0m2bsEU0Eq6BTkj4pG10AebH6QpG4qM,486
15
15
  classy_szfast/custom_profiles/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
16
  classy_szfast/custom_profiles/custom_profiles.py,sha256=4LZwb2XoqwCyWNmW2s24Z7AJdmgVdaRG7yYaBYe-d9Q,1188
17
- classy_szfast-0.0.25.post2.dist-info/METADATA,sha256=SHRC4v8N4uo_BSIHn9D8_KrReCq3coWsZCQIpLqaKmQ,548
18
- classy_szfast-0.0.25.post2.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
19
- classy_szfast-0.0.25.post2.dist-info/top_level.txt,sha256=hRgqpilUck4lx2KkaWI2y9aCDKqF6pFfGHfNaoPFxv0,14
20
- classy_szfast-0.0.25.post2.dist-info/RECORD,,
17
+ classy_szfast-0.0.25.post4.dist-info/METADATA,sha256=oNj9IJ8zELtcPCEOOi4wQzx9SfxVaMmdIl6QjyAFdkY,548
18
+ classy_szfast-0.0.25.post4.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
19
+ classy_szfast-0.0.25.post4.dist-info/top_level.txt,sha256=hRgqpilUck4lx2KkaWI2y9aCDKqF6pFfGHfNaoPFxv0,14
20
+ classy_szfast-0.0.25.post4.dist-info/RECORD,,